mahmoudalyosify's picture
Update app.py
13e07e4 verified
import gradio as gr
import os
import json
from PIL import Image, ImageDraw, ImageFont, ImageColor
import google.generativeai as genai
from google.generativeai import types
from dotenv import load_dotenv
# -----------------------------
# 1️⃣ SETUP API KEY
# -----------------------------
load_dotenv()
api_key = os.getenv("Gemini_API_Key") # ضع مفتاحك في Hugging Face Secrets
genai.configure(api_key=api_key)
# -----------------------------
# 2️⃣ DEFINE MODELS
# -----------------------------
MODEL_ID = "gemini-2.5-flash"
bounding_box_system_instructions = """
Return bounding boxes as a JSON array with labels. Never return masks or code fencing. Limit to 25 objects.
If an object is present multiple times, name them according to their unique characteristic (colors, size, position, unique characteristics, etc..).
"""
IMAGE_MODEL = genai.GenerativeModel(
model_name=MODEL_ID,
system_instruction=bounding_box_system_instructions,
safety_settings=[types.SafetySettingDict(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_ONLY_HIGH")],
)
GEN_CONFIG = genai.types.GenerationConfig(temperature=0.5)
# -----------------------------
# 3️⃣ HELPER FUNCTIONS
# -----------------------------
def parse_json(json_output):
lines = json_output.splitlines()
for i, line in enumerate(lines):
if line.strip() == "```json":
json_output = "\n".join(lines[i+1:])
json_output = json_output.split("```")[0]
break
return json_output
def plot_bounding_boxes(im, bounding_boxes):
additional_colors = [colorname for (colorname, colorcode) in ImageColor.colormap.items()]
im = im.copy()
width, height = im.size
draw = ImageDraw.Draw(im)
colors = ['red','green','blue','yellow','orange','pink','purple','cyan','lime','magenta','violet','gold','silver'] + additional_colors
try:
font = ImageFont.load_default()
for i, box in enumerate(bounding_boxes):
color = colors[i % len(colors)]
abs_y1 = int(box["box_2d"][0] / 1000 * height)
abs_x1 = int(box["box_2d"][1] / 1000 * width)
abs_y2 = int(box["box_2d"][2] / 1000 * height)
abs_x2 = int(box["box_2d"][3] / 1000 * width)
if abs_x1 > abs_x2: abs_x1, abs_x2 = abs_x2, abs_x1
if abs_y1 > abs_y2: abs_y1, abs_y2 = abs_y2, abs_y1
draw.rectangle(((abs_x1, abs_y1), (abs_x2, abs_y2)), outline=color, width=4)
if "label" in box:
draw.text((abs_x1 + 8, abs_y1 + 6), box["label"], fill=color, font=font)
except Exception as e:
print(f"Error drawing bounding boxes: {e}")
return im
# -----------------------------
# 4️⃣ TEXT SEARCH FOR EACH OBJECT
# -----------------------------
def text_search_query(question):
try:
search_tool = types.Tool(google_search=types.GoogleSearch())
response = genai.models.generate_content(
model=MODEL_ID,
contents=question,
config=types.GenerateContentConfig(tools=[search_tool]),
)
search_results = response.candidates[0].grounding_metadata.search_entry_point.rendered_content
return search_results
except Exception as e:
return f"Error: {str(e)}"
# -----------------------------
# 5️⃣ MAIN FUNCTION
# -----------------------------
def detect_objects_with_references(prompt, image):
# Resize image
image = image.resize((1024, int(1024 * image.height / image.width)))
# 1️⃣ Generate bounding boxes
response = IMAGE_MODEL.generate_content([prompt, image], generation_config=GEN_CONFIG)
bounding_boxes = json.loads(parse_json(response.text))
# 2️⃣ Generate web references for each detected object
for box in bounding_boxes:
label = box.get("label", "")
if label:
box["web_reference"] = text_search_query(f"What is {label}?")
# 3️⃣ Draw bounding boxes
img = plot_bounding_boxes(image, bounding_boxes)
# 4️⃣ Prepare HTML for references
refs_html = "<h3>Detected Object References:</h3>"
for i, box in enumerate(bounding_boxes):
label = box.get("label", "")
ref = box.get("web_reference", "")
refs_html += f"<b>{label}</b>: {ref}<br><br>"
return img, refs_html
# -----------------------------
# 6️⃣ GRADIO INTERFACE
# -----------------------------
def gradio_interface():
examples = [
["cookies.jpg", "Detect the cookies and label their types."],
["foreign_menu.jpg", "Detect menu items and label them."],
["messed_room.jpg", "Find unorganized items and label them."],
["objects and cartoon animals.jpg", "Identify animals and objects."]
]
with gr.Blocks() as demo:
gr.Markdown("# Multimodal Gemini Assistant with Web References")
with gr.Row():
with gr.Column():
img_input = gr.Image(type="pil", label="Input Image")
prompt_input = gr.Textbox(lines=2, label="Input Prompt", placeholder="Describe what to detect")
submit_btn = gr.Button("Generate")
gr.Examples(
examples=examples,
inputs=[img_input, prompt_input],
label="Example Images with Prompts"
)
with gr.Column():
img_output = gr.Image(type="pil", label="Output Image")
refs_output = gr.HTML(label="Object References")
submit_btn.click(detect_objects_with_references, inputs=[prompt_input, img_input], outputs=[img_output, refs_output])
return demo
# -----------------------------
# 7️⃣ RUN
# -----------------------------
if __name__ == "__main__":
app = gradio_interface()
app.launch(share=True)