File size: 13,700 Bytes
e9b57cc
 
d9cd647
2af634b
d9cd647
 
1f70db6
 
d9cd647
 
 
 
 
 
35b7a52
3f601c5
1f70db6
 
 
 
 
d9cd647
 
 
 
 
 
 
 
 
1631ff8
 
d9cd647
 
 
 
 
 
 
 
 
1631ff8
d9cd647
 
 
 
 
9b6f7aa
 
9c4953f
 
 
 
 
 
9b6f7aa
 
 
 
9c4953f
 
 
 
d9cd647
 
9c4953f
d9cd647
 
 
 
 
 
 
 
 
ad95629
 
 
 
 
 
 
 
 
3f601c5
ad95629
 
d9cd647
 
 
 
 
 
ad95629
 
 
 
 
 
 
 
d9cd647
1631ff8
d9cd647
 
 
 
1631ff8
ad95629
 
 
 
 
1631ff8
 
ad95629
 
 
1f70db6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9cd647
 
 
 
 
 
 
 
b8000f8
2af634b
d9cd647
 
b8000f8
d9cd647
 
72fe726
1f70db6
 
d9cd647
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f601c5
 
1f70db6
037a89d
 
 
 
 
e9b57cc
d9cd647
 
 
 
 
 
 
 
 
 
 
 
e9b57cc
1f70db6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
069dbea
7374901
e9b57cc
d9cd647
e9b57cc
d9cd647
 
069dbea
1f70db6
 
069dbea
1f70db6
069dbea
d9cd647
1f70db6
 
 
d9cd647
 
 
 
 
 
 
 
 
 
4dd28b2
 
 
 
 
d9cd647
 
 
 
 
 
e9b57cc
d9cd647
 
 
1f70db6
 
 
 
 
 
 
 
 
 
 
 
 
e9b57cc
3f601c5
 
069dbea
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
import os
import time
from typing import List, Tuple, Optional
import google.genai as genai
import gradio as gr
from PIL import Image
from PIL import ImageDraw, ImageFont, ImageColor
import json

GOOGLE_API_KEY = os.environ.get("GEMINI_API_KEY")

IMAGE_WIDTH = 512

system_instruction_analysis = "You are an expert of the given topic. Analyze the provided text with a focus on the topic, identifying recent issues, recent insights, or improvements relevant to academic standards and effectiveness. Offer actionable advice for enhancing knowledge and suggest real-life examples."
model_name = "gemini-2.5-flash"

# Bounding box system instruction
bounding_box_system_instructions = (
    "Return bounding boxes as a JSON array with labels. Never return masks or code fencing. Limit to 25 objects. "
    "If an object is present multiple times, name them according to their unique characteristic (colors, size, position, unique characteristics, etc.)."
)

# Helper Functions
def preprocess_stop_sequences(stop_sequences: str) -> Optional[List[str]]:
    return [seq.strip() for seq in stop_sequences.split(",")] if stop_sequences else None

def preprocess_image(image: Image.Image) -> Image.Image:
    image_height = int(image.height * IMAGE_WIDTH / image.width)
    return image.resize((IMAGE_WIDTH, image_height))

def user(text_prompt: str, chatbot: List):
    return "", chatbot + [{"role": "user", "content": text_prompt}]

def bot(
    google_key: str,
    image_prompt: Optional[Image.Image],
    temperature: float,
    max_output_tokens: int,
    stop_sequences: str,
    top_k: int,
    top_p: float,
    chatbot: List
):
    google_key = google_key or GOOGLE_API_KEY
    if not google_key:
        raise ValueError("GOOGLE_API_KEY is not set. Please set it up.")

    # Extract text content from message (handle both string and list formats)
    content = chatbot[-1]["content"]
    text_prompt = None
    
    if isinstance(content, str):
        text_prompt = content.strip() if content else None
    elif isinstance(content, list) and len(content) > 0:
        # In multimodal format, try to extract text from list
        for item in content:
            if isinstance(item, str):
                text_prompt = item.strip()
                break
            elif isinstance(item, dict) and item.get("type") == "text":
                text_prompt = item.get("text", "").strip()
                break

    # Handle cases for text and/or image input
    if not text_prompt and not image_prompt:
        chatbot[-1]["content"] = "Prompt cannot be empty. Please provide input text or an image."
        yield chatbot
        return
    elif image_prompt and not text_prompt:
        # If only an image is provided
        text_prompt = "Describe the image"
    elif image_prompt and text_prompt:
        # If both text and image are provided, combine them
        text_prompt = f"{text_prompt}. Also, analyze the provided image."

    # Initialize the client with API key
    client = genai.Client(api_key=google_key)
    
    generation_config = {
        "temperature": temperature,
        "max_output_tokens": max_output_tokens,
        "top_k": top_k,
        "top_p": top_p,
    }
    
    if preprocess_stop_sequences(stop_sequences):
        generation_config["stop_sequences"] = preprocess_stop_sequences(stop_sequences)

    # Prepare inputs
    inputs = [text_prompt] if image_prompt is None else [text_prompt, preprocess_image(image_prompt)]

    # Generate response
    try:
        response = client.models.generate_content(
            model=model_name,
            contents=inputs,
            config=genai.types.GenerateContentConfig(
                system_instruction=system_instruction_analysis,
                **generation_config
            ),
        )
    except Exception as e:
        chatbot[-1]["content"] = f"Error occurred: {str(e)}"
        yield chatbot
        return

    # Stream the response back to the chatbot
    chatbot.append({"role": "assistant", "content": ""})
    try:
        if response.text:
            # Stream the response text character by character
            for i in range(0, len(response.text), 10):
                chatbot[-1]["content"] += response.text[i:i + 10]
                time.sleep(0.01)
                yield chatbot
    except Exception as e:
        chatbot[-1]["content"] = f"Error processing response: {str(e)}"
        yield chatbot


