mahmoudalyosify commited on
Commit
10ac47c
·
verified ·
1 Parent(s): 4c17176

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -132
app.py CHANGED
@@ -8,172 +8,119 @@ from google.generativeai import types
8
  from dotenv import load_dotenv
9
 
10
  # -----------------------------
11
- # 1. LOAD API KEY
12
  # -----------------------------
13
  load_dotenv()
14
- DEFAULT_API_KEY = os.getenv("Gemini_API_Key") # fallback if user doesn't input
 
15
 
16
  # -----------------------------
17
- # 2. MODEL SETTINGS
18
  # -----------------------------
19
- DEFAULT_MODEL = "gemini-2.5-flash"
20
- DEFAULT_TEMPERATURE = 0.5
21
- DEFAULT_MAX_TOKENS = 500
22
 
 
23
  bounding_box_system_instructions = """
24
- Return bounding boxes as a JSON array with labels. Never return masks or code fencing.
25
- Limit to 25 objects. If an object is present multiple times, name them according to their unique characteristics
26
- (colors, size, position, unique features, etc.). Also provide actionable suggestions for each object if applicable.
27
  """
 
 
 
 
 
28
 
29
- # -----------------------------
30
- # 3. IMAGE PREPROCESSING
31
- # -----------------------------
32
- def preprocess_image(image):
33
- image = image.convert("RGB")
34
- max_dim = 1024
35
- if image.width > max_dim or image.height > max_dim:
36
- ratio = min(max_dim / image.width, max_dim / image.height)
37
- new_size = (int(image.width * ratio), int(image.height * ratio))
38
- image = image.resize(new_size)
39
- return image
40
 
41
  # -----------------------------
42
- # 4. PARSE JSON OUTPUT
43
  # -----------------------------
44
- def parse_json(json_output):
45
  lines = json_output.splitlines()
46
  for i, line in enumerate(lines):
47
  if line.strip() == "```json":
48
  json_output = "\n".join(lines[i+1:])
49
  json_output = json_output.split("```")[0]
50
  break
51
- try:
52
- return json.loads(json_output)
53
- except json.JSONDecodeError:
54
- return []
55
 
56
- # -----------------------------
57
- # 5. PLOT BOUNDING BOXES
58
- # -----------------------------
59
  def plot_bounding_boxes(im, bounding_boxes):
 
60
  im = im.copy()
61
  width, height = im.size
62
  draw = ImageDraw.Draw(im)
63
- additional_colors = [color for color in ImageColor.colormap.keys()]
64
- colors = [
65
- 'red', 'green', 'blue', 'yellow', 'orange', 'pink', 'purple', 'cyan',
66
- 'lime', 'magenta', 'violet', 'gold', 'silver'
67
- ] + additional_colors
68
 
69
  try:
70
  font = ImageFont.load_default()
71
- except OSError:
72
- font = ImageFont.load_default()
73
-
74
- for i, bbox in enumerate(bounding_boxes):
75
- color = colors[i % len(colors)]
76
- x1, y1, x2, y2 = bbox.get("box_2d", [0,0,0,0])
77
- abs_x1 = int(x1 / 1000 * width)
78
- abs_y1 = int(y1 / 1000 * height)
79
- abs_x2 = int(x2 / 1000 * width)
80
- abs_y2 = int(y2 / 1000 * height)
81
-
82
- if abs_x1 > abs_x2: abs_x1, abs_x2 = abs_x2, abs_x1
83
- if abs_y1 > abs_y2: abs_y1, abs_y2 = abs_y2, abs_y1
84
-
85
- draw.rectangle(((abs_x1, abs_y1), (abs_x2, abs_y2)), outline=color, width=3)
86
- label = bbox.get("label", "")
87
- suggestion = bbox.get("suggestion", "")
88
- if label:
89
- draw.text((abs_x1 + 5, abs_y1 + 5), f"{label}", fill=color, font=font)
90
- if suggestion:
91
- draw.text((abs_x1 + 5, abs_y1 + 20), f"{suggestion}", fill=color, font=font)
92
-
93
  return im
94
 
 
 
 
 
 
 
 
95
  # -----------------------------
96
- # 6. GENERATE RESPONSE
97
  # -----------------------------
