SohaAyub commited on
Commit
ea62d58
·
verified ·
1 Parent(s): 479da54

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -55
app.py CHANGED
@@ -8,52 +8,43 @@ from sentence_transformers import SentenceTransformer
8
  from groq import Groq
9
  from faster_whisper import WhisperModel
10
  import os
11
- import logging
12
-
13
- logging.basicConfig(level=logging.INFO)
14
 
15
  # =========================
16
  # INITIALIZE MODELS
17
  # =========================
 
18
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
19
  whisper_model = WhisperModel("base", compute_type="int8")
20
 
21
- # Groq API client
22
  groq_api_key = os.environ.get("GROQ_API_KEY")
23
- if not groq_api_key:
24
- raise ValueError("GROQ_API_KEY environment variable not set!")
25
-
26
- client = Groq(api_key=groq_api_key)
27
- MODEL_NAME = "llama-3.3-70b-versatile" # Use exactly this model
28
 
29
  # Global storage
30
  sections = {}
31
  section_texts = []
32
  index = None
33
 
 
34
  # =========================
35
- # ARXIV PDF FUNCTIONS
36
  # =========================
37
- def is_valid_arxiv_id(arxiv_id):
38
- pattern = r"^\d{4}\.\d{4,5}$"
39
- return re.match(pattern, arxiv_id)
40
 
41
  def download_arxiv_pdf(arxiv_id):
42
  try:
43
  url = f"https://arxiv.org/pdf/{arxiv_id}.pdf"
44
- response = requests.get(url, timeout=10)
45
- if response.status_code != 200:
46
- url = f"https://arxiv.org/e-print/{arxiv_id}"
47
- response = requests.get(url, timeout=10)
48
  response.raise_for_status()
 
49
  file_path = f"{arxiv_id}.pdf"
50
  with open(file_path, "wb") as f:
51
  f.write(response.content)
 
52
  return file_path
53
- except Exception as e:
54
- logging.error(f"Failed to download PDF for {arxiv_id}: {e}")
55
  return None
56
 
 
57
  def extract_text_from_pdf(pdf_path):
58
  doc = fitz.open(pdf_path)
59
  text = ""
@@ -61,12 +52,14 @@ def extract_text_from_pdf(pdf_path):
61
  text += page.get_text()
62
  return text
63
 
 
64
  def extract_sections(text):
 
65
  patterns = [
66
- r"\n([IVX]+\.\s+[A-Z][A-Z\s]+)",
67
- r"\n(\d+\.\s+[A-Z][^\n]+)",
68
- r"\n(\d+\s+[A-Z][^\n]+)",
69
- r"\n([A-Z][A-Z\s]{3,})\n"
70
  ]
71
 
72
  matches = []
@@ -74,51 +67,66 @@ def extract_sections(text):
74
  matches.extend(list(re.finditer(pattern, text)))
75
 
76
  matches = sorted(matches, key=lambda x: x.start())
 
77
  sections = {}
78
  for i, match in enumerate(matches):
79
  title = match.group(1).strip()
80
  start = match.end()
81
  end = matches[i+1].start() if i+1 < len(matches) else len(text)
82
  sections[title] = text[start:end].strip()
 
83
  return sections
84
 
 
85
  # =========================
86
  # VECTOR STORE
87
  # =========================
 
88
  def build_vector_store(sections_dict):
89
  global index, section_texts
 
90
  section_texts = list(sections_dict.values())
 
91
  if len(section_texts) == 0:
92
  index = None
93
  return
 
94
  embeddings = embedding_model.encode(section_texts)
95
  embeddings = np.array(embeddings).astype("float32")
 
96
  dimension = embeddings.shape[1]
97
  index = faiss.IndexFlatL2(dimension)
98
  index.add(embeddings)
99
 
 
100
  # =========================
101
  # LOAD PAPER
102
  # =========================
 
103
  def load_paper(arxiv_id):
104
  global sections, index
105
- arxiv_id = arxiv_id.strip()
106
- if not is_valid_arxiv_id(arxiv_id):
107
- return gr.update(choices=[]), "❌ Invalid arXiv ID format"
108
  pdf_path = download_arxiv_pdf(arxiv_id)
 
109
  if pdf_path is None:
110
- return gr.update(choices=[]), "❌ Could not download PDF"
 
111
  text = extract_text_from_pdf(pdf_path)
112
  sections = extract_sections(text)
 
113
  build_vector_store(sections)
 
114
  return gr.update(choices=list(sections.keys())), "✅ Paper Loaded Successfully"
115
 
 
116
  # =========================
117
  # SUMMARIZATION
118
  # =========================
 
119
  def summarize_section(section_title):
120
  if section_title not in sections:
121
  return "Please load paper first."
 
122
  content = sections[section_title]
123
 
