File size: 5,725 Bytes
45ee012
b0e4f45
 
 
 
512f2de
46ad3c2
 
dfcdc4f
b0e4f45
 
 
 
 
0e62360
27ba167
 
 
0e62360
b9d05c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e62360
b9d05c0
0e62360
 
b9d05c0
 
0e62360
 
 
b9d05c0
0e62360
 
b0e4f45
 
414bc96
 
 
 
 
512f2de
66f9f66
 
 
c1ca766
 
 
 
 
 
 
515be2e
c1ca766
 
45ee012
 
c1ca766
45ee012
 
072f6c1
515be2e
66f9f66
45ee012
c1ca766
66f9f66
45ee012
 
 
 
 
 
 
 
 
 
c1ca766
515be2e
b0e4f45
512f2de
b0e4f45
 
512f2de
b0e4f45
45ee012
b0e4f45
 
 
 
5e71278
b0e4f45
b9d05c0
 
 
 
 
b0e4f45
 
 
 
 
0e62360
b0e4f45
0e62360
b0e4f45
b9d05c0
 
 
 
 
b0e4f45
 
 
 
 
 
 
b9d05c0
512f2de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e62360
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from copy import deepcopy
import streamlit as st
import pandas as pd
from io import StringIO
from transformers import AutoTokenizer, AutoModelForTableQuestionAnswering
import numpy as np
import weaviate
from weaviate.embedded import EmbeddedOptions
from weaviate import Client, ObjectsBatchRequest

# Initialize TAPAS model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("google/tapas-large-finetuned-wtq")
model = AutoModelForTableQuestionAnswering.from_pretrained("google/tapas-large-finetuned-wtq")

# Initialize Weaviate client for the embedded instance
client = weaviate.Client(
  embedded_options=EmbeddedOptions()
)

def ingest_data_to_weaviate(dataframe, class_name, class_description):
    properties = []
    for column in dataframe.columns:
        data_type = "string"
        if dataframe[column].dtype == "float64":
            data_type = "float"
        elif dataframe[column].dtype == "int64":
            data_type = "int"
        properties.append({
            "name": column,
            "description": column,
            "dataType": [data_type]
        })
    
    schema = {
        "classes": [
            {
                "class": class_name,
                "description": class_description,
                "properties": properties
            }
        ]
    }

    # Create Schema in Weaviate
    client.schema.create(schema)
    
    # Ingest Data
    batch_request = weaviate.ObjectsBatchRequest()
    for _, row in dataframe.iterrows():
        obj = {
            "class": class_name,
            "properties": row.to_dict()
        }
        batch_request.add(obj)
    client.batch.create(batch_request)

def query_weaviate(question):
    # This is a basic example; adapt the query based on the question
    results = client.query.get(class_name).with_near_text(question).do()
    return results

def ask_llm_chunk(chunk, questions):
    chunk = chunk.astype(str)
    try:
        inputs = tokenizer(table=chunk, queries=questions, padding="max_length", truncation=True, return_tensors="pt")
    except Exception as e:
        st.write(f"An error occurred: {e}")
        return ["Error occurred while tokenizing"] * len(questions)

    if inputs["input_ids"].shape[1] > 512:
        st.warning("Token limit exceeded for chunk")
        return ["Token limit exceeded for chunk"] * len(questions)

    outputs = model(**inputs)
    predicted_answer_coordinates, predicted_aggregation_indices = tokenizer.convert_logits_to_predictions(
        inputs,
        outputs.logits.detach(),
        outputs.logits_aggregation.detach()
    )

    answers = []
    for coordinates in predicted_answer_coordinates:
        if len(coordinates) == 1:
            row, col = coordinates[0]
            try:
                st.write(f"DataFrame shape: {chunk.shape}")  # Debugging line
                st.write(f"DataFrame columns: {chunk.columns}")  # Debugging line
                st.write(f"Trying to access row {row}, col {col}")  # Debugging line
                value = chunk.iloc[row, col]
                st.write(f"Value accessed: {value}")  # Debugging line
                answers.append(value)
            except Exception as e:
                st.write(f"An error occurred: {e}")
        else:
            cell_values = []
            for coordinate in coordinates:
                row, col = coordinate
                try:
                    value = chunk.iloc[row, col]
                    cell_values.append(value)
                except Exception as e:
                    st.write(f"An error occurred: {e}")
            answers.append(", ".join(map(str, cell_values)))

    return answers

MAX_ROWS_PER_CHUNK = 200

def summarize_map_reduce(data, questions):
    dataframe = pd.read_csv(StringIO(data))
    num_chunks = len(dataframe) // MAX_ROWS_PER_CHUNK + 1
    dataframe_chunks = [deepcopy(chunk) for chunk in np.array_split(dataframe, num_chunks)]
    all_answers = []
    for chunk in dataframe_chunks:
        chunk_answers = ask_llm_chunk(chunk, questions)
        all_answers.extend(chunk_answers)
    return all_answers

st.title("TAPAS Table Question Answering with Weaviate Integration")

# UI Input for Class and Description
class_name = st.text_input("Enter the class name for your CSV data:")
class_description = st.text_input("Enter a description for your class:")

# Upload CSV data
csv_file = st.file_uploader("Upload a CSV file", type=["csv"])
if csv_file is not None:
    data = csv_file.read().decode("utf-8")
    dataframe = pd.read_csv(StringIO(data))
    st.write("CSV Data Preview:")
    st.write(dataframe.head())

    # Ingest data to Weaviate
    if st.button("Ingest to Weaviate"):
        ingest_data_to_weaviate(dataframe, class_name, class_description)
        st.write("Data ingested successfully!")

    # Input for questions
    questions = st.text_area("Enter your questions (one per line)")
    questions = questions.split("\n")  # split questions by line
    questions = [q for q in questions if q]  # remove empty strings

    if st.button("Submit"):
        if data and questions:
            answers = summarize_map_reduce(data, questions)
            st.write("Answers:")
            for q, a in zip(questions, answers):
                st.write(f"Question: {q}")
                st.write(f"Answer: {a}")

# Add Ctrl+Enter functionality for submitting the questions
st.markdown("""
    <script>
    document.addEventListener("DOMContentLoaded", function(event) {
        document.addEventListener("keydown", function(event) {
            if (event.ctrlKey && event.key === "Enter") {
                document.querySelector(".stButton button").click();
            }
        });
    });
    </script>
    """, unsafe_allow_html=True)