|
|
import streamlit as st |
|
|
import tempfile |
|
|
import os |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
from langchain_community.document_loaders import ( |
|
|
PyPDFLoader, |
|
|
TextLoader, |
|
|
UnstructuredWordDocumentLoader, |
|
|
CSVLoader |
|
|
) |
|
|
from langchain_community.vectorstores import FAISS |
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
|
from langchain_huggingface import HuggingFaceEndpoint |
|
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
|
from langchain_core.prompts import ChatPromptTemplate |
|
|
from langchain_core.runnables import RunnablePassthrough |
|
|
|
|
|
st.set_page_config(page_title="Ask RAG - HF Space", layout="wide") |
|
|
st.title("Ask RAG - HuggingFace Space") |
|
|
|
|
|
|
|
|
HF_TOKEN = os.environ.get("HUGGINGFACE_API_KEY") |
|
|
if not HF_TOKEN: |
|
|
st.error("Please set the HuggingFace API key in your Space secrets!") |
|
|
st.stop() |
|
|
|
|
|
|
|
|
embeddings = HuggingFaceEmbeddings( |
|
|
model_name="intfloat/multilingual-e5-large-instruct", |
|
|
task="feature-extraction", |
|
|
model_kwargs={"use_auth_token": HF_TOKEN} |
|
|
) |
|
|
|
|
|
|
|
|
uploaded_files = st.file_uploader( |
|
|
"Upload files (PDF, DOCX, TXT, CSV)", |
|
|
type=["pdf", "docx", "txt", "csv"], |
|
|
accept_multiple_files=True |
|
|
) |
|
|
|
|
|
@st.cache_resource |
|
|
def load_files(files): |
|
|
if not files: |
|
|
return None, [] |
|
|
|
|
|
loaders = [] |
|
|
temp_files = [] |
|
|
|
|
|
for file in files: |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.name)[-1]) as tmp: |
|
|
tmp.write(file.read()) |
|
|
temp_files.append(tmp.name) |
|
|
|
|
|
|
|
|
if file.name.endswith(".pdf"): |
|
|
loaders.append(PyPDFLoader(tmp.name)) |
|
|
elif file.name.endswith(".txt"): |
|
|
loaders.append(TextLoader(tmp.name)) |
|
|
elif file.name.endswith(".docx"): |
|
|
loaders.append(UnstructuredWordDocumentLoader(tmp.name)) |
|
|
elif file.name.endswith(".csv"): |
|
|
loaders.append(CSVLoader(tmp.name)) |
|
|
|
|
|
|
|
|
docs = [] |
|
|
for loader in loaders: |
|
|
docs.extend(loader.load()) |
|
|
|
|
|
|
|
|
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) |
|
|
split_docs = splitter.split_documents(docs) |
|
|
|
|
|
|
|
|
vectorstore = FAISS.from_documents(split_docs, embeddings) |
|
|
|
|
|
return vectorstore, temp_files |
|
|
|
|
|
vectorstore, temp_files = load_files(uploaded_files) if uploaded_files else (None, []) |
|
|
|
|
|
if vectorstore: |
|
|
retriever = vectorstore.as_retriever() |
|
|
|
|
|
|
|
|
chat_prompt = ChatPromptTemplate.from_template( |
|
|
"""Use the context below to answer the question. |
|
|
|
|
|
Context: |
|
|
{context} |
|
|
|
|
|
Question: |
|
|
{question}""" |
|
|
) |
|
|
|
|
|
|
|
|
llm = HuggingFaceEndpoint( |
|
|
repo_id="AI-Sweden-Models/Llama-3-8B-instruct", |
|
|
task="text-generation", |
|
|
temperature=0.2, |
|
|
max_new_tokens=512, |
|
|
model_kwargs={"use_auth_token": HF_TOKEN} |
|
|
) |
|
|
|
|
|
|
|
|
rag_chain = ( |
|
|
{ |
|
|
"context": retriever | (lambda docs: "\n\n".join(d.page_content for d in docs)), |
|
|
"question": RunnablePassthrough(), |
|
|
} |
|
|
| chat_prompt |
|
|
| llm |
|
|
) |
|
|
|
|
|
|
|
|
if "messages" not in st.session_state: |
|
|
st.session_state.messages = [] |
|
|
|
|
|
|
|
|
for msg in st.session_state.messages: |
|
|
st.chat_message(msg["role"]).markdown(msg["content"]) |
|
|
|
|
|
|
|
|
user_input = st.chat_input("Ask something...") |
|
|
if user_input: |
|
|
st.chat_message("user").markdown(user_input) |
|
|
st.session_state.messages.append({"role": "user", "content": user_input}) |
|
|
|
|
|
response = rag_chain.invoke(user_input) |
|
|
answer = response.content |
|
|
st.chat_message("assistant").markdown(answer) |
|
|
st.session_state.messages.append({"role": "assistant", "content": answer}) |
|
|
|
|
|
else: |
|
|
st.warning("Upload files to start querying.") |
|
|
|
|
|
|
|
|
if st.button("Clear Chat"): |
|
|
st.session_state.messages = [] |
|
|
for path in temp_files: |
|
|
try: |
|
|
os.remove(path) |
|
|
except: |
|
|
pass |
|
|
st.experimental_rerun() |
|
|
|