chippyjolly commited on
Commit
a961c7a
·
verified ·
1 Parent(s): 91df8d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -98
app.py CHANGED
@@ -4,38 +4,39 @@ from PyPDF2 import PdfReader
4
  from langchain_text_splitters import RecursiveCharacterTextSplitter
5
  from langchain_community.vectorstores import FAISS
6
  from langchain_community.embeddings import HuggingFaceEmbeddings
 
7
  from langchain.chains.retrieval_qa.base import RetrievalQA
8
  from langchain.prompts import PromptTemplate
9
  from langchain_core.language_models.llms import LLM
10
  from langchain_core.callbacks import CallbackManagerForLLMRun
 
11
  from typing import Optional, List, Dict, Any
12
- import requests
13
  from dotenv import load_dotenv
14
  from groq import Groq
 
15
  import urllib.parse
16
- import feedparser # Added for the new function
17
 
18
  from numpy import dot
19
- from numpy.linalg import norm #newly added - to let the similar paper work
20
-
21
 
22
  # Load environment variables
23
  load_dotenv()
24
-
25
-
26
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
27
 
28
 
29
- # Custom wrapper for Groq to make it LangChain compatible
 
 
30
  class GroqWrapper(LLM):
31
  client: Any
32
  model_name: str = "llama-3.3-70b-versatile"
33
  temperature: float = 0.7
34
-
35
  @property
36
  def _llm_type(self) -> str:
37
  return "groq"
38
-
39
  def _call(
40
  self,
41
  prompt: str,
@@ -44,18 +45,22 @@ class GroqWrapper(LLM):
44
  **kwargs: Any,
45
  ) -> str:
46
  response = self.client.chat.completions.create(
47
- messages=[{"role": "user", "content": prompt}],
48
  model=self.model_name,
 
49
  temperature=self.temperature,
50
- **kwargs
51
  )
52
  return response.choices[0].message.content
53
 
54
- # Initialize global variables
 
55
  vectorstore = None
56
  qa_chain = None
57
  groq_llm = None
58
 
 
 
 
 
59
  def upload_pdf(file):
60
  global vectorstore, qa_chain, groq_llm
61
 
@@ -64,176 +69,202 @@ def upload_pdf(file):
64
  if groq_llm is None:
65
  groq_llm = GroqWrapper(client=Groq(api_key=GROQ_API_KEY))
66
 
67
- # Extract text
68
  text = "".join(page.extract_text() or "" for page in PdfReader(file).pages)
69
  if not text.strip():
70
  return "Error: No readable text found in PDF"
71
 
72
- # Chunk text
73
- text_splitter = RecursiveCharacterTextSplitter(
74
  chunk_size=1000,
75
  chunk_overlap=150,
76
  separators=["\n\n", "\n", ".", "?", "!"]
77
  )
78
- texts = text_splitter.split_text(text)
79
 
80
- # Embeddings
81
  embeddings = HuggingFaceEmbeddings(
82
  model_name="sentence-transformers/msmarco-MiniLM-L-12-v3"
83
  )
 
84
 
85
- vectorstore = FAISS.from_texts(texts, embeddings)
 
 
 
 
 
 
 
86
 
87
- # Custom prompt
88
- prompt_template = """
89
- Use only the following context to answer the question.
90
- Do NOT make up information. If the answer is not present, say "I don't know."
91
  Context:
92
  {context}
 
93
  Question: {question}
94
- Answer:
 
95
  """
96
- custom_prompt = PromptTemplate(
97
- template=prompt_template,
98
- input_variables=["context", "question"]
99
  )
100
 
101
- # QA chain with custom prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  qa_chain = RetrievalQA.from_chain_type(
103
  llm=groq_llm,
104
- chain_type="refine",
105
  retriever=vectorstore.as_retriever(),
 
106
  return_source_documents=True,
107
- chain_type_kwargs={"prompt": custom_prompt} # pass prompt here
 
 
 
108
  )
109
 
110
  return "PDF processed successfully!"
 
111
  except Exception as e:
112
  return f"Error: {str(e)}"
113
 
114
 
115
- # --- Ask questions ---
 
 
116
  def ask_question(query):
117
  global qa_chain
 
118
  if qa_chain is None:
119
  return "Please upload a PDF first.", ""
120
 
121
  try:
122
- # Simply call the chain, no need to override prompt
123
- result = qa_chain({"query": query}, return_only_outputs=False)
124
  answer = result["result"]
125
- sources = result.get("source_documents", [])
126
 
 
 
127
  if sources:
