Spaces:
Sleeping
Sleeping
| # ========================= | |
| # LIBRARIES & DEVICE SETUP | |
| # ========================= | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from diffusers import DiffusionPipeline | |
| from transformers import pipeline, BlipProcessor, BlipForQuestionAnswering | |
| import lpips | |
| import clip | |
| from bert_score import score | |
| import torchvision.transforms as T | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def free_gpu_cache(): | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| # ========================= | |
| # MODELS | |
| # ========================= | |
| # Image generation | |
| gen_pipe = DiffusionPipeline.from_pretrained( | |
| "stabilityai/sdxl-turbo", | |
| torch_dtype=torch.float16 if device=="cuda" else torch.float32 | |
| ).to(device) | |
| dreamshaper_pipe = DiffusionPipeline.from_pretrained( | |
| "Lykon/dreamshaper-7", | |
| torch_dtype=torch.float16 if device=="cuda" else torch.float32 | |
| ).to(device) | |
| # Captioning | |
| captioner = pipeline( | |
| "image-to-text", | |
| model="Salesforce/blip-image-captioning-large", | |
| device=0 if device=="cuda" else -1, | |
| generate_kwargs={"max_new_tokens":256, "num_beams":5, "temperature":0.7} | |
| ) | |
| # NLP MODELS (UNCHANGED) | |
| sentiment_model = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english", | |
| device=0 if device=="cuda" else -1) | |
| ner_model = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english", | |
| aggregation_strategy="simple", device=0 if device=="cuda" else -1) | |
| topic_model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli", | |
| device=0 if device=="cuda" else -1) | |
| # VQA – MOVED TO GPU (YOUR REQUEST OPTION B) | |
| vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") | |
| vqa_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to(device) | |
| # Metrics | |
| clip_model, clip_preprocess = clip.load("ViT-B/32", device=device) | |
| lpips_model = lpips.LPIPS(net='alex').to(device) | |
| lpips_transform = T.Compose([T.ToTensor(), T.Resize((256,256))]) | |
| # Style presets | |
| style_map = { | |
| "Photorealistic": "photorealistic, ultra-detailed, 8k, cinematic lighting", | |
| "Real Life": "natural lighting, true-to-life colors, DSLR", | |
| "Documentary": "documentary handheld muted colors", | |
| "iPhone Camera": "iPhone photo natural HDR", | |
| "Street Photography": "candid street ambient shadows", | |
| "Cinematic": "cinematic lighting dramatic depth", | |
| "Anime": "anime cel shaded vibrant", | |
| "Watercolor": "watercolor soft wash art", | |
| "Macro": "macro lens shallow DOF", | |
| "Cyberpunk": "neon cyberpunk futuristic", | |
| } | |
| # ========================= | |
| # IMAGE GENERATION FUNCTIONS | |
| # ========================= | |
| def generate_image_with_enhancer(base_caption, enhancer, negative, seed, style, images): | |
| base_caption = base_caption or "" | |
| enhancer = enhancer or "" | |
| final_prompt = f"{base_caption}, {enhancer}".strip(", ") | |
| final_prompt = f"{final_prompt}, {style_map.get(style,'')}".strip(", ") | |
| try: | |
| seed = int(seed) | |
| except: | |
| seed = 42 | |
| generator = torch.Generator(device="cpu").manual_seed(seed) | |
| try: | |
| with torch.no_grad(): | |
| out = gen_pipe(prompt=final_prompt, negative_prompt=negative, generator=generator) | |
| img = out.images[0] | |
| except: | |
| img = None | |
| if img: | |
| images[1] = img # store SD-Turbo at index 1 | |
| free_gpu_cache() | |
| return img, images | |
| def generate_dreamshaper_with_enhancer(base_caption, enhancer, negative, seed, style, images): | |
| base_caption = base_caption or "" | |
| enhancer = enhancer or "" | |
| final_prompt = f"{base_caption}, {enhancer}".strip(", ") | |
| final_prompt = f"{final_prompt}, {style_map.get(style,'')}".strip(", ") | |
| try: | |
| seed = int(seed) | |
| except: | |
| seed = 42 | |
| generator = torch.Generator(device="cpu").manual_seed(seed) | |
| try: | |
| with torch.no_grad(): | |
| out = dreamshaper_pipe(prompt=final_prompt, negative_prompt=negative, generator=generator) | |
| img = out.images[0] | |
| except: | |
| img = None | |
| if img: | |
| images[2] = img # store DreamShaper at index 2 | |
| free_gpu_cache() | |
| return img, images | |
| # ========================= | |
| # CAPTIONING | |
| # ========================= | |
| def caption_for_image(img): | |
| try: | |
| out = captioner(img) | |
| return out[0]["generated_text"] | |
| except: | |
| return "Caption failed." | |
| # ========================= | |
| # VQA (FIXED – now uses GPU + correct image) | |
| # ========================= | |
| def answer_vqa(question, image): | |
| if image is None or not question.strip(): | |
| return "Provide image + question." | |
| try: | |
| inputs_raw = vqa_processor(images=image, text=question, return_tensors="pt") | |
| inputs = {k:v.to(device) for k,v in inputs_raw.items()} | |
| with torch.no_grad(): | |
| out = vqa_model(**inputs) | |
| ans_id = out.logits.argmax(-1) | |
| return vqa_processor.decode(ans_id[0], skip_special_tokens=True) | |
| except: | |
| return "I could not determine the answer." | |
| # ========================= | |
| # METRICS (UNCHANGED LOGIC, FIXED STATE) | |
| # ========================= | |
| def compute_metrics(images, captions, i1, i2): | |
| img1, img2 = images[i1], images[i2] | |
| cap1, cap2 = captions[i1], captions[i2] | |
| # CLIP | |
| t1 = clip_preprocess(img1).unsqueeze(0).to(device) | |
| t2 = clip_preprocess(img2).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| f1 = clip_model.encode_image(t1) | |
| f2 = clip_model.encode_image(t2) | |
| clip_sim = float(torch.cosine_similarity(f1, f2)) | |
| # LPIPS | |
| L1 = (lpips_transform(img1).unsqueeze(0)*2 - 1).to(device) | |
| L2 = (lpips_transform(img2).unsqueeze(0)*2 - 1).to(device) | |
| with torch.no_grad(): | |
| lp = float(lpips_model(L1, L2)) | |
| # BERTScore | |
| if cap1 and cap2: | |
| _, _, F = score([cap1],[cap2], lang="en", verbose=False) | |
| bert_f1 = float(F.mean()) | |
| else: | |
| bert_f1 = 0.0 | |
| return clip_sim, lp, bert_f1 | |
| # ========================= | |
| # UI BUILD | |
| # ========================= | |
| def build_full_ui(): | |
| with gr.Blocks(title="Multimodal AI Image Studio") as demo: | |
| # YOUR CSS (UNCHANGED) | |
| gr.HTML(""" | |
| <style> | |
| .heading-orange h2, .heading-orange h3 { color: #ff5500 !important; } | |
| .orange-btn button { background-color: #ff5500 !important; color: white !important; border-radius: 6px !important; height: 36px !important; font-weight: bold; } | |
| .teal-btn button { background-color: #008080 !important; color: white !important; border-radius: 6px !important; height: 40px !important; font-weight:bold; } | |
| .loading-line { height:4px; background: linear-gradient(90deg,#008080 0%,#00cccc 50%,#008080 100%); background-size:200% 100%; animation:loading 1s linear infinite; } | |
| @keyframes loading { 0% { background-position:200% 0; } 100% { background-position:-200% 0; } } | |
| .enhancer-box textarea { width:100%!important;height:36px!important;font-size:14px; } | |
| </style> | |
| """) | |
| # States | |
| images_state = gr.State([None, None, None]) | |
| captions_state = gr.State(["", "", ""]) | |
| # ========================= | |
| # Section 1: Upload Image | |
| # ========================= | |
| gr.Markdown("## 1️⃣ Upload Reference Image", elem_classes="heading-orange") | |
| with gr.Row(): | |
| with gr.Column(): | |
| upload_input = gr.Image(label="Drag & Drop Image", type="pil") | |
| upload_btn = gr.Button("Upload Image & Generate Caption", elem_classes="orange-btn") | |
| enhancer_box = gr.Textbox(label="Prompt Enhancer (Optional)", elem_classes="enhancer-box") | |
| with gr.Column(): | |
| upload_preview = gr.Image(label="Uploaded Image") | |
| caption_out = gr.Markdown() | |
| def upload_and_caption(img, images_state, captions_state): | |
| if img is None: | |
| return None, "No image uploaded.", images_state, captions_state | |
| images_state[0] = img | |
| cap = caption_for_image(img) | |
| captions_state[0] = cap | |
| return img, cap, images_state, captions_state | |
| upload_btn.click(upload_and_caption, [upload_input, images_state, captions_state], | |
| [upload_preview, caption_out, images_state, captions_state]) | |
| # ========================= | |
| # Section 2: Generate Images | |
| # ========================= | |
| gr.Markdown("## 2️⃣ Generate Images from Caption", elem_classes="heading-orange") | |
| with gr.Row(): | |
| with gr.Column(): | |
| sd_btn = gr.Button("Generate SD-Turbo", elem_classes="orange-btn") | |
| sd_preview = gr.Image(label="SD-Turbo Image") | |
| with gr.Column(): | |
| ds_btn = gr.Button("Generate DreamShaper", elem_classes="orange-btn") | |
| ds_preview = gr.Image(label="DreamShaper Image") | |
| def generate_sd(caption, enhancer, images_state, captions_state): | |
| img, images_state = generate_image_with_enhancer(caption, enhancer, "", 42, "Photorealistic", images_state) | |
| if img: | |
| captions_state[1] = caption_for_image(img) | |
| return img, images_state, captions_state | |
| def generate_ds(caption, enhancer, images_state, captions_state): | |
| img, images_state = generate_dreamshaper_with_enhancer(caption, enhancer, "", 123, "Photorealistic", images_state) | |
| if img: | |
| captions_state[2] = caption_for_image(img) | |
| return img, images_state, captions_state | |
| sd_btn.click(generate_sd, [caption_out, enhancer_box, images_state, captions_state], | |
| [sd_preview, images_state, captions_state]) | |
| ds_btn.click(generate_ds, [caption_out, enhancer_box, images_state, captions_state], | |
| [ds_preview, images_state, captions_state]) | |
| # ========================= | |
| # Section 3: Metrics | |
| # ========================= | |
| gr.Markdown("## 3️⃣ Compute Pairwise Metrics", elem_classes="heading-orange") | |
| metrics_btn = gr.Button("Compute Metrics", elem_classes="teal-btn") | |
| metrics_spinner = gr.HTML() | |
| metrics_out = gr.HTML() | |
| def compute_metrics_ui(images, captions): | |
| yield "<div class='loading-line'></div>", "" | |
| if None in images: | |
| yield "", "<b>All three images and captions are required.</b>" | |
| return | |
| A = compute_metrics(images, captions, 0, 1) | |
| B = compute_metrics(images, captions, 0, 2) | |
| C = compute_metrics(images, captions, 1, 2) | |
| def fmt(m): | |
| return f"CLIP: {m[0]:.3f}<br>LPIPS: {m[1]:.3f}<br>BERTScore: {m[2]:.3f}" | |
| html = f""" | |
| <div style='display:flex; gap:40px; justify-content:space-around;'> | |
| <div><b>Metrics A<br>(Ref ↔ SD)</b><br>{fmt(A)}</div> | |
| <div><b>Metrics B<br>(Ref ↔ DS)</b><br>{fmt(B)}</div> | |
| <div><b>Metrics C<br>(SD ↔ DS)</b><br>{fmt(C)}</div> | |
| </div> | |
| """ | |
| yield "", html | |
| metrics_btn.click(compute_metrics_ui, [images_state, captions_state], | |
| [metrics_spinner, metrics_out]) | |
| # ========================= | |
| # Section 4: NLP (UNCHANGED) | |
| # ========================= | |
| gr.Markdown("## 4️⃣ NLP Analysis of Captions", elem_classes="heading-orange") | |
| nlp_btn = gr.Button("Analyze Captions", elem_classes="teal-btn") | |
| nlp_spinner = gr.HTML() | |
| nlp_out = gr.HTML() | |
| def analyze_captions_ui(captions): | |
| yield "<div class='loading-line'></div>", "" | |
| if any(c == "" for c in captions): | |
| yield "", "<b>All three captions required.</b>" | |
| return | |
| labels = ["Reference", "SD-Turbo", "DreamShaper"] | |
| blocks = [] | |
| for label, caption in zip(labels, captions): | |
| sentiment = "<br>".join([f"{s['label']}: {s['score']:.2f}" for s in sentiment_model(caption)]) | |
| ents_list = ner_model(caption) | |
| ents = "<br>".join([f"{e['entity_group']}: {e['word']}" for e in ents_list]) or "None" | |
| topics_data = topic_model(caption, candidate_labels=['people','animals','objects','food','nature']) | |
| topics = "<br>".join([f"{l}: {sc:.2f}" for l, sc in zip(topics_data['labels'], topics_data['scores'])]) | |
| block = f""" | |
| <div style='flex:1; padding:10px; min-width:250px;'> | |
| <h3><u>{label}</u></h3> | |
| <b>Sentiment</b><br>{sentiment}<br><br> | |
| <b>Entities</b><br>{ents}<br><br> | |
| <b>Topics</b><br>{topics} | |
| </div> | |
| """ | |
| blocks.append(block) | |
| yield "", f"<div style='display:flex; gap:20px;'>{''.join(blocks)}</div>" | |
| nlp_btn.click(analyze_captions_ui, [captions_state], [nlp_spinner, nlp_out]) | |
| # ========================= | |
| # Section 5: VQA (FIXED) | |
| # ========================= | |
| gr.Markdown("## 5️⃣ Visual Question Answering (VQA)", elem_classes="heading-orange") | |
| vqa_input = gr.Textbox(label="Enter a question about the reference image") | |
| vqa_btn = gr.Button("Get Answer", elem_classes="teal-btn") | |
| vqa_spinner = gr.HTML() | |
| vqa_out = gr.Markdown() | |
| def vqa_ui(question, images_state): | |
| yield "<div class='loading-line'></div>", "" | |
| ref_img = images_state[0] | |
| ans = answer_vqa(question, ref_img) | |
| yield "", f"**Answer:** {ans}" | |
| vqa_btn.click(vqa_ui, [vqa_input, images_state], [vqa_spinner, vqa_out]) | |
| return demo | |
| demo = build_full_ui() | |
| demo.launch() | |
| """ | |
| # ========================= | |
| # LIBRARIES & DEVICE SETUP | |
| # ========================= | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from diffusers import DiffusionPipeline | |
| from transformers import pipeline, BlipProcessor, BlipForQuestionAnswering | |
| import lpips | |
| import clip | |
| from bert_score import score | |
| import torchvision.transforms as T | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def free_gpu_cache(): | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| # ========================= | |
| # MODELS | |
| # ========================= | |
| gen_pipe = DiffusionPipeline.from_pretrained( | |
| "stabilityai/sdxl-turbo", | |
| torch_dtype=torch.float16 if device=="cuda" else torch.float32 | |
| ).to(device) | |
| dreamshaper_pipe = DiffusionPipeline.from_pretrained( | |
| "Lykon/dreamshaper-7", | |
| torch_dtype=torch.float16 if device=="cuda" else torch.float32 | |
| ).to(device) | |
| captioner = pipeline( | |
| "image-to-text", | |
| model="Salesforce/blip-image-captioning-large", | |
| device=0 if device=="cuda" else -1, | |
| generate_kwargs={"max_new_tokens":256, "num_beams":5, "temperature":0.7} | |
| ) | |
| sentiment_model = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english", | |
| device=0 if device=="cuda" else -1) | |
| ner_model = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english", | |
| aggregation_strategy="simple", device=0 if device=="cuda" else -1) | |
| topic_model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli", | |
| device=0 if device=="cuda" else -1) | |
| vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") | |
| vqa_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to("cpu") | |
| clip_model, clip_preprocess = clip.load("ViT-B/32", device=device) | |
| lpips_model = lpips.LPIPS(net='alex').to(device) | |
| lpips_transform = T.Compose([T.ToTensor(), T.Resize((256,256))]) | |
| style_map = { | |
| "Photorealistic": "photorealistic, ultra-detailed, 8k, cinematic lighting", | |
| "Real Life": "natural lighting, true-to-life colors, DSLR", | |
| "Documentary": "documentary handheld muted colors", | |
| "iPhone Camera": "iPhone photo natural HDR", | |
| "Street Photography": "candid street ambient shadows", | |
| "Cinematic": "cinematic lighting dramatic depth", | |
| "Anime": "anime cel shaded vibrant", | |
| "Watercolor": "watercolor soft wash art", | |
| "Macro": "macro lens shallow DOF", | |
| "Cyberpunk": "neon cyberpunk futuristic", | |
| } | |
| # ========================= | |
| # IMAGE GENERATION FUNCTIONS | |
| # ========================= | |
| def generate_image_with_enhancer(base_caption, enhancer, negative, seed, style, images): | |
| images = images or [None, None, None] | |
| final_prompt = f"{base_caption}, {enhancer}".strip(", ") | |
| final_prompt = f"{final_prompt}, {style_map.get(style,'')}".strip(", ") | |
| try: | |
| seed = int(seed) | |
| except: | |
| seed = 42 | |
| generator = torch.Generator(device="cpu").manual_seed(seed) | |
| try: | |
| with torch.no_grad(): | |
| out = gen_pipe(prompt=final_prompt, negative_prompt=negative, generator=generator) | |
| img = out.images[0] | |
| except Exception as e: | |
| print("SD Turbo failed:", e) | |
| img = None | |
| if img: | |
| images[1] = img # Always put SD-Turbo at index 1 | |
| free_gpu_cache() | |
| return img, images | |
| def generate_dreamshaper_with_enhancer(base_caption, enhancer, negative, seed, style, images): | |
| images = images or [None, None, None] | |
| final_prompt = f"{base_caption}, {enhancer}".strip(", ") | |
| final_prompt = f"{final_prompt}, {style_map.get(style,'')}".strip(", ") | |
| try: | |
| seed = int(seed) | |
| except: | |
| seed = 42 | |
| generator = torch.Generator(device="cpu").manual_seed(seed) | |
| try: | |
| with torch.no_grad(): | |
| out = dreamshaper_pipe(prompt=final_prompt, negative_prompt=negative, generator=generator) | |
| img = out.images[0] | |
| except Exception as e: | |
| print("DreamShaper failed:", e) | |
| img = None | |
| if img: | |
| images[2] = img # Always put DreamShaper at index 2 | |
| free_gpu_cache() | |
| return img, images | |
| # ========================= | |
| # CAPTIONING | |
| # ========================= | |
| def caption_for_image(img): | |
| try: | |
| out = captioner(img) | |
| return out[0]["generated_text"] | |
| except: | |
| return "Caption failed." | |
| # ========================= | |
| # VQA | |
| # ========================= | |
| def answer_vqa(question, image): | |
| if image is None or not question.strip(): | |
| return "Provide image + question." | |
| try: | |
| inputs_raw = vqa_processor(images=image, text=question, return_tensors="pt") | |
| inputs = {k:v.to("cpu") for k,v in inputs_raw.items()} | |
| with torch.no_grad(): | |
| out = vqa_model(**inputs) | |
| ans_id = out.logits.argmax(-1) | |
| return vqa_processor.decode(ans_id[0], skip_special_tokens=True) | |
| except: | |
| return "I could not determine the answer." | |
| # ========================= | |
| # METRICS | |
| # ========================= | |
| def compute_metrics(images, captions, i1, i2): | |
| img1 = images[i1] | |
| img2 = images[i2] | |
| cap1 = captions[i1] | |
| cap2 = captions[i2] | |
| # CLIP | |
| t1 = clip_preprocess(img1).unsqueeze(0).to("cpu") | |
| t2 = clip_preprocess(img2).unsqueeze(0).to("cpu") | |
| with torch.no_grad(): | |
| f1 = clip_model.encode_image(t1) | |
| f2 = clip_model.encode_image(t2) | |
| clip_sim = float(torch.cosine_similarity(f1, f2)) | |
| # LPIPS | |
| L1 = (lpips_transform(img1).unsqueeze(0)*2 - 1) | |
| L2 = (lpips_transform(img2).unsqueeze(0)*2 - 1) | |
| with torch.no_grad(): | |
| lp = float(lpips_model(L1, L2)) | |
| # BERTScore | |
| if cap1 and cap2: | |
| _, _, F = score([cap1],[cap2], lang="en", verbose=False) | |
| bert_f1 = float(F.mean()) | |
| else: | |
| bert_f1 = 0.0 | |
| return f"CLIP: {clip_sim:.2f}\nLPIPS: {lp:.2f}\nBERTScore F1: {bert_f1:.2f}" | |
| # ========================= | |
| # GRADIO UI BUILD | |
| # ========================= | |
| def build_full_ui(): | |
| with gr.Blocks(title="Multimodal AI Image Studio") as demo: | |
| gr.HTML( | |
| <style> | |
| .heading-orange h2, .heading-orange h3 { color: #ff5500 !important; } | |
| .orange-btn button { background-color: #ff5500 !important; color: white !important; border-radius: 6px !important; height: 36px !important; font-weight: bold; } | |
| .teal-btn button { background-color: #008080 !important; color: white !important; border-radius: 6px !important; height: 40px !important; font-weight: bold; } | |
| .loading-line { height:4px; background: linear-gradient(90deg,#008080 0%,#00cccc 50%,#008080 100%); background-size: 200% 100%; animation: loading 1s linear infinite; } | |
| @keyframes loading { 0% { background-position:200% 0; } 100% { background-position:-200% 0; } } | |
| .enhancer-box textarea { width:100% !important; height:36px !important; box-sizing:border-box; font-size:14px; } | |
| .equal-height-row { display:flex; align-items:stretch; } | |
| .equal-height-row > .gr-column { display:flex; flex-direction:column; } | |
| </style> | |
| ) | |
| images_state = gr.State([None, None, None]) | |
| captions_state = gr.State(["", "", ""]) | |
| # --- Upload Section --- | |
| gr.Markdown("## 1️⃣ Upload Reference Image", elem_classes="heading-orange") | |
| with gr.Row(elem_classes="equal-height-row"): | |
| with gr.Column(scale=1): | |
| upload_input = gr.Image(label="Drag & Drop Image", type="pil") | |
| upload_btn = gr.Button("Upload Image & Generate Caption", elem_classes="orange-btn") | |
| enhancer_box = gr.Textbox(label="Prompt Enhancer (Optional)", placeholder="Example: 'at night with neon lights'", elem_classes="enhancer-box") | |
| with gr.Column(scale=1): | |
| upload_preview = gr.Image(label="Uploaded Image", interactive=False) | |
| caption_out = gr.Markdown(label="Generated Caption") | |
| def upload_and_caption(img, images_state, captions_state): | |
| if img is None: | |
| return None, "No image uploaded.", images_state, captions_state | |
| images_state[0] = img | |
| captions_state[0] = caption_for_image(img) | |
| return img, captions_state[0], images_state, captions_state | |
| upload_btn.click(upload_and_caption, inputs=[upload_input, images_state, captions_state], | |
| outputs=[upload_preview, caption_out, images_state, captions_state]) | |
| # --- Generate SD-Turbo & DreamShaper --- | |
| gr.Markdown("## 2️⃣ Generate Images from Caption", elem_classes="heading-orange") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| sd_btn = gr.Button("Generate SD-Turbo Image", elem_classes="orange-btn") | |
| sd_preview = gr.Image(label="SD-Turbo Image", interactive=False) | |
| with gr.Column(scale=1): | |
| ds_btn = gr.Button("Generate DreamShaper Image", elem_classes="orange-btn") | |
| ds_preview = gr.Image(label="DreamShaper Image", interactive=False) | |
| def generate_sd(caption, enhancer, images_state, captions_state): | |
| img, images_state = generate_image_with_enhancer(caption, enhancer, "", 42, "Photorealistic", images_state) | |
| if img: | |
| captions_state[1] = caption_for_image(img) | |
| return img, images_state, captions_state | |
| def generate_ds(caption, enhancer, images_state, captions_state): | |
| img, images_state = generate_dreamshaper_with_enhancer(caption, enhancer, "", 123, "Photorealistic", images_state) | |
| if img: | |
| captions_state[2] = caption_for_image(img) | |
| return img, images_state, captions_state | |
| sd_btn.click(generate_sd, inputs=[caption_out, enhancer_box, images_state, captions_state], | |
| outputs=[sd_preview, images_state, captions_state]) | |
| ds_btn.click(generate_ds, inputs=[caption_out, enhancer_box, images_state, captions_state], | |
| outputs=[ds_preview, images_state, captions_state]) | |
| # --- Compute Metrics --- | |
| gr.Markdown("## 3️⃣ Compute Pairwise Metrics", elem_classes="heading-orange") | |
| metrics_btn = gr.Button("Compute Metrics for All Pairs", elem_classes="teal-btn") | |
| metrics_spinner = gr.HTML("<div style='height:4px;'></div>") | |
| metrics_A = gr.Markdown() | |
| metrics_B = gr.Markdown() | |
| metrics_C = gr.Markdown() | |
| def compute_metrics_ui(images, captions): | |
| yield "<div class='loading-line'></div>", "", "", "" | |
| if any(i is None for i in images): | |
| msg = "All three images and captions are required." | |
| yield "", msg, msg, msg | |
| else: | |
| A = compute_metrics(images, captions, 0, 1) | |
| B = compute_metrics(images, captions, 0, 2) | |
| C = compute_metrics(images, captions, 1, 2) | |
| yield "", f"**Reference ↔ SD-Turbo**\n{A}", f"**Reference ↔ DreamShaper**\n{B}", f"**SD-Turbo ↔ DreamShaper**\n{C}" | |
| metrics_btn.click(compute_metrics_ui, inputs=[images_state, captions_state], | |
| outputs=[metrics_spinner, metrics_A, metrics_B, metrics_C]) | |
| # --- VQA --- | |
| gr.Markdown("## 5️⃣ Visual Question Answering (VQA)", elem_classes="heading-orange") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| vqa_input = gr.Textbox(label="Enter a question about the reference image") | |
| vqa_btn = gr.Button("Get Answer", elem_classes="teal-btn") | |
| with gr.Column(scale=1): | |
| vqa_spinner = gr.HTML("<div style='height:4px;'></div>") | |
| vqa_out = gr.Markdown(label="VQA Output") | |
| def vqa_ui(question, images_state): | |
| yield "<div class='loading-line'></div>", "" | |
| ans = answer_vqa(question, images_state[0]) | |
| yield "", ans | |
| vqa_btn.click(vqa_ui, inputs=[vqa_input, images_state], outputs=[vqa_spinner, vqa_out]) | |
| return demo | |
| # Launch | |
| demo = build_full_ui() | |
| demo.launch()""" | |
| """ | |
| #Dumped code | |
| # ========================= | |
| # LIBRARIES & DEVICE SETUP | |
| # ========================= | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from diffusers import DiffusionPipeline | |
| from transformers import pipeline, BlipProcessor, BlipForQuestionAnswering | |
| import lpips | |
| import clip | |
| from bert_score import score | |
| import torchvision.transforms as T | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def free_gpu_cache(): | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| # ========================= | |
| # MODELS | |
| # ========================= | |
| # Image generation | |
| gen_pipe = DiffusionPipeline.from_pretrained( | |
| "stabilityai/sdxl-turbo", | |
| torch_dtype=torch.float16 if device=="cuda" else torch.float32 | |
| ).to(device) | |
| dreamshaper_pipe = DiffusionPipeline.from_pretrained( | |
| "Lykon/dreamshaper-7", | |
| torch_dtype=torch.float16 if device=="cuda" else torch.float32 | |
| ).to(device) | |
| # Captioning | |
| captioner = pipeline( | |
| "image-to-text", | |
| model="Salesforce/blip-image-captioning-large", | |
| device=0 if device=="cuda" else -1, | |
| generate_kwargs={"max_new_tokens":256, "num_beams":5, "temperature":0.7} | |
| ) | |
| # NLP | |
| sentiment_model = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english", | |
| device=0 if device=="cuda" else -1) | |
| ner_model = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english", | |
| aggregation_strategy="simple", device=0 if device=="cuda" else -1) | |
| topic_model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli", | |
| device=0 if device=="cuda" else -1) | |
| # VQA | |
| vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") | |
| vqa_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to("cpu") | |
| # Metrics | |
| clip_model, clip_preprocess = clip.load("ViT-B/32", device=device) | |
| lpips_model = lpips.LPIPS(net='alex').to(device) | |
| lpips_transform = T.Compose([T.ToTensor(), T.Resize((256,256))]) | |
| # Styles | |
| style_map = { | |
| "Photorealistic": "photorealistic, ultra-detailed, 8k, cinematic lighting", | |
| "Real Life": "natural lighting, true-to-life colors, DSLR", | |
| "Documentary": "documentary handheld muted colors", | |
| "iPhone Camera": "iPhone photo natural HDR", | |
| "Street Photography": "candid street ambient shadows", | |
| "Cinematic": "cinematic lighting dramatic depth", | |
| "Anime": "anime cel shaded vibrant", | |
| "Watercolor": "watercolor soft wash art", | |
| "Macro": "macro lens shallow DOF", | |
| "Cyberpunk": "neon cyberpunk futuristic", | |
| } | |
| # ========================= | |
| # IMAGE GENERATION FUNCTIONS | |
| # ========================= | |
| def generate_image_with_enhancer(base_caption, enhancer, negative, seed, style, images): | |
| images = images or [] | |
| base_caption = base_caption or "" | |
| enhancer = enhancer or "" | |
| final_prompt = f"{base_caption}, {enhancer}".strip(", ") | |
| final_prompt = f"{final_prompt}, {style_map.get(style,'')}".strip(", ") | |
| try: | |
| seed = int(seed) | |
| except: | |
| seed = 42 | |
| generator = torch.Generator(device="cpu").manual_seed(seed) | |
| try: | |
| with torch.no_grad(): | |
| out = gen_pipe(prompt=final_prompt, negative_prompt=negative, generator=generator) | |
| img = out.images[0] | |
| except Exception as e: | |
| print("SD Turbo failed:", e) | |
| img = None | |
| if img: | |
| images.append(img) | |
| free_gpu_cache() | |
| return img, images | |
| def generate_dreamshaper_with_enhancer(base_caption, enhancer, negative, seed, style, images): | |
| images = images or [] | |
| base_caption = base_caption or "" | |
| enhancer = enhancer or "" | |
| final_prompt = f"{base_caption}, {enhancer}".strip(", ") | |
| final_prompt = f"{final_prompt}, {style_map.get(style,'')}".strip(", ") | |
| try: | |
| seed = int(seed) | |
| except: | |
| seed = 42 | |
| generator = torch.Generator(device="cpu").manual_seed(seed) | |
| try: | |
| with torch.no_grad(): | |
| out = dreamshaper_pipe(prompt=final_prompt, negative_prompt=negative, generator=generator) | |
| img = out.images[0] | |
| except Exception as e: | |
| print("DreamShaper failed:", e) | |
| img = None | |
| if img: | |
| images.append(img) | |
| free_gpu_cache() | |
| return img, images | |
| # ========================= | |
| # CAPTIONING | |
| # ========================= | |
| def caption_for_image(img): | |
| try: | |
| out = captioner(img) | |
| return out[0]["generated_text"] | |
| except: | |
| return "Caption failed." | |
| # ========================= | |
| # VQA | |
| # ========================= | |
| def answer_vqa(question, image): | |
| if not image or not question.strip(): | |
| return "Provide image + question." | |
| try: | |
| inputs_raw = vqa_processor(images=image, text=question, return_tensors="pt") | |
| inputs = {k:v.to("cpu") for k,v in inputs_raw.items()} | |
| with torch.no_grad(): | |
| out = vqa_model(**inputs) | |
| ans_id = out.logits.argmax(-1) | |
| return vqa_processor.decode(ans_id[0], skip_special_tokens=True) | |
| except: | |
| return "I could not determine the answer." | |
| # ========================= | |
| # METRICS | |
| # ========================= | |
| def compute_metrics(images, captions, i1, i2): | |
| img1 = images[i1] | |
| img2 = images[i2] | |
| cap1 = captions[i1] | |
| cap2 = captions[i2] | |
| # CLIP | |
| t1 = clip_preprocess(img1).unsqueeze(0).to("cpu") | |
| t2 = clip_preprocess(img2).unsqueeze(0).to("cpu") | |
| with torch.no_grad(): | |
| f1 = clip_model.encode_image(t1) | |
| f2 = clip_model.encode_image(t2) | |
| clip_sim = float(torch.cosine_similarity(f1, f2)) | |
| # LPIPS | |
| L1 = (lpips_transform(img1).unsqueeze(0)*2 - 1) | |
| L2 = (lpips_transform(img2).unsqueeze(0)*2 - 1) | |
| with torch.no_grad(): | |
| lp = float(lpips_model(L1, L2)) | |
| # BERTScore | |
| if cap1 and cap2: | |
| _, _, F = score([cap1],[cap2], lang="en", verbose=False) | |
| bert_f1 = float(F.mean()) | |
| else: | |
| bert_f1 = 0.0 | |
| return clip_sim, lp, bert_f1 | |
| # ========================= | |
| # GRADIO UI BUILD | |
| # ========================= | |
| def build_full_ui(): | |
| with gr.Blocks(title="Multimodal AI Image Studio") as demo: | |
| # --- CSS Styling --- | |
| gr.HTML( | |
| <style> | |
| .heading-orange h2, .heading-orange h3 { color: #ff5500 !important; } | |
| .orange-btn button { background-color: #ff5500 !important; color: white !important; border-radius: 6px !important; height: 36px !important; font-weight: bold; } | |
| .teal-btn button { background-color: #008080 !important; color: white !important; border-radius: 6px !important; height: 40px !important; font-weight: bold; } | |
| .loading-line { height:4px; background: linear-gradient(90deg,#008080 0%,#00cccc 50%,#008080 100%); background-size: 200% 100%; animation: loading 1s linear infinite; } | |
| @keyframes loading { 0% { background-position:200% 0; } 100% { background-position:-200% 0; } } | |
| .enhancer-box textarea { width:100% !important; height:36px !important; box-sizing:border-box; font-size:14px; } | |
| .equal-height-row { display:flex; align-items:stretch; } | |
| .equal-height-row > .gr-column { display:flex; flex-direction:column; } | |
| </style> | |
| ) | |
| # --- States --- | |
| images_state = gr.State([None, None, None]) | |
| captions_state = gr.State(["", "", ""]) | |
| # ========================= | |
| # Section 1: Upload Reference Image | |
| # ========================= | |
| gr.Markdown("## 1️⃣ Upload Reference Image", elem_classes="heading-orange") | |
| with gr.Row(elem_classes="equal-height-row"): | |
| with gr.Column(scale=1): | |
| upload_input = gr.Image(label="Drag & Drop Image", type="pil") | |
| upload_btn = gr.Button("Upload Image & Generate Caption", elem_classes="orange-btn") | |
| enhancer_box = gr.Textbox(label="Prompt Enhancer (Optional)", placeholder="Example: 'at night with neon lights'", elem_classes="enhancer-box") | |
| with gr.Column(scale=1): | |
| upload_preview = gr.Image(label="Uploaded Image", interactive=False) | |
| caption_out = gr.Markdown(label="Generated Caption") | |
| # Upload & caption function | |
| def upload_and_caption(img, images_state, captions_state): | |
| if img is None: | |
| return None, "No image uploaded.", images_state, captions_state | |
| images_state[0] = img | |
| try: | |
| cap = caption_for_image(img) | |
| except: | |
| cap = "Caption failed." | |
| captions_state[0] = cap | |
| return img, cap, images_state, captions_state | |
| upload_btn.click(upload_and_caption, inputs=[upload_input, images_state, captions_state], | |
| outputs=[upload_preview, caption_out, images_state, captions_state]) | |
| # ========================= | |
| # Section 2: Generate SD-Turbo & DreamShaper | |
| # ========================= | |
| gr.Markdown("## 2️⃣ Generate Images from Caption", elem_classes="heading-orange") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| sd_btn = gr.Button("Generate SD-Turbo Image", elem_classes="orange-btn") | |
| sd_preview = gr.Image(label="SD-Turbo Image", interactive=False) | |
| with gr.Column(scale=1): | |
| ds_btn = gr.Button("Generate DreamShaper Image", elem_classes="orange-btn") | |
| ds_preview = gr.Image(label="DreamShaper Image", interactive=False) | |
| # Generate SD-Turbo | |
| def generate_sd(caption, enhancer, images_state, captions_state): | |
| img, images_state = generate_image_with_enhancer(caption, enhancer, negative="", seed=42, style="Photorealistic", images=images_state) | |
| if img: | |
| captions_state[1] = caption_for_image(img) | |
| return img, images_state, captions_state | |
| # Generate DreamShaper | |
| def generate_ds(caption, enhancer, images_state, captions_state): | |
| img, images_state = generate_dreamshaper_with_enhancer(caption, enhancer, negative="", seed=123, style="Photorealistic", images=images_state) | |
| if img: | |
| captions_state[2] = caption_for_image(img) | |
| return img, images_state, captions_state | |
| sd_btn.click(generate_sd, inputs=[caption_out, enhancer_box, images_state, captions_state], | |
| outputs=[sd_preview, images_state, captions_state]) | |
| ds_btn.click(generate_ds, inputs=[caption_out, enhancer_box, images_state, captions_state], | |
| outputs=[ds_preview, images_state, captions_state]) | |
| # ========================= | |
| # Section 3: Compute Pairwise Metrics (Side-by-Side) | |
| # ========================= | |
| gr.Markdown("## 3️⃣ Compute Pairwise Metrics", elem_classes="heading-orange") | |
| metrics_btn = gr.Button("Compute Metrics for All Pairs", elem_classes="teal-btn") | |
| metrics_spinner = gr.HTML("<div style='height:4px;'></div>") | |
| metrics_out = gr.HTML() | |
| def compute_metrics_ui(images, captions): | |
| yield "<div class='loading-line'></div>", "" | |
| if any(i is None for i in images): | |
| yield "All three images and captions are required." | |
| else: | |
| try: | |
| A = compute_metrics(images, captions, 0, 1) | |
| B = compute_metrics(images, captions, 0, 2) | |
| C = compute_metrics(images, captions, 1, 2) | |
| def fmt(m): | |
| return f"CLIP: {m[0]:.3f}<br>LPIPS: {m[1]:.3f}<br>BERTScore F1: {m[2]:.3f}" | |
| html = f" | |
| <div style='display:flex; gap:40px; justify-content:space-around;'> | |
| <div style='text-align:center;'><b>Metrics A</b><br>{fmt(A)}</div> | |
| <div style='text-align:center;'><b>Metrics B</b><br>{fmt(B)}</div> | |
| <div style='text-align:center;'><b>Metrics C</b><br>{fmt(C)}</div> | |
| </div> | |
| " | |
| yield html | |
| except Exception as e: | |
| print("Metrics error:", e) | |
| yield "Failed to compute metrics." | |
| metrics_btn.click(compute_metrics_ui, inputs=[images_state, captions_state], | |
| outputs=[metrics_out]) | |
| # ========================= | |
| # Section 4: NLP Analysis | |
| # ========================= | |
| gr.Markdown("## 4️⃣ NLP Analysis of Captions", elem_classes="heading-orange") | |
| nlp_btn = gr.Button("Analyze Captions", elem_classes="teal-btn") | |
| nlp_spinner = gr.HTML("<div style='height:4px;'></div>") | |
| nlp_out = gr.HTML() | |
| def analyze_captions_ui(captions): | |
| yield "<div class='loading-line'></div>", "" | |
| if any(c=="" for c in captions): | |
| yield "<b>All three captions are required for NLP analysis.</b>" | |
| else: | |
| labels = ["Reference", "SD-Turbo", "DreamShaper"] | |
| blocks = [] | |
| for label, caption in zip(labels, captions): | |
| try: | |
| sentiment = "<br>".join([f"{s['label']}: {s['score']:.2f}" for s in sentiment_model(caption)]) | |
| except: | |
| sentiment = "Sentiment failed." | |
| try: | |
| ents_list = ner_model(caption) | |
| ents = "<br>".join([f"{e.get('entity_group','')}: {e.get('word','')}" for e in ents_list]) or "None" | |
| except: | |
| ents = "NER failed." | |
| try: | |
| topics_data = topic_model(caption, candidate_labels=['people','animals','objects','food','nature']) | |
| topics = "<br>".join([f"{l}: {sc:.2f}" for l, sc in zip(topics_data.get('labels',[]), topics_data.get('scores',[]))]) | |
| except: | |
| topics = "Topics failed." | |
| block = f"<div style='flex:1;padding:10px;min-width:250px;'><h3><u>{label}</u></h3><b>Sentiment</b><br>{sentiment}<br><br><b>Entities</b><br>{ents}<br><br><b>Topics</b><br>{topics}</div>" | |
| blocks.append(block) | |
| yield f"<div style='display:flex; gap:20px; justify-content:space-between;'>{''.join(blocks)}</div>" | |
| nlp_btn.click(analyze_captions_ui, inputs=[captions_state], outputs=[nlp_out]) | |
| # ========================= | |
| # Section 5: Visual Question Answering | |
| # ========================= | |
| gr.Markdown("## 5️⃣ Visual Question Answering (VQA)", elem_classes="heading-orange") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| vqa_input = gr.Textbox(label="Enter a question about the reference image") | |
| vqa_btn = gr.Button("Get Answer", elem_classes="teal-btn") | |
| with gr.Column(scale=1): | |
| vqa_spinner = gr.HTML("<div style='height:4px;'></div>") | |
| vqa_out = gr.Markdown(label="VQA Output") | |
| def vqa_ui(question, image): | |
| yield "<div class='loading-line'></div>", "" | |
| if not question.strip() or image is None: | |
| yield "Provide image + question." | |
| else: | |
| try: | |
| ans = answer_vqa(question, image) | |
| yield f"<b>Answer:</b> {ans}" | |
| except Exception as e: | |
| print("VQA error:", e) | |
| yield "Could not determine the answer." | |
| vqa_btn.click(vqa_ui, inputs=[vqa_input, upload_preview], outputs=[vqa_out]) | |
| return demo | |
| # Launch | |
| demo = build_full_ui() | |
| demo.launch() | |
| """ |