ha7naa commited on
Commit
c46b909
·
verified ·
1 Parent(s): 9be97dd

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +125 -0
  2. requirements.txt +0 -0
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import os
4
+ from io import BytesIO
5
+ from PIL import Image, ImageDraw, ImageFont
6
+ from PIL import ImageColor
7
+ import json
8
+ import google.generativeai as genai
9
+ from google.generativeai import types
10
+ from dotenv import load_dotenv
11
+ from IPython.display import display
12
+
13
+ # 1. SETUP API KEY
14
+ # ----------------
15
+ load_dotenv()
16
+ api_key = os.getenv("Gemini_API_Key")
17
+ # Configure the Google AI library
18
+ genai.configure(api_key=api_key)
19
+
20
+
21
+ # 2. DEFINE MODEL AND INSTRUCTIONS
22
+
23
+ bounding_box_system_instructions = """
24
+ Return bounding boxes as a JSON array with labels. Never return masks or code fencing. Limit to 25 objects.
25
+ If an object is present multiple times, name them according to their unique characteristic (colors, size, position, unique characteristics, etc..).
26
+ """
27
+ model = genai.GenerativeModel( model_name='gemini-2.5-flash', system_instruction=bounding_box_system_instructions)
28
+ generation_config = genai.types.GenerationConfig(
29
+ temperature=0.5,
30
+
31
+ )
32
+
33
+
34
+
35
+ # 3. PREPARE IMAGE AND PROMPT
36
+
37
+
38
+
39
+
40
+ def parse_json(json_output):
41
+ lines = json_output.splitlines()
42
+ for i, line in enumerate(lines):
43
+ if line == "```json":
44
+ json_output = "\n".join(lines[i+1:]) # Remove everything before "```json"
45
+ json_output = json_output.split("```")[0] # Remove everything after the closing "```"
46
+ break
47
+ return json_output
48
+ print("After parsing JSON from model response...")
49
+
50
+
51
+ def plot_bounding_boxes(im, bounding_boxes):
52
+ """
53
+ Plots bounding boxes on an image with labels.
54
+ """
55
+ additional_colors = [colorname for (colorname, colorcode) in ImageColor.colormap.items()]
56
+
57
+ im = im.copy()
58
+ width, height = im.size
59
+ draw = ImageDraw.Draw(im)
60
+ colors = [
61
+ 'red', 'green', 'blue', 'yellow', 'orange', 'pink', 'purple', 'cyan',
62
+ 'lime', 'magenta', 'violet', 'gold', 'silver'
63
+ ] + additional_colors
64
+
65
+ try:
66
+ # Use a default font if NotoSansCJK is not available
67
+ try:
68
+ font = ImageFont.load_default()
69
+ except OSError:
70
+ print("NotoSansCJK-Regular.ttc not found. Using default font.")
71
+ font = ImageFont.load_default()
72
+
73
+ bounding_boxes_json = json.loads(bounding_boxes)
74
+ for i, bounding_box in enumerate(bounding_boxes_json):
75
+ color = colors[i % len(colors)]
76
+ abs_y1 = int(bounding_box["box_2d"][0] / 1000 * height)
77
+ abs_x1 = int(bounding_box["box_2d"][1] / 1000 * width)
78
+ abs_y2 = int(bounding_box["box_2d"][2] / 1000 * height)
79
+ abs_x2 = int(bounding_box["box_2d"][3] / 1000 * width)
80
+
81
+ if abs_x1 > abs_x2:
82
+ abs_x1, abs_x2 = abs_x2, abs_x1
83
+
84
+ if abs_y1 > abs_y2:
85
+ abs_y1, abs_y2 = abs_y2, abs_y1
86
+
87
+ # Draw bounding box and label
88
+ draw.rectangle(((abs_x1, abs_y1), (abs_x2, abs_y2)), outline=color, width=4)
89
+ if "label" in bounding_box:
90
+ draw.text((abs_x1 + 8, abs_y1 + 6), bounding_box["label"], fill=color, font=font)
91
+ except Exception as e:
92
+ print(f"Error drawing bounding boxes: {e}")
93
+
94
+ return im
95
+
96
+ def detect_and_draw_gradio(user_prompt: str, image: Image.Image, max_width: int = 1024):
97
+ if image is None:
98
+ return None, "Please upload an image."
99
+
100
+ if not user_prompt or not user_prompt.strip():
101
+ user_prompt = PROMPT
102
+
103
+ image = image.convert("RGB")
104
+ W, H = image.size
105
+
106
+ # resize
107
+ if W > max_width:
108
+ newW = max_width
109
+ newH = int(newW * H / W)
110
+ im_resized = image.resize((newW, newH), Image.Resampling.LANCZOS)
111
+ else:
112
+ im_resized = image
113
+
114
+ # send prompt + image
115
+ response = model.generate_content([user_prompt, im_resized], generation_config=generation_config)
116
+ raw_text = getattr(response, "text", "") or ""
117
+
118
+ bounding_boxes = parse_json(raw_text)
119
+ try:
120
+ json.loads(bounding_boxes)
121
+ except Exception:
122
+ return im_resized, raw_text # debugging
123
+
124
+ out_img = plot_bounding_boxes(im_resized, bounding_boxes)
125
+ return out_img, bounding_boxes
requirements.txt ADDED
File without changes