mahmoudalyosify commited on
Commit
7aefb49
·
verified ·
1 Parent(s): 65e8cd5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +139 -101
app.py CHANGED
@@ -1,141 +1,179 @@
1
  import gradio as gr
2
  import os
3
- from io import BytesIO
4
- from PIL import Image, ImageDraw, ImageFont
5
- from PIL import ImageColor
6
  import json
 
 
7
  import google.generativeai as genai
8
  from google.generativeai import types
9
  from dotenv import load_dotenv
10
 
11
-
12
- # 1. SETUP API KEY
13
- # ----------------
14
  load_dotenv()
15
- api_key = os.getenv("Gemini_API_Key")
16
- # Configure the Google AI library
17
- genai.configure(api_key=api_key)
18
-
19
 
20
- # 2. DEFINE MODEL AND INSTRUCTIONS
 
 
 
 
 
21
 
22
  bounding_box_system_instructions = """
23
- Return bounding boxes as a JSON array with labels. Never return masks or code fencing. Limit to 25 objects.
24
- If an object is present multiple times, name them according to their unique characteristic (colors, size, position, unique characteristics, etc..).
25
- """
26
- model = genai.GenerativeModel( model_name='gemini-2.5-flash', system_instruction=bounding_box_system_instructions , safety_settings=[ types.SafetySettingDict( category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_ONLY_HIGH", ) ],)
27
- generation_config = genai.types.GenerationConfig(
28
- temperature=0.5,
29
-
30
- )
31
-
32
-
33
- def generate_bounding_boxes(prompt, image):
34
- image = image.resize((1024, int(1024 * image.height / image.width)))
35
- response = model.generate_content([prompt, image], generation_config=generation_config)
36
- bounding_boxes = parse_json(response.text)
37
- img=plot_bounding_boxes(image, bounding_boxes)
38
- return img
39
-
40
-
41
- def parse_json(json_output):
 
 
42
  lines = json_output.splitlines()
43
  for i, line in enumerate(lines):
44
- if line == "```json":
45
- json_output = "\n".join(lines[i+1:]) # Remove everything before "```json"
46
- json_output = json_output.split("```")[0] # Remove everything after the closing "```"
47
  break
48
- return json_output
 
 
 
49
 
 
 
 
50
  def plot_bounding_boxes(im, bounding_boxes):
51
- """
52
- Plots bounding boxes on an image with labels.
53
- """
54
- additional_colors = [colorname for (colorname, colorcode) in ImageColor.colormap.items()]
55
-
56
  im = im.copy()
57
  width, height = im.size
58
  draw = ImageDraw.Draw(im)
 
59
  colors = [
60
  'red', 'green', 'blue', 'yellow', 'orange', 'pink', 'purple', 'cyan',
61
  'lime', 'magenta', 'violet', 'gold', 'silver'
62
  ] + additional_colors
63
 
64
  try:
65
- # Use a default font if NotoSansCJK is not available
66
- try:
67
- font = ImageFont.load_default()
68
- except OSError:
69
- print("NotoSansCJK-Regular.ttc not found. Using default font.")
70
- font = ImageFont.load_default()
71
-
72
- bounding_boxes_json = json.loads(bounding_boxes)
73
- for i, bounding_box in enumerate(bounding_boxes_json):
74
- color = colors[i % len(colors)]
75
- abs_y1 = int(bounding_box["box_2d"][0] / 1000 * height)
76
- abs_x1 = int(bounding_box["box_2d"][1] / 1000 * width)
77
- abs_y2 = int(bounding_box["box_2d"][2] / 1000 * height)
78
- abs_x2 = int(bounding_box["box_2d"][3] / 1000 * width)
79
-
80
- if abs_x1 > abs_x2:
81
- abs_x1, abs_x2 = abs_x2, abs_x1
82
-
83
- if abs_y1 > abs_y2:
84
- abs_y1, abs_y2 = abs_y2, abs_y1
85
-
86
- # Draw bounding box and label
87
- draw.rectangle(((abs_x1, abs_y1), (abs_x2, abs_y2)), outline=color, width=4)
88
- if "label" in bounding_box:
89
- draw.text((abs_x1 + 8, abs_y1 + 6), bounding_box["label"], fill=color, font=font)
90
- except Exception as e:
91
- print(f"Error drawing bounding boxes: {e}")
92
 
93
  return im
94
- def gradio_interface():
95
- """
96
- Gradio app interface for bounding box generation with example pairs.
97
- """
98
- # Example image + prompt pairs
99
- examples = [
100
- ["cookies.jpg", "Detect the cookies and label their types."],
101
- ["messed_room.jpg", "Find the unorganized item and suggest action in label in the image to fix them."],
102
- ["yoga.jpg", "Show the different yoga poses and name them."],
103
- ["zoom_face.png", "Label the tired faces in the image."]
104
- ]
105
-
106
- with gr.Blocks(gr.themes.Glass(secondary_hue= "rose")) as demo:
107
- gr.Markdown("# Gemini Bounding Box Generator")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  with gr.Row():
110
  with gr.Column():
111
- gr.Markdown("### Input Section")
112
- input_image = gr.Image(type="pil", label="Input Image")
113
- input_prompt = gr.Textbox(lines=2, label="Input Prompt", placeholder="Describe what to detect.")
114
- submit_btn = gr.Button("Generate")
 
 
 
 
115
 
116
  with gr.Column():
117
- gr.Markdown("### Output Section")
118
- output_image = gr.Image(type="pil", label="Output Image")
119
- #output_json = gr.Textbox(label="Bounding Boxes JSON")
 
 
 
 
 
 
 
120
 
121
- gr.Markdown("### Examples")
 
 
 
 
 
 
122
  gr.Examples(
123
  examples=examples,
124
- inputs=[input_image, input_prompt],
125
- label="Example Images with Prompts"
126
- )
127
-
128
- # Event to generate bounding boxes
129
- submit_btn.click(
130
- generate_bounding_boxes,
131
- inputs=[input_prompt, input_image],
132
- outputs=[output_image]
133
  )
134
 
135
  return demo
136
 
137
-
138
-
 
139
  if __name__ == "__main__":
140
- app = gradio_interface()
141
- app.launch()
 
1
  import gradio as gr
2
  import os
 
 
 
3
  import json
4
+ from io import BytesIO
5
+ from PIL import Image, ImageDraw, ImageFont, ImageColor
6
  import google.generativeai as genai
7
  from google.generativeai import types
8
  from dotenv import load_dotenv
9
 
10
+ # -----------------------------
11
+ # 1. LOAD API KEY
12
+ # -----------------------------
13
  load_dotenv()
14
+ DEFAULT_API_KEY = os.getenv("Gemini_API_Key") # fallback if user doesn't input
 
 
 
15
 
16
+ # -----------------------------
17
+ # 2. MODEL SETTINGS
18
+ # -----------------------------
19
+ DEFAULT_MODEL = "gemini-2.5-flash"
20
+ DEFAULT_TEMPERATURE = 0.5
21
+ DEFAULT_MAX_TOKENS = 500
22
 
23
  bounding_box_system_instructions = """
24
+ Return bounding boxes as a JSON array with labels. Never return masks or code fencing.
25
+ Limit to 25 objects. If an object is present multiple times, name them according to their unique characteristics
26
+ (colors, size, position, unique features, etc.). Also provide actionable suggestions for each object if applicable.
27
+ """
28
+
29
+ # -----------------------------
30
+ # 3. IMAGE PREPROCESSING
31
+ # -----------------------------
32
+ def preprocess_image(image):
33
+ image = image.convert("RGB")
34
+ max_dim = 1024
35
+ if image.width > max_dim or image.height > max_dim:
36
+ ratio = min(max_dim / image.width, max_dim / image.height)
37
+ new_size = (int(image.width * ratio), int(image.height * ratio))
38
+ image = image.resize(new_size)
39
+ return image
40
+
41
+ # -----------------------------
42
+ # 4. PARSE JSON OUTPUT
43
+ # -----------------------------
44
+ def parse_json(json_output):
45
  lines = json_output.splitlines()
46
  for i, line in enumerate(lines):
47
+ if line.strip() == "```json":
48
+ json_output = "\n".join(lines[i+1:])
49
+ json_output = json_output.split("```")[0]
50
  break
51
+ try:
52
+ return json.loads(json_output)
53
+ except json.JSONDecodeError:
54
+ return []
55
 
56
+ # -----------------------------
57
+ # 5. PLOT BOUNDING BOXES
58
+ # -----------------------------
59
  def plot_bounding_boxes(im, bounding_boxes):
 
 
 
 
 
60
  im = im.copy()
61
  width, height = im.size
62
  draw = ImageDraw.Draw(im)
63
+ additional_colors = [color for color in ImageColor.colormap.keys()]
64
  colors = [
65
  'red', 'green', 'blue', 'yellow', 'orange', 'pink', 'purple', 'cyan',
66
  'lime', 'magenta', 'violet', 'gold', 'silver'
67
  ] + additional_colors
68
 
69
  try:
70
+ font = ImageFont.load_default()
71
+ except OSError:
72
+ font = ImageFont.load_default()
73
+
74
+ for i, bbox in enumerate(bounding_boxes):
75
+ color = colors[i % len(colors)]
76
+ x1, y1, x2, y2 = bbox.get("box_2d", [0,0,0,0])
77
+ abs_x1 = int(x1 / 1000 * width)
78
+ abs_y1 = int(y1 / 1000 * height)
79
+ abs_x2 = int(x2 / 1000 * width)
80
+ abs_y2 = int(y2 / 1000 * height)
81
+
82
+ if abs_x1 > abs_x2: abs_x1, abs_x2 = abs_x2, abs_x1
83
+ if abs_y1 > abs_y2: abs_y1, abs_y2 = abs_y2, abs_y1
84
+
85
+ draw.rectangle(((abs_x1, abs_y1), (abs_x2, abs_y2)), outline=color, width=3)
86
+ label = bbox.get("label", "")
87
+ suggestion = bbox.get("suggestion", "")
88
+ if label:
89
+ draw.text((abs_x1 + 5, abs_y1 + 5), f"{label}", fill=color, font=font)
90
+ if suggestion:
91
+ draw.text((abs_x1 + 5, abs_y1 + 20), f"{suggestion}", fill=color, font=font)
 
 
 
 
 
92
 
93
  return im
94
+
95
+ # -----------------------------
96
+ # 6. GENERATE RESPONSE
97
+ # -----------------------------
98
+ def generate_response(
99
+ user_prompt,
100
+ user_image=None,
101
+ api_key_input=None,
102
+ model_choice=DEFAULT_MODEL,
103
+ temperature=DEFAULT_TEMPERATURE,
104
+ max_tokens=DEFAULT_MAX_TOKENS
105
+ ):
106
+ api_key_to_use = api_key_input if api_key_input else DEFAULT_API_KEY
107
+ genai.configure(api_key=api_key_to_use)
108
+
109
+ model = genai.GenerativeModel(
110
+ model_name=model_choice,
111
+ system_instruction=bounding_box_system_instructions,
112
+ safety_settings=[types.SafetySettingDict(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_ONLY_HIGH")]
113
+ )
114
+ generation_config = types.GenerationConfig(
115
+ temperature=temperature,
116
+ max_output_tokens=max_tokens
117
+ )
118
+
119
+ if user_image:
120
+ user_image = preprocess_image(user_image)
121
+ response = model.generate_content([user_prompt, user_image], generation_config=generation_config)
122
+ bboxes = parse_json(response.text)
123
+ output_image = plot_bounding_boxes(user_image, bboxes)
124
+ return response.text, output_image
125
+ else:
126
+ response = model.generate_content([user_prompt], generation_config=generation_config)
127
+ return response.text, None
128
+
129
+ # -----------------------------
130
+ # 7. GRADIO INTERFACE
131
+ # -----------------------------
132
+ def build_ui():
133
+ with gr.Blocks() as demo:
134
+ gr.Markdown("# Multi-Modal Assistant with Bounding Boxes & Suggestions")
135
 
136
  with gr.Row():
137
  with gr.Column():
138
+ gr.Markdown("### User Inputs")
139
+ text_input = gr.Textbox(lines=3, label="Prompt")
140
+ image_input = gr.Image(type="pil", label="Optional Image")
141
+ api_key_input = gr.Textbox(label="Google API Key (Optional)", placeholder="Enter your API key")
142
+ model_choice = gr.Radio(["gemini-2.5-flash", "gemini-2.0"], label="Select Model", value=DEFAULT_MODEL)
143
+ temperature_slider = gr.Slider(0, 1, value=DEFAULT_TEMPERATURE, step=0.05, label="Temperature")
144
+ max_tokens_slider = gr.Slider(50, 2000, value=DEFAULT_MAX_TOKENS, step=50, label="Max Tokens")
145
+ run_btn = gr.Button("Run")
146
 
147
  with gr.Column():
148
+ gr.Markdown("### Outputs")
149
+ chatbot_output = gr.Textbox(label="Model Output (Text)", lines=15)
150
+ output_image = gr.Image(type="pil", label="Output Image with Bounding Boxes (if image provided)")
151
+
152
+ # Event
153
+ run_btn.click(
154
+ generate_response,
155
+ inputs=[text_input, image_input, api_key_input, model_choice, temperature_slider, max_tokens_slider],
156
+ outputs=[chatbot_output, output_image]
157
+ )
158
 
159
+ # Add example images + prompts if desired
160
+ gr.Markdown("### Examples (Optional)")
161
+ examples = [
162
+ ["cookies.jpg", "Detect types of cookies and provide suggestions."],
163
+ ["messed_room.jpg", "Identify unorganized items and suggest actions."],
164
+ ["yoga.jpg", "Label the different yoga poses."],
165
+ ]
166
  gr.Examples(
167
  examples=examples,
168
+ inputs=[text_input, image_input],
169
+ label="Example Prompts & Images"
 
 
 
 
 
 
 
170
  )
171
 
172
  return demo
173
 
174
+ # -----------------------------
175
+ # 8. RUN APP
176
+ # -----------------------------
177
  if __name__ == "__main__":
178
+ app = build_ui()
179
+ app.launch()