scmlewis commited on
Commit
2890c7d
·
verified ·
1 Parent(s): 9ee7032

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -286
app.py CHANGED
@@ -1,296 +1,118 @@
1
- import os
2
- import tempfile
3
- import time
4
- from PIL import Image
5
- import gradio as gr
6
- from google import genai
7
- from google.genai import types
8
-
9
- # Helpers
10
- def save_binary_file(file_name, data):
11
- with open(file_name, "wb") as f:
12
- f.write(data)
13
-
14
- def generate_edit(prompt, pil_image, api_key, model="gemini-2.0-flash-exp"):
15
- # Initialize client
16
- client = genai.Client(api_key=(api_key.strip() if api_key and api_key.strip() != "" else os.environ.get("GEMINI_API_KEY")))
17
-
18
- # Save image to a temp path for upload
19
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_img:
20
- image_path = tmp_img.name
21
- pil_image.save(image_path)
22
-
23
- # Upload and prepare content
24
- files = [client.files.upload(file=image_path)]
25
- contents = [
26
- types.Content(
27
- role="user",
28
- parts=[
29
- types.Part.from_uri(file_uri=files[0].uri, mime_type=files[0].mime_type),
30
- types.Part.from_text(text=prompt),
31
- ],
32
- ),
33
- ]
34
-
35
- # Config with image + text modalities
36
- generate_content_config = types.GenerateContentConfig(
37
- temperature=1,
38
- top_p=0.95,
39
- top_k=40,
40
- max_output_tokens=8192,
41
- response_modalities=["image", "text"],
42
- response_mime_type="text/plain",
43
- )
44
-
45
- text_response = ""
46
- image_out_path = None
47
-
48
- # Streamed generation to capture inline image data
49
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_out:
50
- out_path = tmp_out.name
51
- for chunk in client.models.generate_content_stream(
52
- model=model,
53
- contents=contents,
54
- config=generate_content_config,
55
- ):
56
- if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts:
57
- continue
58
- candidate = chunk.candidates[0].content.parts[0]
59
- if candidate.inline_data:
60
- save_binary_file(out_path, candidate.inline_data.data)
61
- image_out_path = out_path
62
- break
63
- else:
64
- text_response += chunk.text + "\n"
65
-
66
- # Cleanup
67
- try:
68
- del files
69
- except Exception:
70
- pass
71
- return image_out_path, text_response
72
-
73
- def process_image_and_prompt(pil_image, prompt, api_key, progress_callback=None):
74
- try:
75
- # Indicate starting
76
- if progress_callback:
77
- progress_callback("Generating…")
78
- image_path, text_out = generate_edit(prompt, pil_image, api_key)
79
- if image_path:
80
- img = Image.open(image_path)
81
- if img.mode == "RGBA":
82
- img = img.convert("RGB")
83
- # success
84
- if progress_callback:
85
- progress_callback("Done ✓")
86
- return img, "Image generated successfully!", None
87
- else:
88
- # fail to generate image
89
- if progress_callback:
90
- progress_callback("Failed to generate image")
91
- return None, f"⚠️ {text_out.strip()}", None
92
- except Exception as e:
93
- if progress_callback:
94
- progress_callback("Error")
95
- return None, f"❌ Generation failed: {str(e)}", None
96
-
97
- def reset_inputs(api_key_value=None):
98
- return None, "", api_key_value or "", ""
99
-
100
- # Styles
101
- css_style = """
102
- :root {
103
- --bg: #14161c;
104
- --panel: #1e1f25;
105
- --text: #e8eaf6;
106
- --muted: #a0aec0;
107
- --accent: #6a8efd;
108
- }
109
- body, .app-container {
110
- background: var(--bg);
111
- color: var(--text);
112
- }
113
- .header-block {
114
- width: 100%;
115
- display: flex;
116
- align-items: center;
117
- justify-content: center;
118
- padding: 18px;
119
- }
120
- .header-gradient {
121
- width: 100%;
122
- padding: 28px 0;
123
- border-radius: 14px;
124
- background: linear-gradient(90deg, #6a8efd, #44abc7);
125
- box-shadow: 0 2px 12px rgb(50 50 70 / 12%);
126
- text-align: center;
127
- }
128
- .header-title {
129
- margin: 0;
130
- font-size: 2.8rem;
131
- font-weight: 900;
132
- color: #fff;
133
- text-shadow: 1px 3px 12px rgba(0,0,0,.25);
134
- }
135
- .header-subtitle {
136
- margin-top: 6px;
137
- font-size: 1.05rem;
138
- color: #e8f2ff;
139
- }
140
- .gradient-button {
141
- background: linear-gradient(90deg, #44abc7, #6a8efd);
142
- color: white;
143
- font-weight: 700;
144
- border: none;
145
- padding: 12px 28px;
146
- border-radius: 10px;
147
- cursor: pointer;
148
- transition: background 0.25s ease;
149
- }
150
- .gradient-button:hover {
151
- background: linear-gradient(90deg, #6a8efd, #44abc7);
152
- }
153
- .main {
154
- display: flex;
155
- gap: 22px;
156
  }
157
- .sidebar {
158
- background: #1f2230;
159
- padding: 20px;
160
- border-radius: 12px;
161
- min-height: 360px;
162
- width: 320px;
163
- box-shadow: 0 2px 10px rgb(0 0 0 / 0.25);
164
- }
165
- .sidebar h2 {
166
- color: #8ab4ff;
167
- font-size: 1rem;
168
- margin: 6px 0 8px;
169
- }
170
- .sidebar ul {
171
- margin: 0;
172
- padding-left: 18px;
173
- color: #dbeafe;
174
- line-height: 1.8;
175
- }
176
- .sidebar a { color: #97b7ff; text-decoration: none; }
177
- .sidebar a:hover { text-decoration: underline; }
178
-
179
- .main-panel {
180
- flex: 1;
181
- min-width: 0;
182
- }
183
- .section-header {
184
- font-size: 1.15rem;
185
- font-weight: 700;
186
- color: #cbd5e1;
187
- margin: 8px 0;
188
- }
189
- .input-area, .output-area {
190
- background: #1b1e28;
191
- border-radius: 12px;
192
- padding: 14px;
193
- box-shadow: inset 0 0 0 rgba(0,0,0,0.0);
194
- }
195
- .input-area { margin-bottom: 12px; }
196
- .output-area { margin-top: 6px; text-align: center; }
197
- #status-text {
198
- height: 1.2em;
199
- line-height: 1.2em;
200
- font-weight: 600;
201
- text-align: left;
202
- overflow: hidden; /* prevent scrollbars for single line */
203
- white-space: nowrap;
204
- }
205
- #output-image {
206
- display: flex;
207
- justify-content: center;
208
- align-items: center;
209
- }
210
- #output-image img {
211
- max-width: 100%;
212
- max-height: 420px;
213
- width: auto;
214
- height: auto;
215
- object-fit: contain;
216
- border-radius: 12px;
217
- background: #23252b;
218
- }
219
- .input-header { font-family: inherit; margin: 6px 0 6px; font-weight: 700; }
220
- .small { font-size: .9rem; color: var(--muted); }
221
  """
