imeesam commited on
Commit
cbcf4a5
Β·
verified Β·
1 Parent(s): ba45220

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -0
app.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py β€” PDF upload version
2
+
3
+ import os
4
+ import gradio as gr
5
+ from langchain_community.document_loaders import PyPDFLoader
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+ from langchain_community.embeddings import HuggingFaceEmbeddings
8
+ from langchain_community.vectorstores import FAISS
9
+ from langchain_groq import ChatGroq
10
+ from langchain.prompts import ChatPromptTemplate
11
+ from langchain.schema.runnable import RunnablePassthrough
12
+ from langchain.schema.output_parser import StrOutputParser
13
+
14
+ # ── Config ─────────────────────────────────────────────────────────────────────
15
+ EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
16
+ GROQ_MODEL = "llama-3.1-8b-instant"
17
+ CHUNK_SIZE = 800 # larger chunks work better for dense PDFs
18
+ CHUNK_OVERLAP = 100
19
+ TOP_K = 4
20
+
21
+ RAG_PROMPT = ChatPromptTemplate.from_template("""
22
+ You are a helpful assistant. Answer the question using ONLY the context below.
23
+ If the answer is not in the context, say "I don't have enough information."
24
+
25
+ Context:
26
+ {context}
27
+
28
+ Question: {question}
29
+
30
+ Answer:
31
+ """)
32
+
33
+ # ── Load embedding model once at startup (slow, ~30s) ─────────────────────────
34
+ print("Loading embedding model...")
35
+ embeddings = HuggingFaceEmbeddings(
36
+ model_name=EMBED_MODEL,
37
+ model_kwargs={"device": "cpu"},
38
+ encode_kwargs={"normalize_embeddings": True}
39
+ )
40
+ print("Embeddings ready.")
41
+
42
+ # Global state β€” replaced whenever new PDFs are uploaded
43
+ rag_chain = None
44
+ vectorstore = None
45
+
46
+ # ── Core logic ─────────────────────────────────────────────────────────────────
47
+ def process_pdfs(pdf_files):
48
+ """
49
+ Called when user clicks 'Process PDFs'.
50
+ pdf_files: list of temp file paths Gradio provides.
51
+ Returns a status message.
52
+ """
53
+ global rag_chain, vectorstore
54
+
55
+ if not pdf_files:
56
+ return "No files uploaded. Please upload at least one PDF."
57
+
58
+ all_chunks = []
59
+ splitter = RecursiveCharacterTextSplitter(
60
+ chunk_size=CHUNK_SIZE,
61
+ chunk_overlap=CHUNK_OVERLAP,
62
+ separators=["\n\n", "\n", ".", " ", ""]
63
+ )
64
+
65
+ for pdf_file in pdf_files:
66
+ try:
67
+ # pdf_file is a temp path string like /tmp/gradio/abc123/file.pdf
68
+ loader = PyPDFLoader(pdf_file)
69
+ pages = loader.load() # one Document per page
70
+
71
+ # Add filename to metadata for traceability
72
+ filename = os.path.basename(pdf_file)
73
+ for page in pages:
74
+ page.metadata["source"] = filename
75
+
76
+ chunks = splitter.split_documents(pages)
77
+ all_chunks.extend(chunks)
78
+ print(f"Loaded {filename}: {len(pages)} pages β†’ {len(chunks)} chunks")
79
+
80
+ except Exception as e:
81
+ return f"Error loading {pdf_file}: {str(e)}"
82
+
83
+ if not all_chunks:
84
+ return "No text could be extracted. Check if the PDFs contain selectable text (not scanned images)."
85
+
86
+ # Build FAISS index from all chunks
87
+ print(f"Indexing {len(all_chunks)} chunks...")
88
+ vectorstore = FAISS.from_documents(all_chunks, embeddings)
89
+ retriever = vectorstore.as_retriever(search_kwargs={"k": TOP_K})
90
+
91
+ # Build LLM
92
+ llm = ChatGroq(
93
+ model=GROQ_MODEL,
94
+ temperature=0.2,
95
+ max_tokens=1024,
96
+ groq_api_key=os.environ["GROQ_API_KEY"]
97
+ )
98
+
99
+ def format_docs(docs):
100
+ # Include source filename in context so the LLM knows where info came from
101
+ return "\n\n".join(
102
+ f"[Source: {d.metadata.get('source', 'unknown')}, "
103
+ f"Page {d.metadata.get('page', '?')+1}]\n{d.page_content}"
104
+ for d in docs
105
+ )
106
+
107
+ rag_chain = (
108
+ {"context": retriever | format_docs, "question": RunnablePassthrough()}
109
+ | RAG_PROMPT
110
+ | llm
111
+ | StrOutputParser()
112
+ )
113
+
114
+ total_pages = sum(
115
+ len(PyPDFLoader(f).load()) for f in pdf_files
116
+ )
117
+ return (
118
+ f"Ready! Indexed {len(pdf_files)} PDF(s), "
119
+ f"{total_pages} pages, "
120
+ f"{len(all_chunks)} chunks. You can now ask questions."
121
+ )
122
+
123
+
124
+ def chat(message, history):
125
+ if rag_chain is None:
126
+ return "", history + [[message, "Please upload and process PDFs first."]]
127
+ if not message.strip():
128
+ return "", history
129
+ try:
130
+ response = rag_chain.invoke(message)
131
+ except Exception as e:
132
+ response = f"Error: {str(e)}"
133
+ history.append([message, response])
134
+ return "", history
135
+
136
+
137
+ # ── Gradio UI ──────────────────────────────────────────────────────────────────
138
+ with gr.Blocks(title="PDF RAG Chatbot", theme=gr.themes.Soft()) as demo:
139
+ gr.Markdown("## PDF RAG Chatbot\nUpload your PDFs, then ask questions about them.")
140
+
141
+ with gr.Row():
142
+ with gr.Column(scale=1):
143
+ pdf_input = gr.File(
144
+ label="Upload PDFs",
145
+ file_types=[".pdf"],
146
+ file_count="multiple" # allow multiple files at once
147
+ )
148
+ process_btn = gr.Button("Process PDFs", variant="primary")
149
+ status_box = gr.Textbox(
150
+ label="Status",
151
+ interactive=False,
152
+ placeholder="Upload PDFs and click Process..."
153
+ )
154
+
155
+ with gr.Column(scale=2):
156
+ chatbot = gr.Chatbot(height=450, label="Chat")
157
+ msg = gr.Textbox(placeholder="Ask a question about your PDFs...", label="Question")
158
+ clear = gr.Button("Clear chat")
159
+
160
+ # Wire up events
161
+ process_btn.click(
162
+ fn=process_pdfs,
163
+ inputs=[pdf_input],
164
+ outputs=[status_box]
165
+ )
166
+ msg.submit(chat, [msg, chatbot], [msg, chatbot])
167
+ clear.click(lambda: [], outputs=[chatbot])
168
+
169
+ if __name__ == "__main__":
170
+ demo.launch()