AnwinMJ commited on
Commit
bf4298a
Β·
verified Β·
1 Parent(s): c491ca3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -35
app.py CHANGED
@@ -1,20 +1,20 @@
1
  import os
2
  import gradio as gr
 
 
 
3
  from langchain_community.embeddings import HuggingFaceEmbeddings
4
  from langchain_community.vectorstores import Chroma
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
  from langchain.document_loaders import PyPDFLoader
7
  from langchain.chains import RetrievalQA
8
  from langchain.llms.base import LLM
9
- from typing import List, Optional
10
  from groq import Groq
11
- import tempfile
12
- import shutil
13
 
14
- # Custom LLM using Groq
15
  class GroqLLM(LLM):
16
  model: str = "llama3-8b-8192"
17
- api_key: str = os.environ.get("GROQ_API_KEY") # Use env var for security
18
  temperature: float = 0.7
19
 
20
  def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
@@ -34,64 +34,68 @@ class GroqLLM(LLM):
34
  def _llm_type(self) -> str:
35
  return "groq-llm"
36
 
37
- # Global cache to reuse vectorstore during the session
38
- vectorstore_cache = {}
 
 
 
 
 
39
 
40
- def process_pdf(file_obj):
41
- # Save uploaded PDF to temp directory
42
  with tempfile.TemporaryDirectory() as temp_dir:
43
- file_path = os.path.join(temp_dir, file_obj.name)
44
- with open(file_path, "wb") as f:
45
- f.write(file_obj.read())
46
 
47
- # Load and split
48
- loader = PyPDFLoader(file_path)
49
  documents = loader.load()
50
 
51
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
52
- docs = text_splitter.split_documents(documents)
53
 
 
54
  embedding = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
55
-
56
- # Create persistent Chroma DB
57
- persist_dir = os.path.join(temp_dir, "chroma_db")
58
- vectorstore = Chroma.from_documents(docs, embedding, persist_directory=persist_dir)
59
  vectorstore.persist()
60
 
61
- # Store for session use
62
- vectorstore_cache["retriever"] = vectorstore.as_retriever()
63
-
64
- return "PDF processed and ready. You can now ask questions."
65
 
 
 
 
66
  def ask_question(query):
67
- if "retriever" not in vectorstore_cache:
68
- return "Please upload a PDF first."
 
69
 
70
  llm = GroqLLM()
71
  qa_chain = RetrievalQA.from_chain_type(
72
  llm=llm,
73
- retriever=vectorstore_cache["retriever"],
74
  return_source_documents=True
75
  )
 
76
  result = qa_chain({"query": query})
77
  answer = result["result"]
78
- sources = "\n".join([doc.metadata.get("source", "No metadata") for doc in result["source_documents"]])
79
- return f"### Answer:\n{answer}\n\n### Sources:\n{sources}"
80
 
 
81
  with gr.Blocks() as demo:
82
- gr.Markdown("## πŸ“„ PDF Question Answering Bot (Groq + HuggingFace + LangChain)")
83
 
84
  with gr.Row():
85
- pdf_file = gr.File(label="Upload your PDF")
86
  upload_btn = gr.Button("Process PDF")
 
87
 
88
- upload_output = gr.Textbox(label="Status", interactive=False)
89
- upload_btn.click(process_pdf, inputs=pdf_file, outputs=upload_output)
90
 
91
- query = gr.Textbox(label="Ask a question")
 
92
  answer_output = gr.Markdown()
93
- query_btn = gr.Button("Get Answer")
94
 
95
- query_btn.click(ask_question, inputs=query, outputs=answer_output)
96
 
97
  demo.launch()
 
1
  import os
2
  import gradio as gr
3
+ import tempfile
4
+ from typing import List, Optional
5
+
6
  from langchain_community.embeddings import HuggingFaceEmbeddings
7
  from langchain_community.vectorstores import Chroma
8
  from langchain.text_splitter import RecursiveCharacterTextSplitter
9
  from langchain.document_loaders import PyPDFLoader
10
  from langchain.chains import RetrievalQA
11
  from langchain.llms.base import LLM
 
12
  from groq import Groq
 
 
13
 
14
+ # ---- Custom GroqLLM class using LangChain LLM base ----
15
  class GroqLLM(LLM):
16
  model: str = "llama3-8b-8192"
17
+ api_key: str = os.environ.get("GROQ_API_KEY") # Load from HF secrets
18
  temperature: float = 0.7
19
 
20
  def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
 
34
  def _llm_type(self) -> str:
35
  return "groq-llm"
36
 
37
+ # Global cache for vectorstore
38
+ rag_context = {"retriever": None}
39
+
40
+ # ---- Step 1: Upload & Embed PDF ----
41
+ def process_pdf(file):
42
+ if file is None:
43
+ return "❌ Please upload a PDF."
44
 
 
 
45
  with tempfile.TemporaryDirectory() as temp_dir:
46
+ temp_pdf_path = os.path.join(temp_dir, file.name)
47
+ with open(temp_pdf_path, "wb") as f:
48
+ f.write(file.read())
49
 
50
+ # Load and split PDF
51
+ loader = PyPDFLoader(temp_pdf_path)
52
  documents = loader.load()
53
 
54
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
55
+ chunks = text_splitter.split_documents(documents)
56
 
57
+ # Embeddings and vectorstore
58
  embedding = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
59
+ vectorstore = Chroma.from_documents(chunks, embedding, persist_directory=temp_dir)
 
 
 
60
  vectorstore.persist()
61
 
62
+ # Store retriever in session
63
+ rag_context["retriever"] = vectorstore.as_retriever()
 
 
64
 
65
+ return "βœ… PDF processed and ready. Ask your questions!"
66
+
67
+ # ---- Step 2: Ask questions to the RAG chain ----
68
  def ask_question(query):
69
+ retriever = rag_context.get("retriever")
70
+ if retriever is None:
71
+ return "❌ Please upload and process a PDF first."
72
 
73
  llm = GroqLLM()
74
  qa_chain = RetrievalQA.from_chain_type(
75
  llm=llm,
76
+ retriever=retriever,
77
  return_source_documents=True
78
  )
79
+
80
  result = qa_chain({"query": query})
81
  answer = result["result"]
82
+ return f"### Answer:\n{answer}"
 
83
 
84
+ # ---- Gradio UI ----
85
  with gr.Blocks() as demo:
86
+ gr.Markdown("# πŸ“š RAG Chatbot with Groq & LangChain\nUpload a PDF, then ask questions about it!")
87
 
88
  with gr.Row():
89
+ pdf_input = gr.File(label="Upload PDF", file_types=[".pdf"])
90
  upload_btn = gr.Button("Process PDF")
91
+ upload_status = gr.Textbox(label="Status", interactive=False)
92
 
93
+ upload_btn.click(process_pdf, inputs=pdf_input, outputs=upload_status)
 
94
 
95
+ query_input = gr.Textbox(label="Ask a question")
96
+ ask_btn = gr.Button("Get Answer")
97
  answer_output = gr.Markdown()
 
98
 
99
+ ask_btn.click(ask_question, inputs=query_input, outputs=answer_output)
100
 
101
  demo.launch()