124
  prompt = f"""
@@ -132,36 +140,35 @@ Generate a structured scientific summary:
132
 
133
  Section Title: {section_title}
134
  Section Content:
135
- {content[:2500]} # truncate to avoid exceeding model context
136
  """
137
 
138
- try:
139
- response = client.chat.completions.create(
140
- model=MODEL_NAME,
141
- messages=[{"role": "user", "content": prompt}],
142
- temperature=0.3
143
- )
144
- answer = response.choices[0].message.content
145
- except Exception as e:
146
- logging.error("❌ Summarization failed", exc_info=True)
147
- answer = f"Error generating summary: {e}"
148
 
149
- return answer
150
 
151
  # =========================
152
  # RAG CHAT
153
  # =========================
 
154
  def rag_chat(message, history):
155
  global index
 
156
  if index is None:
157
  history.append((message, "Please load a paper first."))
158
- return history, gr.Textbox.update(value="")
159
 
160
  query_embedding = embedding_model.encode([message])
161
  query_embedding = np.array(query_embedding).astype("float32")
162
- D, I = index.search(query_embedding, k=min(3, len(section_texts)))
163
 
164
- retrieved = "\n\n".join([section_texts[i] for i in I[0] if i != -1])
 
 
165
 
166
  prompt = f"""
167
  Answer strictly using the provided research paper context.
@@ -174,33 +181,36 @@ Context:
174
  Question:
175
  {message}
176
  """
177
- try:
178
- response = client.chat.completions.create(
179
- model=MODEL_NAME,
180
- messages=[{"role": "user", "content": prompt}],
181
- temperature=0.2
182
- )
183
- answer = response.choices[0].message.content
184
- except Exception as e:
185
- logging.error("❌ RAG chat failed", exc_info=True)
186
- answer = f"Error generating answer: {e}"
187
 
 
 
 
 
 
 
 
188
  history.append((message, answer))
189
- return history, gr.Textbox.update(value="")
 
190
 
191
  # =========================
192
  # VOICE CHAT
193
  # =========================
 
194
  def voice_chat(audio, history):
195
  if audio is None:
196
- return history, gr.Textbox.update(value="")
 
197
  segments, _ = whisper_model.transcribe(audio)
198
  text = "".join([segment.text for segment in segments])
 
199
  return rag_chat(text, history)
200
 
 
201
  # =========================
202
  # GRADIO UI
203
  # =========================
 
204
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
205
  gr.Markdown("# 📚 ArXiv RAG Research Assistant")
206
 
@@ -209,6 +219,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
209
  load_button = gr.Button("Load Paper")
210
 
211
  load_status = gr.Markdown()
 
212
  section_dropdown = gr.Dropdown(label="Select Section")
213
  summarize_button = gr.Button("Generate Summary")
214
  summary_output = gr.Markdown()
@@ -228,4 +239,4 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
228
  send.click(rag_chat, inputs=[msg, chatbot], outputs=[chatbot, msg])
229
  voice_button.click(voice_chat, inputs=[audio_input, chatbot], outputs=[chatbot, msg])
230
 
231
- demo.launch(debug=True)
 
8
  from groq import Groq
9
  from faster_whisper import WhisperModel
10
  import os
 
 
 
11
 
12
  # =========================
13
  # INITIALIZE MODELS
14
  # =========================
15
+
16
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
17
  whisper_model = WhisperModel("base", compute_type="int8")
18
 
19
+ # Retrieve Groq API key from environment variables
20
  groq_api_key = os.environ.get("GROQ_API_KEY")
21
+ MODEL_NAME = "llama-3.3-70b-versatile"
 
 
 
 
22
 
23
  # Global storage
24
  sections = {}
25
  section_texts = []
26
  index = None
27
 
28
+
29
  # =========================
30
+ # PDF FUNCTIONS
31
  # =========================
 
 
 
32
 
33
  def download_arxiv_pdf(arxiv_id):
34
  try:
35
  url = f"https://arxiv.org/pdf/{arxiv_id}.pdf"
36
+ response = requests.get(url)
 
 
 
37
  response.raise_for_status()
38
+
39
  file_path = f"{arxiv_id}.pdf"
40
  with open(file_path, "wb") as f:
41
  f.write(response.content)
42
+
43
  return file_path
44
+ except:
 
45
  return None
46
 
47
+
48
  def extract_text_from_pdf(pdf_path):
49
  doc = fitz.open(pdf_path)
50
  text = ""
 
52
  text += page.get_text()
53
  return text
54
 
55
+
56
  def extract_sections(text):
