lantzmurray commited on
Commit
42a870f
·
verified ·
1 Parent(s): 5a49bae

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +12 -13
src/streamlit_app.py CHANGED
@@ -5,7 +5,7 @@ from langchain.text_splitter import CharacterTextSplitter
5
  from langchain.schema import Document
6
  from langchain.embeddings import SentenceTransformerEmbeddings
7
  from langchain.vectorstores import FAISS
8
- from transformers import pipeline
9
 
10
  # Cache the QA initialization so ingestion runs once per session
11
  @st.cache_resource
@@ -37,28 +37,27 @@ def init_qa(zip_bytes):
37
  embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
38
  vector_store = FAISS.from_documents(split_docs, embeddings)
39
 
40
- # Load the RAG model
41
- generator = pipeline(
42
- "text2text-generation",
43
- model="PleIAs/Pleias-RAG-350M",
44
- tokenizer="PleIAs/Pleias-RAG-350M"
45
- )
46
 
47
- return vector_store, generator
48
 
49
  # Streamlit UI
50
- st.title("Pleias-RAG 350M Streamlit App")
51
- st.write("Upload a ZIP of PDFs to initialize the RAG engine.")
52
  zip_file = st.file_uploader("ZIP file", type=["zip"])
53
 
54
  if zip_file:
55
- vector_store, generator = init_qa(zip_file.read())
56
  query = st.text_input("Ask a question:")
57
  if query:
58
  docs = vector_store.similarity_search(query, k=4)
59
  context = "\n\n".join([doc.page_content for doc in docs])
60
- prompt = f"question: {query}\ncontext: {context}"
61
- answer = generator(prompt, max_length=512, do_sample=False)[0]["generated_text"]
 
62
  st.write(answer)
63
  else:
64
  st.info("Awaiting ZIP upload.")
 
5
  from langchain.schema import Document
6
  from langchain.embeddings import SentenceTransformerEmbeddings
7
  from langchain.vectorstores import FAISS
8
+ from transformers import AutoTokenizer, AutoModelForQuestionAnswering, pipeline
9
 
10
  # Cache the QA initialization so ingestion runs once per session
11
  @st.cache_resource
 
37
  embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
38
  vector_store = FAISS.from_documents(split_docs, embeddings)
39
 
40
+ # Load the QA model and tokenizer
41
+ tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
42
+ model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
43
+ qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer)
 
 
44
 
45
+ return vector_store, qa_pipeline
46
 
47
  # Streamlit UI
48
+ st.title("RoBERTa QA Streamlit App")
49
+ st.write("Upload a ZIP of PDFs to initialize the QA engine.")
50
  zip_file = st.file_uploader("ZIP file", type=["zip"])
51
 
52
  if zip_file:
53
+ vector_store, qa = init_qa(zip_file.read())
54
  query = st.text_input("Ask a question:")
55
  if query:
56
  docs = vector_store.similarity_search(query, k=4)
57
  context = "\n\n".join([doc.page_content for doc in docs])
58
+ # Run QA
59
+ result = qa(question=query, context=context)
60
+ answer = result.get("answer", "No answer found.")
61
  st.write(answer)
62
  else:
63
  st.info("Awaiting ZIP upload.")