RAG / src /streamlit_app.py
Hissen's picture
Update src/streamlit_app.py
c94a869 verified
import streamlit as st
import tempfile
import os
import numpy as np
# LangChain community loaders & FAISS
from langchain_community.document_loaders import (
PyPDFLoader,
TextLoader,
UnstructuredWordDocumentLoader,
CSVLoader
)
from langchain_community.vectorstores import FAISS
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEndpoint
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
st.set_page_config(page_title="Ask RAG - HF Space", layout="wide")
st.title("Ask RAG - HuggingFace Space")
# HuggingFace API key (set via Space Secrets)
HF_TOKEN = os.environ.get("HUGGINGFACE_API_KEY")
if not HF_TOKEN:
st.error("Please set the HuggingFace API key in your Space secrets!")
st.stop()
# Wrapper embeddings via HF API
embeddings = HuggingFaceEmbeddings(
model_name="intfloat/multilingual-e5-large-instruct",
task="feature-extraction",
model_kwargs={"use_auth_token": HF_TOKEN}
)
# Upload files
uploaded_files = st.file_uploader(
"Upload files (PDF, DOCX, TXT, CSV)",
type=["pdf", "docx", "txt", "csv"],
accept_multiple_files=True
)
@st.cache_resource
def load_files(files):
if not files:
return None, []
loaders = []
temp_files = []
for file in files:
# Save temp file
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.name)[-1]) as tmp:
tmp.write(file.read())
temp_files.append(tmp.name)
# Choose loader
if file.name.endswith(".pdf"):
loaders.append(PyPDFLoader(tmp.name))
elif file.name.endswith(".txt"):
loaders.append(TextLoader(tmp.name))
elif file.name.endswith(".docx"):
loaders.append(UnstructuredWordDocumentLoader(tmp.name))
elif file.name.endswith(".csv"):
loaders.append(CSVLoader(tmp.name))
# Load documents
docs = []
for loader in loaders:
docs.extend(loader.load())
# Split documents
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
split_docs = splitter.split_documents(docs)
# Create FAISS vectorstore
vectorstore = FAISS.from_documents(split_docs, embeddings)
return vectorstore, temp_files
vectorstore, temp_files = load_files(uploaded_files) if uploaded_files else (None, [])
if vectorstore:
retriever = vectorstore.as_retriever()
# Chat prompt
chat_prompt = ChatPromptTemplate.from_template(
"""Use the context below to answer the question.
Context:
{context}
Question:
{question}"""
)
# LLM via HF Inference
llm = HuggingFaceEndpoint(
repo_id="AI-Sweden-Models/Llama-3-8B-instruct",
task="text-generation",
temperature=0.2,
max_new_tokens=512,
model_kwargs={"use_auth_token": HF_TOKEN}
)
# Build RAG chain
rag_chain = (
{
"context": retriever | (lambda docs: "\n\n".join(d.page_content for d in docs)),
"question": RunnablePassthrough(),
}
| chat_prompt
| llm
)
# Session state
if "messages" not in st.session_state:
st.session_state.messages = []
# Display previous messages
for msg in st.session_state.messages:
st.chat_message(msg["role"]).markdown(msg["content"])
# User input
user_input = st.chat_input("Ask something...")
if user_input:
st.chat_message("user").markdown(user_input)
st.session_state.messages.append({"role": "user", "content": user_input})
response = rag_chain.invoke(user_input)
answer = response.content
st.chat_message("assistant").markdown(answer)
st.session_state.messages.append({"role": "assistant", "content": answer})
else:
st.warning("Upload files to start querying.")
# Clear chat button
if st.button("Clear Chat"):
st.session_state.messages = []
for path in temp_files:
try:
os.remove(path)
except:
pass
st.experimental_rerun()