Chyd19 commited on
Commit
9b6aa9b
·
verified ·
1 Parent(s): ee20230

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +382 -3
app.py CHANGED
@@ -13,6 +13,385 @@ import torchvision.transforms as T
13
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def free_gpu_cache():
17
  if device == "cuda":
18
  torch.cuda.empty_cache()
@@ -172,7 +551,7 @@ def compute_metrics(images, captions, i1, i2):
172
  # =========================
173
  def build_full_ui():
174
  with gr.Blocks(title="Multimodal AI Image Studio") as demo:
175
- gr.HTML("""
176
  <style>
177
  .heading-orange h2, .heading-orange h3 { color: #ff5500 !important; }
178
  .orange-btn button { background-color: #ff5500 !important; color: white !important; border-radius: 6px !important; height: 36px !important; font-weight: bold; }
@@ -183,7 +562,7 @@ def build_full_ui():
183
  .equal-height-row { display:flex; align-items:stretch; }
184
  .equal-height-row > .gr-column { display:flex; flex-direction:column; }
185
  </style>
186
- """)
187
 
188
  images_state = gr.State([None, None, None])
189
  captions_state = gr.State(["", "", ""])
@@ -279,7 +658,7 @@ def build_full_ui():
279
 
280
  # Launch
281
  demo = build_full_ui()
282
- demo.launch()
283
 
