mhamoody commited on
Commit
1f70db6
·
verified ·
1 Parent(s): 8069f71
Files changed (1) hide show
  1. app.py +175 -13
app.py CHANGED
@@ -4,6 +4,8 @@ from typing import List, Tuple, Optional
4
  import google.genai as genai
5
  import gradio as gr
6
  from PIL import Image
 
 
7
 
8
  GOOGLE_API_KEY = os.environ.get("GEMINI_API_KEY")
9
 
@@ -12,8 +14,11 @@ IMAGE_WIDTH = 512
12
  system_instruction_analysis = "You are an expert of the given topic. Analyze the provided text with a focus on the topic, identifying recent issues, recent insights, or improvements relevant to academic standards and effectiveness. Offer actionable advice for enhancing knowledge and suggest real-life examples."
13
  model_name = "gemini-2.5-flash"
14
 
15
- # Initialize model (will be configured with API key in bot function)
16
- model = None
 
 
 
17
 
18
  # Helper Functions
19
  def preprocess_stop_sequences(stop_sequences: str) -> Optional[List[str]]:
@@ -111,6 +116,95 @@ def bot(
111
  except Exception as e:
112
  chatbot[-1]["content"] = f"Error processing response: {str(e)}"
113
  yield chatbot
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  # Components
115
  google_key_component = gr.Textbox(
116
  label="Google API Key",
@@ -127,6 +221,8 @@ text_prompt_component = gr.Textbox(
127
  lines=3
128
  )
129
  run_button_component = gr.Button("Submit")
 
 
130
  temperature_component = gr.Slider(
131
  minimum=0,
132
  maximum=1.0,
@@ -168,7 +264,7 @@ example_scenarios = [
168
  "Describe Multimodal AI",
169
  "What are the difference between multiagent llm and multiagent system",
170
  "Why it's difficult to integrate multimodality in prompt"]
171
- example_images = [["ex1.png"],["ex2.png"]]
172
 
173
  # Gradio Interface
174
  user_inputs = [text_prompt_component, chatbot_component]
@@ -184,19 +280,79 @@ bot_inputs = [
184
  ]
185
 
186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  with gr.Blocks() as demo:
188
- gr.Markdown("<h1 style='font-size: 36px; font-weight: bold; font-family: Arial;'>Gemini 2.5 Multimodal Chatbot</h1>")
189
  with gr.Row():
190
  google_key_component.render()
191
  with gr.Row():
192
  chatbot_component.render()
193
  with gr.Row():
194
  with gr.Column(scale=1):
195
- text_prompt_component.render()
 
196
  with gr.Column(scale=1):
197
- image_prompt_component.render()
198
  with gr.Column(scale=1):
199
  run_button_component.render()
 
 
 
200
  with gr.Accordion("🧪Example Text 💬", open=False):
201
  example_radio = gr.Radio(
202
  choices=example_scenarios,
@@ -207,12 +363,6 @@ with gr.Blocks() as demo:
207
  fn=lambda query: query if query else "No query selected.",
208
  inputs=[example_radio],
209
  outputs=[text_prompt_component])
210
- with gr.Accordion("🧪Example Image 🩻", open=False):
211
- gr.Examples(
212
- examples=example_images,
213
- inputs=[image_prompt_component],
214
- label="Example Figures",
215
- )
216
  with gr.Accordion("🛠️Customize", open=False):
217
  temperature_component.render()
218
  max_output_tokens_component.render()
@@ -223,7 +373,19 @@ with gr.Blocks() as demo:
223
  run_button_component.click(
224
  fn=user, inputs=user_inputs, outputs=[text_prompt_component, chatbot_component]
225
  ).then(
226
- fn=bot, inputs=bot_inputs, outputs=[chatbot_component]
 
 
 
 
 
 
 
 
 
 
 
 
227
  )
228
 
229
  if __name__ == "__main__":
 
4
  import google.genai as genai
5
  import gradio as gr
6
  from PIL import Image
7
+ from PIL import ImageDraw, ImageFont, ImageColor
8
+ import json
9
 
10
  GOOGLE_API_KEY = os.environ.get("GEMINI_API_KEY")
11
 
 
14
  system_instruction_analysis = "You are an expert of the given topic. Analyze the provided text with a focus on the topic, identifying recent issues, recent insights, or improvements relevant to academic standards and effectiveness. Offer actionable advice for enhancing knowledge and suggest real-life examples."
15
  model_name = "gemini-2.5-flash"
16
 
17
+ # Bounding box system instruction
18
+ bounding_box_system_instructions = (
19
+ "Return bounding boxes as a JSON array with labels. Never return masks or code fencing. Limit to 25 objects. "
20
+ "If an object is present multiple times, name them according to their unique characteristic (colors, size, position, unique characteristics, etc.)."
21
+ )
22
 
23
  # Helper Functions
24
  def preprocess_stop_sequences(stop_sequences: str) -> Optional[List[str]]:
 
116
  except Exception as e:
117
  chatbot[-1]["content"] = f"Error processing response: {str(e)}"
118
  yield chatbot
119
+
120
+
121
+ def _strip_codefence_json(text: str) -> str:
122
+ """Strip markdown code fences and return the JSON payload portion."""
123
+ if not text:
124
+ return ""
125
+ lines = text.splitlines()
126
+ for i, line in enumerate(lines):
127
+ if line.strip().startswith("```json"):
128
+ payload = "\n".join(lines[i+1:])
129
+ payload = payload.split("```")[0]
130
+ return payload.strip()
131
+ # fallback: try to find first '[' or '{'
132
+ idx = min((text.find("{") if text.find("{")!=-1 else len(text)), (text.find("[") if text.find("[")!=-1 else len(text)))
133
+ return text[idx:].strip() if idx < len(text) else text.strip()
134
+
135
+
136
+ def generate_bounding_boxes(google_key: str, prompt: str, image: Optional[Image.Image]):
137
+ """Generate bounding boxes from the model and return a PIL image with boxes drawn."""
138
+ google_key = google_key or GOOGLE_API_KEY
139
+ if not google_key:
140
+ raise ValueError("GOOGLE_API_KEY is not set. Please set it up.")
141
+
142
+ if image is None:
143
+ # Nothing to process
144
+ return None
145
+
146
+ client = genai.Client(api_key=google_key)
147
+
148
+ # Resize image for generation (keep aspect ratio)
149
+ img_for_model = image.resize((1024, int(1024 * image.height / image.width)))
150
+
151
+ try:
152
+ response = client.models.generate_content(
153
+ model=model_name,
154
+ contents=[prompt, img_for_model],
155
+ config=genai.types.GenerateContentConfig(
156
+ system_instruction=bounding_box_system_instructions,
157
+ temperature=0.3,
158
+ max_output_tokens=1024,
159
+ ),
160
+ )
161
+ except Exception as e:
162
+ print("Error generating bounding boxes:", e)
163
+ return None
164
+
165
+ json_text = _strip_codefence_json(getattr(response, "text", "") or "")
166
+ try:
167
+ bounding_boxes = json.loads(json_text)
168
+ except Exception as e:
169
+ print("Failed to parse bounding box JSON:", e)
170
+ return None
171
+
172
+ # Draw boxes
173
+ try:
174
+ out = image.copy()
175
+ draw = ImageDraw.Draw(out)
176
+ width, height = out.size
177
+
178
+ # font
179
+ try:
180
+ font = ImageFont.load_default()
181
+ except Exception:
182
+ font = None
183
+
184
+ colors = list(ImageColor.colormap.keys())
185
+ for i, bb in enumerate(bounding_boxes):
186
+ color = colors[i % len(colors)]
187
+ # Expecting box_2d as [y1, x1, y2, x2] in 0-1000 scale like test.py
188
+ y1 = int(bb["box_2d"][0] / 1000 * height)
189
+ x1 = int(bb["box_2d"][1] / 1000 * width)
190
+ y2 = int(bb["box_2d"][2] / 1000 * height)
191
+ x2 = int(bb["box_2d"][3] / 1000 * width)
192
+
193
+ # normalize
194
+ if x1 > x2:
195
+ x1, x2 = x2, x1
196
+ if y1 > y2:
197
+ y1, y2 = y2, y1
198
+
199
+ draw.rectangle(((x1, y1), (x2, y2)), outline=color, width=4)
200
+ label = bb.get("label") or bb.get("name") or ""
201
+ if label:
202
+ draw.text((x1 + 6, y1 + 4), label, fill=color, font=font)
203
+
204
+ return out
205
+ except Exception as e:
206
+ print("Error drawing bounding boxes:", e)
207
+ return None
208
  # Components
209
  google_key_component = gr.Textbox(
210
  label="Google API Key",
 
221
  lines=3
222
  )
223
  run_button_component = gr.Button("Submit")
224
+ bbox_mode_component = gr.Checkbox(label="Bounding box mode (detect & label objects)", value=False)
225
+ output_image_component = gr.Image(type="pil", label="Output Image")
226
  temperature_component = gr.Slider(
227
  minimum=0,
228
  maximum=1.0,
 
264
  "Describe Multimodal AI",
265
  "What are the difference between multiagent llm and multiagent system",
266
  "Why it's difficult to integrate multimodality in prompt"]
267
+
268
 
269
  # Gradio Interface
270
  user_inputs = [text_prompt_component, chatbot_component]
 
280
  ]
281
 
282
 
283
+ def handle_submit(
284
+ google_key: str,
285
+ image_prompt: Optional[Image.Image],
286
+ temperature: float,
287
+ max_output_tokens: int,
288
+ stop_sequences: str,
289
+ top_k: int,
290
+ top_p: float,
291
+ chatbot: List,
292
+ bbox_mode: bool,
293
+ ):
294
+ """Route submission: if bounding-box-mode (or keywords present) and image exists, call bounding box generator; otherwise stream text via `bot`."""
295
+ # Extract last user text
296
+ content = chatbot[-1]["content"] if chatbot else None
297
+ text_prompt = None
298
+ if isinstance(content, str):
299
+ text_prompt = content.strip() if content else None
300
+ elif isinstance(content, list) and len(content) > 0:
301
+ for item in content:
302
+ if isinstance(item, str):
303
+ text_prompt = item.strip()
304
+ break
305
+
306
+ # Simple keyword detection
307
+ bbox_triggers = ["detect", "detect the", "bounding", "box", "label", "find the"]
308
+ trigger = False
309
+ if bbox_mode:
310
+ trigger = True
311
+ elif image_prompt is not None and text_prompt:
312
+ low = text_prompt.lower()
313
+ for kw in bbox_triggers:
314
+ if kw in low:
315
+ trigger = True
316
+ break
317
+
318
+ if trigger and image_prompt is not None:
319
+ out_img = generate_bounding_boxes(google_key, text_prompt or "Detect objects in the image", image_prompt)
320
+ # Append an assistant message
321
+ chatbot.append({"role": "assistant", "content": "Generated bounding boxes (see image)."})
322
+ yield chatbot, out_img
323
+ return
324
+
325
+ # Fallback to text generation: stream from bot and keep image output empty
326
+ for chat_state in bot(
327
+ google_key,
328
+ image_prompt,
329
+ temperature,
330
+ max_output_tokens,
331
+ stop_sequences,
332
+ top_k,
333
+ top_p,
334
+ chatbot,
335
+ ):
336
+ yield chat_state, None
337
+
338
+
339
  with gr.Blocks() as demo:
340
+ gr.Markdown("<h1 style='font-size: 36px; font-weight: bold; font-family: Arial;'>Gemini 2.0 Multimodal Chatbot</h1>")
341
  with gr.Row():
342
  google_key_component.render()
343
  with gr.Row():
344
  chatbot_component.render()
345
  with gr.Row():
346
  with gr.Column(scale=1):
347
+ text_prompt_component.render()
348
+ bbox_mode_component.render()
349
  with gr.Column(scale=1):
350
+ image_prompt_component.render()
351
  with gr.Column(scale=1):
352
  run_button_component.render()
353
+ with gr.Row():
354
+ with gr.Column(scale=1):
355
+ output_image_component.render()
356
  with gr.Accordion("🧪Example Text 💬", open=False):
357
  example_radio = gr.Radio(
358
  choices=example_scenarios,
 
363
  fn=lambda query: query if query else "No query selected.",
364
  inputs=[example_radio],
365
  outputs=[text_prompt_component])
 
 
 
 
 
 
366
  with gr.Accordion("🛠️Customize", open=False):
367
  temperature_component.render()
368
  max_output_tokens_component.render()
 
373
  run_button_component.click(
374
  fn=user, inputs=user_inputs, outputs=[text_prompt_component, chatbot_component]
375
  ).then(
376
+ fn=handle_submit,
377
+ inputs=[
378
+ google_key_component,
379
+ image_prompt_component,
380
+ temperature_component,
381
+ max_output_tokens_component,
382
+ stop_sequences_component,
383
+ top_k_component,
384
+ top_p_component,
385
+ chatbot_component,
386
+ bbox_mode_component,
387
+ ],
388
+ outputs=[chatbot_component, output_image_component],
389
  )
390
 
391
  if __name__ == "__main__":