Spaces:
Build error
Build error
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +59 -44
src/streamlit_app.py
CHANGED
|
@@ -1,14 +1,13 @@
|
|
| 1 |
-
|
| 2 |
-
import pandas as pd
|
| 3 |
-
import json
|
| 4 |
-
import io
|
| 5 |
import os
|
| 6 |
-
|
| 7 |
os.environ["STREAMLIT_HOME"] = "/tmp"
|
| 8 |
os.environ["XDG_CONFIG_HOME"] = "/tmp"
|
| 9 |
os.environ["XDG_DATA_HOME"] = "/tmp"
|
| 10 |
|
| 11 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
from langchain.llms import OpenAI
|
| 14 |
from langchain_experimental.agents.agent_toolkits import create_pandas_dataframe_agent
|
|
@@ -27,22 +26,25 @@ _ = load_dotenv(find_dotenv())
|
|
| 27 |
# Get API key from Streamlit secrets
|
| 28 |
API_KEY = os.getenv("OPENAI_API_KEY")
|
| 29 |
|
| 30 |
-
# Initialize embedding model
|
| 31 |
embeddings_model = OpenAIEmbeddings(openai_api_key=API_KEY)
|
| 32 |
|
| 33 |
-
#
|
| 34 |
st.set_page_config(page_title="RAG File Chat", layout="centered")
|
| 35 |
st.title("π§ Chat with Your Uploaded File")
|
| 36 |
|
| 37 |
-
#
|
| 38 |
if "vectorstore" not in st.session_state:
|
| 39 |
st.session_state.vectorstore = None
|
| 40 |
if "agent_created" not in st.session_state:
|
| 41 |
st.session_state.agent_created = False
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
|
| 44 |
def extract_text_from_file(file_content, file_type):
|
| 45 |
-
"""Extract
|
| 46 |
if file_type == "pdf":
|
| 47 |
reader = PyPDF2.PdfReader(io.BytesIO(file_content))
|
| 48 |
return "\n".join([page.extract_text() for page in reader.pages if page.extract_text()])
|
|
@@ -53,46 +55,41 @@ def extract_text_from_file(file_content, file_type):
|
|
| 53 |
|
| 54 |
|
| 55 |
def create_agent_and_index(file_content, file_type):
|
| 56 |
-
"""
|
| 57 |
if file_type == "csv":
|
| 58 |
df = pd.read_csv(io.StringIO(file_content.decode("utf-8")))
|
| 59 |
st.success("π CSV file loaded into DataFrame.")
|
|
|
|
|
|
|
|
|
|
| 60 |
elif file_type == "xlsx":
|
| 61 |
df = pd.read_excel(file_content)
|
| 62 |
st.success("π Excel file loaded into DataFrame.")
|
|
|
|
|
|
|
|
|
|
| 63 |
elif file_type == "json":
|
| 64 |
df = pd.DataFrame(json.loads(file_content.decode("utf-8")))
|
| 65 |
st.success("π JSON file loaded into DataFrame.")
|
|
|
|
|
|
|
|
|
|
| 66 |
elif file_type in ["pdf", "docx"]:
|
| 67 |
text = extract_text_from_file(file_content, file_type)
|
| 68 |
st.success(f"π Extracted text from {file_type.upper()}.")
|
| 69 |
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
| 70 |
texts = text_splitter.split_text(text)
|
| 71 |
-
|
| 72 |
-
st.success("
|
| 73 |
-
st.session_state.vectorstore = FAISS.from_texts(
|
| 74 |
-
texts=df['text'].tolist(),
|
| 75 |
-
embedding=embeddings_model
|
| 76 |
-
)
|
| 77 |
-
st.success("π§ Text embedded and stored in FAISS (in-memory).")
|
| 78 |
else:
|
| 79 |
st.error("β Unsupported file type.")
|
| 80 |
-
return
|
| 81 |
-
|
| 82 |
-
# Create agent for tabular data
|
| 83 |
-
if file_type in ["csv", "xlsx", "json"]:
|
| 84 |
-
llm = OpenAI(openai_api_key=API_KEY)
|
| 85 |
-
agent = create_pandas_dataframe_agent(llm, df, verbose=False)
|
| 86 |
-
st.session_state.agent_created = True
|
| 87 |
-
st.success("π€ Tabular data agent created.")
|
| 88 |
-
return agent
|
| 89 |
-
|
| 90 |
st.session_state.agent_created = True
|
| 91 |
-
|
| 92 |
|
| 93 |
|
| 94 |
def query_vectorstore(query):
|
| 95 |
-
"""
|
| 96 |
qa_chain = RetrievalQA.from_chain_type(
|
| 97 |
llm=OpenAI(openai_api_key=API_KEY),
|
| 98 |
chain_type="stuff",
|
|
@@ -102,18 +99,21 @@ def query_vectorstore(query):
|
|
| 102 |
return result["result"]
|
| 103 |
|
| 104 |
|
| 105 |
-
# UI
|
| 106 |
-
uploaded_file = st.file_uploader("π
|
| 107 |
|
| 108 |
-
if uploaded_file
|
| 109 |
-
st.
|
| 110 |
-
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
st.session_state.file_type = file_type
|
| 116 |
|
|
|
|
| 117 |
query = st.text_area("π¬ Ask a question about your uploaded file")
|
| 118 |
|
| 119 |
if st.button("Submit Query"):
|
|
@@ -121,12 +121,27 @@ if st.button("Submit Query"):
|
|
| 121 |
st.warning("β οΈ Please enter a valid question.")
|
| 122 |
elif not st.session_state.agent_created:
|
| 123 |
st.warning("π Please upload and process a file first.")
|
| 124 |
-
elif st.session_state.file_type in ["pdf", "docx"]:
|
| 125 |
-
with st.spinner("π‘ Thinking..."):
|
| 126 |
-
response = query_vectorstore(query)
|
| 127 |
-
st.subheader("π Answer")
|
| 128 |
-
st.write(response)
|
| 129 |
else:
|
| 130 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
|
|
|
| 132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Fix permission issue on Hugging Face Spaces
|
|
|
|
|
|
|
|
|
|
| 2 |
import os
|
|
|
|
| 3 |
os.environ["STREAMLIT_HOME"] = "/tmp"
|
| 4 |
os.environ["XDG_CONFIG_HOME"] = "/tmp"
|
| 5 |
os.environ["XDG_DATA_HOME"] = "/tmp"
|
| 6 |
|
| 7 |
import streamlit as st
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import json
|
| 10 |
+
import io
|
| 11 |
|
| 12 |
from langchain.llms import OpenAI
|
| 13 |
from langchain_experimental.agents.agent_toolkits import create_pandas_dataframe_agent
|
|
|
|
| 26 |
# Get API key from Streamlit secrets
|
| 27 |
API_KEY = os.getenv("OPENAI_API_KEY")
|
| 28 |
|
|
|
|
| 29 |
embeddings_model = OpenAIEmbeddings(openai_api_key=API_KEY)
|
| 30 |
|
| 31 |
+
# Streamlit app setup
|
| 32 |
st.set_page_config(page_title="RAG File Chat", layout="centered")
|
| 33 |
st.title("π§ Chat with Your Uploaded File")
|
| 34 |
|
| 35 |
+
# Initialize session state
|
| 36 |
if "vectorstore" not in st.session_state:
|
| 37 |
st.session_state.vectorstore = None
|
| 38 |
if "agent_created" not in st.session_state:
|
| 39 |
st.session_state.agent_created = False
|
| 40 |
+
if "file_type" not in st.session_state:
|
| 41 |
+
st.session_state.file_type = None
|
| 42 |
+
if "agent" not in st.session_state:
|
| 43 |
+
st.session_state.agent = None
|
| 44 |
|
| 45 |
|
| 46 |
def extract_text_from_file(file_content, file_type):
|
| 47 |
+
"""Extract text from PDF or DOCX."""
|
| 48 |
if file_type == "pdf":
|
| 49 |
reader = PyPDF2.PdfReader(io.BytesIO(file_content))
|
| 50 |
return "\n".join([page.extract_text() for page in reader.pages if page.extract_text()])
|
|
|
|
| 55 |
|
| 56 |
|
| 57 |
def create_agent_and_index(file_content, file_type):
|
| 58 |
+
"""Process and embed file content."""
|
| 59 |
if file_type == "csv":
|
| 60 |
df = pd.read_csv(io.StringIO(file_content.decode("utf-8")))
|
| 61 |
st.success("π CSV file loaded into DataFrame.")
|
| 62 |
+
llm = OpenAI(openai_api_key=API_KEY)
|
| 63 |
+
st.session_state.agent = create_pandas_dataframe_agent(llm, df, verbose=False)
|
| 64 |
+
st.success("π€ Agent created for tabular data.")
|
| 65 |
elif file_type == "xlsx":
|
| 66 |
df = pd.read_excel(file_content)
|
| 67 |
st.success("π Excel file loaded into DataFrame.")
|
| 68 |
+
llm = OpenAI(openai_api_key=API_KEY)
|
| 69 |
+
st.session_state.agent = create_pandas_dataframe_agent(llm, df, verbose=False)
|
| 70 |
+
st.success("π€ Agent created for tabular data.")
|
| 71 |
elif file_type == "json":
|
| 72 |
df = pd.DataFrame(json.loads(file_content.decode("utf-8")))
|
| 73 |
st.success("π JSON file loaded into DataFrame.")
|
| 74 |
+
llm = OpenAI(openai_api_key=API_KEY)
|
| 75 |
+
st.session_state.agent = create_pandas_dataframe_agent(llm, df, verbose=False)
|
| 76 |
+
st.success("π€ Agent created for tabular data.")
|
| 77 |
elif file_type in ["pdf", "docx"]:
|
| 78 |
text = extract_text_from_file(file_content, file_type)
|
| 79 |
st.success(f"π Extracted text from {file_type.upper()}.")
|
| 80 |
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
| 81 |
texts = text_splitter.split_text(text)
|
| 82 |
+
st.session_state.vectorstore = FAISS.from_texts(texts, embeddings_model)
|
| 83 |
+
st.success("π§ Text embedded and stored in FAISS.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
else:
|
| 85 |
st.error("β Unsupported file type.")
|
| 86 |
+
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
st.session_state.agent_created = True
|
| 88 |
+
st.session_state.file_type = file_type
|
| 89 |
|
| 90 |
|
| 91 |
def query_vectorstore(query):
|
| 92 |
+
"""Query FAISS vectorstore using RAG."""
|
| 93 |
qa_chain = RetrievalQA.from_chain_type(
|
| 94 |
llm=OpenAI(openai_api_key=API_KEY),
|
| 95 |
chain_type="stuff",
|
|
|
|
| 99 |
return result["result"]
|
| 100 |
|
| 101 |
|
| 102 |
+
# --- UI Section: File Upload ---
|
| 103 |
+
uploaded_file = st.file_uploader("π Browse and select a file", type=["csv", "xlsx", "json", "pdf", "docx"])
|
| 104 |
|
| 105 |
+
if uploaded_file:
|
| 106 |
+
st.info(f"β
File selected: `{uploaded_file.name}` ({uploaded_file.size / 1024:.1f} KB)")
|
| 107 |
+
if st.button("π€ Upload File"):
|
| 108 |
+
file_content = uploaded_file.read()
|
| 109 |
+
file_type = uploaded_file.name.split(".")[-1]
|
| 110 |
+
with st.spinner("π Uploading and processing..."):
|
| 111 |
+
create_agent_and_index(file_content, file_type)
|
| 112 |
|
| 113 |
+
# --- Output Format ---
|
| 114 |
+
output_format = st.selectbox("π Select Output Format", ["Plain Text", "Markdown", "Tabular View"])
|
|
|
|
| 115 |
|
| 116 |
+
# --- UI Section: Query ---
|
| 117 |
query = st.text_area("π¬ Ask a question about your uploaded file")
|
| 118 |
|
| 119 |
if st.button("Submit Query"):
|
|
|
|
| 121 |
st.warning("β οΈ Please enter a valid question.")
|
| 122 |
elif not st.session_state.agent_created:
|
| 123 |
st.warning("π Please upload and process a file first.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
else:
|
| 125 |
+
with st.spinner("π‘ Thinking..."):
|
| 126 |
+
if st.session_state.file_type in ["pdf", "docx"]:
|
| 127 |
+
response = query_vectorstore(query)
|
| 128 |
+
else:
|
| 129 |
+
response = st.session_state.agent.run(query)
|
| 130 |
|
| 131 |
+
st.subheader("π Answer")
|
| 132 |
|
| 133 |
+
if output_format == "Plain Text":
|
| 134 |
+
st.text(response)
|
| 135 |
+
elif output_format == "Markdown":
|
| 136 |
+
st.markdown(response)
|
| 137 |
+
elif output_format == "Tabular View":
|
| 138 |
+
# Try to parse tabular response (tab or comma-separated)
|
| 139 |
+
rows = [line.split("\t") for line in response.split("\n") if "\t" in line]
|
| 140 |
+
if not rows or len(rows[0]) == 1:
|
| 141 |
+
rows = [line.split(",") for line in response.split("\n") if "," in line]
|
| 142 |
+
try:
|
| 143 |
+
df = pd.DataFrame(rows[1:], columns=rows[0])
|
| 144 |
+
st.dataframe(df)
|
| 145 |
+
except Exception:
|
| 146 |
+
st.warning("β οΈ Could not render a table. Showing raw output instead.")
|
| 147 |
+
st.text(response)
|