simran40 commited on
Commit
0fd613a
·
verified ·
1 Parent(s): 720e7c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -57
app.py CHANGED
@@ -2,31 +2,26 @@ import gradio as gr
2
  import fitz # PyMuPDF
3
  import re
4
  import faiss
5
- import torch
6
  import numpy as np
7
 
8
  from sentence_transformers import SentenceTransformer
9
- from transformers import AutoTokenizer, AutoModelForCausalLM
10
 
11
 
12
  # =================================================
13
  # MODEL LOADING (ONCE AT STARTUP)
14
  # =================================================
15
 
16
- # Better embedding model for Q&A
17
  embedding_model = SentenceTransformer("multi-qa-MiniLM-L6-cos-v1")
18
 
19
- # Lightweight open-source LLM (CPU friendly)
20
- LLM_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
21
-
22
- tokenizer = AutoTokenizer.from_pretrained(LLM_NAME)
23
- llm = AutoModelForCausalLM.from_pretrained(
24
- LLM_NAME,
25
- torch_dtype=torch.float32
26
  )
27
 
28
- llm.eval()
29
-
30
 
31
  # =================================================
32
  # PDF PROCESSING
@@ -41,13 +36,13 @@ def extract_text_from_pdf(pdf_path):
41
 
42
 
43
  def clean_text(text):
44
- # remove extra spaces
45
  text = re.sub(r"\s+", " ", text)
46
 
47
- # remove table of contents noise
48
  text = re.sub(r"Table of contents.*?Introduction", "", text, flags=re.I)
49
 
50
- # remove page numbers
51
  text = re.sub(r"\bPage \d+\b", "", text)
52
 
53
  return text.strip()
@@ -55,7 +50,7 @@ def clean_text(text):
55
 
56
  def chunk_text(text, chunk_size=350, overlap=80):
57
  """
58
- Smaller overlapping chunks improve semantic retrieval accuracy
59
  """
60
  chunks = []
61
  start = 0
@@ -85,7 +80,7 @@ def build_faiss_index(chunks):
85
 
86
  def retrieve_relevant_chunks(query, index, chunks, top_k=5):
87
  """
88
- Retrieve top-K chunks + re-rank by distance
89
  """
90
  query_embedding = embedding_model.encode([query]).astype("float32")
91
  distances, indices = index.search(query_embedding, top_k)
@@ -94,56 +89,41 @@ def retrieve_relevant_chunks(query, index, chunks, top_k=5):
94
  for rank, idx in enumerate(indices[0]):
95
  results.append((chunks[idx], distances[0][rank]))
96
 
97
- # re-rank: smaller distance = more relevant
98
  results.sort(key=lambda x: x[1])
99
 
100
  return [r[0] for r in results]
101
 
102
 
103
  # =================================================
104
- # ANSWER GENERATION (LLM)
105
  # =================================================
106
 
107
  def generate_answer(question, context_chunks):
108
- context = "\n\n".join(context_chunks)
109
-
110
- prompt = f"""
111
- You are a precise academic assistant.
112
-
113
- RULES:
114
- - Answer ONLY from the given context.
115
- - Do NOT add external knowledge.
116
- - Be concise and factual.
117
- - If the answer is missing, reply exactly:
118
- "Information not found in the document."
119
-
120
- CONTEXT:
121
- {context}
122
 
123
- QUESTION:
124
- {question}
125
-
126
- FINAL ANSWER:
127
- """
128
 
129
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
 
 
130
 
131
- with torch.no_grad():
132
- output = llm.generate(
133
- **inputs,
134
- max_new_tokens=180,
135
- temperature=0.1
136
- )
137
 
138
- decoded = tokenizer.decode(output[0], skip_special_tokens=True)
139
- return decoded.split("FINAL ANSWER:")[-1].strip()
140
 
141
 
142
  # =================================================
143
- # MAIN RAG PIPELINE
144
  # =================================================
145
 
146
- def pdf_rag_chat(pdf_file, question):
147
  if pdf_file is None or question.strip() == "":
148
  return "Please upload a PDF and enter a valid question."
149
 
@@ -154,13 +134,13 @@ def pdf_rag_chat(pdf_file, question):
154
  # 2. Chunking
155
  chunks = chunk_text(cleaned_text)
156
 
157
- # 3. Vector DB
158
  index, chunks = build_faiss_index(chunks)
159
 
160
- # 4. Retrieval
161
  relevant_chunks = retrieve_relevant_chunks(question, index, chunks)
162
 
163
- # 5. Answer generation
164
  return generate_answer(question, relevant_chunks)
165
 
166
 
@@ -171,11 +151,13 @@ def pdf_rag_chat(pdf_file, question):
171
  with gr.Blocks() as demo:
172
 