284
  """
285
  #Dumped code
 
13
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
+ def free_gpu_cache():
17
+ if device == "cuda":
18
+ torch.cuda.empty_cache()
19
+
20
+
21
+ # =========================
22
+ # MODELS
23
+ # =========================
24
+ # Image generation
25
+ gen_pipe = DiffusionPipeline.from_pretrained(
26
+ "stabilityai/sdxl-turbo",
27
+ torch_dtype=torch.float16 if device=="cuda" else torch.float32
28
+ ).to(device)
29
+
30
+ dreamshaper_pipe = DiffusionPipeline.from_pretrained(
31
+ "Lykon/dreamshaper-7",
32
+ torch_dtype=torch.float16 if device=="cuda" else torch.float32
33
+ ).to(device)
34
+
35
+ # Captioning
36
+ captioner = pipeline(
37
+ "image-to-text",
38
+ model="Salesforce/blip-image-captioning-large",
39
+ device=0 if device=="cuda" else -1,
40
+ generate_kwargs={"max_new_tokens":256, "num_beams":5, "temperature":0.7}
41
+ )
42
+
43
+ # NLP MODELS (UNCHANGED)
44
+ sentiment_model = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english",
45
+ device=0 if device=="cuda" else -1)
46
+ ner_model = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english",
47
+ aggregation_strategy="simple", device=0 if device=="cuda" else -1)
48
+ topic_model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli",
49
+ device=0 if device=="cuda" else -1)
50
+
51
+ # VQA – MOVED TO GPU (YOUR REQUEST OPTION B)
52
+ vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
53
+ vqa_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to(device)
54
+
55
+ # Metrics
56
+ clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
57
+ lpips_model = lpips.LPIPS(net='alex').to(device)
58
+ lpips_transform = T.Compose([T.ToTensor(), T.Resize((256,256))]
59
+
60
+ # Style presets
61
+ style_map = {
62
+ "Photorealistic": "photorealistic, ultra-detailed, 8k, cinematic lighting",
63
+ "Real Life": "natural lighting, true-to-life colors, DSLR",
64
+ "Documentary": "documentary handheld muted colors",
65
+ "iPhone Camera": "iPhone photo natural HDR",
66
+ "Street Photography": "candid street ambient shadows",
67
+ "Cinematic": "cinematic lighting dramatic depth",
68
+ "Anime": "anime cel shaded vibrant",
69
+ "Watercolor": "watercolor soft wash art",
70
+ "Macro": "macro lens shallow DOF",
71
+ "Cyberpunk": "neon cyberpunk futuristic",
72
+ }
73
+
74
+
75
+ # =========================
76
+ # IMAGE GENERATION FUNCTIONS
77
+ # =========================
78
+ def generate_image_with_enhancer(base_caption, enhancer, negative, seed, style, images):
79
+ base_caption = base_caption or ""
80
+ enhancer = enhancer or ""
81
+
82
+ final_prompt = f"{base_caption}, {enhancer}".strip(", ")
83
+ final_prompt = f"{final_prompt}, {style_map.get(style,'')}".strip(", ")
84
+
85
+ try:
86
+ seed = int(seed)
87
+ except:
88
+ seed = 42
89
+
90
+ generator = torch.Generator(device="cpu").manual_seed(seed)
91
+
92
+ try:
93
+ with torch.no_grad():
94
+ out = gen_pipe(prompt=final_prompt, negative_prompt=negative, generator=generator)
95
+ img = out.images[0]
96
+ except:
97
+ img = None
98
+
99
+ if img:
100
+ images[1] = img # store SD-Turbo at index 1
101
+
102
+ free_gpu_cache()
103
+ return img, images
104
+
105
+
106
+ def generate_dreamshaper_with_enhancer(base_caption, enhancer, negative, seed, style, images):
107
+ base_caption = base_caption or ""
108
+ enhancer = enhancer or ""
109
+
110
+ final_prompt = f"{base_caption}, {enhancer}".strip(", ")
111
+ final_prompt = f"{final_prompt}, {style_map.get(style,'')}".strip(", ")
112
+
113
+ try:
114
+ seed = int(seed)
115
+ except:
116
+ seed = 42
117
+
118
+ generator = torch.Generator(device="cpu").manual_seed(seed)
119
+
120
+ try:
121
+ with torch.no_grad():
122
+ out = dreamshaper_pipe(prompt=final_prompt, negative_prompt=negative, generator=generator)
123
+ img = out.images[0]
124
+ except:
125
+ img = None
126
+
127
+ if img:
128
+ images[2] = img # store DreamShaper at index 2
129
+
130
+ free_gpu_cache()
131
+ return img, images
132
+
133
+
134
+ # =========================
135
+ # CAPTIONING
136
+ # =========================
137
+ def caption_for_image(img):
138
+ try:
139
+ out = captioner(img)
140
+ return out[0]["generated_text"]
141
+ except:
142
+ return "Caption failed."
143
+
144
+
145
+ # =========================
146
+ # VQA (FIXED – now uses GPU + correct image)
147
+ # =========================
148
+ def answer_vqa(question, image):
149
+ if image is None or not question.strip():
150
+ return "Provide image + question."
151
+
152
+ try:
153
+ inputs_raw = vqa_processor(images=image, text=question, return_tensors="pt")
154
+ inputs = {k:v.to(device) for k,v in inputs_raw.items()}
155
+ with torch.no_grad():
156
+ out = vqa_model(**inputs)
157
+ ans_id = out.logits.argmax(-1)
158
+ return vqa_processor.decode(ans_id[0], skip_special_tokens=True)
159
+ except:
160
+ return "I could not determine the answer."
161
+
162
+
163
+ # =========================
164
+ # METRICS (UNCHANGED LOGIC, FIXED STATE)
165
+ # =========================
166
+ def compute_metrics(images, captions, i1, i2):
167
+ img1, img2 = images[i1], images[i2]
168
+ cap1, cap2 = captions[i1], captions[i2]
169
+
170
+ # CLIP
171
+ t1 = clip_preprocess(img1).unsqueeze(0).to(device)
172
+ t2 = clip_preprocess(img2).unsqueeze(0).to(device)
173
+ with torch.no_grad():
174
+ f1 = clip_model.encode_image(t1)
175
+ f2 = clip_model.encode_image(t2)
176
+ clip_sim = float(torch.cosine_similarity(f1, f2))
177
+
178
+ # LPIPS
179
+ L1 = (lpips_transform(img1).unsqueeze(0)*2 - 1).to(device)
180
+ L2 = (lpips_transform(img2).unsqueeze(0)*2 - 1).to(device)
181
+ with torch.no_grad():
182
+ lp = float(lpips_model(L1, L2))
183
+
184
+ # BERTScore
185
+ if cap1 and cap2:
186
+ _, _, F = score([cap1],[cap2], lang="en", verbose=False)
187
+ bert_f1 = float(F.mean())
188
+ else:
189
+ bert_f1 = 0.0
190
+
191
+ return clip_sim, lp, bert_f1
192
+
193
+
194
+ # =========================
195
+ # UI BUILD
196
+ # =========================
197
+ def build_full_ui():
198
+ with gr.Blocks(title="Multimodal AI Image Studio") as demo:
199
+
200
+ # YOUR CSS (UNCHANGED)
201
+ gr.HTML("""
202
+ <style>
203
+ .heading-orange h2, .heading-orange h3 { color: #ff5500 !important; }
204
+ .orange-btn button { background-color: #ff5500 !important; color: white !important; border-radius: 6px !important; height: 36px !important; font-weight: bold; }
205
+ .teal-btn button { background-color: #008080 !important; color: white !important; border-radius: 6px !important; height: 40px !important; font-weight:bold; }
206
+ .loading-line { height:4px; background: linear-gradient(90deg,#008080 0%,#00cccc 50%,#008080 100%); background-size:200% 100%; animation:loading 1s linear infinite; }
207
+ @keyframes loading { 0% { background-position:200% 0; } 100% { background-position:-200% 0; } }
208
+ .enhancer-box textarea { width:100%!important;height:36px!important;font-size:14px; }
209
+ </style>
210
+ """)
211
+
212
+ # States
213
+ images_state = gr.State([None, None, None])
214
+ captions_state = gr.State(["", "", ""])
215
+
216
+ # =========================
217
+ # Section 1: Upload Image
218
+ # =========================
219
+ gr.Markdown("## 1️⃣ Upload Reference Image", elem_classes="heading-orange")
220
+
221
+ with gr.Row():
222
+ with gr.Column():
223
+ upload_input = gr.Image(label="Drag & Drop Image", type="pil")
224
+ upload_btn = gr.Button("Upload Image & Generate Caption", elem_classes="orange-btn")
225
+ enhancer_box = gr.Textbox(label="Prompt Enhancer (Optional)", elem_classes="enhancer-box")
226
+
227
+ with gr.Column():
228
+ upload_preview = gr.Image(label="Uploaded Image")
229
+ caption_out = gr.Markdown()
230
+
231
+ def upload_and_caption(img, images_state, captions_state):
232
+ if img is None:
233
+ return None, "No image uploaded.", images_state, captions_state
234
+
235
+ images_state[0] = img
236
+ cap = caption_for_image(img)
237
+ captions_state[0] = cap
238
+ return img, cap, images_state, captions_state
239
+
240
+ upload_btn.click(upload_and_caption, [upload_input, images_state, captions_state],
241
+ [upload_preview, caption_out, images_state, captions_state])
242
+
243
+
244
+ # =========================
245
+ # Section 2: Generate Images
246
+ # =========================
247
+ gr.Markdown("## 2️⃣ Generate Images from Caption", elem_classes="heading-orange")
248
+
249
+ with gr.Row():
250
+ with gr.Column():
251
+ sd_btn = gr.Button("Generate SD-Turbo", elem_classes="orange-btn")
252
+ sd_preview = gr.Image(label="SD-Turbo Image")
253
+
254
+ with gr.Column():
255
+ ds_btn = gr.Button("Generate DreamShaper", elem_classes="orange-btn")
256
+ ds_preview = gr.Image(label="DreamShaper Image")
257
+
258
+ def generate_sd(caption, enhancer, images_state, captions_state):
259
+ img, images_state = generate_image_with_enhancer(caption, enhancer, "", 42, "Photorealistic", images_state)
260
+ if img:
261
+ captions_state[1] = caption_for_image(img)
262
+ return img, images_state, captions_state
263
+
264
+ def generate_ds(caption, enhancer, images_state, captions_state):
265
+ img, images_state = generate_dreamshaper_with_enhancer(caption, enhancer, "", 123, "Photorealistic", images_state)
266
+ if img:
267
+ captions_state[2] = caption_for_image(img)
268
+ return img, images_state, captions_state
269
+
270
+ sd_btn.click(generate_sd, [caption_out, enhancer_box, images_state, captions_state],
271
+ [sd_preview, images_state, captions_state])
272
+
273
+ ds_btn.click(generate_ds, [caption_out, enhancer_box, images_state, captions_state],
274
+ [ds_preview, images_state, captions_state])
275
+
276
+
277
+ # =========================
278
+ # Section 3: Metrics
279
+ # =========================
280
+ gr.Markdown("## 3️⃣ Compute Pairwise Metrics", elem_classes="heading-orange")
281
+
282
+ metrics_btn = gr.Button("Compute Metrics", elem_classes="teal-btn")
283
+ metrics_spinner = gr.HTML()
284
+ metrics_out = gr.HTML()
285
+
286
+ def compute_metrics_ui(images, captions):
287
+ yield "<div class='loading-line'></div>", ""
288
+
289
+ if None in images:
290
+ yield "", "<b>All three images and captions are required.</b>"
291
+ return
292
+
293
+ A = compute_metrics(images, captions, 0, 1)
294
+ B = compute_metrics(images, captions, 0, 2)
295
+ C = compute_metrics(images, captions, 1, 2)
296
+
297
+ def fmt(m):
298
+ return f"CLIP: {m[0]:.3f}<br>LPIPS: {m[1]:.3f}<br>BERTScore: {m[2]:.3f}"
299
+
300
+ html = f"""
301
+ <div style='display:flex; gap:40px; justify-content:space-around;'>
302
+ <div><b>Metrics A<br>(Ref ↔ SD)</b><br>{fmt(A)}</div>
303
+ <div><b>Metrics B<br>(Ref ↔ DS)</b><br>{fmt(B)}</div>
304
+ <div><b>Metrics C<br>(SD ↔ DS)</b><br>{fmt(C)}</div>
305
+ </div>
306
+ """
307
+
308
+ yield "", html
309
+
310
+ metrics_btn.click(compute_metrics_ui, [images_state, captions_state],
311
+ [metrics_spinner, metrics_out])
312
+
313
+
314
+ # =========================
315
+ # Section 4: NLP (UNCHANGED)
316
+ # =========================
317
+ gr.Markdown("## 4️⃣ NLP Analysis of Captions", elem_classes="heading-orange")
318
+
319
+ nlp_btn = gr.Button("Analyze Captions", elem_classes="teal-btn")
320
+ nlp_spinner = gr.HTML()
321
+ nlp_out = gr.HTML()
322
+
323
+ def analyze_captions_ui(captions):
324
+ yield "<div class='loading-line'></div>", ""
325
+
326
+ if any(c == "" for c in captions):
327
+ yield "", "<b>All three captions required.</b>"
328
+ return
329
+
330
+ labels = ["Reference", "SD-Turbo", "DreamShaper"]
331
+ blocks = []
332
+
333
+ for label, caption in zip(labels, captions):
334
+ sentiment = "<br>".join([f"{s['label']}: {s['score']:.2f}" for s in sentiment_model(caption)])
335
+ ents_list = ner_model(caption)
336
+ ents = "<br>".join([f"{e['entity_group']}: {e['word']}" for e in ents_list]) or "None"
337
+
338
+ topics_data = topic_model(caption, candidate_labels=['people','animals','objects','food','nature'])
339
+ topics = "<br>".join([f"{l}: {sc:.2f}" for l, sc in zip(topics_data['labels'], topics_data['scores'])])
340
+
341
+ block = f"""
342
+ <div style='flex:1; padding:10px; min-width:250px;'>
343
+ <h3><u>{label}</u></h3>
344
+ <b>Sentiment</b><br>{sentiment}<br><br>
345
+ <b>Entities</b><br>{ents}<br><br>
346
+ <b>Topics</b><br>{topics}
347
+ </div>
348
+ """
349
+ blocks.append(block)
350
+
351
+ yield "", f"<div style='display:flex; gap:20px;'>{''.join(blocks)}</div>"
352
+
353
+ nlp_btn.click(analyze_captions_ui, [captions_state], [nlp_spinner, nlp_out])
354
+
355
+
356
+ # =========================
357
+ # Section 5: VQA (FIXED)
358
+ # =========================
359
+ gr.Markdown("## 5️⃣ Visual Question Answering (VQA)", elem_classes="heading-orange")
360
+
361
+ vqa_input = gr.Textbox(label="Enter a question about the reference image")
362
+ vqa_btn = gr.Button("Get Answer", elem_classes="teal-btn")
363
+ vqa_spinner = gr.HTML()
364
+ vqa_out = gr.Markdown()
365
+
366
+ def vqa_ui(question, images_state):
367
+ yield "<div class='loading-line'></div>", ""
368
+ ref_img = images_state[0]
369
+ ans = answer_vqa(question, ref_img)
370
+ yield "", f"**Answer:** {ans}"
371
+
372
+ vqa_btn.click(vqa_ui, [vqa_input, images_state], [vqa_spinner, vqa_out])
373
+
374
+ return demo
375
+
376
+
377
+ demo = build_full_ui()
378
+ demo.launch()
379
+ """
380
+ # =========================
381
+ # LIBRARIES & DEVICE SETUP
382
+ # =========================
383
+ import torch
384
+ import gradio as gr
385
+ from PIL import Image
386
+ from diffusers import DiffusionPipeline
387
+ from transformers import pipeline, BlipProcessor, BlipForQuestionAnswering
388
+ import lpips
389
+ import clip
390
+ from bert_score import score
391
+ import torchvision.transforms as T
392
+
393
+ device = "cuda" if torch.cuda.is_available() else "cpu"
394
+
395
  def free_gpu_cache():
396
  if device == "cuda":
397
  torch.cuda.empty_cache()
 
551
  # =========================
552
  def build_full_ui():
553
  with gr.Blocks(title="Multimodal AI Image Studio") as demo:
554
+ gr.HTML(
555
  <style>
556
  .heading-orange h2, .heading-orange h3 { color: #ff5500 !important; }
557
  .orange-btn button { background-color: #ff5500 !important; color: white !important; border-radius: 6px !important; height: 36px !important; font-weight: bold; }
 
562
  .equal-height-row { display:flex; align-items:stretch; }
563
  .equal-height-row > .gr-column { display:flex; flex-direction:column; }
564
  </style>
565
+ )
566
 
567
  images_state = gr.State([None, None, None])
568
  captions_state = gr.State(["", "", ""])
 
658
 
659
  # Launch
660
  demo = build_full_ui()
661
+ demo.launch()"""
662
 
663
  """
664
  #Dumped code