YashB1 commited on
Commit
1b56b89
·
verified ·
1 Parent(s): 65fd348

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +346 -0
app.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, Form, Request, HTTPException, Depends
2
+ from fastapi.security import HTTPBasic, HTTPBasicCredentials
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from typing import List
5
+ import uvicorn
6
+ from io import BytesIO
7
+ from dotenv import load_dotenv
8
+ import os, re, requests, arxiv, secrets
9
+
10
+ from PyPDF2 import PdfReader
11
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
12
+ from langchain.vectorstores import FAISS
13
+ from langchain_groq import ChatGroq
14
+ from langchain.chains import LLMChain, ConversationalRetrievalChain
15
+ from langchain.prompts import PromptTemplate
16
+ from langchain_community.embeddings import HuggingFaceEmbeddings
17
+ from langchain.retrievers import EnsembleRetriever
18
+ from langchain.memory import ConversationBufferMemory
19
+ from pydantic import BaseModel
20
+
21
+ # -------------------------------
22
+ # Utils
23
+ # -------------------------------
24
+ load_dotenv()
25
+ GROQ_API_KEY = None
26
+ hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
27
+ security = HTTPBasic()
28
+ users_db = {"username" : "password"}
29
+ user_objects = {}
30
+
31
+ class ApiKeyRequest(BaseModel):
32
+ api_key: str
33
+
34
+ class RegisterRequest(BaseModel):
35
+ username: str
36
+ password: str
37
+
38
+ # ✅ Pydantic model for API key request
39
+ def get_current_user(credentials: HTTPBasicCredentials = Depends(security)):
40
+ username = credentials.username
41
+ password = credentials.password
42
+
43
+ if username not in users_db:
44
+ raise HTTPException(status_code=401, detail="Invalid username")
45
+
46
+ # Secure password check
47
+ correct_password = secrets.compare_digest(password, users_db[username])
48
+ if not correct_password:
49
+ raise HTTPException(status_code=401, detail="Invalid password")
50
+
51
+ # Create User() object if not exists
52
+ if username not in user_objects:
53
+ user_objects[username] = User()
54
+
55
+ return user_objects[username]
56
+
57
+ def get_pdf_text(pdf_docs):
58
+ text = ""
59
+ for pdf in pdf_docs:
60
+ pdf_reader = PdfReader(pdf)
61
+ for page in pdf_reader.pages:
62
+ text += page.extract_text()
63
+ return text
64
+
65
+ def get_text_chunks(text):
66
+ text_splitter = RecursiveCharacterTextSplitter(
67
+ chunk_size=4000, chunk_overlap=400, length_function=len
68
+ )
69
+ return text_splitter.split_text(text)
70
+
71
+ # -------------------------------
72
+ # Paper Class
73
+ # -------------------------------
74
+ class Paper:
75
+ def __init__(self, mode, input_data):
76
+ global GROQ_API_KEY
77
+ self.pdf = None
78
+ self.text = None
79
+ self.title = ""
80
+ self.arxiv_id = None
81
+ self.references = []
82
+ self.title_extractor_LLM = ChatGroq(api_key=GROQ_API_KEY, model_name="openai/gpt-oss-120b")
83
+ self.references_titles_extractor_LLM = ChatGroq(api_key=GROQ_API_KEY, model_name="openai/gpt-oss-120b")
84
+ self.req_session = requests.Session()
85
+
86
+ if mode == "pdf":
87
+ self.pdf = BytesIO(input_data) if isinstance(input_data, bytes) else input_data
88
+ self.text = self.load_pdf(self.pdf)
89
+ self.title = self.extract_title(self.text)
90
+ else:
91
+ self.arxiv_id = self.fetch_arxiv_id(input_data)
92
+ arxiv_url = f"https://export.arxiv.org/pdf/{self.arxiv_id}.pdf"
93
+ res = self.req_session.get(arxiv_url)
94
+ pdf = BytesIO(res.content)
95
+ self.pdf = pdf
96
+ self.text = self.load_pdf(pdf)
97
+ self.title = self.extract_title(self.text)
98
+
99
+ print("Loaded Paper:", self.title)
100
+
101
+ def load_pdf(self, pdf):
102
+ return get_pdf_text([pdf])
103
+
104
+ def fetch_arxiv_id(self, url_id):
105
+ if re.match(r'^\d{4}\.\d{5}$', url_id): # arXiv ID
106
+ return url_id
107
+ else: # extract from URL
108
+ match = re.search(r'arxiv\.org/(?:abs|pdf)/(\d{4}\.\d{5})', url_id)
109
+ return match.group(1)
110
+
111
+ def extract_title(self, text):
112
+ prompt_template = """
113
+ You are given the full text of a scientific paper.
114
+ Extract and return the TITLE of the paper.
115
+
116
+ Example:
117
+ Input:
118
+ "3D Gaussian Splatting for Real-Time Radiance Field Rendering
119
+ BERNHARD KERBL, Inria, Université Côte dAzur, France
120
+ GEORGIOS KOPANAS, Inria, Université Côte dAzur, France
121
+ THOMAS LEIMKÜHLER, Max-Planck-Institut für Informatik, Germany...."
122
+
123
+
124
+ Output:
125
+ "3D Gaussian Splatting for Real-Time Radiance Field Rendering"
126
+
127
+ Now process the following text:
128
+ {paper_text}
129
+ """
130
+ prompt = PromptTemplate(template=prompt_template, input_variables=["paper_text"])
131
+ chain = LLMChain(llm=self.title_extractor_LLM, prompt=prompt)
132
+ response = chain.run({"paper_text": text[:500]})
133
+ return response.strip().strip('"')
134
+
135
+ def get_references(self):
136
+ ref_text = self.extract_reference_section()
137
+ print("Reference Section Extracted")
138
+ self.references_titles = self.extract_references(ref_text)
139
+ print(f"Extracted {len(self.references_titles)} reference titles")
140
+ self.references_arxiv_ids = self.search_arxiv_ids(self.references_titles)
141
+ print(f"Found {len(self.references_arxiv_ids)} arXiv IDs for references")
142
+ for ref_arx_id in list(self.references_arxiv_ids.values())[:2]: # limit to 2
143
+ self.references.append(Paper("arxiv_id", ref_arx_id))
144
+
145
+ def extract_reference_section(self):
146
+ ref_match = re.split(r"(?i)\breferences\b", self.text)
147
+ return ref_match[-1] if len(ref_match) >= 2 else ""
148
+
149
+ def chunk_references(self, ref_text, max_refs=10):
150
+ lines = [line.strip() for line in ref_text.split("\n") if line.strip()]
151
+ for i in range(0, len(lines), max_refs):
152
+ yield "\n".join(lines[i:i + max_refs])
153
+
154
+ def extract_references(self, references_text):
155
+ prompt_template = """
156
+ You are given raw reference entries from a scientific paper.
157
+ Extract only the TITLE of the referenced work.
158
+ Ignore authors, year, venue, volume, etc.
159
+ Provide results as a list of strings.
160
+
161
+ Example:
162
+ Input:
163
+ - Smith, J., 2020. Deep learning for images. IEEE CVPR.
164
+ - Brown, L. & Green, P., 2019. X-ray scattering tensor tomography based finite element modelling of heterogeneous materials. Nature Materials.
165
+
166
+ Output:
167
+ ["Deep learning for images",
168
+ "X-ray scattering tensor tomography based finite element modelling of heterogeneous materials"]
169
+
170
+ Now process the following references:
171
+ {references}
172
+ """
173
+ prompt = PromptTemplate(template=prompt_template, input_variables=["references"])
174
+ chain = LLMChain(llm=self.references_titles_extractor_LLM, prompt=prompt)
175
+
176
+ all_titles = []
177
+ for chunk in self.chunk_references(references_text):
178
+ response = chain.run({"references": chunk})
179
+ try:
180
+ titles = eval(response.strip())
181
+ except :
182
+ titles = [line.strip() for line in response.split("\n") if line.strip()]
183
+ all_titles.extend(titles)
184
+ return all_titles
185
+
186
+ def search_arxiv_ids(self, ref_titles):
187
+ client = arxiv.Client(page_size=100, delay_seconds=3, num_retries=5)
188
+ arxiv_ids = {}
189
+ for title in ref_titles:
190
+ try:
191
+ search = arxiv.Search(query=title, max_results=100, sort_by=arxiv.SortCriterion.Relevance)
192
+ results = list(client.results(search))
193
+ for r in results:
194
+ if title.lower() == r.title.lower():
195
+ arxiv_ids[title] = re.sub(r'v\d+$', '', r.entry_id.split("/")[-1])
196
+ print(title, "->", arxiv_ids[title])
197
+ break
198
+ except Exception as e:
199
+ print(f"Could not extract {title}, due to Error: {e}")
200
+ continue
201
+ return arxiv_ids
202
+
203
+ # -------------------------------
204
+ # User Class
205
+ # -------------------------------
206
+ class User:
207
+ def __init__(self):
208
+ global GROQ_API_KEY
209
+ self.papers = []
210
+ self.context_papers = []
211
+ self.retriever = None
212
+ self.QA_LLM = None
213
+ self.QA_Chain = None
214
+ self.dense_embeddings = HuggingFaceEmbeddings()
215
+ self.sparse_embeddings = HuggingFaceEmbeddings(model_name="naver/splade-cocondenser-ensembledistil")
216
+ self.memory = ConversationBufferMemory(
217
+ memory_key="chat_history", return_messages=True,
218
+ input_key="question", output_key="answer"
219
+ )
220
+
221
+ def set_API_key(self,api_key):
222
+ global GROQ_API_KEY
223
+ GROQ_API_KEY = api_key
224
+ self.QA_LLM = ChatGroq(api_key=GROQ_API_KEY, model_name="openai/gpt-oss-120b")
225
+
226
+ def add_paper(self, mode, input_data):
227
+ print("Adding paper...")
228
+ paper = Paper(mode, input_data)
229
+ self.papers.append(paper)
230
+ self.context_papers.append(paper.title)
231
+ self._update_retriever_with_new_paper(-1)
232
+ print("Paper added:", paper.title)
233
+
234
+ def add_reference_papers(self, index):
235
+ print("Adding reference papers...")
236
+ if self.papers[index].references:
237
+ return
238
+ self.papers[index].get_references()
239
+ for ref in self.papers[index].references:
240
+ self.context_papers.append(ref.title)
241
+ self._update_retriever_with_new_paper(index, ref=True)
242
+ return [ref.title for ref in self.papers[index].references]
243
+
244
+ def _update_retriever_with_new_paper(self, index, ref=False):
245
+ paper = self.papers[index]
246
+ if not self.retriever:
247
+ chunks = get_text_chunks(paper.text)
248
+ sparse_vs = FAISS.from_texts(chunks, self.sparse_embeddings)
249
+ dense_vs = FAISS.from_texts(chunks, self.dense_embeddings)
250
+ self.retriever = EnsembleRetriever(
251
+ retrievers=[sparse_vs.as_retriever(search_kwargs={"k": 3}),
252
+ dense_vs.as_retriever(search_kwargs={"k": 3})],
253
+ weights=[0.5, 0.5]
254
+ )
255
+ elif ref:
256
+ for ref_paper in paper.references:
257
+ ref_chunks = get_text_chunks(ref_paper.text)
258
+ self.retriever.retrievers[0].vectorstore.add_texts(ref_chunks, embedding=self.sparse_embeddings)
259
+ self.retriever.retrievers[1].vectorstore.add_texts(ref_chunks, embedding=self.dense_embeddings)
260
+ else:
261
+ chunks = get_text_chunks(paper.text)
262
+ self.retriever.retrievers[0].vectorstore.add_texts(chunks, embedding=self.sparse_embeddings)
263
+ self.retriever.retrievers[1].vectorstore.add_texts(chunks, embedding=self.dense_embeddings)
264
+ self.QA_Chain = self.get_conversational_chain()
265
+
266
+ def get_conversational_chain(self):
267
+ prompt_template = """Use the following pieces of context to answer the question at the end.
268
+ Whenever you are asked a question, only answer in reference to the context papers {context_papers}.
269
+ If you don't know the answer or the answer is not in the context papers, just say that you don't know, don't try to make up an answer.
270
+ {context}
271
+ Question: {question}
272
+ Answer in a concise manner.
273
+ """
274
+ prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question", "context_papers"])
275
+ return ConversationalRetrievalChain.from_llm(
276
+ llm=self.QA_LLM,
277
+ retriever=self.retriever,
278
+ memory=self.memory,
279
+ combine_docs_chain_kwargs={"prompt": prompt},
280
+ return_source_documents=True
281
+ )
282
+
283
+ def ask_question(self, question):
284
+ if not self.QA_Chain:
285
+ return "Please add a paper first."
286
+ response = self.QA_Chain({"question": question, "context_papers": ", ".join(self.context_papers)}, return_only_outputs=True)
287
+ return response["answer"]
288
+
289
+ # -------------------------------
290
+ # FastAPI Setup
291
+ # -------------------------------
292
+ app = FastAPI()
293
+ app.add_middleware(
294
+ CORSMiddleware, allow_origins=["*"], allow_credentials=True,
295
+ allow_methods=["*"], allow_headers=["*"],
296
+ )
297
+
298
+ # ✅ Register endpoint
299
+ @app.post("/register/")
300
+ async def register(body: RegisterRequest):
301
+ if body.username in users_db:
302
+ raise HTTPException(status_code=400, detail="Username already exists")
303
+
304
+ if not body.username or not body.password:
305
+ raise HTTPException(status_code=400, detail="Username and password are required")
306
+
307
+ if len(body.username) < 3:
308
+ raise HTTPException(status_code=400, detail="Username must be at least 3 characters")
309
+
310
+ if len(body.password) < 6:
311
+ raise HTTPException(status_code=400, detail="Password must be at least 6 characters")
312
+
313
+ # Add user to the users database
314
+ users_db[body.username] = body.password
315
+
316
+ return {"message": "User registered successfully"}
317
+
318
+ # ✅ Set API key endpoint
319
+ @app.post("/set_api_key/")
320
+ async def set_api_key(body: ApiKeyRequest, user: User = Depends(get_current_user)):
321
+ user.set_API_key(body.api_key)
322
+ return {"message": "API key stored for user"}
323
+
324
+ @app.post("/upload_pdf/")
325
+ async def upload_pdf(file: UploadFile, user: User = Depends(get_current_user)):
326
+ pdf_bytes = await file.read()
327
+ user.add_paper("pdf", pdf_bytes)
328
+ return {"message": "PDF added", "context_papers": user.context_papers}
329
+
330
+ @app.post("/add_arxiv/")
331
+ async def add_arxiv(arxiv_id: str = Form(...), user: User = Depends(get_current_user)):
332
+ user.add_paper("arxiv_id", arxiv_id)
333
+ return {"message": f"Arxiv paper {arxiv_id} added", "context_papers": user.context_papers}
334
+
335
+ @app.post("/add_references/")
336
+ async def add_references(index: int = Form(...), user: User = Depends(get_current_user)):
337
+ refs = user.add_reference_papers(index)
338
+ return {"message": "References added", "references": refs, "context_papers": user.context_papers}
339
+
340
+ @app.get("/ask/")
341
+ async def ask_question(q: str, user: User = Depends(get_current_user)):
342
+ answer = user.ask_question(q)
343
+ return {"question": q, "answer": answer}
344
+
345
+ if __name__ == "__main__":
346
+ uvicorn.run(app, host="0.0.0.0", port=8000)