Chyd19 commited on
Commit
8ddb2d1
·
verified ·
1 Parent(s): 24b980d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +260 -6
app.py CHANGED
@@ -1,4 +1,259 @@
1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  # ==============================
3
  # SECTION 1
4
  # ==============================
@@ -181,7 +436,7 @@ def compute_metrics(images, captions, i1, i2):
181
  def build_ui_with_custom_ui():
182
  with gr.Blocks(title="Multimodal AI Image Studio") as demo:
183
  # ---------------- CSS Styling ----------------
184
- gr.HTML("""
185
  <style>
186
  .heading-orange h2, .heading-orange h3 { color: #ff5500 !important; }
187
  .orange-btn button { background-color: #ff5500 !important; color: white !important; border-radius: 6px !important; height: 36px !important; font-weight: bold; }
@@ -217,7 +472,7 @@ def build_ui_with_custom_ui():
217
  flex-direction: column;
218
  }
219
  </style>
220
- """)
221
 
222
  # ---------------- Heading ----------------
223
  gr.Markdown("## Multimodal AI Image Studio: An Integrated Comparative Perspective", elem_classes="heading-orange")
@@ -404,8 +659,8 @@ def build_ui_with_custom_ui():
404
  demo = build_ui_with_custom_ui()
405
  demo.launch()
406
 
407
-
408
- """
409
  # Section 3
410
  # ---------------- Build Gradio UI with Custom Look ----------------
411
  def build_ui_with_custom_ui():
@@ -597,6 +852,5 @@ def build_ui_with_custom_ui():
597
 
598
  # Launch the interface
599
  demo = build_ui_with_custom_ui()
600
- demo.launch()
601
- """
602
 
 
1
 
