MalikShehram's picture
Update app.py
db63cbd verified
import gradio as gr
import requests
import fitz
import re
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
from groq import Groq
from faster_whisper import WhisperModel
# =========================
# INITIALIZE MODELS
# =========================
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
whisper_model = WhisperModel("base", compute_type="int8")
# πŸ”‘ PUT YOUR GROQ API KEY HERE
client = Groq(api_key="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")
# Use stable model
MODEL_NAME = "llama-3.3-70b-versatile"
# Global storage
sections = {}
section_texts = []
index = None
# =========================
# PDF FUNCTIONS
# =========================
def download_arxiv_pdf(arxiv_id):
try:
url = f"https://arxiv.org/pdf/{arxiv_id}.pdf"
response = requests.get(url)
response.raise_for_status()
file_path = f"{arxiv_id}.pdf"
with open(file_path, "wb") as f:
f.write(response.content)
return file_path
except:
return None
def extract_text_from_pdf(pdf_path):
doc = fitz.open(pdf_path)
text = ""
for page in doc:
text += page.get_text()
return text
# =========================
# ROBUST SECTION EXTRACTION
# =========================
def extract_sections(text):
patterns = [
r"\n([IVX]+\.\s+[A-Z][A-Z\s]+)", # Roman
r"\n(\d+\.\d+\.\d+\s+[^\n]+)", # 1.1.1
r"\n(\d+\.\d+\s+[^\n]+)", # 1.1
r"\n(\d+\.\s+[^\n]+)", # 1.
r"\n(\d+\s+[^\n]+)", # 1
r"\n([A-Z][A-Z\s]{4,})\n" # ALL CAPS
]
matches = []
for p in patterns:
matches.extend(list(re.finditer(p, text)))
matches = sorted(matches, key=lambda x: x.start())
extracted = {}
for i, match in enumerate(matches):
title = match.group(1).strip()
start = match.end()
end = matches[i+1].start() if i+1 < len(matches) else len(text)
content = text[start:end].strip()
if len(content) > 4000:
content = content[:4000]
extracted[title] = content
# Add abstract manually
abstract_match = re.search(r"Abstract(.*?)\n", text, re.DOTALL)
if abstract_match:
extracted["Abstract"] = abstract_match.group(1).strip()
return extracted
# =========================
# VECTOR STORE
# =========================
def build_vector_store(sections_dict):
global index, section_texts
section_texts = list(sections_dict.values())
if len(section_texts) == 0:
index = None
return
embeddings = embedding_model.encode(section_texts)
embeddings = np.array(embeddings).astype("float32")
dim = embeddings.shape[1]
index = faiss.IndexFlatL2(dim)
index.add(embeddings)
# =========================
# LOAD PAPER
# =========================
def load_paper(arxiv_id):
global sections
pdf_path = download_arxiv_pdf(arxiv_id)
if pdf_path is None:
return gr.update(choices=[]), "❌ Invalid arXiv ID"
text = extract_text_from_pdf(pdf_path)
sections = extract_sections(text)
build_vector_store(sections)
return gr.update(choices=list(sections.keys())), "βœ… Paper Loaded Successfully"
# =========================
# SUMMARY FUNCTION
# =========================
def summarize_section(section_title):
try:
if not sections:
return "❌ Load paper first"
if section_title not in sections:
return "❌ Section not found"
content = sections[section_title]
if not content:
return "❌ Empty section"
content = content[:4000]
prompt = f"""
Summarize this research section:
- Main idea
- Key concepts
- Simple explanation
- Importance
Section: {section_title}
Content:
{content}
"""
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[{"role": "user", "content": prompt}],
temperature=0.3
)
return response.choices[0].message.content
except Exception as e:
return f"❌ Error:\n{str(e)}"
# =========================
# RAG CHAT
# =========================
def rag_chat(message, history):
try:
global index
if index is None:
history.append({"role": "assistant", "content": "❌ Load paper first"})
return history, ""
query_embedding = embedding_model.encode([message])
query_embedding = np.array(query_embedding).astype("float32")
D, I = index.search(query_embedding, k=3)
retrieved = "\n\n".join([section_texts[i] for i in I[0]])
prompt = f"""
Answer using ONLY this context.
Context:
{retrieved}
Question:
{message}
"""
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[{"role": "user", "content": prompt}],
temperature=0.2
)
answer = response.choices[0].message.content
# βœ… FIXED FORMAT
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": answer})
return history, ""
except Exception as e:
history.append({"role": "assistant", "content": f"❌ Error:\n{str(e)}"})
return history, ""
# =========================
# VOICE CHAT
# =========================
def voice_chat(audio, history):
try:
if audio is None:
return history, ""
segments, _ = whisper_model.transcribe(audio)
text = " ".join([seg.text for seg in segments])
return rag_chat(text, history)
except Exception as e:
history.append({"role": "assistant", "content": f"❌ Error:\n{str(e)}"})
return history, ""
# =========================
# UI
# =========================
with gr.Blocks() as demo:
gr.Markdown("# πŸ“š ArXiv Research Assistant", elem_id="title")
with gr.Row():
arxiv_input = gr.Textbox(label="Enter arXiv ID", scale=4)
load_btn = gr.Button("Load Paper", variant="primary", scale=1)
status = gr.Markdown()
with gr.Row():
section_dropdown = gr.Dropdown(label="Sections", scale=3)
summarize_btn = gr.Button("Generate Summary", variant="secondary", scale=1)
summary_output = gr.Markdown()
gr.Markdown("## πŸ’¬ Chat with Paper")
chatbot = gr.Chatbot(height=400)
with gr.Row():
msg = gr.Textbox(label="Ask a question", scale=4)
send_btn = gr.Button("Send", variant="primary", scale=1)
gr.Markdown("## πŸŽ™ Voice Query")
with gr.Row():
audio = gr.Audio(type="filepath", scale=4)
voice_btn = gr.Button("Ask via Voice", scale=1)
# Actions
load_btn.click(load_paper, inputs=arxiv_input, outputs=[section_dropdown, status])
summarize_btn.click(summarize_section, inputs=section_dropdown, outputs=summary_output)
send_btn.click(rag_chat, inputs=[msg, chatbot], outputs=[chatbot, msg])
voice_btn.click(voice_chat, inputs=[audio, chatbot], outputs=[chatbot, msg])
demo.launch(
theme=gr.themes.Soft(),
css="""
#title {text-align:center}
"""
)