57
+
58
  patterns = [
59
+ r"\n([IVX]+\.\s+[A-Z][A-Z\s]+)", # Roman numeral ALL CAPS
60
+ r"\n(\d+\.\s+[A-Z][^\n]+)", # 1. Introduction
61
+ r"\n(\d+\s+[A-Z][^\n]+)", # 1 Introduction
62
+ r"\n([A-Z][A-Z\s]{3,})\n" # ALL CAPS standalone
63
  ]
64
 
65
  matches = []
 
67
  matches.extend(list(re.finditer(pattern, text)))
68
 
69
  matches = sorted(matches, key=lambda x: x.start())
70
+
71
  sections = {}
72
  for i, match in enumerate(matches):
73
  title = match.group(1).strip()
74
  start = match.end()
75
  end = matches[i+1].start() if i+1 < len(matches) else len(text)
76
  sections[title] = text[start:end].strip()
77
+
78
  return sections
79
 
80
+
81
  # =========================
82
  # VECTOR STORE
83
  # =========================
84
+
85
  def build_vector_store(sections_dict):
86
  global index, section_texts
87
+
88
  section_texts = list(sections_dict.values())
89
+
90
  if len(section_texts) == 0:
91
  index = None
92
  return
93
+
94
  embeddings = embedding_model.encode(section_texts)
95
  embeddings = np.array(embeddings).astype("float32")
96
+
97
  dimension = embeddings.shape[1]
98
  index = faiss.IndexFlatL2(dimension)
99
  index.add(embeddings)
100
 
101
+
102
  # =========================
103
  # LOAD PAPER
104
  # =========================
105
+
106
  def load_paper(arxiv_id):
107
  global sections, index
108
+
 
 
109
  pdf_path = download_arxiv_pdf(arxiv_id)
110
+
111
  if pdf_path is None:
112
+ return gr.update(choices=[]), "❌ Invalid arXiv ID"
113
+
114
  text = extract_text_from_pdf(pdf_path)
115
  sections = extract_sections(text)
116
+
117
  build_vector_store(sections)
118
+
119
  return gr.update(choices=list(sections.keys())), "✅ Paper Loaded Successfully"
120
 
121
+
122
  # =========================
123
  # SUMMARIZATION
124
  # =========================
125
+
126
  def summarize_section(section_title):
127
  if section_title not in sections:
128
  return "Please load paper first."
129
+
130
  content = sections[section_title]
131
 
132
  prompt = f"""
 
140
 
141
  Section Title: {section_title}
142
  Section Content:
143
+ {content[:6000]}
144
  """
145
 
146
+ response = client.chat.completions.create(
147
+ model=MODEL_NAME,
148
+ messages=[{"role": "user", "content": prompt}],
149
+ temperature=0.3
150
+ )
151
+
152
+ return response.choices[0].message.content
 
 
 
153
 
 
154
 
155
  # =========================
156
  # RAG CHAT
157
  # =========================
158
+
159
  def rag_chat(message, history):
160
  global index
161
+
162
  if index is None:
163
  history.append((message, "Please load a paper first."))
164
+ return history, ""
165
 
166
  query_embedding = embedding_model.encode([message])
167
  query_embedding = np.array(query_embedding).astype("float32")
 
168
 
169
+ D, I = index.search(query_embedding, k=3)
170
+
171
+ retrieved = "\n\n".join([section_texts[i] for i in I[0]])
172
 
173
  prompt = f"""
174
  Answer strictly using the provided research paper context.
 
181
  Question:
182
  {message}
183
  """
 
 
 
 
 
 
 
 
 
 
184
 
185
+ response = client.chat.completions.create(
186
+ model=MODEL_NAME,
187
+ messages=[{"role": "user", "content": prompt}],
188
+ temperature=0.2
189
+ )
190
+
191
+ answer = response.choices[0].message.content
192
  history.append((message, answer))
193
+ return history, ""
194
+
195
 
196
  # =========================
197
  # VOICE CHAT
198
  # =========================
199
+
200
  def voice_chat(audio, history):
201
  if audio is None:
202
+ return history, ""
203
+
204
  segments, _ = whisper_model.transcribe(audio)
205
  text = "".join([segment.text for segment in segments])
206
+
207
  return rag_chat(text, history)
208
 
209
+
210
  # =========================
211
  # GRADIO UI
212
  # =========================
213
+
214
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
215
  gr.Markdown("# 📚 ArXiv RAG Research Assistant")
216
 
 
219
  load_button = gr.Button("Load Paper")
220
 
221
  load_status = gr.Markdown()
222
+
223
  section_dropdown = gr.Dropdown(label="Select Section")
224
  summarize_button = gr.Button("Generate Summary")
225
  summary_output = gr.Markdown()
 
239
  send.click(rag_chat, inputs=[msg, chatbot], outputs=[chatbot, msg])
240
  voice_button.click(voice_chat, inputs=[audio_input, chatbot], outputs=[chatbot, msg])
241
 
242
+ demo.launch(debug=True)