File size: 4,177 Bytes
8a088a6
c94a869
 
 
 
 
f74d054
 
 
 
 
 
c94a869
 
4230a34
c94a869
 
 
7185776
c94a869
a56d871
7185776
c94a869
 
 
 
 
 
 
 
 
 
 
 
 
 
7185776
a56d871
 
7185776
 
 
 
 
 
c94a869
7185776
 
c94a869
 
7185776
c94a869
 
 
 
 
 
7185776
c94a869
7185776
c94a869
7185776
c94a869
7185776
c94a869
8a088a6
c94a869
a56d871
7185776
a56d871
 
c94a869
a56d871
 
7185776
c94a869
 
7185776
c94a869
7185776
c94a869
7185776
 
 
c94a869
 
 
 
 
 
 
 
 
 
 
 
 
4230a34
 
 
 
c94a869
 
7185776
 
c94a869
 
 
 
 
 
 
 
7185776
 
c94a869
a56d871
7185776
 
f74d054
a56d871
 
7185776
f74d054
a56d871
7185776
 
 
 
c94a869
 
7185776
 
c94a869
7185776
a56d871
8a088a6
f74d054
a56d871
7185776
c94a869
 
 
 
 
7185776
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import streamlit as st
import tempfile
import os
import numpy as np

# LangChain community loaders & FAISS
from langchain_community.document_loaders import (
    PyPDFLoader,
    TextLoader,
    UnstructuredWordDocumentLoader,
    CSVLoader
)
from langchain_community.vectorstores import FAISS
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEndpoint
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough

st.set_page_config(page_title="Ask RAG - HF Space", layout="wide")
st.title("Ask RAG - HuggingFace Space")

# HuggingFace API key (set via Space Secrets)
HF_TOKEN = os.environ.get("HUGGINGFACE_API_KEY")
if not HF_TOKEN:
    st.error("Please set the HuggingFace API key in your Space secrets!")
    st.stop()

# Wrapper embeddings via HF API
embeddings = HuggingFaceEmbeddings(
    model_name="intfloat/multilingual-e5-large-instruct",
    task="feature-extraction",
    model_kwargs={"use_auth_token": HF_TOKEN}
)

# Upload files
uploaded_files = st.file_uploader(
    "Upload files (PDF, DOCX, TXT, CSV)",
    type=["pdf", "docx", "txt", "csv"],
    accept_multiple_files=True
)

@st.cache_resource
def load_files(files):
    if not files:
        return None, []

    loaders = []
    temp_files = []

    for file in files:
        # Save temp file
        with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.name)[-1]) as tmp:
            tmp.write(file.read())
            temp_files.append(tmp.name)

        # Choose loader
        if file.name.endswith(".pdf"):
            loaders.append(PyPDFLoader(tmp.name))
        elif file.name.endswith(".txt"):
            loaders.append(TextLoader(tmp.name))
        elif file.name.endswith(".docx"):
            loaders.append(UnstructuredWordDocumentLoader(tmp.name))
        elif file.name.endswith(".csv"):
            loaders.append(CSVLoader(tmp.name))

    # Load documents
    docs = []
    for loader in loaders:
        docs.extend(loader.load())

    # Split documents
    splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
    split_docs = splitter.split_documents(docs)

    # Create FAISS vectorstore
    vectorstore = FAISS.from_documents(split_docs, embeddings)

    return vectorstore, temp_files

vectorstore, temp_files = load_files(uploaded_files) if uploaded_files else (None, [])

if vectorstore:
    retriever = vectorstore.as_retriever()

    # Chat prompt
    chat_prompt = ChatPromptTemplate.from_template(
        """Use the context below to answer the question.

Context:
{context}

Question:
{question}"""
    )

    # LLM via HF Inference
    llm = HuggingFaceEndpoint(
        repo_id="AI-Sweden-Models/Llama-3-8B-instruct",
        task="text-generation",
        temperature=0.2,
        max_new_tokens=512,
        model_kwargs={"use_auth_token": HF_TOKEN}
    )

    # Build RAG chain
    rag_chain = (
        {
            "context": retriever | (lambda docs: "\n\n".join(d.page_content for d in docs)),
            "question": RunnablePassthrough(),
        }
        | chat_prompt
        | llm
    )

    # Session state
    if "messages" not in st.session_state:
        st.session_state.messages = []

    # Display previous messages
    for msg in st.session_state.messages:
        st.chat_message(msg["role"]).markdown(msg["content"])

    # User input
    user_input = st.chat_input("Ask something...")
    if user_input:
        st.chat_message("user").markdown(user_input)
        st.session_state.messages.append({"role": "user", "content": user_input})

        response = rag_chain.invoke(user_input)
        answer = response.content
        st.chat_message("assistant").markdown(answer)
        st.session_state.messages.append({"role": "assistant", "content": answer})

else:
    st.warning("Upload files to start querying.")

# Clear chat button
if st.button("Clear Chat"):
    st.session_state.messages = []
    for path in temp_files:
        try:
            os.remove(path)
        except:
            pass
    st.experimental_rerun()