Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import requests | |
| import fitz | |
| import re | |
| import numpy as np | |
| import faiss | |
| from sentence_transformers import SentenceTransformer | |
| from groq import Groq | |
| from faster_whisper import WhisperModel | |
| # ========================= | |
| # INITIALIZE MODELS | |
| # ========================= | |
| embedding_model = SentenceTransformer("all-MiniLM-L6-v2") | |
| whisper_model = WhisperModel("base", compute_type="int8") | |
| # π PUT YOUR GROQ API KEY HERE | |
| client = Groq(api_key="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") | |
| # Use stable model | |
| MODEL_NAME = "llama-3.3-70b-versatile" | |
| # Global storage | |
| sections = {} | |
| section_texts = [] | |
| index = None | |
| # ========================= | |
| # PDF FUNCTIONS | |
| # ========================= | |
| def download_arxiv_pdf(arxiv_id): | |
| try: | |
| url = f"https://arxiv.org/pdf/{arxiv_id}.pdf" | |
| response = requests.get(url) | |
| response.raise_for_status() | |
| file_path = f"{arxiv_id}.pdf" | |
| with open(file_path, "wb") as f: | |
| f.write(response.content) | |
| return file_path | |
| except: | |
| return None | |
| def extract_text_from_pdf(pdf_path): | |
| doc = fitz.open(pdf_path) | |
| text = "" | |
| for page in doc: | |
| text += page.get_text() | |
| return text | |
| # ========================= | |
| # ROBUST SECTION EXTRACTION | |
| # ========================= | |
| def extract_sections(text): | |
| patterns = [ | |
| r"\n([IVX]+\.\s+[A-Z][A-Z\s]+)", # Roman | |
| r"\n(\d+\.\d+\.\d+\s+[^\n]+)", # 1.1.1 | |
| r"\n(\d+\.\d+\s+[^\n]+)", # 1.1 | |
| r"\n(\d+\.\s+[^\n]+)", # 1. | |
| r"\n(\d+\s+[^\n]+)", # 1 | |
| r"\n([A-Z][A-Z\s]{4,})\n" # ALL CAPS | |
| ] | |
| matches = [] | |
| for p in patterns: | |
| matches.extend(list(re.finditer(p, text))) | |
| matches = sorted(matches, key=lambda x: x.start()) | |
| extracted = {} | |
| for i, match in enumerate(matches): | |
| title = match.group(1).strip() | |
| start = match.end() | |
| end = matches[i+1].start() if i+1 < len(matches) else len(text) | |
| content = text[start:end].strip() | |
| if len(content) > 4000: | |
| content = content[:4000] | |
| extracted[title] = content | |
| # Add abstract manually | |
| abstract_match = re.search(r"Abstract(.*?)\n", text, re.DOTALL) | |
| if abstract_match: | |
| extracted["Abstract"] = abstract_match.group(1).strip() | |
| return extracted | |
| # ========================= | |
| # VECTOR STORE | |
| # ========================= | |
| def build_vector_store(sections_dict): | |
| global index, section_texts | |
| section_texts = list(sections_dict.values()) | |
| if len(section_texts) == 0: | |
| index = None | |
| return | |
| embeddings = embedding_model.encode(section_texts) | |
| embeddings = np.array(embeddings).astype("float32") | |
| dim = embeddings.shape[1] | |
| index = faiss.IndexFlatL2(dim) | |
| index.add(embeddings) | |
| # ========================= | |
| # LOAD PAPER | |
| # ========================= | |
| def load_paper(arxiv_id): | |
| global sections | |
| pdf_path = download_arxiv_pdf(arxiv_id) | |
| if pdf_path is None: | |
| return gr.update(choices=[]), "β Invalid arXiv ID" | |
| text = extract_text_from_pdf(pdf_path) | |
| sections = extract_sections(text) | |
| build_vector_store(sections) | |
| return gr.update(choices=list(sections.keys())), "β Paper Loaded Successfully" | |
| # ========================= | |
| # SUMMARY FUNCTION | |
| # ========================= | |
| def summarize_section(section_title): | |
| try: | |
| if not sections: | |
| return "β Load paper first" | |
| if section_title not in sections: | |
| return "β Section not found" | |
| content = sections[section_title] | |
| if not content: | |
| return "β Empty section" | |
| content = content[:4000] | |
| prompt = f""" | |
| Summarize this research section: | |
| - Main idea | |
| - Key concepts | |
| - Simple explanation | |
| - Importance | |
| Section: {section_title} | |
| Content: | |
| {content} | |
| """ | |
| response = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0.3 | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| return f"β Error:\n{str(e)}" | |
| # ========================= | |
| # RAG CHAT | |
| # ========================= | |
| def rag_chat(message, history): | |
| try: | |
| global index | |
| if index is None: | |
| history.append({"role": "assistant", "content": "β Load paper first"}) | |
| return history, "" | |
| query_embedding = embedding_model.encode([message]) | |
| query_embedding = np.array(query_embedding).astype("float32") | |
| D, I = index.search(query_embedding, k=3) | |
| retrieved = "\n\n".join([section_texts[i] for i in I[0]]) | |
| prompt = f""" | |
| Answer using ONLY this context. | |
| Context: | |
| {retrieved} | |
| Question: | |
| {message} | |
| """ | |
| response = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0.2 | |
| ) | |
| answer = response.choices[0].message.content | |
| # β FIXED FORMAT | |
| history.append({"role": "user", "content": message}) | |
| history.append({"role": "assistant", "content": answer}) | |
| return history, "" | |
| except Exception as e: | |
| history.append({"role": "assistant", "content": f"β Error:\n{str(e)}"}) | |
| return history, "" | |
| # ========================= | |
| # VOICE CHAT | |
| # ========================= | |
| def voice_chat(audio, history): | |
| try: | |
| if audio is None: | |
| return history, "" | |
| segments, _ = whisper_model.transcribe(audio) | |
| text = " ".join([seg.text for seg in segments]) | |
| return rag_chat(text, history) | |
| except Exception as e: | |
| history.append({"role": "assistant", "content": f"β Error:\n{str(e)}"}) | |
| return history, "" | |
| # ========================= | |
| # UI | |
| # ========================= | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# π ArXiv Research Assistant", elem_id="title") | |
| with gr.Row(): | |
| arxiv_input = gr.Textbox(label="Enter arXiv ID", scale=4) | |
| load_btn = gr.Button("Load Paper", variant="primary", scale=1) | |
| status = gr.Markdown() | |
| with gr.Row(): | |
| section_dropdown = gr.Dropdown(label="Sections", scale=3) | |
| summarize_btn = gr.Button("Generate Summary", variant="secondary", scale=1) | |
| summary_output = gr.Markdown() | |
| gr.Markdown("## π¬ Chat with Paper") | |
| chatbot = gr.Chatbot(height=400) | |
| with gr.Row(): | |
| msg = gr.Textbox(label="Ask a question", scale=4) | |
| send_btn = gr.Button("Send", variant="primary", scale=1) | |
| gr.Markdown("## π Voice Query") | |
| with gr.Row(): | |
| audio = gr.Audio(type="filepath", scale=4) | |
| voice_btn = gr.Button("Ask via Voice", scale=1) | |
| # Actions | |
| load_btn.click(load_paper, inputs=arxiv_input, outputs=[section_dropdown, status]) | |
| summarize_btn.click(summarize_section, inputs=section_dropdown, outputs=summary_output) | |
| send_btn.click(rag_chat, inputs=[msg, chatbot], outputs=[chatbot, msg]) | |
| voice_btn.click(voice_chat, inputs=[audio, chatbot], outputs=[chatbot, msg]) | |
| demo.launch( | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| #title {text-align:center} | |
| """ | |
| ) |