128
- source_text = "\n\n---\n".join([
129
- f"Source {i+1} (excerpt):\n{doc.page_content[:500]}{'...' if len(doc.page_content) > 500 else ''}"
130
  for i, doc in enumerate(sources)
131
- ])
132
  else:
133
- source_text = "No sources cited"
134
 
135
  return answer, source_text
136
 
137
  except Exception as e:
138
- return f"Error processing your question: {str(e)}", ""
139
 
140
-
141
 
142
- # --- Summarize PDF ---
143
- def summarize_pdf(num_points: int = 6):
144
- global vectorstore, groq_llm
 
 
145
  if vectorstore is None:
146
  return "Please upload a PDF first."
147
 
148
  try:
149
  docs = vectorstore.similarity_search("summary", k=5)
150
- context = "\n\n".join([doc.page_content for doc in docs])
151
 
152
  prompt = f"""
153
- Imagine you are a passionate science communicator.
154
- Summarize the following research paper in {num_points} bullet points.
155
- Highlight core discoveries and significance. Keep it engaging, insightful, and clear.
156
 
157
- Paper Content:
158
  {context}
159
 
160
  Summary:
161
  """
 
162
  if groq_llm is None:
163
  groq_llm = GroqWrapper(client=Groq(api_key=GROQ_API_KEY))
164
 
165
- summary = groq_llm(prompt)
166
- return summary.strip()
167
 
168
  except Exception as e:
169
- return f"Error during summarization: {str(e)}"
170
-
171
- # summary = groq_llm(prompt)
172
- # return summary.strip()
173
-
174
- # except Exception as e:
175
- # return f"Error during summarization: {str(e)}"
176
-
177
- # *** Modified find_similar_papers function ONLY ***
178
 
179
 
180
- # --- Find similar papers (with embedding rerank) ---
 
 
181
  def find_similar_papers():
 
 
182
  if vectorstore is None:
183
  return "Please upload a PDF first."
184
 
185
  try:
186
- # Get top chunks from uploaded PDF
187
  top_chunks = vectorstore.similarity_search("", k=5)
188
- paper_text = " ".join([doc.page_content for doc in top_chunks])
189
- if not paper_text.strip():
190
- return "PDF content too small for similarity search."
191
-
192
- # Extract keywords for arXiv query
193
- keywords = " ".join(paper_text.split()[:20])
194
- encoded_query = urllib.parse.quote(keywords)
195
- arxiv_url = f"http://export.arxiv.org/api/query?search_query=all:{encoded_query}&start=0&max_results=5"
196
- feed = feedparser.parse(arxiv_url)
 
 
197
  entries = feed.entries
 
198
  if not entries:
199
- return "No similar papers found on arXiv."
200
 
201
- # Embeddings for reranking
202
- embeddings_model = HuggingFaceEmbeddings(model_name="sentence-transformers/msmarco-MiniLM-L-12-v3")
203
- paper_embedding = embeddings_model.embed_query(paper_text)
 
 
204
 
205
- ranked_results = []
206
  for entry in entries:
207
- arxiv_text = f"{entry.title} {entry.summary}"
208
- arxiv_embedding = embeddings_model.embed_query(arxiv_text)
209
- similarity = dot(paper_embedding, arxiv_embedding) / (norm(paper_embedding) * norm(arxiv_embedding))
210
- ranked_results.append({
 
211
  "title": entry.title,
212
  "summary": entry.summary.replace("\n", " ").strip(),
213
  "link": entry.link,
214
- "similarity": similarity
215
  })
216
 
217
- # Sort by similarity
218
- ranked_results.sort(key=lambda x: x["similarity"], reverse=True)
219
 
220
- # Format top 3
221
- results = []
222
- for paper in ranked_results[:3]:
223
- results.append(
224
- f"**{paper['title']}**\n{paper['summary']}\n🔗 {paper['link']}\nSimilarity: {paper['similarity']:.2f}"
 
 
225
  )
226
 
227
- return "\n\n".join(results)
228
 
229
  except Exception as e:
230
- return f"Error fetching similar papers: {str(e)}"
231
- # results.append(f"**{title}**\n{summary}\n🔗 {link}")
232
-
233
- # return "\n\n".join(results)
234
 
235
- # except Exception as e:
236
- # return f"Error fetching similar papers: {str(e)}"
237
 
238
 
239
 
@@ -534,8 +565,4 @@ with gr.Blocks(css=css) as demo:
534
  similar_btn.click(find_similar_papers, outputs=similar_output)
535
 