def _strip_codefence_json(text: str) -> str:
    """Strip markdown code fences and return the JSON payload portion."""
    if not text:
        return ""
    lines = text.splitlines()
    for i, line in enumerate(lines):
        if line.strip().startswith("```json"):
            payload = "\n".join(lines[i+1:])
            payload = payload.split("```")[0]
            return payload.strip()
    # fallback: try to find first '[' or '{'
    idx = min((text.find("{") if text.find("{")!=-1 else len(text)), (text.find("[") if text.find("[")!=-1 else len(text)))
    return text[idx:].strip() if idx < len(text) else text.strip()


def generate_bounding_boxes(google_key: str, prompt: str, image: Optional[Image.Image]):
    """Generate bounding boxes from the model and return a PIL image with boxes drawn."""
    google_key = google_key or GOOGLE_API_KEY
    if not google_key:
        raise ValueError("GOOGLE_API_KEY is not set. Please set it up.")

    if image is None:
        # Nothing to process
        return None

    client = genai.Client(api_key=google_key)

    # Resize image for generation (keep aspect ratio)
    img_for_model = image.resize((1024, int(1024 * image.height / image.width)))

    try:
        response = client.models.generate_content(
            model=model_name,
            contents=[prompt, img_for_model],
            config=genai.types.GenerateContentConfig(
                system_instruction=bounding_box_system_instructions,
                temperature=0.3,
                max_output_tokens=1024,
            ),
        )
    except Exception as e:
        print("Error generating bounding boxes:", e)
        return None

    json_text = _strip_codefence_json(getattr(response, "text", "") or "")
    try:
        bounding_boxes = json.loads(json_text)
    except Exception as e:
        print("Failed to parse bounding box JSON:", e)
        return None

    # Draw boxes
    try:
        out = image.copy()
        draw = ImageDraw.Draw(out)
        width, height = out.size

        # font
        try:
            font = ImageFont.load_default()
        except Exception:
            font = None

        colors = list(ImageColor.colormap.keys())
        for i, bb in enumerate(bounding_boxes):
            color = colors[i % len(colors)]
            # Expecting box_2d as [y1, x1, y2, x2] in 0-1000 scale like test.py
            y1 = int(bb["box_2d"][0] / 1000 * height)
            x1 = int(bb["box_2d"][1] / 1000 * width)
            y2 = int(bb["box_2d"][2] / 1000 * height)
            x2 = int(bb["box_2d"][3] / 1000 * width)

            # normalize
            if x1 > x2:
                x1, x2 = x2, x1
            if y1 > y2:
                y1, y2 = y2, y1

            draw.rectangle(((x1, y1), (x2, y2)), outline=color, width=4)
            label = bb.get("label") or bb.get("name") or ""
            if label:
                draw.text((x1 + 6, y1 + 4), label, fill=color, font=font)

        return out
    except Exception as e:
        print("Error drawing bounding boxes:", e)
        return None
# Components
google_key_component = gr.Textbox(
    label="Google API Key",
    type="password",
    placeholder="Enter your Google API Key",
    visible=GOOGLE_API_KEY is None
)

