RAG / src /streamlit_app.py
lantzmurray's picture
Update src/streamlit_app.py
99379ec verified
import streamlit as st
import zipfile, io, os
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.schema import Document
from langchain.embeddings import SentenceTransformerEmbeddings
from langchain.vectorstores import FAISS
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, pipeline
# Cache the QA initialization so ingestion runs once per session
@st.cache_resource
def init_qa(zip_bytes):
tmp_dir = "tmp_pdfs"
# Clean up or create temp folder
if os.path.exists(tmp_dir):
for f in os.listdir(tmp_dir):
os.remove(os.path.join(tmp_dir, f))
else:
os.makedirs(tmp_dir)
# Extract uploaded ZIP
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as z:
z.extractall(tmp_dir)
# Load all PDFs
docs = []
for fname in os.listdir(tmp_dir):
if fname.lower().endswith(".pdf"):
loader = PyPDFLoader(os.path.join(tmp_dir, fname))
docs.extend(loader.load())
# Split into manageable chunks
splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
split_docs = splitter.split_documents(docs)
# Build vector store
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
vector_store = FAISS.from_documents(split_docs, embeddings)
# Load the QA model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer)
return vector_store, qa_pipeline
# Streamlit UI
st.title("RoBERTa QA Streamlit App")
st.write("Upload a ZIP of PDFs to initialize the QA engine.")
zip_file = st.file_uploader("ZIP file", type=["zip"])
if zip_file:
vector_store, qa = init_qa(zip_file.read())
query = st.text_input("Ask a question:")
if query:
docs = vector_store.similarity_search(query, k=4)
context = "\n\n".join([doc.page_content for doc in docs])
# Run QA
result = qa(question=query, context=context)
answer = result.get("answer", "No answer found.")
st.write(answer)
else:
st.info("Awaiting ZIP upload.")