536
  if __name__ == "__main__":
537
- demo.launch()
538
-
539
-
540
-
541
-
 
4
  from langchain_text_splitters import RecursiveCharacterTextSplitter
5
  from langchain_community.vectorstores import FAISS
6
  from langchain_community.embeddings import HuggingFaceEmbeddings
7
+
8
  from langchain.chains.retrieval_qa.base import RetrievalQA
9
  from langchain.prompts import PromptTemplate
10
  from langchain_core.language_models.llms import LLM
11
  from langchain_core.callbacks import CallbackManagerForLLMRun
12
+
13
  from typing import Optional, List, Dict, Any
 
14
  from dotenv import load_dotenv
15
  from groq import Groq
16
+
17
  import urllib.parse
18
+ import feedparser
19
 
20
  from numpy import dot
21
+ from numpy.linalg import norm
 
22
 
23
  # Load environment variables
24
  load_dotenv()
 
 
25
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
26
 
27
 
28
+ # -----------------------------------------------------------
29
+ # GROQ WRAPPER
30
+ # -----------------------------------------------------------
31
  class GroqWrapper(LLM):
32
  client: Any
33
  model_name: str = "llama-3.3-70b-versatile"
34
  temperature: float = 0.7
35
+
36
  @property
37
  def _llm_type(self) -> str:
38
  return "groq"
39
+
40
  def _call(
41
  self,
42
  prompt: str,
 
45
  **kwargs: Any,
46
  ) -> str:
47
  response = self.client.chat.completions.create(
 
48
  model=self.model_name,
49
+ messages=[{"role": "user", "content": prompt}],
50
  temperature=self.temperature,
 
51
  )
52
  return response.choices[0].message.content
53
 
54
+
55
+ # Globals
56
  vectorstore = None
57
  qa_chain = None
58
  groq_llm = None
59
 
60
+
61
+ # -----------------------------------------------------------
62
+ # PROCESS PDF
63
+ # -----------------------------------------------------------
64
  def upload_pdf(file):
65
  global vectorstore, qa_chain, groq_llm
66
 
 
69
  if groq_llm is None:
70
  groq_llm = GroqWrapper(client=Groq(api_key=GROQ_API_KEY))
71
 
72
+ # Extract text from PDF
73
  text = "".join(page.extract_text() or "" for page in PdfReader(file).pages)
74
  if not text.strip():
75
  return "Error: No readable text found in PDF"
76
 
77
+ # Chunk the text
78
+ splitter = RecursiveCharacterTextSplitter(
79
  chunk_size=1000,
80
  chunk_overlap=150,
81
  separators=["\n\n", "\n", ".", "?", "!"]
82
  )
83
+ chunks = splitter.split_text(text)
84
 
85
+ # Create Vectorstore
86
  embeddings = HuggingFaceEmbeddings(
87
  model_name="sentence-transformers/msmarco-MiniLM-L-12-v3"
88
  )
89
+ vectorstore = FAISS.from_texts(chunks, embeddings)
90
 
91
+ # --- CUSTOM REFINE PROMPTS ---
92
+ initial_prompt = PromptTemplate(
93
+ input_variables=["context", "question"],
94
+ template="""
95
+ You are an expert researcher.
96
+
97
+ Use ONLY the given context to answer the question.
98
+ If the answer is not in the context, say "I don't know".
99
 
 
 
 
 
100
  Context:
101
  {context}
102
+
103
  Question: {question}
104
+
105
+ Initial Answer:
106
  """
 
 
 
107
  )
108
 
109
+ refine_prompt = PromptTemplate(
110
+ input_variables=["context", "question", "existing_answer"],
111
+ template="""
112
+ We have an existing answer:
113
+ {existing_answer}
114
+
115
+ Using the additional context below, refine the answer.
116
+
117
+ Additional Context:
118
+ {context}
119
+
120
+ Question: {question}
121
+
122
+ Refined Answer:
123
+ """
124
+ )
125
+
126
+ # --- BUILD QA CHAIN ---
127
  qa_chain = RetrievalQA.from_chain_type(
128
  llm=groq_llm,
 
129
  retriever=vectorstore.as_retriever(),
130
+ chain_type="refine",
131
  return_source_documents=True,
132
+ chain_type_kwargs={
133
+ "initial_response_prompt": initial_prompt,
134
+ "refine_prompt": refine_prompt
135
+ }
136
  )
137
 
138
  return "PDF processed successfully!"
139
+
140
  except Exception as e:
141
  return f"Error: {str(e)}"
142
 
143
 
