File size: 3,237 Bytes
736448d
837c8fa
1a7b2d4
837c8fa
 
2d6ed01
837c8fa
09aa142
837c8fa
0d5b491
 
 
1a7b2d4
0d5b491
5d4a40e
 
837c8fa
0d5b491
 
 
 
 
 
 
837c8fa
0d5b491
837c8fa
 
 
0d5b491
ea7b8ea
2ae095c
 
3099672
ea7b8ea
2ae095c
 
 
837c8fa
b7b493d
0d5b491
837c8fa
0d5b491
837c8fa
0d5b491
 
 
 
 
 
 
 
837c8fa
 
0d5b491
 
 
 
 
 
 
 
736448d
0d5b491
 
837c8fa
0d5b491
 
 
 
 
 
 
 
 
 
837c8fa
736448d
0d5b491
837c8fa
 
 
0d5b491
 
 
837c8fa
 
b7b493d
837c8fa
736448d
0d5b491
837c8fa
 
b7b493d
837c8fa
 
 
 
09aa142
0d5b491
837c8fa
0d5b491
 
837c8fa
 
736448d
0d5b491
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import tempfile
import streamlit as st

from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain.schema import Document
# from langchain_groq import GroqLLM
from langchain_groq import ChatGroq

# --- Environment Variables ---
GROQ_API_KEY = os.getenv("GROQ_API_KEY", "your-groq-api-key")
HUGGINGFACE_API_KEY = os.getenv("HUGGINGFACE_API_KEY", "your-huggingface-api-key")

# --- Initialize Groq LLM ---
# llm = GroqLLM(
#     api_key=GROQ_API_KEY,
#     model="llama3-8b-8192",
#     temperature=0.1
# )
llm = ChatGroq(
    api_key=GROQ_API_KEY,
    model_name="llama3-8b-8192",  # Note: it's `model_name` not `model`
    temperature=0.1
)

# --- HuggingFace Embeddings ---
embedding = HuggingFaceEmbeddings(
    model_name="sentence-transformers/all-MiniLM-L6-v2",
    cache_folder="./hf_cache",
    # huggingfacehub_api_token=HUGGINGFACE_API_KEY
)
# embedding = HuggingFaceEmbeddings(
#     model_name="sentence-transformers/all-MiniLM-L6-v2"
# )

# --- Streamlit UI ---
st.title("πŸ“„πŸ“₯ Chat with PDF or Text using Groq + RAG")

# Option to upload PDF
uploaded_file = st.file_uploader("Upload a PDF file", type=["pdf"])

# Option to paste raw text
pasted_text = st.text_area("Or paste some text below:")

# User's question
user_query = st.text_input("Ask a question about the content")

# Submit button
submit_button = st.button("Submit")

if submit_button:
    documents = []

    # Handle uploaded PDF
    if uploaded_file:
        with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file:
            tmp_file.write(uploaded_file.read())
            tmp_path = tmp_file.name

        loader = PyPDFLoader(tmp_path)
        documents = loader.load_and_split()

    # Handle pasted text if no PDF
    elif pasted_text.strip():
        documents = [Document(page_content=pasted_text)]

    else:
        st.warning("Please upload a PDF or paste some text.")
        st.stop()

    # Create vector store
    vectorstore = FAISS.from_documents(documents, embedding)
    retriever = vectorstore.as_retriever()

    # Optional custom prompt
    prompt_template = PromptTemplate(
        input_variables=["context", "question"],
        template="""
        You are an AI assistant. Use the following context to answer the question.
        Be concise, accurate, and helpful.

        Context: {context}
        Question: {question}
        Answer:"""
    )

    # QA Chain
    qa_chain = RetrievalQA.from_chain_type(
        llm=llm,
        chain_type="stuff",
        retriever=retriever,
        return_source_documents=True,
        chain_type_kwargs={"prompt": prompt_template}
    )

    # Run QA
    result = qa_chain({"query": user_query})

    # Show result
    st.markdown("### πŸ’¬ Answer")
    st.write(result["result"])

    # Show sources (only if from PDF)
    if uploaded_file:
        with st.expander("πŸ“„ Sources"):
            for i, doc in enumerate(result["source_documents"]):
                st.write(f"**Page {i+1}** β€” {doc.metadata.get('source', 'Unknown')}")