File size: 5,832 Bytes
6a037e9
eb305b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97fa1e2
eb305b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7aefb49
10ac47c
eb305b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97fa1e2
6a037e9
eb305b9
10ac47c
6a037e9
eb305b9
6a037e9
eb305b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5ded30
eb305b9
 
 
c5ded30
eb305b9
13e07e4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import gradio as gr
import os
import json
from PIL import Image, ImageDraw, ImageFont, ImageColor
import google.generativeai as genai
from google.generativeai import types
from dotenv import load_dotenv

# -----------------------------
# 1️⃣ SETUP API KEY
# -----------------------------
load_dotenv()
api_key = os.getenv("Gemini_API_Key")  # ضع مفتاحك في Hugging Face Secrets
genai.configure(api_key=api_key)

# -----------------------------
# 2️⃣ DEFINE MODELS
# -----------------------------
MODEL_ID = "gemini-2.5-flash"

bounding_box_system_instructions = """
Return bounding boxes as a JSON array with labels. Never return masks or code fencing. Limit to 25 objects.
If an object is present multiple times, name them according to their unique characteristic (colors, size, position, unique characteristics, etc..).
"""

IMAGE_MODEL = genai.GenerativeModel(
    model_name=MODEL_ID,
    system_instruction=bounding_box_system_instructions,
    safety_settings=[types.SafetySettingDict(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_ONLY_HIGH")],
)

GEN_CONFIG = genai.types.GenerationConfig(temperature=0.5)

# -----------------------------
# 3️⃣ HELPER FUNCTIONS
# -----------------------------
def parse_json(json_output):
    lines = json_output.splitlines()
    for i, line in enumerate(lines):
        if line.strip() == "```json":
            json_output = "\n".join(lines[i+1:])
            json_output = json_output.split("```")[0]
            break
    return json_output

def plot_bounding_boxes(im, bounding_boxes):
    additional_colors = [colorname for (colorname, colorcode) in ImageColor.colormap.items()]
    im = im.copy()
    width, height = im.size
    draw = ImageDraw.Draw(im)
    colors = ['red','green','blue','yellow','orange','pink','purple','cyan','lime','magenta','violet','gold','silver'] + additional_colors

    try:
        font = ImageFont.load_default()
        for i, box in enumerate(bounding_boxes):
            color = colors[i % len(colors)]
            abs_y1 = int(box["box_2d"][0] / 1000 * height)
            abs_x1 = int(box["box_2d"][1] / 1000 * width)
            abs_y2 = int(box["box_2d"][2] / 1000 * height)
            abs_x2 = int(box["box_2d"][3] / 1000 * width)
            if abs_x1 > abs_x2: abs_x1, abs_x2 = abs_x2, abs_x1
            if abs_y1 > abs_y2: abs_y1, abs_y2 = abs_y2, abs_y1
            draw.rectangle(((abs_x1, abs_y1), (abs_x2, abs_y2)), outline=color, width=4)
            if "label" in box:
                draw.text((abs_x1 + 8, abs_y1 + 6), box["label"], fill=color, font=font)
    except Exception as e:
        print(f"Error drawing bounding boxes: {e}")
    return im

# -----------------------------
# 4️⃣ TEXT SEARCH FOR EACH OBJECT
# -----------------------------
def text_search_query(question):
    try:
        search_tool = types.Tool(google_search=types.GoogleSearch())
        response = genai.models.generate_content(
            model=MODEL_ID,
            contents=question,
            config=types.GenerateContentConfig(tools=[search_tool]),
        )
        search_results = response.candidates[0].grounding_metadata.search_entry_point.rendered_content
        return search_results
    except Exception as e:
        return f"Error: {str(e)}"

# -----------------------------
# 5️⃣ MAIN FUNCTION
# -----------------------------
def detect_objects_with_references(prompt, image):
    # Resize image
    image = image.resize((1024, int(1024 * image.height / image.width)))
    
    # 1️⃣ Generate bounding boxes
    response = IMAGE_MODEL.generate_content([prompt, image], generation_config=GEN_CONFIG)
    bounding_boxes = json.loads(parse_json(response.text))
    
    # 2️⃣ Generate web references for each detected object
    for box in bounding_boxes:
        label = box.get("label", "")
        if label:
            box["web_reference"] = text_search_query(f"What is {label}?")
    
    # 3️⃣ Draw bounding boxes
    img = plot_bounding_boxes(image, bounding_boxes)
    
    # 4️⃣ Prepare HTML for references
    refs_html = "<h3>Detected Object References:</h3>"
    for i, box in enumerate(bounding_boxes):
        label = box.get("label", "")
        ref = box.get("web_reference", "")
        refs_html += f"<b>{label}</b>: {ref}<br><br>"
    
    return img, refs_html

# -----------------------------
# 6️⃣ GRADIO INTERFACE
# -----------------------------
def gradio_interface():
    examples = [
        ["cookies.jpg", "Detect the cookies and label their types."],
        ["foreign_menu.jpg", "Detect menu items and label them."],
        ["messed_room.jpg", "Find unorganized items and label them."],
        ["objects and cartoon animals.jpg", "Identify animals and objects."]
    ]

    with gr.Blocks() as demo:
        gr.Markdown("# Multimodal Gemini Assistant with Web References")
        with gr.Row():
            with gr.Column():
                img_input = gr.Image(type="pil", label="Input Image")
                prompt_input = gr.Textbox(lines=2, label="Input Prompt", placeholder="Describe what to detect")
                submit_btn = gr.Button("Generate")
                gr.Examples(
                    examples=examples,
                    inputs=[img_input, prompt_input],
                    label="Example Images with Prompts"
                )
            with gr.Column():
                img_output = gr.Image(type="pil", label="Output Image")
                refs_output = gr.HTML(label="Object References")
        
        submit_btn.click(detect_objects_with_references, inputs=[prompt_input, img_input], outputs=[img_output, refs_output])
    
    return demo

# -----------------------------
# 7️⃣ RUN
# -----------------------------
if __name__ == "__main__":
    app = gradio_interface()
    app.launch(share=True)