144
+ # -----------------------------------------------------------
145
+ # QUESTION ANSWERING
146
+ # -----------------------------------------------------------
147
  def ask_question(query):
148
  global qa_chain
149
+
150
  if qa_chain is None:
151
  return "Please upload a PDF first.", ""
152
 
153
  try:
154
+ result = qa_chain({"query": query})
 
155
  answer = result["result"]
 
156
 
157
+ # Format sources
158
+ sources = result.get("source_documents", [])
159
  if sources:
160
+ source_text = "\n\n---\n".join(
161
+ f"Source {i+1}:\n{doc.page_content[:500]}..."
162
  for i, doc in enumerate(sources)
163
+ )
164
  else:
165
+ source_text = "No sources found."
166
 
167
  return answer, source_text
168
 
169
  except Exception as e:
170
+ return f"Error: {str(e)}", ""
171
 
 
172
 
173
+ # -----------------------------------------------------------
174
+ # SUMMARIZE PDF
175
+ # -----------------------------------------------------------
176
+ def summarize_pdf(num_points=6):
177
+ global groq_llm, vectorstore
178
  if vectorstore is None:
179
  return "Please upload a PDF first."
180
 
181
  try:
182
  docs = vectorstore.similarity_search("summary", k=5)
183
+ context = "\n\n".join([d.page_content for d in docs])
184
 
185
  prompt = f"""
186
+ Summarize the research paper in {num_points} bullet points.
187
+ Make it clear, meaningful, and highlight key contributions.
 
188
 
189
+ Content:
190
  {context}
191
 
192
  Summary:
193
  """
194
+
195
  if groq_llm is None:
196
  groq_llm = GroqWrapper(client=Groq(api_key=GROQ_API_KEY))
197
 
198
+ return groq_llm(prompt).strip()
 
199
 
200
  except Exception as e:
201
+ return f"Error: {str(e)}"
 
 
 
 
 
 
 
 
202
 
203
 
204
+ # -----------------------------------------------------------
205
+ # FIND SIMILAR PAPERS (arXiv)
206
+ # -----------------------------------------------------------
207
  def find_similar_papers():
208
+ global vectorstore
209
+
210
  if vectorstore is None:
211
  return "Please upload a PDF first."
212
 
213
  try:
214
+ # Get content from PDF
215
  top_chunks = vectorstore.similarity_search("", k=5)
216
+ pdf_text = " ".join(doc.page_content for doc in top_chunks)
217
+
218
+ if not pdf_text.strip():
219
+ return "PDF content too small."
220
+
221
+ # Extract keywords
222
+ keywords = " ".join(pdf_text.split()[:20])
223
+ encoded = urllib.parse.quote(keywords)
224
+ url = f"http://export.arxiv.org/api/query?search_query=all:{encoded}&start=0&max_results=5"
225
+
226
+ feed = feedparser.parse(url)
227
  entries = feed.entries
228
+
229
  if not entries:
230
+ return "No arXiv results found."
231
 
232
+ # Embeddings for ranking
233
+ embedding_model = HuggingFaceEmbeddings(
234
+ model_name="sentence-transformers/msmarco-MiniLM-L-12-v3"
235
+ )
236
+ pdf_emb = embedding_model.embed_query(pdf_text)
237
 
238
+ results = []
239
  for entry in entries:
240
+ txt = f"{entry.title} {entry.summary}"
241
+ emb = embedding_model.embed_query(txt)
242
+ sim = dot(pdf_emb, emb) / (norm(pdf_emb) * norm(emb))
243
+
244
+ results.append({
245
  "title": entry.title,
246
  "summary": entry.summary.replace("\n", " ").strip(),
247
  "link": entry.link,
248
+ "similarity": sim
249
  })
250
 
251
+ # Sort by similarity DESC
252
+ results.sort(key=lambda x: x["similarity"], reverse=True)
253
 
254
+ formatted = []
255
+ for paper in results[:3]:
256
+ formatted.append(
257
+ f"**{paper['title']}**\n"
258
+ f"{paper['summary']}\n"
259
+ f"🔗 {paper['link']}\n"
260
+ f"Similarity Score: {paper['similarity']:.2f}"
261
  )
262
 
263
+ return "\n\n".join(formatted)
264
 
265
  except Exception as e:
266
+ return f"Error: {str(e)}"
 
 
 
267
 
 
 
268
 
269
 
270
 
 
565
  similar_btn.click(find_similar_papers, outputs=similar_output)
566
 
567
  if __name__ == "__main__":
568
+ demo.launch()