MBilal-72 commited on
Commit
b7b493d
Β·
verified Β·
1 Parent(s): c5e15d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -22
app.py CHANGED
@@ -9,68 +9,69 @@ from langchain.chains import RetrievalQA
9
  from langchain.prompts import PromptTemplate
10
  from langchain_groq import GroqLLM
11
 
12
- # Set environment variables (You can also use os.environ or Streamlit secrets)
13
- GROQ_API_KEY = os.getenv("GROQ_API_KEY")
14
- HUGGINGFACE_API_KEY = os.getenv("HUGGINGFACE_API_KEY")
15
 
16
- # Initialize Groq LLM
17
  llm = GroqLLM(
18
  api_key=GROQ_API_KEY,
19
- model="llama3-8b-8192", # <- correct param
20
  temperature=0.1
21
  )
22
 
23
- # HuggingFace Embeddings
24
- embedding = HuggingFaceEmbeddings()
 
 
 
 
25
 
 
26
  st.title("πŸ“„ RAG Chat with Groq + HuggingFace")
27
 
28
- # Upload PDF
29
  uploaded_file = st.file_uploader("Upload a PDF file", type=["pdf"])
30
-
31
  user_query = st.text_input("Ask something about the document")
32
  submit_button = st.button("Submit")
33
 
34
  if uploaded_file and submit_button:
35
- # Save PDF temporarily
36
  with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file:
37
  tmp_file.write(uploaded_file.read())
38
  tmp_path = tmp_file.name
39
 
40
- # Load and split
41
  loader = PyPDFLoader(tmp_path)
42
  pages = loader.load_and_split()
43
 
44
- # Create FAISS vectorstore
45
  vectorstore = FAISS.from_documents(pages, embedding)
46
  retriever = vectorstore.as_retriever()
47
 
48
- # Custom prompt (optional)
49
  prompt_template = PromptTemplate(
50
  input_variables=["context", "question"],
51
  template="""
52
- Use the following context to answer the question. Be concise and accurate.
53
-
54
  Context: {context}
55
-
56
  Question: {question}
57
- """
58
  )
59
 
60
- # Create QA chain
61
  qa_chain = RetrievalQA.from_chain_type(
62
  llm=llm,
 
63
  retriever=retriever,
64
  return_source_documents=True,
65
  chain_type_kwargs={"prompt": prompt_template}
66
  )
67
 
68
- # Run QA
69
  result = qa_chain({"query": user_query})
70
  st.markdown("### πŸ’¬ Answer")
71
  st.write(result["result"])
72
 
73
- # Optional: Show sources
74
  with st.expander("πŸ“„ Sources"):
75
- for doc in result["source_documents"]:
76
- st.write(doc.metadata["source"])
 
9
  from langchain.prompts import PromptTemplate
10
  from langchain_groq import GroqLLM
11
 
12
+ # --- Environment Variable Setup ---
13
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY", "your-groq-api-key")
14
+ HUGGINGFACE_API_KEY = os.getenv("HUGGINGFACE_API_KEY", "your-huggingface-api-key")
15
 
16
+ # --- Groq LLM Initialization ---
17
  llm = GroqLLM(
18
  api_key=GROQ_API_KEY,
19
+ model="llama3-8b-8192",
20
  temperature=0.1
21
  )
22
 
23
+ # --- HuggingFace Embeddings (add a default model name if needed) ---
24
+ embedding = HuggingFaceEmbeddings(
25
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
26
+ cache_folder="./hf_cache",
27
+ huggingfacehub_api_token=HUGGINGFACE_API_KEY
28
+ )
29
 
30
+ # --- Streamlit UI ---
31
  st.title("πŸ“„ RAG Chat with Groq + HuggingFace")
32
 
 
33
  uploaded_file = st.file_uploader("Upload a PDF file", type=["pdf"])
 
34
  user_query = st.text_input("Ask something about the document")
35
  submit_button = st.button("Submit")
36
 
37
  if uploaded_file and submit_button:
 
38
  with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file:
39
  tmp_file.write(uploaded_file.read())
40
  tmp_path = tmp_file.name
41
 
42
+ # --- Load and Split PDF ---
43
  loader = PyPDFLoader(tmp_path)
44
  pages = loader.load_and_split()
45
 
46
+ # --- FAISS Vector Store ---
47
  vectorstore = FAISS.from_documents(pages, embedding)
48
  retriever = vectorstore.as_retriever()
49
 
50
+ # --- Optional Custom Prompt ---
51
  prompt_template = PromptTemplate(
52
  input_variables=["context", "question"],
53
  template="""
54
+ You are an intelligent assistant. Use the following context to answer the question accurately.
 
55
  Context: {context}
 
56
  Question: {question}
57
+ Answer:"""
58
  )
59
 
60
+ # --- RetrievalQA Chain ---
61
  qa_chain = RetrievalQA.from_chain_type(
62
  llm=llm,
63
+ chain_type="stuff",
64
  retriever=retriever,
65
  return_source_documents=True,
66
  chain_type_kwargs={"prompt": prompt_template}
67
  )
68
 
69
+ # --- Run the Chain ---
70
  result = qa_chain({"query": user_query})
71
  st.markdown("### πŸ’¬ Answer")
72
  st.write(result["result"])
73
 
74
+ # --- Optional: Show Source Documents ---
75
  with st.expander("πŸ“„ Sources"):
76
+ for i, doc in enumerate(result["source_documents"]):
77
+ st.write(f"**Page {i+1}** β€” {doc.metadata.get('source', 'Unknown')}")