222
 
223
- # Layout
224
- with gr.Blocks(css=css_style) as app:
225
- # Header
226
- gr.HTML("""
227
- <div class='header-block'>
228
- <div class='header-gradient'>
229
- <h1 class='header-title'>🖼️ Image Editor <span style="font-size:1.1em;">(Powered by Gemini)</span> 🔮</h1>
230
- <div class='header-subtitle'>Edit images with AI, fast and simple.</div>
231
- </div>
232
- </div>
233
- """)
234
-
235
- with gr.Row():
236
- # Sidebar (instructions)
237
- with gr.Column(scale=3, elem_classes="sidebar"):
238
- gr.Markdown(
239
- """
240
- <h2>📖 How to Use</h2>
241
- <ul>
242
- <li>Step-by-step prompts guide the editing process.</li>
243
- <li>Upload a PNG image, enter a prompt, then generate.</li>
244
- <li>Keep your Gemini API key secure.</li>
245
- </ul>
246
- <hr>
247
- <h2>🔑 API Key</h2>
248
- <div>Get your key here: <a href="https://aistudio.google.com/apikey" target="_blank">Get your Google API key</a></div>
249
- """
250
- )
251
- # Main panel (steps and outputs)
252
- with gr.Column(scale=9, elem_classes="main-panel"):
253
- with gr.Column():
254
- # Step 1: Upload Image
255
- gr.Markdown("<div class='section-header'>Step 1: Upload Image</div>")
256
- image_input = gr.Image(type="pil", label=None, image_mode="RGBA")
257
 
258
- # Step 2: Prompt + API Key
259
- gr.Markdown("<div class='section-header'>Step 2: Enter Editing Prompt</div>")
260
- prompt_input = gr.Textbox(label="Edit Prompt", placeholder="Describe how to edit the image", lines=2)
261
- api_key_input = gr.Textbox(label="Gemini API Key (required)", placeholder="Enter your Gemini API key here", type="password")
262
 
263
- with gr.Row():
264
- submit_btn = gr.Button("Generate Edit", elem_classes="gradient-button")
265
- reset_btn = gr.Button("Reset Inputs")
266
 
267
- # Step 3: Output
268
- gr.Markdown("<div class='section-header'>Step 3: Image Output</div>")
269
- output_image = gr.Image(label=None, show_label=False, type="pil")
270
- status_text = gr.Textbox(label="Status", interactive=False, lines=1, elem_id="status-text")
271
 
272
- # Callback wiring
273
- def on_submit(pil_img, prompt, key, progress=None):
274
- if not key or key.strip() == "":
275
- raise gr.Error("Gemini API Key is required!")
276
- # progress: a function to update status text
277
- def update(msg):
278
- if progress:
279
- progress(msg)
280
- img, stat, _ = process_image_and_prompt(pil_img, prompt, key)
281
- update("Completed" if img is not None else stat)
282
- return img, stat
283
 
284
- submit_btn.click(
285
- fn=on_submit,
286
- inputs=[image_input, prompt_input, api_key_input],
287
- outputs=[output_image, status_text]
288
- )
 
 
 
 
 
 
 
