File size: 4,177 Bytes
8a088a6 c94a869 f74d054 c94a869 4230a34 c94a869 7185776 c94a869 a56d871 7185776 c94a869 7185776 a56d871 7185776 c94a869 7185776 c94a869 7185776 c94a869 7185776 c94a869 7185776 c94a869 7185776 c94a869 7185776 c94a869 8a088a6 c94a869 a56d871 7185776 a56d871 c94a869 a56d871 7185776 c94a869 7185776 c94a869 7185776 c94a869 7185776 c94a869 4230a34 c94a869 7185776 c94a869 7185776 c94a869 a56d871 7185776 f74d054 a56d871 7185776 f74d054 a56d871 7185776 c94a869 7185776 c94a869 7185776 a56d871 8a088a6 f74d054 a56d871 7185776 c94a869 7185776 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import streamlit as st
import tempfile
import os
import numpy as np
# LangChain community loaders & FAISS
from langchain_community.document_loaders import (
PyPDFLoader,
TextLoader,
UnstructuredWordDocumentLoader,
CSVLoader
)
from langchain_community.vectorstores import FAISS
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEndpoint
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
st.set_page_config(page_title="Ask RAG - HF Space", layout="wide")
st.title("Ask RAG - HuggingFace Space")
# HuggingFace API key (set via Space Secrets)
HF_TOKEN = os.environ.get("HUGGINGFACE_API_KEY")
if not HF_TOKEN:
st.error("Please set the HuggingFace API key in your Space secrets!")
st.stop()
# Wrapper embeddings via HF API
embeddings = HuggingFaceEmbeddings(
model_name="intfloat/multilingual-e5-large-instruct",
task="feature-extraction",
model_kwargs={"use_auth_token": HF_TOKEN}
)
# Upload files
uploaded_files = st.file_uploader(
"Upload files (PDF, DOCX, TXT, CSV)",
type=["pdf", "docx", "txt", "csv"],
accept_multiple_files=True
)
@st.cache_resource
def load_files(files):
if not files:
return None, []
loaders = []
temp_files = []
for file in files:
# Save temp file
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.name)[-1]) as tmp:
tmp.write(file.read())
temp_files.append(tmp.name)
# Choose loader
if file.name.endswith(".pdf"):
loaders.append(PyPDFLoader(tmp.name))
elif file.name.endswith(".txt"):
loaders.append(TextLoader(tmp.name))
elif file.name.endswith(".docx"):
loaders.append(UnstructuredWordDocumentLoader(tmp.name))
elif file.name.endswith(".csv"):
loaders.append(CSVLoader(tmp.name))
# Load documents
docs = []
for loader in loaders:
docs.extend(loader.load())
# Split documents
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
split_docs = splitter.split_documents(docs)
# Create FAISS vectorstore
vectorstore = FAISS.from_documents(split_docs, embeddings)
return vectorstore, temp_files
vectorstore, temp_files = load_files(uploaded_files) if uploaded_files else (None, [])
if vectorstore:
retriever = vectorstore.as_retriever()
# Chat prompt
chat_prompt = ChatPromptTemplate.from_template(
"""Use the context below to answer the question.
Context:
{context}
Question:
{question}"""
)
# LLM via HF Inference
llm = HuggingFaceEndpoint(
repo_id="AI-Sweden-Models/Llama-3-8B-instruct",
task="text-generation",
temperature=0.2,
max_new_tokens=512,
model_kwargs={"use_auth_token": HF_TOKEN}
)
# Build RAG chain
rag_chain = (
{
"context": retriever | (lambda docs: "\n\n".join(d.page_content for d in docs)),
"question": RunnablePassthrough(),
}
| chat_prompt
| llm
)
# Session state
if "messages" not in st.session_state:
st.session_state.messages = []
# Display previous messages
for msg in st.session_state.messages:
st.chat_message(msg["role"]).markdown(msg["content"])
# User input
user_input = st.chat_input("Ask something...")
if user_input:
st.chat_message("user").markdown(user_input)
st.session_state.messages.append({"role": "user", "content": user_input})
response = rag_chain.invoke(user_input)
answer = response.content
st.chat_message("assistant").markdown(answer)
st.session_state.messages.append({"role": "assistant", "content": answer})
else:
st.warning("Upload files to start querying.")
# Clear chat button
if st.button("Clear Chat"):
st.session_state.messages = []
for path in temp_files:
try:
os.remove(path)
except:
pass
st.experimental_rerun()
|