kpbotla commited on
Commit
2bec23f
·
verified ·
1 Parent(s): 73cbc3b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -12
app.py CHANGED
@@ -18,6 +18,10 @@ from transformers import pipeline
18
  import logging
19
 
20
 
 
 
 
 
21
  # (Keep Constants as is)
22
  # --- Constants ---
23
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
@@ -36,11 +40,12 @@ class BasicAgent:
36
  print(f"Agent returning fixed answer: {fixed_answer}")
37
  return fixed_answer
38
 
39
-
40
  class SmartResearchAgent:
41
  def __init__(self):
42
  logging.info("Initializing SmartResearchAgent")
43
  self.summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
 
 
44
 
45
  def search_web(self, query: str) -> str:
46
  try:
@@ -85,39 +90,147 @@ class SmartResearchAgent:
85
  logging.error(f"Citation error: {e}")
86
  return "Error during citation generation."
87
 
88
- def __call__(self, question: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
89
  logging.info(f"Received question: {question}")
90
  q_lower = question.lower().strip()
91
 
92
  try:
 
 
 
 
 
 
 
 
93
  if q_lower.startswith("search:"):
94
  query = question.split(":", 1)[1].strip()
95
- result = self.search_web(query)
96
  elif q_lower.startswith("summarize:"):
97
  target = question.split(":", 1)[1].strip()
98
- result = self.summarize(target)
99
  elif q_lower.startswith("generate citation:") or q_lower.startswith("cite:"):
100
  url = question.split(":", 1)[1].strip()
101
- result = self.generate_citation(url)
102
  else:
103
- # Default: search + summarize first link
104
  search_result = self.search_web(question)
105
- logging.debug(f"Search result:\n{search_result}")
106
  first_url = next((line.split(": ", 1)[-1] for line in search_result.splitlines() if "http" in line), None)
107
  if first_url:
108
  summary = self.summarize(first_url)
109
- result = f"{summary}\n\nSource: {first_url}"
110
  else:
111
- result = "Sorry, I couldn't find relevant information."
 
 
 
 
 
 
 
 
112
 
113
- logging.info(f"Agent answer: {result}")
114
- return result
 
 
 
 
 
 
 
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  except Exception as e:
117
  logging.exception("Unhandled error in agent call")
118
  return f"Agent error: {e}"
119
-
120
 
 
121
 
122
  def run_and_submit_all( profile: gr.OAuthProfile | None):
123
  """
 
18
  import logging
19
 
20
 
21
+ from PIL import Image
22
+ from transformers import BlipProcessor, BlipForConditionalGeneration
23
+
24
+
25
  # (Keep Constants as is)
26
  # --- Constants ---
27
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
 
40
  print(f"Agent returning fixed answer: {fixed_answer}")
41
  return fixed_answer
42
 
 
43
  class SmartResearchAgent:
44
  def __init__(self):
45
  logging.info("Initializing SmartResearchAgent")
46
  self.summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
47
+ self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
48
+ self.blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
49
 
50
  def search_web(self, query: str) -> str:
51
  try:
 
90
  logging.error(f"Citation error: {e}")
91
  return "Error during citation generation."
92
 
93
+ def caption_image(self, image_path: str) -> str:
94
+ try:
95
+ image = Image.open(image_path).convert("RGB")
96
+ inputs = self.blip_processor(image, return_tensors="pt")
97
+ out = self.blip_model.generate(**inputs)
98
+ caption = self.blip_processor.decode(out[0], skip_special_tokens=True)
99
+ return f"Image analysis: {caption}"
100
+ except Exception as e:
101
+ logging.error(f"Image processing error: {e}")
102
+ return "Unable to process the image."
103
+
104
+ def __call__(self, question: str, image: Image.Image = None) -> str:
105
  logging.info(f"Received question: {question}")
106
  q_lower = question.lower().strip()
107
 
108
  try:
109
+ # 🔍 Handle image-based chess questions
110
+ if image is not None:
111
+ logging.info("Image input detected")
112
+ temp_path = "/tmp/input_image.jpg"
113
+ image.save(temp_path)
114
+ return self.caption_image(temp_path)
115
+
116
+ # 🔍 Handle text-based tasks
117
  if q_lower.startswith("search:"):
118
  query = question.split(":", 1)[1].strip()
119
+ return self.search_web(query)
120
  elif q_lower.startswith("summarize:"):
121
  target = question.split(":", 1)[1].strip()
122
+ return self.summarize(target)
123
  elif q_lower.startswith("generate citation:") or q_lower.startswith("cite:"):
124
  url = question.split(":", 1)[1].strip()
125
+ return self.generate_citation(url)
126
  else:
127
+ # Default: search + summarize
128
  search_result = self.search_web(question)
 
129
  first_url = next((line.split(": ", 1)[-1] for line in search_result.splitlines() if "http" in line), None)
130
  if first_url:
131
  summary = self.summarize(first_url)
132
+ return f"{summary}\n\nSource: {first_url}"
133
  else:
134
+ return "Sorry, I couldn't find relevant information."
135
+ except Exception as e:
136
+ logging.exception("Unhandled error in agent call")
137
+ return f"Agent error: {e}"class SmartResearchAgent:
138
+ def __init__(self):
139
+ logging.info("Initializing SmartResearchAgent")
140
+ self.summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
141
+ self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
142
+ self.blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
143
 
144
+ def search_web(self, query: str) -> str:
145
+ try:
146
+ with DDGS() as ddgs:
147
+ results = ddgs.text(query, max_results=3)
148
+ top = [f"{i+1}. {r['title']}: {r['href']}" for i, r in enumerate(results)]
149
+ return "\n".join(top) if top else "No results found."
150
+ except Exception as e:
151
+ logging.error(f"Search error: {e}")
152
+ return "Error during web search."
153
 
154
+ def summarize(self, input_text_or_url: str) -> str:
155
+ try:
156
+ if input_text_or_url.startswith("http"):
157
+ article = Article(input_text_or_url)
158
+ article.download()
159
+ article.parse()
160
+ input_text_or_url = article.text
161
+ if not input_text_or_url.strip():
162
+ return "No content to summarize."
163
+ summary = self.summarizer(input_text_or_url, max_length=160, min_length=40, do_sample=False)
164
+ return summary[0]['summary_text'].strip()
165
+ except Exception as e:
166
+ logging.error(f"Summarization error: {e}")
167
+ return "Error during summarization."
168
+
169
+ def generate_citation(self, url: str) -> str:
170
+ try:
171
+ citation_id = hashlib.md5(url.encode()).hexdigest()[:6]
172
+ year = datetime.datetime.now().year
173
+ citation = (
174
+ f"@article{{cite{citation_id},\n"
175
+ f" title={{Generated Reference}},\n"
176
+ f" author={{Unknown}},\n"
177
+ f" journal={{Online}},\n"
178
+ f" year={{ {year} }},\n"
179
+ f" url={{ {url} }}\n"
180
+ f"}}"
181
+ )
182
+ return citation
183
+ except Exception as e:
184
+ logging.error(f"Citation error: {e}")
185
+ return "Error during citation generation."
186
+
187
+ def caption_image(self, image_path: str) -> str:
188
+ try:
189
+ image = Image.open(image_path).convert("RGB")
190
+ inputs = self.blip_processor(image, return_tensors="pt")
191
+ out = self.blip_model.generate(**inputs)
192
+ caption = self.blip_processor.decode(out[0], skip_special_tokens=True)
193
+ return f"Image analysis: {caption}"
194
+ except Exception as e:
195
+ logging.error(f"Image processing error: {e}")
196
+ return "Unable to process the image."
197
+
198
+ def __call__(self, question: str, image: Image.Image = None) -> str:
199
+ logging.info(f"Received question: {question}")
200
+ q_lower = question.lower().strip()
201
+
202
+ try:
203
+ # 🔍 Handle image-based chess questions
204
+ if image is not None:
205
+ logging.info("Image input detected")
206
+ temp_path = "/tmp/input_image.jpg"
207
+ image.save(temp_path)
208
+ return self.caption_image(temp_path)
209
+
210
+ # 🔍 Handle text-based tasks
211
+ if q_lower.startswith("search:"):
212
+ query = question.split(":", 1)[1].strip()
213
+ return self.search_web(query)
214
+ elif q_lower.startswith("summarize:"):
215
+ target = question.split(":", 1)[1].strip()
216
+ return self.summarize(target)
217
+ elif q_lower.startswith("generate citation:") or q_lower.startswith("cite:"):
218
+ url = question.split(":", 1)[1].strip()
219
+ return self.generate_citation(url)
220
+ else:
221
+ # Default: search + summarize
222
+ search_result = self.search_web(question)
223
+ first_url = next((line.split(": ", 1)[-1] for line in search_result.splitlines() if "http" in line), None)
224
+ if first_url:
225
+ summary = self.summarize(first_url)
226
+ return f"{summary}\n\nSource: {first_url}"
227
+ else:
228
+ return "Sorry, I couldn't find relevant information."
229
  except Exception as e:
230
  logging.exception("Unhandled error in agent call")
231
  return f"Agent error: {e}"
 
232
 
233
+
234
 
235
  def run_and_submit_all( profile: gr.OAuthProfile | None):
236
  """