Update app.py
Browse files
app.py
CHANGED
|
@@ -1,141 +1,179 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import os
|
| 3 |
-
from io import BytesIO
|
| 4 |
-
from PIL import Image, ImageDraw, ImageFont
|
| 5 |
-
from PIL import ImageColor
|
| 6 |
import json
|
|
|
|
|
|
|
| 7 |
import google.generativeai as genai
|
| 8 |
from google.generativeai import types
|
| 9 |
from dotenv import load_dotenv
|
| 10 |
|
| 11 |
-
|
| 12 |
-
# 1.
|
| 13 |
-
# ----------------
|
| 14 |
load_dotenv()
|
| 15 |
-
|
| 16 |
-
# Configure the Google AI library
|
| 17 |
-
genai.configure(api_key=api_key)
|
| 18 |
-
|
| 19 |
|
| 20 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
bounding_box_system_instructions = """
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
return
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
| 42 |
lines = json_output.splitlines()
|
| 43 |
for i, line in enumerate(lines):
|
| 44 |
-
if line == "```json":
|
| 45 |
-
json_output = "\n".join(lines[i+1:])
|
| 46 |
-
json_output = json_output.split("```")[0]
|
| 47 |
break
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
| 49 |
|
|
|
|
|
|
|
|
|
|
| 50 |
def plot_bounding_boxes(im, bounding_boxes):
|
| 51 |
-
"""
|
| 52 |
-
Plots bounding boxes on an image with labels.
|
| 53 |
-
"""
|
| 54 |
-
additional_colors = [colorname for (colorname, colorcode) in ImageColor.colormap.items()]
|
| 55 |
-
|
| 56 |
im = im.copy()
|
| 57 |
width, height = im.size
|
| 58 |
draw = ImageDraw.Draw(im)
|
|
|
|
| 59 |
colors = [
|
| 60 |
'red', 'green', 'blue', 'yellow', 'orange', 'pink', 'purple', 'cyan',
|
| 61 |
'lime', 'magenta', 'violet', 'gold', 'silver'
|
| 62 |
] + additional_colors
|
| 63 |
|
| 64 |
try:
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
draw.rectangle(((abs_x1, abs_y1), (abs_x2, abs_y2)), outline=color, width=4)
|
| 88 |
-
if "label" in bounding_box:
|
| 89 |
-
draw.text((abs_x1 + 8, abs_y1 + 6), bounding_box["label"], fill=color, font=font)
|
| 90 |
-
except Exception as e:
|
| 91 |
-
print(f"Error drawing bounding boxes: {e}")
|
| 92 |
|
| 93 |
return im
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
with gr.Row():
|
| 110 |
with gr.Column():
|
| 111 |
-
gr.Markdown("###
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
with gr.Column():
|
| 117 |
-
gr.Markdown("###
|
| 118 |
-
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
gr.Examples(
|
| 123 |
examples=examples,
|
| 124 |
-
inputs=[
|
| 125 |
-
label="Example
|
| 126 |
-
)
|
| 127 |
-
|
| 128 |
-
# Event to generate bounding boxes
|
| 129 |
-
submit_btn.click(
|
| 130 |
-
generate_bounding_boxes,
|
| 131 |
-
inputs=[input_prompt, input_image],
|
| 132 |
-
outputs=[output_image]
|
| 133 |
)
|
| 134 |
|
| 135 |
return demo
|
| 136 |
|
| 137 |
-
|
| 138 |
-
|
|
|
|
| 139 |
if __name__ == "__main__":
|
| 140 |
-
app =
|
| 141 |
-
app.launch()
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import os
|
|
|
|
|
|
|
|
|
|
| 3 |
import json
|
| 4 |
+
from io import BytesIO
|
| 5 |
+
from PIL import Image, ImageDraw, ImageFont, ImageColor
|
| 6 |
import google.generativeai as genai
|
| 7 |
from google.generativeai import types
|
| 8 |
from dotenv import load_dotenv
|
| 9 |
|
| 10 |
+
# -----------------------------
|
| 11 |
+
# 1. LOAD API KEY
|
| 12 |
+
# -----------------------------
|
| 13 |
load_dotenv()
|
| 14 |
+
DEFAULT_API_KEY = os.getenv("Gemini_API_Key") # fallback if user doesn't input
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
+
# -----------------------------
|
| 17 |
+
# 2. MODEL SETTINGS
|
| 18 |
+
# -----------------------------
|
| 19 |
+
DEFAULT_MODEL = "gemini-2.5-flash"
|
| 20 |
+
DEFAULT_TEMPERATURE = 0.5
|
| 21 |
+
DEFAULT_MAX_TOKENS = 500
|
| 22 |
|
| 23 |
bounding_box_system_instructions = """
|
| 24 |
+
Return bounding boxes as a JSON array with labels. Never return masks or code fencing.
|
| 25 |
+
Limit to 25 objects. If an object is present multiple times, name them according to their unique characteristics
|
| 26 |
+
(colors, size, position, unique features, etc.). Also provide actionable suggestions for each object if applicable.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
# -----------------------------
|
| 30 |
+
# 3. IMAGE PREPROCESSING
|
| 31 |
+
# -----------------------------
|
| 32 |
+
def preprocess_image(image):
|
| 33 |
+
image = image.convert("RGB")
|
| 34 |
+
max_dim = 1024
|
| 35 |
+
if image.width > max_dim or image.height > max_dim:
|
| 36 |
+
ratio = min(max_dim / image.width, max_dim / image.height)
|
| 37 |
+
new_size = (int(image.width * ratio), int(image.height * ratio))
|
| 38 |
+
image = image.resize(new_size)
|
| 39 |
+
return image
|
| 40 |
+
|
| 41 |
+
# -----------------------------
|
| 42 |
+
# 4. PARSE JSON OUTPUT
|
| 43 |
+
# -----------------------------
|
| 44 |
+
def parse_json(json_output):
|
| 45 |
lines = json_output.splitlines()
|
| 46 |
for i, line in enumerate(lines):
|
| 47 |
+
if line.strip() == "```json":
|
| 48 |
+
json_output = "\n".join(lines[i+1:])
|
| 49 |
+
json_output = json_output.split("```")[0]
|
| 50 |
break
|
| 51 |
+
try:
|
| 52 |
+
return json.loads(json_output)
|
| 53 |
+
except json.JSONDecodeError:
|
| 54 |
+
return []
|
| 55 |
|
| 56 |
+
# -----------------------------
|
| 57 |
+
# 5. PLOT BOUNDING BOXES
|
| 58 |
+
# -----------------------------
|
| 59 |
def plot_bounding_boxes(im, bounding_boxes):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
im = im.copy()
|
| 61 |
width, height = im.size
|
| 62 |
draw = ImageDraw.Draw(im)
|
| 63 |
+
additional_colors = [color for color in ImageColor.colormap.keys()]
|
| 64 |
colors = [
|
| 65 |
'red', 'green', 'blue', 'yellow', 'orange', 'pink', 'purple', 'cyan',
|
| 66 |
'lime', 'magenta', 'violet', 'gold', 'silver'
|
| 67 |
] + additional_colors
|
| 68 |
|
| 69 |
try:
|
| 70 |
+
font = ImageFont.load_default()
|
| 71 |
+
except OSError:
|
| 72 |
+
font = ImageFont.load_default()
|
| 73 |
+
|
| 74 |
+
for i, bbox in enumerate(bounding_boxes):
|
| 75 |
+
color = colors[i % len(colors)]
|
| 76 |
+
x1, y1, x2, y2 = bbox.get("box_2d", [0,0,0,0])
|
| 77 |
+
abs_x1 = int(x1 / 1000 * width)
|
| 78 |
+
abs_y1 = int(y1 / 1000 * height)
|
| 79 |
+
abs_x2 = int(x2 / 1000 * width)
|
| 80 |
+
abs_y2 = int(y2 / 1000 * height)
|
| 81 |
+
|
| 82 |
+
if abs_x1 > abs_x2: abs_x1, abs_x2 = abs_x2, abs_x1
|
| 83 |
+
if abs_y1 > abs_y2: abs_y1, abs_y2 = abs_y2, abs_y1
|
| 84 |
+
|
| 85 |
+
draw.rectangle(((abs_x1, abs_y1), (abs_x2, abs_y2)), outline=color, width=3)
|
| 86 |
+
label = bbox.get("label", "")
|
| 87 |
+
suggestion = bbox.get("suggestion", "")
|
| 88 |
+
if label:
|
| 89 |
+
draw.text((abs_x1 + 5, abs_y1 + 5), f"{label}", fill=color, font=font)
|
| 90 |
+
if suggestion:
|
| 91 |
+
draw.text((abs_x1 + 5, abs_y1 + 20), f"{suggestion}", fill=color, font=font)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
return im
|
| 94 |
+
|
| 95 |
+
# -----------------------------
|
| 96 |
+
# 6. GENERATE RESPONSE
|
| 97 |
+
# -----------------------------
|
| 98 |
+
def generate_response(
|
| 99 |
+
user_prompt,
|
| 100 |
+
user_image=None,
|
| 101 |
+
api_key_input=None,
|
| 102 |
+
model_choice=DEFAULT_MODEL,
|
| 103 |
+
temperature=DEFAULT_TEMPERATURE,
|
| 104 |
+
max_tokens=DEFAULT_MAX_TOKENS
|
| 105 |
+
):
|
| 106 |
+
api_key_to_use = api_key_input if api_key_input else DEFAULT_API_KEY
|
| 107 |
+
genai.configure(api_key=api_key_to_use)
|
| 108 |
+
|
| 109 |
+
model = genai.GenerativeModel(
|
| 110 |
+
model_name=model_choice,
|
| 111 |
+
system_instruction=bounding_box_system_instructions,
|
| 112 |
+
safety_settings=[types.SafetySettingDict(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_ONLY_HIGH")]
|
| 113 |
+
)
|
| 114 |
+
generation_config = types.GenerationConfig(
|
| 115 |
+
temperature=temperature,
|
| 116 |
+
max_output_tokens=max_tokens
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
if user_image:
|
| 120 |
+
user_image = preprocess_image(user_image)
|
| 121 |
+
response = model.generate_content([user_prompt, user_image], generation_config=generation_config)
|
| 122 |
+
bboxes = parse_json(response.text)
|
| 123 |
+
output_image = plot_bounding_boxes(user_image, bboxes)
|
| 124 |
+
return response.text, output_image
|
| 125 |
+
else:
|
| 126 |
+
response = model.generate_content([user_prompt], generation_config=generation_config)
|
| 127 |
+
return response.text, None
|
| 128 |
+
|
| 129 |
+
# -----------------------------
|
| 130 |
+
# 7. GRADIO INTERFACE
|
| 131 |
+
# -----------------------------
|
| 132 |
+
def build_ui():
|
| 133 |
+
with gr.Blocks() as demo:
|
| 134 |
+
gr.Markdown("# Multi-Modal Assistant with Bounding Boxes & Suggestions")
|
| 135 |
|
| 136 |
with gr.Row():
|
| 137 |
with gr.Column():
|
| 138 |
+
gr.Markdown("### User Inputs")
|
| 139 |
+
text_input = gr.Textbox(lines=3, label="Prompt")
|
| 140 |
+
image_input = gr.Image(type="pil", label="Optional Image")
|
| 141 |
+
api_key_input = gr.Textbox(label="Google API Key (Optional)", placeholder="Enter your API key")
|
| 142 |
+
model_choice = gr.Radio(["gemini-2.5-flash", "gemini-2.0"], label="Select Model", value=DEFAULT_MODEL)
|
| 143 |
+
temperature_slider = gr.Slider(0, 1, value=DEFAULT_TEMPERATURE, step=0.05, label="Temperature")
|
| 144 |
+
max_tokens_slider = gr.Slider(50, 2000, value=DEFAULT_MAX_TOKENS, step=50, label="Max Tokens")
|
| 145 |
+
run_btn = gr.Button("Run")
|
| 146 |
|
| 147 |
with gr.Column():
|
| 148 |
+
gr.Markdown("### Outputs")
|
| 149 |
+
chatbot_output = gr.Textbox(label="Model Output (Text)", lines=15)
|
| 150 |
+
output_image = gr.Image(type="pil", label="Output Image with Bounding Boxes (if image provided)")
|
| 151 |
+
|
| 152 |
+
# Event
|
| 153 |
+
run_btn.click(
|
| 154 |
+
generate_response,
|
| 155 |
+
inputs=[text_input, image_input, api_key_input, model_choice, temperature_slider, max_tokens_slider],
|
| 156 |
+
outputs=[chatbot_output, output_image]
|
| 157 |
+
)
|
| 158 |
|
| 159 |
+
# Add example images + prompts if desired
|
| 160 |
+
gr.Markdown("### Examples (Optional)")
|
| 161 |
+
examples = [
|
| 162 |
+
["cookies.jpg", "Detect types of cookies and provide suggestions."],
|
| 163 |
+
["messed_room.jpg", "Identify unorganized items and suggest actions."],
|
| 164 |
+
["yoga.jpg", "Label the different yoga poses."],
|
| 165 |
+
]
|
| 166 |
gr.Examples(
|
| 167 |
examples=examples,
|
| 168 |
+
inputs=[text_input, image_input],
|
| 169 |
+
label="Example Prompts & Images"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
)
|
| 171 |
|
| 172 |
return demo
|
| 173 |
|
| 174 |
+
# -----------------------------
|
| 175 |
+
# 8. RUN APP
|
| 176 |
+
# -----------------------------
|
| 177 |
if __name__ == "__main__":
|
| 178 |
+
app = build_ui()
|
| 179 |
+
app.launch()
|