File size: 3,619 Bytes
61a1426
6fe62f6
 
 
e5547b9
6fe62f6
 
 
 
 
 
 
 
 
 
 
a91a7d7
 
 
 
6fe62f6
577efd3
6fe62f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61a1426
6fe62f6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import streamlit as st
import pandas as pd
import json
import io
import os

from langchain.llms import OpenAI
from langchain_experimental.agents.agent_toolkits import create_pandas_dataframe_agent
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQA

import PyPDF2
from docx import Document

from dotenv import load_dotenv, find_dotenv

_ = load_dotenv(find_dotenv())

# Get API key from Streamlit secrets
API_KEY = os.getenv("OPENAI_API_KEY")

# Initialize Chroma
embeddings_model = OpenAIEmbeddings(openai_api_key=API_KEY)
persist_directory = "db"
vectorstore = Chroma(persist_directory=persist_directory, embedding_function=embeddings_model)


def create_agent(file_content, file_type):
    """Create an agent based on file content and type."""
    if file_type == "csv":
        df = pd.read_csv(io.StringIO(file_content.decode("utf-8")), header=0)
    elif file_type == "xlsx":
        df = pd.read_excel(file_content, header=0)
    elif file_type == "json":
        df = pd.DataFrame(json.loads(file_content.decode("utf-8")))
    elif file_type in ["pdf", "docx"]:
        text = extract_text_from_file(file_content, file_type)
        text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
        texts = text_splitter.split_text(text)
        df = pd.DataFrame({"text": texts})
    else:
        raise ValueError(f"Unsupported file type: {file_type}")

    # Add data to Chroma (if not already present)
    global vectorstore
    if not vectorstore._collection.count():
        vectorstore.add_texts(texts=df['text'].tolist(), metadatas=[{'source': file_type}] * len(df))

    llm = OpenAI(openai_api_key=API_KEY)
    return create_pandas_dataframe_agent(llm, df, verbose=False)  

def extract_text_from_file(file_content, file_type):
    """Extract text from PDF or Word document."""
    if file_type == "pdf":
        pdf_reader = PyPDF2.PdfReader(io.BytesIO(file_content))
        text = ""
        for page in pdf_reader.pages:
            text += page.extract_text()
    elif file_type == "docx":
        doc = Document(io.BytesIO(file_content))
        text = "\n".join([paragraph.text for paragraph in doc.paragraphs])
    else:
        raise ValueError(f"Unsupported file type: {file_type}")
    return text


def query_agent(query):
    """Query the agent and return the response as a string."""
    # Initialize RetrievalQA chain
    qa_chain = RetrievalQA.from_chain_type(
        llm=OpenAI(openai_api_key=API_KEY),
        chain_type="stuff",
        retriever=vectorstore.as_retriever(search_kwargs={"k": 5}),
    )

    # Get answer from RetrievalQA chain
    result = qa_chain({"query": query})
    answer = result['result']

    return answer

# --- Streamlit app ---
st.title("👨‍💻 Chat with your data")
st.write("Please upload your data file below.")

uploaded_file = st.file_uploader("Upload a file", type=["csv", "xlsx", "json", "pdf", "docx"])

if uploaded_file is not None:
    file_content = uploaded_file.read()
    file_type = uploaded_file.name.split(".")[-1]

    query = st.text_area("Type your query here")

    if st.button("Submit Query", type="primary"):
        # Persist Chroma collection (if it doesn't exist)
        if not vectorstore._collection.count():
            create_agent(file_content, file_type)  # Call create_agent to load and index data
            vectorstore.persist()
            st.write("Data loaded and persisted to Chroma.")

        response = query_agent(query)
        st.write(response)