Hissen commited on
Commit
c94a869
·
verified ·
1 Parent(s): f74d054

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +76 -25
src/streamlit_app.py CHANGED
@@ -1,20 +1,39 @@
1
  import streamlit as st
 
 
 
 
 
2
  from langchain_community.document_loaders import (
3
  PyPDFLoader,
4
  TextLoader,
5
  UnstructuredWordDocumentLoader,
6
  CSVLoader
7
  )
8
- from langchain.vectorstores import FAISS
9
- from langchain.text_splitter import RecursiveCharacterTextSplitter
10
- from langchain.embeddings import HuggingFaceEmbeddings
11
  from langchain_huggingface import HuggingFaceEndpoint
12
- from langchain.prompts import ChatPromptTemplate
13
- from langchain.chains import RetrievalQA
 
14
 
 
15
  st.title("Ask RAG - HuggingFace Space")
16
 
17
- # File uploader
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  uploaded_files = st.file_uploader(
19
  "Upload files (PDF, DOCX, TXT, CSV)",
20
  type=["pdf", "docx", "txt", "csv"],
@@ -24,52 +43,77 @@ uploaded_files = st.file_uploader(
24
  @st.cache_resource
25
  def load_files(files):
26
  if not files:
27
- return None
28
 
29
  loaders = []
 
 
30
  for file in files:
 
 
 
 
 
 
31
  if file.name.endswith(".pdf"):
32
- loaders.append(PyPDFLoader(file))
33
  elif file.name.endswith(".txt"):
34
- loaders.append(TextLoader(file))
35
  elif file.name.endswith(".docx"):
36
- loaders.append(UnstructuredWordDocumentLoader(file))
37
  elif file.name.endswith(".csv"):
38
- loaders.append(CSVLoader(file))
39
 
 
40
  docs = []
41
  for loader in loaders:
42
  docs.extend(loader.load())
43
 
 
44
  splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
45
  split_docs = splitter.split_documents(docs)
46
 
47
- # Embeddings via HF
48
- embeddings = HuggingFaceEmbeddings(
49
- model_name="intfloat/multilingual-e5-large-instruct"
50
- )
51
 
52
- return FAISS.from_documents(split_docs, embeddings)
53
 
54
- vectorstore = load_files(uploaded_files) if uploaded_files else None
55
 
56
  if vectorstore:
57
  retriever = vectorstore.as_retriever()
58
-
 
 
 
 
 
 
 
 
 
 
 
 
59
  llm = HuggingFaceEndpoint(
60
  repo_id="AI-Sweden-Models/Llama-3-8B-instruct",
61
  task="text-generation",
62
  temperature=0.2,
63
- max_new_tokens=512
 
64
  )
65
 
66
- qa_chain = RetrievalQA.from_chain_type(
67
- llm=llm,
68
- retriever=retriever,
69
- return_source_documents=False,
70
- chain_type="stuff"
 
 
 
71
  )
72
 
 
73
  if "messages" not in st.session_state:
74
  st.session_state.messages = []
75
 
@@ -83,13 +127,20 @@ if vectorstore:
83
  st.chat_message("user").markdown(user_input)
84
  st.session_state.messages.append({"role": "user", "content": user_input})
85
 
86
- answer = qa_chain.run(user_input)
 
87
  st.chat_message("assistant").markdown(answer)
88
  st.session_state.messages.append({"role": "assistant", "content": answer})
 
89
  else:
90
  st.warning("Upload files to start querying.")
91
 
92
  # Clear chat button
93
  if st.button("Clear Chat"):
94
  st.session_state.messages = []
 
 
 
 
 
95
  st.experimental_rerun()
 
1
  import streamlit as st
2
+ import tempfile
3
+ import os
4
+ import numpy as np
5
+
6
+ # LangChain community loaders & FAISS
7
  from langchain_community.document_loaders import (
8
  PyPDFLoader,
9
  TextLoader,
10
  UnstructuredWordDocumentLoader,
11
  CSVLoader
12
  )
13
+ from langchain_community.vectorstores import FAISS
14
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
 
15
  from langchain_huggingface import HuggingFaceEndpoint
16
+ from langchain.embeddings import HuggingFaceEmbeddings
17
+ from langchain_core.prompts import ChatPromptTemplate
18
+ from langchain_core.runnables import RunnablePassthrough
19
 
20
+ st.set_page_config(page_title="Ask RAG - HF Space", layout="wide")
21
  st.title("Ask RAG - HuggingFace Space")
22
 
23
+ # HuggingFace API key (set via Space Secrets)
24
+ HF_TOKEN = os.environ.get("HUGGINGFACE_API_KEY")
25
+ if not HF_TOKEN:
26
+ st.error("Please set the HuggingFace API key in your Space secrets!")
27
+ st.stop()
28
+
29
+ # Wrapper embeddings via HF API
30
+ embeddings = HuggingFaceEmbeddings(
31
+ model_name="intfloat/multilingual-e5-large-instruct",
32
+ task="feature-extraction",
33
+ model_kwargs={"use_auth_token": HF_TOKEN}
34
+ )
35
+
36
+ # Upload files
37
  uploaded_files = st.file_uploader(
38
  "Upload files (PDF, DOCX, TXT, CSV)",
39
  type=["pdf", "docx", "txt", "csv"],
 
43
  @st.cache_resource
44
  def load_files(files):
45
  if not files:
46
+ return None, []
47
 
48
  loaders = []
49
+ temp_files = []
50
+
51
  for file in files:
52
+ # Save temp file
53
+ with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.name)[-1]) as tmp:
54
+ tmp.write(file.read())
55
+ temp_files.append(tmp.name)
56
+
57
+ # Choose loader
58
  if file.name.endswith(".pdf"):
59
+ loaders.append(PyPDFLoader(tmp.name))
60
  elif file.name.endswith(".txt"):
61
+ loaders.append(TextLoader(tmp.name))
62
  elif file.name.endswith(".docx"):
63
+ loaders.append(UnstructuredWordDocumentLoader(tmp.name))
64
  elif file.name.endswith(".csv"):
65
+ loaders.append(CSVLoader(tmp.name))
66
 
67
+ # Load documents
68
  docs = []
69
  for loader in loaders:
70
  docs.extend(loader.load())
71
 
72
+ # Split documents
73
  splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
74
  split_docs = splitter.split_documents(docs)
75
 
76
+ # Create FAISS vectorstore
77
+ vectorstore = FAISS.from_documents(split_docs, embeddings)
 
 
78
 
79
+ return vectorstore, temp_files
80
 
81
+ vectorstore, temp_files = load_files(uploaded_files) if uploaded_files else (None, [])
82
 
83
  if vectorstore:
84
  retriever = vectorstore.as_retriever()
85
+
86
+ # Chat prompt
87
+ chat_prompt = ChatPromptTemplate.from_template(
88
+ """Use the context below to answer the question.
89
+
90
+ Context:
91
+ {context}
92
+
93
+ Question:
94
+ {question}"""
95
+ )
96
+
97
+ # LLM via HF Inference
98
  llm = HuggingFaceEndpoint(
99
  repo_id="AI-Sweden-Models/Llama-3-8B-instruct",
100
  task="text-generation",
101
  temperature=0.2,
102
+ max_new_tokens=512,
103
+ model_kwargs={"use_auth_token": HF_TOKEN}
104
  )
105
 
106
+ # Build RAG chain
107
+ rag_chain = (
108
+ {
109
+ "context": retriever | (lambda docs: "\n\n".join(d.page_content for d in docs)),
110
+ "question": RunnablePassthrough(),
111
+ }
112
+ | chat_prompt
113
+ | llm
114
  )
115
 
116
+ # Session state
117
  if "messages" not in st.session_state:
118
  st.session_state.messages = []
119
 
 
127
  st.chat_message("user").markdown(user_input)
128
  st.session_state.messages.append({"role": "user", "content": user_input})
129
 
130
+ response = rag_chain.invoke(user_input)
131
+ answer = response.content
132
  st.chat_message("assistant").markdown(answer)
133
  st.session_state.messages.append({"role": "assistant", "content": answer})
134
+
135
  else:
136
  st.warning("Upload files to start querying.")
137
 
138
  # Clear chat button
139
  if st.button("Clear Chat"):
140
  st.session_state.messages = []
141
+ for path in temp_files:
142
+ try:
143
+ os.remove(path)
144
+ except:
145
+ pass
146
  st.experimental_rerun()