2
+ # ==============================
3
+ # Libraries
4
+ # ==============================
5
+ import torch
6
+ import gradio as gr
7
+ from PIL import Image
8
+ from diffusers import DiffusionPipeline
9
+ from transformers import pipeline, BlipProcessor, BlipForQuestionAnswering
10
+ import lpips
11
+ import clip
12
+ from bert_score import score
13
+ import torchvision.transforms as T
14
+
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+
17
+ def free_gpu_cache():
18
+ if device == "cuda":
19
+ torch.cuda.empty_cache()
20
+
21
+ # ==============================
22
+ # Load Models (HF-ready, memory safe)
23
+ # ==============================
24
+ # SDXL-Turbo
25
+ gen_pipe = DiffusionPipeline.from_pretrained(
26
+ "stabilityai/sdxl-turbo",
27
+ torch_dtype=torch.float16 if device=="cuda" else torch.float32
28
+ ).to(device)
29
+
30
+ # DreamShaper
31
+ dreamshaper_pipe = DiffusionPipeline.from_pretrained(
32
+ "Lykon/dreamshaper-7",
33
+ torch_dtype=torch.float16 if device=="cuda" else torch.float32
34
+ ).to(device)
35
+
36
+ # BLIP Captioning
37
+ captioner = pipeline(
38
+ "image-to-text",
39
+ model="Salesforce/blip-image-captioning-large",
40
+ device=0 if device=="cuda" else -1,
41
+ generate_kwargs={"max_new_tokens":256, "num_beams":5, "temperature":0.7}
42
+ )
43
+
44
+ # Sentiment / NER / Topic
45
+ sentiment_model = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english",
46
+ device=0 if device=="cuda" else -1)
47
+ ner_model = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english",
48
+ aggregation_strategy="simple", device=0 if device=="cuda" else -1)
49
+ topic_model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli",
50
+ device=0 if device=="cuda" else -1)
51
+
52
+ # BLIP VQA
53
+ vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
54
+ vqa_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to("cpu")
55
+
56
+ # CLIP / LPIPS
57
+ clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
58
+ lpips_model = lpips.LPIPS(net='alex').to(device)
59
+ lpips_transform = T.Compose([T.ToTensor(), T.Resize((256,256))])
60
+
61
+ # Style map
62
+ style_map = {
63
+ "Photorealistic": "photorealistic, ultra-detailed, 8k, cinematic lighting",
64
+ "Real Life": "natural lighting, true-to-life colors, DSLR",
65
+ "Documentary": "documentary handheld muted colors",
66
+ "iPhone Camera": "iPhone photo natural HDR",
67
+ "Street Photography": "candid street ambient shadows",
68
+ "Cinematic": "cinematic lighting dramatic depth",
69
+ "Anime": "anime cel shaded vibrant",
70
+ "Watercolor": "watercolor soft wash art",
71
+ "Macro": "macro lens shallow DOF",
72
+ "Cyberpunk": "neon cyberpunk futuristic",
73
+ }
74
+
75
+ # ==============================
76
+ # Functions
77
+ # ==============================
78
+
79
+ def generate_image(pipe, caption, enhancer, negative, seed, style):
80
+ final_prompt = f"{caption}, {enhancer}".strip(", ")
81
+ final_prompt = f"{final_prompt}, {style_map.get(style,'')}".strip(", ")
82
+
83
+ try:
84
+ seed = int(seed)
85
+ except:
86
+ seed = 42
87
+
88
+ generator = torch.Generator(device="cpu").manual_seed(seed)
89
+ img = None
90
+
91
+ try:
92
+ with torch.no_grad():
93
+ out = pipe(prompt=final_prompt, negative_prompt=negative, generator=generator, height=512, width=512)
94
+ img = out.images[0]
95
+ except Exception as e:
96
+ print(f"{pipe} generation failed:", e)
97
+
98
+ free_gpu_cache()
99
+ return img
100
+
101
+ def caption_for_image(img):
102
+ try:
103
+ out = captioner(img)
104
+ return out[0]["generated_text"]
105
+ except:
106
+ return "Caption failed."
107
+
108
+ def compute_metrics(images, captions, i1, i2):
109
+ img1, img2 = images[i1], images[i2]
110
+ cap1, cap2 = captions[i1], captions[i2]
111
+
112
+ # CLIP similarity
113
+ t1, t2 = clip_preprocess(img1).unsqueeze(0).to(device), clip_preprocess(img2).unsqueeze(0).to(device)
114
+ with torch.no_grad():
115
+ f1, f2 = clip_model.encode_image(t1), clip_model.encode_image(t2)
116
+ clip_sim = float(torch.cosine_similarity(f1, f2))
117
+
118
+ # LPIPS
119
+ L1 = (lpips_transform(img1).unsqueeze(0)*2 - 1).to(device)
120
+ L2 = (lpips_transform(img2).unsqueeze(0)*2 - 1).to(device)
121
+ with torch.no_grad():
122
+ lp = float(lpips_model(L1, L2))
123
+
124
+ # BERTScore
125
+ if cap1 and cap2:
126
+ _, _, F = score([cap1],[cap2], lang="en", verbose=False)
127
+ bert_f1 = float(F.mean())
128
+ else:
129
+ bert_f1 = 0.0
130
+
131
+ return clip_sim, lp, bert_f1
132
+
133
+ def answer_vqa(question, image):
134
+ if not image or not question.strip():
135
+ return "Provide image + question."
136
+ try:
137
+ inputs_raw = vqa_processor(images=image, text=question, return_tensors="pt")
138
+ inputs = {k:v.to("cpu") for k,v in inputs_raw.items()}
139
+ with torch.no_grad():
140
+ out = vqa_model(**inputs)
141
+ ans_id = out.logits.argmax(-1)
142
+ return vqa_processor.decode(ans_id[0], skip_special_tokens=True)
143
+ except:
144
+ return "I could not determine the answer."
145
+
146
+ # ==============================
147
+ # Gradio UI
148
+ # ==============================
149
+ def build_ui():
150
+ with gr.Blocks(title="Multimodal AI Image Studio") as demo:
151
+ images_state = gr.State([None, None, None])
152
+ captions_state = gr.State(["", "", ""])
153
+
154
+ gr.Markdown("## Multimodal AI Image Studio (HF-ready)")
155
+
156
+ # --- Step 1: Upload Reference ---
157
+ upload_input = gr.Image(label="Upload Reference Image", type="pil")
158
+ upload_btn = gr.Button("Upload & Caption")
159
+ upload_preview = gr.Image(interactive=False)
160
+ caption_out = gr.Markdown()
161
+
162
+ def upload_and_caption(img, images_state, captions_state):
163
+ if img is None:
164
+ return None, "No image uploaded.", images_state, captions_state
165
+ caption = caption_for_image(img)
166
+ images_state[0] = img
167
+ captions_state[0] = caption
168
+ return img, caption, images_state, captions_state
169
+
170
+ upload_btn.click(upload_and_caption, inputs=[upload_input, images_state, captions_state],
171
+ outputs=[upload_preview, caption_out, images_state, captions_state])
172
+
173
+ # --- Step 2: Generate SDXL & DreamShaper ---
174
+ sd_btn = gr.Button("Generate SD-Turbo")
175
+ ds_btn = gr.Button("Generate DreamShaper")
176
+ sd_preview = gr.Image(interactive=False)
177
+ ds_preview = gr.Image(interactive=False)
178
+
179
+ def gen_sd(caption, images_state, captions_state):
180
+ img = generate_image(gen_pipe, caption, enhancer="", negative="", seed=42, style="Photorealistic")
181
+ if img:
182
+ images_state[1] = img
183
+ captions_state[1] = caption_for_image(img)
184
+ return img, images_state, captions_state
185
+
186
+ def gen_ds(caption, images_state, captions_state):
187
+ img = generate_image(dreamshaper_pipe, caption, enhancer="", negative="", seed=123, style="Photorealistic")
188
+ if img:
189
+ images_state[2] = img
190
+ captions_state[2] = caption_for_image(img)
191
+ return img, images_state, captions_state
192
+
193
+ sd_btn.click(gen_sd, inputs=[caption_out, images_state, captions_state],
194
+ outputs=[sd_preview, images_state, captions_state])
195
+ ds_btn.click(gen_ds, inputs=[caption_out, images_state, captions_state],
196
+ outputs=[ds_preview, images_state, captions_state])
197
+
198
+ # --- Step 3: Metrics ---
199
+ metrics_btn = gr.Button("Compute Metrics")
200
+ metrics_out = gr.Markdown()
201
+
202
+ def metrics_ui(images_state, captions_state):
203
+ imgs = images_state or []
204
+ caps = captions_state or []
205
+ if None in imgs or "" in caps:
206
+ return "All three images and captions are required."
207
+ A = compute_metrics(imgs, caps, 0, 1)
208
+ B = compute_metrics(imgs, caps, 0, 2)
209
+ C = compute_metrics(imgs, caps, 1, 2)
210
+ return f"Reference ↔ SD-Turbo: {A}\nReference ↔ DreamShaper: {B}\nSD-Turbo ↔ DreamShaper: {C}"
211
+
212
+ metrics_btn.click(metrics_ui, inputs=[images_state, captions_state], outputs=[metrics_out])
213
+
214
+ # --- Step 4: NLP ---
215
+ nlp_btn = gr.Button("Analyze Captions")
216
+ nlp_out = gr.HTML()
217
+
218
+ def analyze_nlp(captions_state):
219
+ caps = captions_state or []
220
+ if "" in caps:
221
+ return "<b>All three captions are required.</b>"
222
+ labels = ["Reference", "SD-Turbo", "DreamShaper"]
223
+ html_blocks = []
224
+ for label, cap in zip(labels, caps):
225
+ # Sentiment
226
+ sentiment = "<br>".join([f"{s['label']}: {s['score']:.2f}" for s in sentiment_model(cap)])
227
+ # Entities
228
+ ents_list = ner_model(cap)
229
+ ents = "<br>".join([f"{e['entity_group']}: {e['word']}" for e in ents_list])
230
+ # Topics
231
+ topics_data = topic_model(cap, candidate_labels=['people','animals','objects','food','nature'])
232
+ topics = "<br>".join([f"{l}: {sc:.2f}" for l, sc in zip(topics_data['labels'], topics_data['scores'])])
233
+ html_blocks.append(f"<div style='padding:10px;'><h3>{label}</h3><b>Sentiment</b><br>{sentiment}<br><b>Entities</b><br>{ents}<br><b>Topics</b><br>{topics}</div>")
234
+ return "<div style='display:flex;gap:20px;'>" + "".join(html_blocks) + "</div>"
235
+
236
+ nlp_btn.click(analyze_nlp, inputs=[captions_state], outputs=[nlp_out])
237
+
238
+ # --- Step 5: VQA ---
239
+ vqa_input = gr.Textbox(label="Ask about reference image")
240
+ vqa_btn = gr.Button("Get Answer")
241
+ vqa_out = gr.Markdown()
242
+
243
+ def vqa_ui(question, img):
244
+ return answer_vqa(question, img)
245
+
246
+ vqa_btn.click(vqa_ui, inputs=[vqa_input, upload_preview], outputs=[vqa_out])
247
+
248
+ return demo
249
+
250
+ # Launch
251
+ demo = build_ui()
252
+ demo.launch()
253
+
254
+ # Dumped section
255
+ """
256
+ ####################################################################################
257
  # ==============================
258
  # SECTION 1
259
  # ==============================
 
436
  def build_ui_with_custom_ui():
437
  with gr.Blocks(title="Multimodal AI Image Studio") as demo:
438
  # ---------------- CSS Styling ----------------
439
+ gr.HTML(
440
  <style>
441
  .heading-orange h2, .heading-orange h3 { color: #ff5500 !important; }
442
  .orange-btn button { background-color: #ff5500 !important; color: white !important; border-radius: 6px !important; height: 36px !important; font-weight: bold; }
 
472
  flex-direction: column;
473
  }
474
  </style>
475
+ )
476
 
477
  # ---------------- Heading ----------------
478
  gr.Markdown("## Multimodal AI Image Studio: An Integrated Comparative Perspective", elem_classes="heading-orange")
 
659
  demo = build_ui_with_custom_ui()
660
  demo.launch()
661
 
662
+ ####################################################################################
663
+
664
  # Section 3
665
  # ---------------- Build Gradio UI with Custom Look ----------------
666
  def build_ui_with_custom_ui():
 
852
 
853
  # Launch the interface
854
  demo = build_ui_with_custom_ui()
855
+ demo.launch()"""
 
856