98
- def generate_response(
99
- user_prompt,
100
- user_image=None,
101
- api_key_input=None,
102
- model_choice=DEFAULT_MODEL,
103
- temperature=DEFAULT_TEMPERATURE,
104
- max_tokens=DEFAULT_MAX_TOKENS
105
- ):
106
- api_key_to_use = api_key_input if api_key_input else DEFAULT_API_KEY
107
- genai.configure(api_key=api_key_to_use)
108
-
109
- model = genai.GenerativeModel(
110
- model_name=model_choice,
111
- system_instruction=bounding_box_system_instructions,
112
- safety_settings=[types.SafetySettingDict(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_ONLY_HIGH")]
113
- )
114
- generation_config = types.GenerationConfig(
115
- temperature=temperature,
116
- max_output_tokens=max_tokens
117
- )
118
-
119
- if user_image:
120
- user_image = preprocess_image(user_image)
121
- response = model.generate_content([user_prompt, user_image], generation_config=generation_config)
122
- bboxes = parse_json(response.text)
123
- output_image = plot_bounding_boxes(user_image, bboxes)
124
- return response.text, output_image
125
- else:
126
- response = model.generate_content([user_prompt], generation_config=generation_config)
127
- return response.text, None
128
 
129
  # -----------------------------
130
- # 7. GRADIO INTERFACE
131
  # -----------------------------
132
- def build_ui():
133
  with gr.Blocks() as demo:
134
- gr.Markdown("# Multi-Modal Assistant with Bounding Boxes & Suggestions")
135
-
136
- with gr.Row():
137
- with gr.Column():
138
- gr.Markdown("### User Inputs")
139
- text_input = gr.Textbox(lines=3, label="Prompt")
140
- image_input = gr.Image(type="pil", label="Optional Image")
141
- api_key_input = gr.Textbox(label="Google API Key (Optional)", placeholder="Enter your API key")
142
- model_choice = gr.Radio(["gemini-2.5-flash", "gemini-2.0"], label="Select Model", value=DEFAULT_MODEL)
143
- temperature_slider = gr.Slider(0, 1, value=DEFAULT_TEMPERATURE, step=0.05, label="Temperature")
144
- max_tokens_slider = gr.Slider(50, 2000, value=DEFAULT_MAX_TOKENS, step=50, label="Max Tokens")
145
- run_btn = gr.Button("Run")
146
-
147
- with gr.Column():
148
- gr.Markdown("### Outputs")
149
- chatbot_output = gr.Textbox(label="Model Output (Text)", lines=15)
150
- output_image = gr.Image(type="pil", label="Output Image with Bounding Boxes (if image provided)")
151
-
152
- # Event
153
- run_btn.click(
154
- generate_response,
155
- inputs=[text_input, image_input, api_key_input, model_choice, temperature_slider, max_tokens_slider],
156
- outputs=[chatbot_output, output_image]
157
- )
158
-
159
- # Add example images + prompts if desired
160
- gr.Markdown("### Examples (Optional)")
161
- examples = [
162
- ["cookies.jpg", "Detect types of cookies and provide suggestions."],
163
- ["messed_room.jpg", "Identify unorganized items and suggest actions."],
164
- ["yoga.jpg", "Label the different yoga poses."],
165
- ]
166
- gr.Examples(
167
- examples=examples,
168
- inputs=[text_input, image_input],
169
- label="Example Prompts & Images"
170
- )
171
-
172
  return demo
173
 
174
- # -----------------------------
175
- # 8. RUN APP
176
- # -----------------------------
177
  if __name__ == "__main__":
178
- app = build_ui()
179
- app.launch()
 
8
  from dotenv import load_dotenv
9
 
10
  # -----------------------------
11
+ # 1. SETUP API KEY
12
  # -----------------------------
13
  load_dotenv()
14
+ api_key = os.getenv("Gemini_API_Key") # لازم تحط المفتاح في Hugging Face Secrets
15
+ genai.configure(api_key=api_key)
16
 
17
  # -----------------------------
18
+ # 2. DEFINE MODELS
19
  # -----------------------------
20
+ # Text & Web Search Model
21
+ TEXT_MODEL_ID = "gemini-2.5-flash"
 
22
 
23
+ # Image / Bounding Box Model
24
  bounding_box_system_instructions = """
25
+ Return bounding boxes as a JSON array with labels. Never return masks or code fencing. Limit to 25 objects.
26
+ If an object is present multiple times, name them according to their unique characteristic (colors, size, position, unique characteristics, etc..).
 
27
  """
28
+ IMAGE_MODEL = genai.GenerativeModel(
29
+ model_name='gemini-2.5-flash',
30
+ system_instruction=bounding_box_system_instructions,
31
+ safety_settings=[types.SafetySettingDict(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_ONLY_HIGH")],
32
+ )
33
 
34
+ GEN_CONFIG = genai.types.GenerationConfig(temperature=0.5)
 
 
 
 
 
 
 
 
 
 
35
 
36
  # -----------------------------
37
+ # 3. IMAGE FUNCTIONS
38
  # -----------------------------
39
+ def parse_json(json_output):
40
  lines = json_output.splitlines()
41
  for i, line in enumerate(lines):
42
  if line.strip() == "```json":
43
  json_output = "\n".join(lines[i+1:])
44
  json_output = json_output.split("```")[0]
45
  break
46
+ return json_output
 
 
 
47
 
 
 
 
48
  def plot_bounding_boxes(im, bounding_boxes):
49
+ additional_colors = [colorname for (colorname, colorcode) in ImageColor.colormap.items()]
50
  im = im.copy()
51
  width, height = im.size
52
  draw = ImageDraw.Draw(im)
53
+ colors = ['red','green','blue','yellow','orange','pink','purple','cyan','lime','magenta','violet','gold','silver'] + additional_colors
 
 
 
 
54
 
55
  try:
56
  font = ImageFont.load_default()
57
+ bounding_boxes_json = json.loads(bounding_boxes)
58
+ for i, bounding_box in enumerate(bounding_boxes_json):
59
+ color = colors[i % len(colors)]
60
+ abs_y1 = int(bounding_box["box_2d"][0] / 1000 * height)
61
+ abs_x1 = int(bounding_box["box_2d"][1] / 1000 * width)
62
+ abs_y2 = int(bounding_box["box_2d"][2] / 1000 * height)
63
+ abs_x2 = int(bounding_box["box_2d"][3] / 1000 * width)
64
+ if abs_x1 > abs_x2: abs_x1, abs_x2 = abs_x2, abs_x1
65
+ if abs_y1 > abs_y2: abs_y1, abs_y2 = abs_y2, abs_y1
66
+ draw.rectangle(((abs_x1, abs_y1), (abs_x2, abs_y2)), outline=color, width=4)
67
+ if "label" in bounding_box: draw.text((abs_x1 + 8, abs_y1 + 6), bounding_box["label"], fill=color, font=font)
68
+ except Exception as e:
69
+ print(f"Error drawing bounding boxes: {e}")
 
 
 
 
 
 
 
 
 
70
  return im
71
 
72
+ def generate_bounding_boxes(prompt, image):
73
+ image = image.resize((1024, int(1024 * image.height / image.width)))
74
+ response = IMAGE_MODEL.generate_content([prompt, image], generation_config=GEN_CONFIG)
75
+ bounding_boxes = parse_json(response.text)
76
+ img = plot_bounding_boxes(image, bounding_boxes)
77
+ return img
78
+
79
  # -----------------------------
80
+ # 4. TEXT / SEARCH FUNCTION
81
  # -----------------------------
82
+ def text_search_query(question):
83
+ try:
84
+ search_tool = types.Tool(google_search=types.GoogleSearch())
85
+ response = genai.models.generate_content(
86
+ model=TEXT_MODEL_ID,
87
+ contents=question,
88
+ config=types.GenerateContentConfig(tools=[search_tool]),
89
+ )
90
+ ai_response = response.text
91
+ search_results = response.candidates[0].grounding_metadata.search_entry_point.rendered_content
92
+ return ai_response, search_results
93
+ except Exception as e:
94
+ return f"Error: {str(e)}", ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  # -----------------------------
97
+ # 5. GRADIO INTERFACE
98
  # -----------------------------
99
+ def gradio_interface():
100
  with gr.Blocks() as demo:
101
+ gr.Markdown("# Multimodal Gemini Assistant")
102
+ with gr.Tab("Text & Web Search"):
103
+ with gr.Row():
104
+ with gr.Column():
105
+ txt_input = gr.Textbox(lines=2, label="Ask a Question")
106
+ txt_btn = gr.Button("Submit")
107
+ with gr.Column():
108
+ txt_output = gr.Textbox(label="AI Response")
109
+ search_output = gr.HTML(label="Search Results")
110
+ txt_btn.click(text_search_query, inputs=txt_input, outputs=[txt_output, search_output])
111
+
112
+ with gr.Tab("Image Bounding Boxes"):
113
+ with gr.Row():
114
+ with gr.Column():
115
+ img_input = gr.Image(type="pil", label="Input Image")
116
+ prompt_input = gr.Textbox(lines=2, label="Input Prompt", placeholder="Describe what to detect")
117
+ img_btn = gr.Button("Generate")
118
+ with gr.Column():
119
+ img_output = gr.Image(type="pil", label="Output Image")
120
+ img_btn.click(generate_bounding_boxes, inputs=[prompt_input, img_input], outputs=img_output)
121
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  return demo
123
 
 
 
 
124
  if __name__ == "__main__":
125
+ app = gradio_interface()
126
+ app.launch(share=True)