image_prompt_component = gr.Image(type="pil", label="Input Image (Optional: Figure/Graph)")
chatbot_component = gr.Chatbot(label="Chatbot")
text_prompt_component = gr.Textbox(
    placeholder="Type your question here...",
    label="Ask",
    lines=3
)
run_button_component = gr.Button("Submit")
bbox_mode_component = gr.Checkbox(label="Bounding box mode (detect & label objects)", value=False)
output_image_component = gr.Image(type="pil", label="Output Image")
temperature_component = gr.Slider(
    minimum=0,
    maximum=1.0,
    value=0.4,
    step=0.05,
    label="Creativity (Temperature)",
    info="Controls the randomness of the response. Higher values result in more creative answers."
)
max_output_tokens_component = gr.Slider(
    minimum=1,
    maximum=2048,
    value=1024,
    step=1,
    label="Response Length (Token Limit)",
    info="Sets the maximum number of tokens in the output response."
)
stop_sequences_component = gr.Textbox(
    label="Stop Sequences (Optional)",
    placeholder="Enter stop sequences, e.g., STOP, END",
    info="Specify sequences to stop the generation."
)
top_k_component = gr.Slider(
    minimum=1,
    maximum=40,
    value=32,
    step=1,
    label="Top-K Sampling",
    info="Limits token selection to the top K most probable tokens. Lower values produce conservative outputs."
)
top_p_component = gr.Slider(
    minimum=0,
    maximum=1,
    value=1,
    step=0.01,
    label="Top-P Sampling",
    info="Limits token selection to tokens with a cumulative probability up to P. Lower values produce conservative outputs."
)
example_scenarios = [
    "Describe Multimodal AI",
    "What are the difference between multiagent llm and multiagent system",
    "Why it's difficult to integrate multimodality in prompt"]

example_images = [
    ["ex1.png"],
    ["ex2.png"]
]

# Gradio Interface
user_inputs = [text_prompt_component, chatbot_component]
bot_inputs = [
    google_key_component,
    image_prompt_component,
    temperature_component,
    max_output_tokens_component,
    stop_sequences_component,
    top_k_component,
    top_p_component,
    chatbot_component,
]


def handle_submit(
    google_key: str,
    image_prompt: Optional[Image.Image],
    temperature: float,
    max_output_tokens: int,
    stop_sequences: str,
    top_k: int,
    top_p: float,
    chatbot: List,
    bbox_mode: bool,
):
    """Route submission: if bounding-box-mode (or keywords present) and image exists, call bounding box generator; otherwise stream text via `bot`."""
    # Extract last user text
    content = chatbot[-1]["content"] if chatbot else None
    text_prompt = None
    if isinstance(content, str):
        text_prompt = content.strip() if content else None
    elif isinstance(content, list) and len(content) > 0:
        for item in content:
            if isinstance(item, str):
                text_prompt = item.strip()
                break

    # Simple keyword detection
    bbox_triggers = ["detect", "detect the", "bounding", "box", "label", "find the"]
    trigger = False
    if bbox_mode:
        trigger = True
    elif image_prompt is not None and text_prompt:
        low = text_prompt.lower()
        for kw in bbox_triggers:
            if kw in low:
                trigger = True
                break

    if trigger and image_prompt is not None:
        out_img = generate_bounding_boxes(google_key, text_prompt or "Detect objects in the image", image_prompt)
        # Append an assistant message
        chatbot.append({"role": "assistant", "content": "Generated bounding boxes (see image)."})
        yield chatbot, out_img
        return

    # Fallback to text generation: stream from bot and keep image output empty
    for chat_state in bot(
        google_key,
        image_prompt,
        temperature,
        max_output_tokens,
        stop_sequences,
        top_k,
        top_p,
        chatbot,
    ):
        yield chat_state, None


with gr.Blocks() as demo:
    gr.Markdown("<h1 style='font-size: 36px; font-weight: bold; font-family: Arial;'>Gemini 2.5 Multimodal Chatbot</h1>")
    with gr.Row():
        google_key_component.render()
    with gr.Row():
        chatbot_component.render()
    with gr.Row():
        with gr.Column(scale=1):
            text_prompt_component.render()
            bbox_mode_component.render()
        with gr.Column(scale=1):
            image_prompt_component.render()
        with gr.Column(scale=1):
            run_button_component.render()
    with gr.Row():
        with gr.Column(scale=1):
            output_image_component.render()
    with gr.Accordion("🧪Example Text 💬", open=False):
        example_radio = gr.Radio(
        choices=example_scenarios,
        label="Example Queries",
        info="Select an example query.")
        # Debug callback
        example_radio.change(
        fn=lambda query: query if query else "No query selected.",
        inputs=[example_radio],
        outputs=[text_prompt_component])
        gr.Examples(
        examples=example_images,
        inputs=[image_prompt_component],
        label="Example Figures",
        )
    with gr.Accordion("🛠️Customize", open=False):
        temperature_component.render()
        max_output_tokens_component.render()
        stop_sequences_component.render()
        top_k_component.render()
        top_p_component.render()

    run_button_component.click(
        fn=user, inputs=user_inputs, outputs=[text_prompt_component, chatbot_component]
    ).then(
        fn=handle_submit,
        inputs=[
            google_key_component,
            image_prompt_component,
            temperature_component,
            max_output_tokens_component,
            stop_sequences_component,
            top_k_component,
            top_p_component,
            chatbot_component,
            bbox_mode_component,
        ],
        outputs=[chatbot_component, output_image_component],
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860, share=False, theme="earneleh/paris")