Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, Form, Request, HTTPException, Depends | |
| from fastapi.security import HTTPBasic, HTTPBasicCredentials | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from typing import List | |
| import uvicorn | |
| from io import BytesIO | |
| from dotenv import load_dotenv | |
| import os, re, requests, arxiv, secrets | |
| from PyPDF2 import PdfReader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.vectorstores import FAISS | |
| from langchain_groq import ChatGroq | |
| from langchain.chains import LLMChain, ConversationalRetrievalChain | |
| from langchain.prompts import PromptTemplate | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain.retrievers import EnsembleRetriever | |
| from langchain.memory import ConversationBufferMemory | |
| from pydantic import BaseModel | |
| # ------------------------------- | |
| # Utils | |
| # ------------------------------- | |
| os.environ["HF_HOME"] = "/tmp/hf_cache" | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache" | |
| os.environ["XDG_CACHE_HOME"] = "/tmp/hf_cache" | |
| load_dotenv() | |
| GROQ_API_KEY = None | |
| hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| security = HTTPBasic() | |
| users_db = {"username" : "password"} | |
| user_objects = {} | |
| class ApiKeyRequest(BaseModel): | |
| api_key: str | |
| class RegisterRequest(BaseModel): | |
| username: str | |
| password: str | |
| # β Pydantic model for API key request | |
| def get_current_user(credentials: HTTPBasicCredentials = Depends(security)): | |
| username = credentials.username | |
| password = credentials.password | |
| if username not in users_db: | |
| raise HTTPException(status_code=401, detail="Invalid username") | |
| # Secure password check | |
| correct_password = secrets.compare_digest(password, users_db[username]) | |
| if not correct_password: | |
| raise HTTPException(status_code=401, detail="Invalid password") | |
| # Create User() object if not exists | |
| if username not in user_objects: | |
| user_objects[username] = User() | |
| return user_objects[username] | |
| def get_pdf_text(pdf_docs): | |
| text = "" | |
| for pdf in pdf_docs: | |
| pdf_reader = PdfReader(pdf) | |
| for page in pdf_reader.pages: | |
| text += page.extract_text() | |
| return text | |
| def get_text_chunks(text): | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=4000, chunk_overlap=400, length_function=len | |
| ) | |
| return text_splitter.split_text(text) | |
| # ------------------------------- | |
| # Paper Class | |
| # ------------------------------- | |
| class Paper: | |
| def __init__(self, mode, input_data): | |
| global GROQ_API_KEY | |
| self.pdf = None | |
| self.text = None | |
| self.title = "" | |
| self.arxiv_id = None | |
| self.references = [] | |
| self.title_extractor_LLM = ChatGroq(api_key=GROQ_API_KEY, model_name="openai/gpt-oss-120b") | |
| self.references_titles_extractor_LLM = ChatGroq(api_key=GROQ_API_KEY, model_name="openai/gpt-oss-120b") | |
| self.req_session = requests.Session() | |
| if mode == "pdf": | |
| self.pdf = BytesIO(input_data) if isinstance(input_data, bytes) else input_data | |
| self.text = self.load_pdf(self.pdf) | |
| self.title = self.extract_title(self.text) | |
| else: | |
| self.arxiv_id = self.fetch_arxiv_id(input_data) | |
| arxiv_url = f"https://export.arxiv.org/pdf/{self.arxiv_id}.pdf" | |
| res = self.req_session.get(arxiv_url) | |
| pdf = BytesIO(res.content) | |
| self.pdf = pdf | |
| self.text = self.load_pdf(pdf) | |
| self.title = self.extract_title(self.text) | |
| print("Loaded Paper:", self.title) | |
| def load_pdf(self, pdf): | |
| return get_pdf_text([pdf]) | |
| def fetch_arxiv_id(self, url_id): | |
| if re.match(r'^\d{4}\.\d{5}$', url_id): # arXiv ID | |
| return url_id | |
| else: # extract from URL | |
| match = re.search(r'arxiv\.org/(?:abs|pdf)/(\d{4}\.\d{5})', url_id) | |
| return match.group(1) | |
| def extract_title(self, text): | |
| prompt_template = """ | |
| You are given the full text of a scientific paper. | |
| Extract and return the TITLE of the paper. | |
| Example: | |
| Input: | |
| "3D Gaussian Splatting for Real-Time Radiance Field Rendering | |
| BERNHARD KERBL, Inria, UniversitΓ© CΓ΄te dAzur, France | |
| GEORGIOS KOPANAS, Inria, UniversitΓ© CΓ΄te dAzur, France | |
| THOMAS LEIMKΓHLER, Max-Planck-Institut fΓΌr Informatik, Germany...." | |
| Output: | |
| "3D Gaussian Splatting for Real-Time Radiance Field Rendering" | |
| Now process the following text: | |
| {paper_text} | |
| """ | |
| prompt = PromptTemplate(template=prompt_template, input_variables=["paper_text"]) | |
| chain = LLMChain(llm=self.title_extractor_LLM, prompt=prompt) | |
| response = chain.run({"paper_text": text[:500]}) | |
| return response.strip().strip('"') | |
| def get_references(self): | |
| ref_text = self.extract_reference_section() | |
| print("Reference Section Extracted") | |
| self.references_titles = self.extract_references(ref_text) | |
| print(f"Extracted {len(self.references_titles)} reference titles") | |
| self.references_arxiv_ids = self.search_arxiv_ids(self.references_titles) | |
| print(f"Found {len(self.references_arxiv_ids)} arXiv IDs for references") | |
| for ref_arx_id in list(self.references_arxiv_ids.values())[:2]: # limit to 2 | |
| self.references.append(Paper("arxiv_id", ref_arx_id)) | |
| def extract_reference_section(self): | |
| ref_match = re.split(r"(?i)\breferences\b", self.text) | |
| return ref_match[-1] if len(ref_match) >= 2 else "" | |
| def chunk_references(self, ref_text, max_refs=10): | |
| lines = [line.strip() for line in ref_text.split("\n") if line.strip()] | |
| for i in range(0, len(lines), max_refs): | |
| yield "\n".join(lines[i:i + max_refs]) | |
| def extract_references(self, references_text): | |
| prompt_template = """ | |
| You are given raw reference entries from a scientific paper. | |
| Extract only the TITLE of the referenced work. | |
| Ignore authors, year, venue, volume, etc. | |
| Provide results as a list of strings. | |
| Example: | |
| Input: | |
| - Smith, J., 2020. Deep learning for images. IEEE CVPR. | |
| - Brown, L. & Green, P., 2019. X-ray scattering tensor tomography based finite element modelling of heterogeneous materials. Nature Materials. | |
| Output: | |
| ["Deep learning for images", | |
| "X-ray scattering tensor tomography based finite element modelling of heterogeneous materials"] | |
| Now process the following references: | |
| {references} | |
| """ | |
| prompt = PromptTemplate(template=prompt_template, input_variables=["references"]) | |
| chain = LLMChain(llm=self.references_titles_extractor_LLM, prompt=prompt) | |
| all_titles = [] | |
| for chunk in self.chunk_references(references_text): | |
| response = chain.run({"references": chunk}) | |
| try: | |
| titles = eval(response.strip()) | |
| except : | |
| titles = [line.strip() for line in response.split("\n") if line.strip()] | |
| all_titles.extend(titles) | |
| return all_titles | |
| def search_arxiv_ids(self, ref_titles): | |
| client = arxiv.Client(page_size=100, delay_seconds=3, num_retries=5) | |
| arxiv_ids = {} | |
| for title in ref_titles: | |
| try: | |
| search = arxiv.Search(query=title, max_results=100, sort_by=arxiv.SortCriterion.Relevance) | |
| results = list(client.results(search)) | |
| for r in results: | |
| if title.lower() == r.title.lower(): | |
| arxiv_ids[title] = re.sub(r'v\d+$', '', r.entry_id.split("/")[-1]) | |
| print(title, "->", arxiv_ids[title]) | |
| break | |
| except Exception as e: | |
| print(f"Could not extract {title}, due to Error: {e}") | |
| continue | |
| return arxiv_ids | |
| # ------------------------------- | |
| # User Class | |
| # ------------------------------- | |
| class User: | |
| def __init__(self): | |
| global GROQ_API_KEY | |
| self.papers = [] | |
| self.context_papers = [] | |
| self.retriever = None | |
| self.QA_LLM = None | |
| self.QA_Chain = None | |
| self.dense_embeddings = HuggingFaceEmbeddings() | |
| self.sparse_embeddings = HuggingFaceEmbeddings(model_name="naver/splade-cocondenser-ensembledistil") | |
| self.memory = ConversationBufferMemory( | |
| memory_key="chat_history", return_messages=True, | |
| input_key="question", output_key="answer" | |
| ) | |
| def set_API_key(self,api_key): | |
| global GROQ_API_KEY | |
| GROQ_API_KEY = api_key | |
| self.QA_LLM = ChatGroq(api_key=GROQ_API_KEY, model_name="openai/gpt-oss-120b") | |
| def add_paper(self, mode, input_data): | |
| print("Adding paper...") | |
| paper = Paper(mode, input_data) | |
| self.papers.append(paper) | |
| self.context_papers.append(paper.title) | |
| self._update_retriever_with_new_paper(-1) | |
| print("Paper added:", paper.title) | |
| def add_reference_papers(self, index): | |
| print("Adding reference papers...") | |
| if self.papers[index].references: | |
| return | |
| self.papers[index].get_references() | |
| for ref in self.papers[index].references: | |
| self.context_papers.append(ref.title) | |
| self._update_retriever_with_new_paper(index, ref=True) | |
| return [ref.title for ref in self.papers[index].references] | |
| def _update_retriever_with_new_paper(self, index, ref=False): | |
| paper = self.papers[index] | |
| if not self.retriever: | |
| chunks = get_text_chunks(paper.text) | |
| sparse_vs = FAISS.from_texts(chunks, self.sparse_embeddings) | |
| dense_vs = FAISS.from_texts(chunks, self.dense_embeddings) | |
| self.retriever = EnsembleRetriever( | |
| retrievers=[sparse_vs.as_retriever(search_kwargs={"k": 3}), | |
| dense_vs.as_retriever(search_kwargs={"k": 3})], | |
| weights=[0.5, 0.5] | |
| ) | |
| elif ref: | |
| for ref_paper in paper.references: | |
| ref_chunks = get_text_chunks(ref_paper.text) | |
| self.retriever.retrievers[0].vectorstore.add_texts(ref_chunks, embedding=self.sparse_embeddings) | |
| self.retriever.retrievers[1].vectorstore.add_texts(ref_chunks, embedding=self.dense_embeddings) | |
| else: | |
| chunks = get_text_chunks(paper.text) | |
| self.retriever.retrievers[0].vectorstore.add_texts(chunks, embedding=self.sparse_embeddings) | |
| self.retriever.retrievers[1].vectorstore.add_texts(chunks, embedding=self.dense_embeddings) | |
| self.QA_Chain = self.get_conversational_chain() | |
| def get_conversational_chain(self): | |
| prompt_template = """Use the following pieces of context to answer the question at the end. | |
| Whenever you are asked a question, only answer in reference to the context papers {context_papers}. | |
| If you don't know the answer or the answer is not in the context papers, just say that you don't know, don't try to make up an answer. | |
| {context} | |
| Question: {question} | |
| Answer in a concise manner. | |
| """ | |
| prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question", "context_papers"]) | |
| return ConversationalRetrievalChain.from_llm( | |
| llm=self.QA_LLM, | |
| retriever=self.retriever, | |
| memory=self.memory, | |
| combine_docs_chain_kwargs={"prompt": prompt}, | |
| return_source_documents=True | |
| ) | |
| def ask_question(self, question): | |
| if not self.QA_Chain: | |
| return "Please add a paper first." | |
| response = self.QA_Chain({"question": question, "context_papers": ", ".join(self.context_papers)}, return_only_outputs=True) | |
| return response["answer"] | |
| # ------------------------------- | |
| # FastAPI Setup | |
| # ------------------------------- | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, allow_origins=["*"], allow_credentials=True, | |
| allow_methods=["*"], allow_headers=["*"], | |
| ) | |
| async def health(): | |
| return {"status": "ok"} | |
| # β Register endpoint | |
| async def register(body: RegisterRequest): | |
| if body.username in users_db: | |
| raise HTTPException(status_code=400, detail="Username already exists") | |
| if not body.username or not body.password: | |
| raise HTTPException(status_code=400, detail="Username and password are required") | |
| if len(body.username) < 3: | |
| raise HTTPException(status_code=400, detail="Username must be at least 3 characters") | |
| if len(body.password) < 6: | |
| raise HTTPException(status_code=400, detail="Password must be at least 6 characters") | |
| # Add user to the users database | |
| users_db[body.username] = body.password | |
| return {"message": "User registered successfully"} | |
| # β Set API key endpoint | |
| async def set_api_key(body: ApiKeyRequest, user: User = Depends(get_current_user)): | |
| user.set_API_key(body.api_key) | |
| return {"message": "API key stored for user"} | |
| async def upload_pdf(file: UploadFile, user: User = Depends(get_current_user)): | |
| pdf_bytes = await file.read() | |
| user.add_paper("pdf", pdf_bytes) | |
| return {"message": "PDF added", "context_papers": user.context_papers} | |
| async def add_arxiv(arxiv_id: str = Form(...), user: User = Depends(get_current_user)): | |
| user.add_paper("arxiv_id", arxiv_id) | |
| return {"message": f"Arxiv paper {arxiv_id} added", "context_papers": user.context_papers} | |
| async def add_references(index: int = Form(...), user: User = Depends(get_current_user)): | |
| print(f"Received request to add references for index: {index}") | |
| print(f"User has {len(user.papers)} main papers") | |
| print(f"Paper titles: {[paper.title for paper in user.papers]}") | |
| if index < 0 or index >= len(user.papers): | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid paper index: {index}. User has {len(user.papers)} papers (valid indices: 0-{len(user.papers)-1})" | |
| ) | |
| try: | |
| refs = user.add_reference_papers(index) | |
| return {"message": "References added", "references": refs or [], "context_papers": user.context_papers} | |
| except Exception as e: | |
| print(f"Error adding references: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Failed to add references: {str(e)}") | |
| async def ask_question(q: str, user: User = Depends(get_current_user)): | |
| answer = user.ask_question(q) | |
| return {"question": q, "answer": answer} | |
| async def get_user_data(user: User = Depends(get_current_user)): | |
| """Get user's current session data including papers and API key status""" | |
| detailed_papers = [] | |
| for i, paper in enumerate(user.papers): | |
| detailed_papers.append({ | |
| "title": paper.title, | |
| "type": "arxiv" if paper.arxiv_id else "pdf", | |
| "has_references": bool(paper.references), | |
| "references_loaded": bool(paper.references), | |
| "references": [ref.title for ref in paper.references] if paper.references else [] | |
| }) | |
| return { | |
| "papers": user.context_papers, # Keep for backward compatibility | |
| "detailed_papers": detailed_papers, | |
| "has_api_key": user.QA_LLM is not None, | |
| "paper_count": len(user.papers) | |
| } | |
| async def check_api_key(user: User = Depends(get_current_user)): | |
| """Check if user has an existing API key""" | |
| return { | |
| "has_api_key": user.QA_LLM is not None, | |
| "message": "API key found" if user.QA_LLM is not None else "No API key found" | |
| } | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |