Chyd19 commited on
Commit
afb1906
·
verified ·
1 Parent(s): 936d869

my app.py

Browse files
Files changed (1) hide show
  1. app.py +373 -0
app.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Srction 1
2
+ # ==============================
3
+ # SECTION 1
4
+ # ==============================
5
+
6
+ # Libraries
7
+ import torch
8
+ import gradio as gr
9
+ from PIL import Image
10
+ from diffusers import DiffusionPipeline
11
+ from transformers import pipeline, BlipProcessor, BlipForQuestionAnswering
12
+ import lpips
13
+ import clip
14
+ from bert_score import score
15
+ import torchvision.transforms as T
16
+
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+ def free_gpu_cache():
20
+ if device == "cuda":
21
+ torch.cuda.empty_cache()
22
+
23
+ # ==============================
24
+ # MODELS
25
+ # ==============================
26
+ gen_pipe = DiffusionPipeline.from_pretrained(
27
+ "stabilityai/sdxl-turbo",
28
+ torch_dtype=torch.float16 if device=="cuda" else torch.float32
29
+ ).to(device)
30
+
31
+ dreamshaper_pipe = DiffusionPipeline.from_pretrained(
32
+ "Lykon/dreamshaper-7",
33
+ torch_dtype=torch.float16 if device=="cuda" else torch.float32
34
+ ).to(device)
35
+
36
+ captioner = pipeline(
37
+ "image-to-text",
38
+ model="Salesforce/blip-image-captioning-large",
39
+ device=0 if device=="cuda" else -1,
40
+ generate_kwargs={"max_new_tokens":256, "num_beams":5, "temperature":0.7}
41
+ )
42
+
43
+ sentiment_model = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english",
44
+ device=0 if device=="cuda" else -1)
45
+ ner_model = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english",
46
+ aggregation_strategy="simple", device=0 if device=="cuda" else -1)
47
+ topic_model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli",
48
+ device=0 if device=="cuda" else -1)
49
+
50
+ vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
51
+ vqa_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to("cpu")
52
+
53
+ clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
54
+ lpips_model = lpips.LPIPS(net='alex').to(device)
55
+ lpips_transform = T.Compose([T.ToTensor(), T.Resize((256,256))])
56
+
57
+ style_map = {
58
+ "Photorealistic": "photorealistic, ultra-detailed, 8k, cinematic lighting",
59
+ "Real Life": "natural lighting, true-to-life colors, DSLR",
60
+ "Documentary": "documentary handheld muted colors",
61
+ "iPhone Camera": "iPhone photo natural HDR",
62
+ "Street Photography": "candid street ambient shadows",
63
+ "Cinematic": "cinematic lighting dramatic depth",
64
+ "Anime": "anime cel shaded vibrant",
65
+ "Watercolor": "watercolor soft wash art",
66
+ "Macro": "macro lens shallow DOF",
67
+ "Cyberpunk": "neon cyberpunk futuristic",
68
+ }
69
+ # Section 2
70
+ # ==============================
71
+ # SECTION 2 — FUNCTIONS
72
+ # ==============================
73
+ def generate_image_with_enhancer(base_caption, enhancer, negative, seed, style, images):
74
+ images = images or []
75
+ base_caption = base_caption or ""
76
+ enhancer = enhancer or ""
77
+
78
+ final_prompt = f"{base_caption}, {enhancer}".strip(", ")
79
+ final_prompt = f"{final_prompt}, {style_map.get(style,'')}".strip(", ")
80
+
81
+ try:
82
+ seed = int(seed)
83
+ except:
84
+ seed = 42
85
+
86
+ generator = torch.Generator(device="cpu").manual_seed(seed)
87
+
88
+ try:
89
+ with torch.no_grad():
90
+ out = gen_pipe(prompt=final_prompt, negative_prompt=negative, generator=generator)
91
+ img = out.images[0]
92
+ except Exception as e:
93
+ print("SD Turbo failed:", e)
94
+ img = None
95
+
96
+ if img:
97
+ images.append(img)
98
+
99
+ free_gpu_cache()
100
+ return img, images
101
+
102
+ def generate_dreamshaper_with_enhancer(base_caption, enhancer, negative, seed, style, images):
103
+ images = images or []
104
+ base_caption = base_caption or ""
105
+ enhancer = enhancer or ""
106
+
107
+ final_prompt = f"{base_caption}, {enhancer}".strip(", ")
108
+ final_prompt = f"{final_prompt}, {style_map.get(style,'')}".strip(", ")
109
+
110
+ try:
111
+ seed = int(seed)
112
+ except:
113
+ seed = 42
114
+
115
+ generator = torch.Generator(device="cpu").manual_seed(seed)
116
+
117
+ try:
118
+ with torch.no_grad():
119
+ out = dreamshaper_pipe(prompt=final_prompt, negative_prompt=negative, generator=generator)
120
+ img = out.images[0]
121
+ except Exception as e:
122
+ print("DreamShaper failed:", e)
123
+ img = None
124
+
125
+ if img:
126
+ images.append(img)
127
+
128
+ free_gpu_cache()
129
+ return img, images
130
+
131
+ def caption_for_image(img):
132
+ try:
133
+ out = captioner(img)
134
+ return out[0]["generated_text"]
135
+ except:
136
+ return "Caption failed."
137
+
138
+ def answer_vqa(question, image):
139
+ if not image or not question.strip():
140
+ return "Provide image + question."
141
+ try:
142
+ inputs_raw = vqa_processor(images=image, text=question, return_tensors="pt")
143
+ inputs = {k:v.to("cpu") for k,v in inputs_raw.items()}
144
+ with torch.no_grad():
145
+ out = vqa_model(**inputs)
146
+ ans_id = out.logits.argmax(-1)
147
+ return vqa_processor.decode(ans_id[0], skip_special_tokens=True)
148
+ except:
149
+ return "VQA failed."
150
+
151
+ def compute_metrics(images, captions, i1, i2):
152
+ img1 = images[i1]
153
+ img2 = images[i2]
154
+ cap1 = captions[i1]
155
+ cap2 = captions[i2]
156
+
157
+ # CLIP
158
+ t1 = clip_preprocess(img1).unsqueeze(0).to("cpu")
159
+ t2 = clip_preprocess(img2).unsqueeze(0).to("cpu")
160
+ with torch.no_grad():
161
+ f1 = clip_model.encode_image(t1)
162
+ f2 = clip_model.encode_image(t2)
163
+ clip_sim = float(torch.cosine_similarity(f1, f2))
164
+
165
+ # LPIPS
166
+ L1 = (lpips_transform(img1).unsqueeze(0)*2 - 1)
167
+ L2 = (lpips_transform(img2).unsqueeze(0)*2 - 1)
168
+ with torch.no_grad():
169
+ lp = float(lpips_model(L1, L2))
170
+
171
+ # BERTScore
172
+ if cap1 and cap2:
173
+ _, _, F = score([cap1],[cap2], lang="en", verbose=False)
174
+ bert_f1 = float(F.mean())
175
+ else:
176
+ bert_f1 = 0.0
177
+
178
+ return clip_sim, lp, bert_f1
179
+
180
+ # Section 3
181
+ # ---------------- Build Gradio UI with Custom Look ----------------
182
+ def build_ui_with_custom_ui():
183
+ with gr.Blocks(title="Multimodal AI Image Studio") as demo:
184
+ # ---------------- CSS Styling ----------------
185
+ gr.HTML("""
186
+ <style>
187
+ .heading-orange h2, .heading-orange h3 { color: #ff5500 !important; }
188
+ .orange-btn button { background-color: #ff5500 !important; color: white !important; border-radius: 6px !important; height: 36px !important; font-weight: bold; }
189
+ .teal-btn button { background-color: #008080 !important; color: white !important; border-radius: 6px !important; height: 40px !important; font-weight: bold; }
190
+
191
+ /* Horizontal thin spinner */
192
+ .loading-line {
193
+ height: 4px;
194
+ background: linear-gradient(90deg, #008080 0%, #00cccc 50%, #008080 100%);
195
+ background-size: 200% 100%;
196
+ animation: loading 1s linear infinite;
197
+ }
198
+ @keyframes loading {
199
+ 0% { background-position: 200% 0; }
200
+ 100% { background-position: -200% 0; }
201
+ }
202
+
203
+ /* Match enhancer box to upload button */
204
+ .enhancer-box textarea {
205
+ width: 100% !important;
206
+ height: 36px !important;
207
+ box-sizing: border-box;
208
+ font-size: 14px;
209
+ }
210
+
211
+ /* Equal-height styling for Step-1 columns */
212
+ .equal-height-row {
213
+ display: flex;
214
+ align-items: stretch;
215
+ }
216
+ .equal-height-row > .gr-column {
217
+ display: flex;
218
+ flex-direction: column;
219
+ }
220
+ </style>
221
+ """)
222
+
223
+ # ---------------- Heading ----------------
224
+ gr.Markdown("## Multimodal AI Image Studio: An Integrated Comparative Perspective", elem_classes="heading-orange")
225
+
226
+ # ---------------- States ----------------
227
+ images_state = gr.State([])
228
+ captions_state = gr.State([])
229
+
230
+ # ---------------- Step 1: Upload Reference Image ----------------
231
+ gr.Markdown("### Upload Reference Image", elem_classes="heading-orange")
232
+
233
+ # ✅ APPLY equal-height class here
234
+ with gr.Row(elem_classes="equal-height-row"):
235
+ with gr.Column(scale=1):
236
+ upload_input = gr.Image(label="Drag & Drop Image", type="pil")
237
+ upload_btn = gr.Button("Upload Image & Generate Caption", elem_classes="orange-btn")
238
+ with gr.Column(scale=1):
239
+ upload_preview = gr.Image(label="Uploaded Image", interactive=False)
240
+ enhancer_box = gr.Textbox(
241
+ label="Add Prompt Enhancer (Optional)",
242
+ placeholder="Example: 'at night with neon lights', 'wearing a red jacket', etc.",
243
+ elem_classes="enhancer-box"
244
+ )
245
+ caption_out = gr.Markdown(label="Generated Caption")
246
+
247
+ # Robust captioning
248
+ def upload_and_generate_caption_ui(img, images_state, captions_state):
249
+ if img is None:
250
+ return None, "No image uploaded.", [], []
251
+
252
+ images = [img]
253
+ try:
254
+ output = captioner(img)
255
+ caption = output[0]["generated_text"] if len(output) > 0 and "generated_text" in output[0] else "Caption failed."
256
+ except Exception as e:
257
+ print("Captioning error:", e)
258
+ caption = "Caption failed."
259
+
260
+ captions = [caption]
261
+ return img, caption, images, captions
262
+
263
+ upload_btn.click(
264
+ upload_and_generate_caption_ui,
265
+ inputs=[upload_input, images_state, captions_state],
266
+ outputs=[upload_preview, caption_out, images_state, captions_state]
267
+ )
268
+
269
+ # ---------------- Step 2: Generate SD-Turbo & DreamShaper ----------------
270
+ gr.Markdown("### Generate Images from Caption", elem_classes="heading-orange")
271
+ with gr.Row():
272
+ with gr.Column(scale=1, min_width=300):
273
+ sd_btn = gr.Button("Generate SD-Turbo Image", elem_classes="orange-btn")
274
+ sd_preview = gr.Image(label="SD-Turbo Image", interactive=False)
275
+ with gr.Column(scale=1, min_width=300):
276
+ ds_btn = gr.Button("Generate DreamShaper Image", elem_classes="orange-btn")
277
+ ds_preview = gr.Image(label="DreamShaper Image", interactive=False)
278
+
279
+ def generate_sd_from_caption_ui(caption, enhancer, images_state, captions_state):
280
+ final_prompt = f"{caption}, {enhancer}".strip(", ")
281
+ img, images = generate_image_with_enhancer(final_prompt, enhancer="", negative="", seed=42, style="Photorealistic", images=images_state)
282
+ try:
283
+ generated_caption = captioner(img)[0]["generated_text"]
284
+ except:
285
+ generated_caption = "Caption failed."
286
+ captions_state[1:2] = [generated_caption]
287
+ return img, images, captions_state
288
+
289
+ def generate_ds_from_caption_ui(caption, enhancer, images_state, captions_state):
290
+ final_prompt = f"{caption}, {enhancer}".strip(", ")
291
+ img, images = generate_dreamshaper_with_enhancer(final_prompt, enhancer="", negative="", seed=123, style="Photorealistic", images=images_state)
292
+ try:
293
+ generated_caption = captioner(img)[0]["generated_text"]
294
+ except:
295
+ generated_caption = "Caption failed."
296
+ captions_state[2:3] = [generated_caption]
297
+ return img, images, captions_state
298
+
299
+ sd_btn.click(generate_sd_from_caption_ui, inputs=[caption_out, enhancer_box, images_state, captions_state],
300
+ outputs=[sd_preview, images_state, captions_state])
301
+ ds_btn.click(generate_ds_from_caption_ui, inputs=[caption_out, enhancer_box, images_state, captions_state],
302
+ outputs=[ds_preview, images_state, captions_state])
303
+
304
+ # ---------------- Step 3: Compute Pairwise Metrics ----------------
305
+ gr.Markdown("### Compute Pairwise Metrics", elem_classes="heading-orange")
306
+ metrics_btn = gr.Button("Compute Metrics for All Pairs", elem_classes="teal-btn")
307
+ with gr.Row():
308
+ metrics_A = gr.Markdown()
309
+ metrics_B = gr.Markdown()
310
+ metrics_C = gr.Markdown()
311
+
312
+ def compute_metrics_all_pairs_ui(images, captions):
313
+ yield ("<div class='loading-line'></div>", "<div class='loading-line'></div>", "<div class='loading-line'></div>")
314
+ if len(images) < 3:
315
+ msg = "All three images and captions are required to compute metrics."
316
+ yield msg, msg, msg
317
+ else:
318
+ A = compute_metrics(images, captions, 0, 1)
319
+ B = compute_metrics(images, captions, 0, 2)
320
+ C = compute_metrics(images, captions, 1, 2)
321
+ yield (f"**Reference ↔ SD-Turbo**\n{A}",
322
+ f"**Reference ↔ DreamShaper**\n{B}",
323
+ f"**SD-Turbo ↔ DreamShaper**\n{C}")
324
+
325
+ metrics_btn.click(compute_metrics_all_pairs_ui, inputs=[images_state, captions_state],
326
+ outputs=[metrics_A, metrics_B, metrics_C])
327
+
328
+ # ---------------- Step 4: NLP Analysis ----------------
329
+ gr.Markdown("### NLP Analysis of Captions", elem_classes="heading-orange")
330
+ nlp_btn = gr.Button("Analyze Captions", elem_classes="teal-btn")
331
+ nlp_out = gr.HTML()
332
+
333
+ def analyze_caption_pipeline_ui(captions):
334
+ yield "<div class='loading-line'></div>"
335
+ if len(captions) < 3:
336
+ yield "<b>All three captions are required for NLP analysis.</b>"
337
+ else:
338
+ labels = ["Reference Image", "SD-Turbo", "DreamShaper"]
339
+ blocks = []
340
+ for label, caption in zip(labels, captions):
341
+ sentiment = "<br>".join([f"{s['label']}: {s['score']:.2f}" for s in sentiment_model(caption)])
342
+ ents = "<br>".join([f"{e['entity_group']}: {e['word']}" for e in ner_model(caption)]) or "None"
343
+ topics_data = topic_model(caption, candidate_labels=['people','animals','objects','food','nature'])
344
+ topics = "<br>".join([f"{l}: {sc:.2f}" for l, sc in zip(topics_data['labels'], topics_data['scores'])])
345
+ block = f"<div style='flex:1;padding:10px;min-width:250px;'><h3><u>{label}</u></h3><b>Sentiment</b><br>{sentiment}<br><br><b>Entities</b><br>{ents}<br><br><b>Topics</b><br>{topics}</div>"
346
+ blocks.append(block)
347
+ yield f"<div style='display:flex; gap:20px; justify-content:space-between;'>{''.join(blocks)}</div>"
348
+
349
+ nlp_btn.click(analyze_caption_pipeline_ui, inputs=[captions_state], outputs=[nlp_out])
350
+
351
+ # ---------------- Step 5: Visual Question Answering ----------------
352
+ gr.Markdown("### Visual Question Answering (VQA)", elem_classes="heading-orange")
353
+ with gr.Row():
354
+ with gr.Column(scale=1):
355
+ vqa_input = gr.Textbox(label="Enter a question about the reference image")
356
+ vqa_btn = gr.Button("Get Answer", elem_classes="teal-btn")
357
+ with gr.Column(scale=1):
358
+ vqa_out = gr.Markdown(label="VQA Output")
359
+
360
+ def answer_vqa_ui(question, image):
361
+ yield "<div class='loading-line'></div>"
362
+ ans = answer_vqa(question, image)
363
+ yield ans
364
+
365
+ vqa_btn.click(answer_vqa_ui, inputs=[vqa_input, upload_preview], outputs=[vqa_out])
366
+
367
+ return demo
368
+
369
+ # Launch the interface
370
+ demo = build_ui_with_custom_ui()
371
+ demo.launch()
372
+
373
+