173
  gr.Markdown("""
174
- # 📄 PDF RAG Chatbot (Open-Source AI)
175
 
176
- Upload a **PDF document** and ask questions based **only on its content**.
177
- This system implements an **accuracy-optimized Retrieval Augmented Generation (RAG)** pipeline
178
- using **open-source Hugging Face models**, running on **free CPU**.
 
 
179
  """)
180
 
181
  with gr.Row():
@@ -200,7 +182,7 @@ with gr.Blocks() as demo:
200
  )
201
 
202
  submit_btn.click(
203
- fn=pdf_rag_chat,
204
  inputs=[pdf_input, question_input],
205
  outputs=answer_output
206
  )
 
2
  import fitz # PyMuPDF
3
  import re
4
  import faiss
 
5
  import numpy as np
6
 
7
  from sentence_transformers import SentenceTransformer
8
+ from transformers import pipeline
9
 
10
 
11
  # =================================================
12
  # MODEL LOADING (ONCE AT STARTUP)
13
  # =================================================
14
 
15
+ # Embedding model (good for question-answer retrieval)
16
  embedding_model = SentenceTransformer("multi-qa-MiniLM-L6-cos-v1")
17
 
18
+ # Extractive Question Answering model (HIGH ACCURACY)
19
+ qa_pipeline = pipeline(
20
+ "question-answering",
21
+ model="deepset/roberta-base-squad2",
22
+ tokenizer="deepset/roberta-base-squad2"
 
 
23
  )
24
 
 
 
25
 
26
  # =================================================
27
  # PDF PROCESSING
 
36
 
37
 
38
  def clean_text(text):
39
+ # Remove extra spaces
40
  text = re.sub(r"\s+", " ", text)
41
 
42
+ # Remove table of contents noise
43
  text = re.sub(r"Table of contents.*?Introduction", "", text, flags=re.I)
44
 
45
+ # Remove page numbers
46
  text = re.sub(r"\bPage \d+\b", "", text)
47
 
48
  return text.strip()
 
50
 
51
  def chunk_text(text, chunk_size=350, overlap=80):
52
  """
53
+ Smaller overlapping chunks improve accuracy
54
  """
55
  chunks = []
56
  start = 0
 
80
 
81
  def retrieve_relevant_chunks(query, index, chunks, top_k=5):
82
  """
83
+ Retrieve top-K chunks and re-rank by distance
84
  """
85
  query_embedding = embedding_model.encode([query]).astype("float32")
86
  distances, indices = index.search(query_embedding, top_k)
 
89
  for rank, idx in enumerate(indices[0]):
90
  results.append((chunks[idx], distances[0][rank]))
91
 
92
+ # Re-rank (lower distance = more relevant)
93
  results.sort(key=lambda x: x[1])
94
 
95
  return [r[0] for r in results]
96
 
97
 
98
  # =================================================
99
+ # ANSWER GENERATION (EXTRACTIVE QA – ACCURATE)
100
  # =================================================
101
 
102
  def generate_answer(question, context_chunks):
103
+ best_answer = ""
104
+ best_score = 0.0
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
+ for chunk in context_chunks:
107
+ result = qa_pipeline(
108
+ question=question,
109
+ context=chunk
110
+ )
111
 
112
+ if result["score"] > best_score:
113
+ best_score = result["score"]
114
+ best_answer = result["answer"]
115
 
116
+ if best_score < 0.25 or best_answer.strip() == "":
117
+ return "Information not found in the document."
 
 
 
 
118
 
119
+ return best_answer
 
120
 
121
 
122
  # =================================================
123
+ # MAIN PIPELINE
124
  # =================================================
125
 
126
+ def pdf_qa_chat(pdf_file, question):
127
  if pdf_file is None or question.strip() == "":
128
  return "Please upload a PDF and enter a valid question."
129
 
 
134
  # 2. Chunking
135
  chunks = chunk_text(cleaned_text)
136
 
137
+ # 3. Vector database
138
  index, chunks = build_faiss_index(chunks)
139
 
140
+ # 4. Retrieve relevant chunks
141
  relevant_chunks = retrieve_relevant_chunks(question, index, chunks)
142
 
143
+ # 5. Extractive QA
144
  return generate_answer(question, relevant_chunks)
145
 
146
 
 
151
  with gr.Blocks() as demo:
152
 
153
  gr.Markdown("""
154
+ # 📄 PDF Question Answering System (Accurate AI)
155
 
156
+ Upload a **PDF document** and ask questions.
157
+ The system uses **semantic retrieval + extractive AI**, ensuring
158
+ **accurate answers strictly from the document text**.
159
+
160
+ ---
161
  """)
162
 
163
  with gr.Row():
 
182
  )
183
 
184
  submit_btn.click(
185
+ fn=pdf_qa_chat,
186
  inputs=[pdf_input, question_input],
187
  outputs=answer_output
188
  )