jskinner215's picture
Update app.py
862e59b
raw
history blame
9.76 kB
from copy import deepcopy
from langchain.callbacks import StreamlitCallbackHandler
import streamlit as st
import logging
from weaviate_utils import *
from tapas_utils import *
from ui_utils import *
# ...
selected_class = ui_utils.display_class_dropdown(client)
ui_utils.handle_new_class_selection(selected_class)
ui_utils.csv_upload_and_ingestion(selected_class)
ui_utils.display_query_input()
# ...
# Initialize Weaviate client
client = initialize_weaviate_client()
# Initialize TAPAS
tokenizer, model = initialize_tapas()
# UI components
display_initial_buttons()
selected_class = display_class_dropdown(client)
handle_new_class_selection()
csv_upload_and_ingestion()
display_query_input()
# Initialize session state attributes
if "debug" not in st.session_state:
st.session_state.debug = False
st_callback = StreamlitCallbackHandler(st.container())
class StreamlitCallbackHandler(logging.Handler):
def emit(self, record):
log_entry = self.format(record)
st.write(log_entry)
# Initialize TAPAS model and tokenizer
#tokenizer = AutoTokenizer.from_pretrained("google/tapas-large-finetuned-wtq")
#model = AutoModelForTableQuestionAnswering.from_pretrained("google/tapas-large-finetuned-wtq")
# Initialize Weaviate client for the embedded instance
#client = weaviate.Client(
# embedded_options=EmbeddedOptions()
#)
# Global list to store debugging information
DEBUG_LOGS = []
def log_debug_info(message):
if st.session_state.debug:
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
# Check if StreamlitCallbackHandler is already added to avoid duplicate logs
if not any(isinstance(handler, StreamlitCallbackHandler) for handler in logger.handlers):
handler = StreamlitCallbackHandler()
logger.addHandler(handler)
logger.debug(message)
# Function to check if a class already exists in Weaviate
#def class_exists(class_name):
# try:
# client.schema.get_class(class_name)
# return True
# except:
# return False
#def map_dtype_to_weaviate(dtype):
## """
# Map pandas data types to Weaviate data types.
# """
# if "int" in str(dtype):
# return "int"
# elif "float" in str(dtype):
# return "number"
# elif "bool" in str(dtype):
# return "boolean"
# else:
# return "string"
# def ingest_data_to_weaviate(dataframe, class_name, class_description):
# # Create class schema
# class_schema = {
# "class": class_name,
# "description": class_description,
# "properties": [] # Start with an empty properties list
# }
#
# # Try to create the class without properties first
# try:
# client.schema.create({"classes": [class_schema]})
# except weaviate.exceptions.SchemaValidationException:
# # Class might already exist, so we can continue
# pass#
# # Now, let's add properties to the class
# for column_name, data_type in zip(dataframe.columns, dataframe.dtypes):
# property_schema = {
# "name": column_name,
# "description": f"Property for {column_name}",
# "dataType": [map_dtype_to_weaviate(data_type)]
# }
# try:
# client.schema.property.create(class_name, property_schema)
# except weaviate.exceptions.SchemaValidationException:
# # Property might already exist, so we can continue
# pass
#
# # Ingest data
# for index, row in dataframe.iterrows():
# obj = {
# "class": class_name,
# "id": str(index),
# "properties": row.to_dict()
# }
# client.data_object.create(obj)
# Log data ingestion
# log_debug_info(f"Data ingested into Weaviate for class: {class_name}")
def query_weaviate(question):
# This is a basic example; adapt the query based on the question
results = client.query.get(class_name).with_near_text(question).do()
return results
#def ask_llm_chunk(chunk, questions):
# chunk = chunk.astype(str)
# try:
# inputs = tokenizer(table=chunk, queries=questions, padding="max_length", truncation=True, return_tensors="pt")
# except Exception as e:
# log_debug_info(f"Tokenization error: {e}")
# st.write(f"An error occurred: {e}")
# return ["Error occurred while tokenizing"] * len(questions)
#
## if inputs["input_ids"].shape[1] > 512:
# log_debug_info("Token limit exceeded for chunk")
# st.warning("Token limit exceeded for chunk")
# return ["Token limit exceeded for chunk"] * len(questions)#
#
# outputs = model(**inputs)
# predicted_answer_coordinates, predicted_aggregation_indices = tokenizer.convert_logits_to_predictions(
# inputs,
# outputs.logits.detach(),
# outputs.logits_aggregation.detach()
# )
#
# answers = []
# for coordinates in predicted_answer_coordinates:
# if len(coordinates) == 1:
# row, col = coordinates[0]
# try:
# value = chunk.iloc[row, col]
# log_debug_info(f"Accessed value for row {row}, col {col}: {value}")
# answers.append(value)
# except Exception as e:
# log_debug_info(f"Error accessing value for row {row}, col {col}: {e}")
# st.write(f"An error occurred: {e}")
# else:
# cell_values = []
# for coordinate in coordinates:
# row, col = coordinate
# try:
# value = chunk.iloc[row, col]
# cell_values.append(value)
# except Exception as e:
# log_debug_info(f"Error accessing value for row {row}, col {col}: {e}")
# st.write(f"An error occurred: {e}")
# answers.append(", ".join(map(str, cell_values)))
#
# return answers
# MAX_ROWS_PER_CHUNK = 200
# def summarize_map_reduce(data, questions):
# dataframe = pd.read_csv(StringIO(data))
# num_chunks = len(dataframe) // MAX_ROWS_PER_CHUNK + 1
# dataframe_chunks = [deepcopy(chunk) for chunk in np.array_split(dataframe, num_chunks)]
# all_answers = []
# for chunk in dataframe_chunks:
# chunk_answers = ask_llm_chunk(chunk, questions)
# all_answers.extend(chunk_answers)
# return all_answers
def get_class_schema(class_name):
"""
Get the schema for a specific class.
"""
all_classes = client.schema.get()["classes"]
for cls in all_classes:
if cls["class"] == class_name:
return cls
return None
st.title("TAPAS Table Question Answering with Weaviate")
# Get existing classes from Weaviate
existing_classes = [cls["class"] for cls in client.schema.get()["classes"]]
class_options = existing_classes + ["New Class"]
selected_class = st.selectbox("Select a class or create a new one:", class_options)
if selected_class == "New Class":
class_name = st.text_input("Enter the new class name:")
class_description = st.text_input("Enter a description for the class:")
else:
class_name = selected_class
class_description = "" # We can fetch the description from Weaviate if needed
# Upload CSV data
csv_file = st.file_uploader("Upload a CSV file", type=["csv"])
# Display the schema if an existing class is selected
class_schema = None # Initialize class_schema to None
if selected_class != "New Class":
st.write(f"Schema for {selected_class}:")
class_schema = get_class_schema(selected_class)
if class_schema:
properties = class_schema["properties"]
schema_df = pd.DataFrame(properties)
st.table(schema_df[["name", "dataType"]]) # Display only the name and dataType columns
# Before ingesting data into Weaviate, check if CSV columns match the class schema
if csv_file is not None:
data = csv_file.read().decode("utf-8")
dataframe = pd.read_csv(StringIO(data))
# Log CSV upload information
log_debug_info(f"CSV uploaded with shape: {dataframe.shape}")
# Display the uploaded CSV data
st.write("Uploaded CSV Data:")
st.write(dataframe)
# Check if columns match
if class_schema: # Ensure class_schema is not None
schema_columns = [prop["name"] for prop in class_schema["properties"]]
if set(dataframe.columns) != set(schema_columns):
st.error("The columns in the uploaded CSV do not match the schema of the selected class. Please check and upload the correct CSV or create a new class.")
else:
# Ingest data into Weaviate
ingest_data_to_weaviate(dataframe, class_name, class_description)
# Input for questions
questions = st.text_area("Enter your questions (one per line)")
questions = questions.split("\n") # split questions by line
questions = [q for q in questions if q] # remove empty strings
if st.button("Submit"):
if data and questions:
answers = summarize_map_reduce(data, questions)
st.write("Answers:")
for q, a in zip(questions, answers):
st.write(f"Question: {q}")
st.write(f"Answer: {a}")
# Display debugging information
if st.checkbox("Show Debugging Information"):
st.write("Debugging Logs:")
for log in DEBUG_LOGS:
st.write(log)
# Add Ctrl+Enter functionality for submitting the questions
st.markdown("""
<script>
document.addEventListener("DOMContentLoaded", function(event) {
document.addEventListener("keydown", function(event) {
if (event.ctrlKey && event.key === "Enter") {
document.querySelector(".stButton button").click();
}
});
});
</script>
""", unsafe_allow_html=True)