| import gradio as gr |
| import os |
| import json |
| from PIL import Image, ImageDraw, ImageFont, ImageColor |
| import google.generativeai as genai |
| from google.generativeai import types |
| from dotenv import load_dotenv |
|
|
| |
| |
| |
| load_dotenv() |
| api_key = os.getenv("Gemini_API_Key") |
| genai.configure(api_key=api_key) |
|
|
| |
| |
| |
| MODEL_ID = "gemini-2.5-flash" |
|
|
| bounding_box_system_instructions = """ |
| Return bounding boxes as a JSON array with labels. Never return masks or code fencing. Limit to 25 objects. |
| If an object is present multiple times, name them according to their unique characteristic (colors, size, position, unique characteristics, etc..). |
| """ |
|
|
| IMAGE_MODEL = genai.GenerativeModel( |
| model_name=MODEL_ID, |
| system_instruction=bounding_box_system_instructions, |
| safety_settings=[types.SafetySettingDict(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_ONLY_HIGH")], |
| ) |
|
|
| GEN_CONFIG = genai.types.GenerationConfig(temperature=0.5) |
|
|
| |
| |
| |
| def parse_json(json_output): |
| lines = json_output.splitlines() |
| for i, line in enumerate(lines): |
| if line.strip() == "```json": |
| json_output = "\n".join(lines[i+1:]) |
| json_output = json_output.split("```")[0] |
| break |
| return json_output |
|
|
| def plot_bounding_boxes(im, bounding_boxes): |
| additional_colors = [colorname for (colorname, colorcode) in ImageColor.colormap.items()] |
| im = im.copy() |
| width, height = im.size |
| draw = ImageDraw.Draw(im) |
| colors = ['red','green','blue','yellow','orange','pink','purple','cyan','lime','magenta','violet','gold','silver'] + additional_colors |
|
|
| try: |
| font = ImageFont.load_default() |
| for i, box in enumerate(bounding_boxes): |
| color = colors[i % len(colors)] |
| abs_y1 = int(box["box_2d"][0] / 1000 * height) |
| abs_x1 = int(box["box_2d"][1] / 1000 * width) |
| abs_y2 = int(box["box_2d"][2] / 1000 * height) |
| abs_x2 = int(box["box_2d"][3] / 1000 * width) |
| if abs_x1 > abs_x2: abs_x1, abs_x2 = abs_x2, abs_x1 |
| if abs_y1 > abs_y2: abs_y1, abs_y2 = abs_y2, abs_y1 |
| draw.rectangle(((abs_x1, abs_y1), (abs_x2, abs_y2)), outline=color, width=4) |
| if "label" in box: |
| draw.text((abs_x1 + 8, abs_y1 + 6), box["label"], fill=color, font=font) |
| except Exception as e: |
| print(f"Error drawing bounding boxes: {e}") |
| return im |
|
|
| |
| |
| |
| def text_search_query(question): |
| try: |
| search_tool = types.Tool(google_search=types.GoogleSearch()) |
| response = genai.models.generate_content( |
| model=MODEL_ID, |
| contents=question, |
| config=types.GenerateContentConfig(tools=[search_tool]), |
| ) |
| search_results = response.candidates[0].grounding_metadata.search_entry_point.rendered_content |
| return search_results |
| except Exception as e: |
| return f"Error: {str(e)}" |
|
|
| |
| |
| |
| def detect_objects_with_references(prompt, image): |
| |
| image = image.resize((1024, int(1024 * image.height / image.width))) |
| |
| |
| response = IMAGE_MODEL.generate_content([prompt, image], generation_config=GEN_CONFIG) |
| bounding_boxes = json.loads(parse_json(response.text)) |
| |
| |
| for box in bounding_boxes: |
| label = box.get("label", "") |
| if label: |
| box["web_reference"] = text_search_query(f"What is {label}?") |
| |
| |
| img = plot_bounding_boxes(image, bounding_boxes) |
| |
| |
| refs_html = "<h3>Detected Object References:</h3>" |
| for i, box in enumerate(bounding_boxes): |
| label = box.get("label", "") |
| ref = box.get("web_reference", "") |
| refs_html += f"<b>{label}</b>: {ref}<br><br>" |
| |
| return img, refs_html |
|
|
| |
| |
| |
| def gradio_interface(): |
| examples = [ |
| ["cookies.jpg", "Detect the cookies and label their types."], |
| ["foreign_menu.jpg", "Detect menu items and label them."], |
| ["messed_room.jpg", "Find unorganized items and label them."], |
| ["objects and cartoon animals.jpg", "Identify animals and objects."] |
| ] |
|
|
| with gr.Blocks() as demo: |
| gr.Markdown("# Multimodal Gemini Assistant with Web References") |
| with gr.Row(): |
| with gr.Column(): |
| img_input = gr.Image(type="pil", label="Input Image") |
| prompt_input = gr.Textbox(lines=2, label="Input Prompt", placeholder="Describe what to detect") |
| submit_btn = gr.Button("Generate") |
| gr.Examples( |
| examples=examples, |
| inputs=[img_input, prompt_input], |
| label="Example Images with Prompts" |
| ) |
| with gr.Column(): |
| img_output = gr.Image(type="pil", label="Output Image") |
| refs_output = gr.HTML(label="Object References") |
| |
| submit_btn.click(detect_objects_with_references, inputs=[prompt_input, img_input], outputs=[img_output, refs_output]) |
| |
| return demo |
|
|
| |
| |
| |
| if __name__ == "__main__": |
| app = gradio_interface() |
| app.launch(share=True) |