289
 
290
- reset_btn.click(
291
- fn=reset_inputs,
292
- inputs=[api_key_input],
293
- outputs=[image_input, prompt_input, api_key_input, status_text]
294
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
 
296
- app.launch()
 
 
1
+ custom_css = """
2
+ /* Center main content and lock max width to 900px, with responsive shrink */
3
+ #main-app-area {
4
+ max-width: 900px;
5
+ margin-left: auto;
6
+ margin-right: auto;
7
+ padding: 0 16px;
8
+ }
9
+ /* Responsive for mobile (<950px) */
10
+ @media (max-width: 950px) {
11
+ #main-app-area {
12
+ max-width: 99vw;
13
+ padding: 0 2vw;
14
+ }
15
+ }
16
+ #app-title {
17
+ text-align: center;
18
+ font-size: 38px;
19
+ color: #53c9fc;
20
+ font-weight: bold;
21
+ padding-top: 12px;
22
+ }
23
+ #instructions {
24
+ text-align: center;
25
+ font-size: 19px;
26
+ margin: 14px 0 22px 0;
27
+ }
28
+ #generate-btn {
29
+ background: linear-gradient(90deg, #31b2fd 0%, #98f972 100%);
30
+ color: white;
31
+ font-size: 18px;
32
+ font-weight: bold;
33
+ border: none;
34
+ border-radius: 11px;
35
+ margin-top: 8px;
36
+ margin-bottom: 14px;
37
+ transition: 0.2s;
38
+ }
39
+ #generate-btn:hover {
40
+ filter: brightness(1.08);
41
+ box-shadow: 0 2px 16px #9efbc344;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  """
44
 
45
+ from transformers import BlipProcessor, BlipForConditionalGeneration
46
+ from ultralytics import YOLO
47
+ import torch
48
+ import gradio as gr
49
+ from PIL import Image
50
+ from collections import deque
51
+ import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
54
+ model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
55
+ detect_model = YOLO('yolov5s.pt')
 
56
 
57
+ MEMORY_SIZE = 10
58
+ last_images = deque([], maxlen=MEMORY_SIZE)
59
+ last_captions = deque([], maxlen=MEMORY_SIZE)
60
 
61
+ def preprocess_image(image):
62
+ if image.mode != "RGB":
63
+ image = image.convert("RGB")
64
+ return image
65
 
66
+ def detect_objects(image):
67
+ img_np = np.array(image)
68
+ results = detect_model(img_np)
69
+ detected_objs = set()
70
+ for r in results:
71
+ for box in r.boxes.data.tolist():
72
+ class_id = int(box[-1])
73
+ label = detect_model.names[class_id]
74
+ detected_objs.add(label)
75
+ return list(detected_objs)
 
76
 
77
+ def generate_caption(image):
78
+ image = preprocess_image(image)
79
+ inputs = processor(image, return_tensors="pt")
80
+ out = model.generate(**inputs, max_length=30, num_beams=5, early_stopping=True)
81
+ caption = processor.decode(out[0], skip_special_tokens=True)
82
+ detected_objs = detect_objects(image)
83
+ last_images.append(image)
84
+ last_captions.append(caption)
85
+ tags = ", ".join(detected_objs) if detected_objs else "None"
86
+ gallery = [(img, f"Detected objects: {tags}\nCaption: {caption}") for img, caption in zip(list(last_images), list(last_captions))]
87
+ result_text = f"Detected objects: {tags}\nCaption: {caption}"
88
+ return result_text, gallery
89
 
90
+ with gr.Blocks(css=custom_css) as iface:
91
+ gr.HTML('<div id="main-app-area">') # Start content region
92
+ gr.HTML('<div id="app-title">🖼️ Image Captioning with Object Detection</div>')
93
+ gr.HTML(
94
+ '<div id="instructions">'
95
+ '🙌 <b>Welcome!</b> Instantly analyze images using AI.<br>'
96
+ '1️⃣ <b>Upload</b> your image.<br>'
97
+ '2️⃣ Click <b>⭐ Generate Caption</b>.<br>'
98
+ '3️⃣ View and scroll through your history below.<br>'
99
+ '📜 <i>Last 10 results are stored for you.</i>'
100
+ '</div>'
101
+ )
102
+ image_input = gr.Image(type="pil", label="Upload Image")
103
+ generate_btn = gr.Button("⭐ Generate Caption", elem_id="generate-btn")
104
+ caption_output = gr.Textbox(label="📝 Caption and Detected Objects", lines=5, interactive=True)
105
+ gallery = gr.Gallery(label="Last 10 Images and Captions", scale=3)
106
+ def on_generate(image):
107
+ if image is None:
108
+ return "Please upload an image.", []
109
+ return generate_caption(image)
110
+ generate_btn.click(
111
+ fn=on_generate,
112
+ inputs=image_input,
113
+ outputs=[caption_output, gallery]
114
+ )
115
+ gr.HTML('</div>') # End content region
116
 
117
+ if __name__ == "__main__":
118
+ iface.launch()