# **Purpose** # ===================================================== # Multimodal AI Image Studio # ===================================================== # Purpose: # This script provides a unified interface for generating, # comparing, and analyzing AI-generated images. # # Key Features: # 1. Upload a reference image and automatically generate captions. # 2. Enhance prompts to generate images using: # - SD-Turbo (Stability AI) # - DreamShaper (Artistic style model) # 3. Compute pairwise metrics between images: # - CLIP similarity # - LPIPS perceptual similarity # - BERTScore textual similarity # 4. NLP analysis of captions: # - Sentiment analysis # - Named entity recognition # - Topic classification # 5. Visual Question Answering (VQA) on the reference image. # # Requirements: # - Python >= 3.9 # - GPU recommended for faster image generation # # Usage: # 1. Install dependencies (see requirements.txt) # 2. Run this script # 3. Access the Gradio web interface for interactive exploration """ # **Section One** # ============================== # SECTION 1 # ============================== # Install # Section One # ---------------- Install Libraries ---------------- # Libraries import torch import gradio as gr from PIL import Image from diffusers import DiffusionPipeline from transformers import pipeline, BlipProcessor, BlipForQuestionAnswering import lpips import clip from bert_score import score import torchvision.transforms as T import requests from io import BytesIO device = "cuda" if torch.cuda.is_available() else "cpu" def free_gpu_cache(): if device == "cuda": torch.cuda.empty_cache() # ============================== # MODELS # ============================== gen_pipe = DiffusionPipeline.from_pretrained( "stabilityai/sdxl-turbo", torch_dtype=torch.float16 if device=="cuda" else torch.float32 ).to(device) dreamshaper_pipe = DiffusionPipeline.from_pretrained( "Lykon/dreamshaper-7", torch_dtype=torch.float16 if device=="cuda" else torch.float32 ).to(device) captioner = pipeline( "image-to-text", model="Salesforce/blip-image-captioning-large", device=0 if device=="cuda" else -1 ) sentiment_model = pipeline( "sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english", device=-1 ) ner_model = pipeline( "ner", model="dbmdz/bert-large-cased-finetuned-conll03-english", aggregation_strategy="simple", device=-1 ) topic_model = pipeline( "zero-shot-classification", model="facebook/bart-large-mnli", device=-1 ) vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") vqa_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to(device) clip_model, clip_preprocess = clip.load("ViT-B/32", device=device) lpips_model = lpips.LPIPS(net='alex').to(device) lpips_transform = T.Compose([T.ToTensor(), T.Resize((256,256))]) style_map = { "Photorealistic": "photorealistic, ultra-detailed, 8k, cinematic lighting", "Real Life": "natural lighting, true-to-life colors, DSLR", "Documentary": "documentary handheld muted colors", "iPhone Camera": "iPhone photo natural HDR", "Street Photography": "candid street ambient shadows", "Cinematic": "cinematic lighting dramatic depth", "Anime": "anime cel shaded vibrant", "Watercolor": "watercolor soft wash art", "Macro": "macro lens shallow DOF", "Cyberpunk": "neon cyberpunk futuristic", } # SEction Two # ============================== # FUNCTIONS # ============================== def generate_image_with_enhancer(base_caption, enhancer, negative, seed, style, images, pipe=gen_pipe): images = images or [] base_caption = base_caption or "" enhancer = enhancer or "" final_prompt = f"{base_caption}, {enhancer}".strip(", ") final_prompt = f"{final_prompt}, {style_map.get(style,'')}".strip(", ") try: seed = int(seed) except: seed = 42 generator = torch.Generator(device=device).manual_seed(seed) try: with torch.no_grad(): out = pipe(prompt=final_prompt, negative_prompt=negative, generator=generator) img = out.images[0] except Exception as e: print(f"{pipe} failed:", e) img = None if img: images.append(img) free_gpu_cache() return img, images generate_dreamshaper_with_enhancer = lambda base_caption, enhancer, negative, seed, style, images: \ generate_image_with_enhancer(base_caption, enhancer, negative, seed, style, images, pipe=dreamshaper_pipe) def caption_for_image(img): try: out = captioner(img) return out[0]["generated_text"] except: return "Caption failed." def answer_vqa(question, image): if not image or not question.strip(): return "Provide image + question." try: inputs_raw = vqa_processor(images=image, text=question, return_tensors="pt") inputs = {k:v.to(device) for k,v in inputs_raw.items()} with torch.no_grad(): out = vqa_model(**inputs) ans_id = out.logits.argmax(-1) return vqa_processor.decode(ans_id[0], skip_special_tokens=True) except: return "VQA failed." def compute_metrics(images, captions, i1, i2): img1, img2 = images[i1], images[i2] cap1, cap2 = captions[i1], captions[i2] t1 = clip_preprocess(img1).unsqueeze(0).to(device) t2 = clip_preprocess(img2).unsqueeze(0).to(device) with torch.no_grad(): f1 = clip_model.encode_image(t1) f2 = clip_model.encode_image(t2) clip_sim = float(torch.cosine_similarity(f1, f2)) L1 = (lpips_transform(img1).unsqueeze(0)*2 - 1).to(device) L2 = (lpips_transform(img2).unsqueeze(0)*2 - 1).to(device) with torch.no_grad(): lp = float(lpips_model(L1, L2)) if cap1 and cap2: _, _, F = score([cap1],[cap2], lang="en", verbose=False) bert_f1 = float(F.mean()) else: bert_f1 = 0.0 return clip_sim, lp, bert_f1 def caption_and_store(img, images, captions): if img is None: return None, "", images, captions try: caption = captioner(img)[0]["generated_text"] except Exception as e: print("Captioning failed:", e) caption = "Caption failed." images = images + [img] captions = captions + [caption] return img, caption, images, captions def fetch_and_caption(url, images, captions): if not url: return None, "", images, captions try: response = requests.get(url) img = Image.open(BytesIO(response.content)).convert("RGB") except Exception as e: print("Failed to fetch image from URL:", e) return None, "Failed to fetch image", images, captions return caption_and_store(img, images, captions) # SECTION THREE # ---------------- Section Three: UI ---------------- def build_ui_with_custom_ui(): with gr.Blocks(title="Multimodal AI Image Studio") as demo: # ---------------- CSS Styling ---------------- gr.HTML("" "") # ---------------- Heading ---------------- gr.Markdown("## Multimodal AI Image Studio: An Integrated Comparative Perspective", elem_classes="heading-orange") images_state = gr.State([]) captions_state = gr.State([]) # ---------------- Step 1: Upload Image ---------------- gr.Markdown("### Upload Reference Image", elem_classes="heading-orange") with gr.Tabs(): with gr.Tab("📁 Upload Image"): with gr.Row(elem_classes="equal-height-row"): with gr.Column(scale=1): upload_input = gr.Image(label="Drag & Drop Image", type="pil") upload_btn = gr.Button("Upload Image & Generate Caption", elem_classes="orange-btn") with gr.Column(scale=1): upload_preview = gr.Image(label="Uploaded Image", interactive=False, elem_classes="stretch-img") enhancer_box = gr.Textbox(label="Add Prompt Enhancer (Optional)", elem_classes="enhancer-box") caption_out = gr.Markdown(label="Generated Caption") with gr.Tab("📷 Webcam"): with gr.Row(elem_classes="equal-height-row"): with gr.Column(scale=1): webcam_input = gr.Image(label="Webcam Live", type="pil", sources=["webcam"], elem_classes="stretch-img") webcam_btn = gr.Button("Capture & Generate Caption", elem_classes="orange-btn") with gr.Column(scale=1): webcam_preview = gr.Image(label="Captured Image", interactive=False, elem_classes="stretch-img") enhancer_box_webcam = gr.Textbox(label="Add Prompt Enhancer (Optional)", elem_classes="enhancer-box") caption_out_webcam = gr.Markdown(label="Generated Caption") with gr.Tab("🔗 From URL"): url_input = gr.Textbox(label="Paste Image URL") url_btn = gr.Button("Fetch & Generate Caption", elem_classes="orange-btn") # ---------------- Caption Buttons ---------------- upload_btn.click(caption_and_store, [upload_input, images_state, captions_state], [upload_preview, caption_out, images_state, captions_state]) webcam_btn.click(caption_and_store, [webcam_input, images_state, captions_state], [webcam_preview, caption_out_webcam, images_state, captions_state]) url_btn.click(fetch_and_caption, [url_input, images_state, captions_state], [upload_preview, caption_out, images_state, captions_state]) # ---------------- Step 2: Generate Images ---------------- gr.Markdown("### Generate Images from Caption", elem_classes="heading-orange") with gr.Row(): with gr.Column(): sd_btn = gr.Button("Generate SD-Turbo Image", elem_classes="orange-btn") sd_preview = gr.Image(label="SD-Turbo Image") with gr.Column(): ds_btn = gr.Button("Generate DreamShaper Image", elem_classes="orange-btn") ds_preview = gr.Image(label="DreamShaper Image") # ---------------- Image Generation Functions ---------------- def generate_sd(_, enhancer, images, captions): if not captions: return None, images, captions base_caption = captions[-1] img, images = generate_image_with_enhancer(base_caption, enhancer or "", negative="", seed=42, style="Photorealistic", images=images) if img: new_caption = captioner(img)[0]["generated_text"] captions = captions + [new_caption] return img, images, captions def generate_ds(_, enhancer, images, captions): if not captions: return None, images, captions base_caption = captions[-1] img, images = generate_dreamshaper_with_enhancer(base_caption, enhancer or "", negative="", seed=123, style="Photorealistic", images=images) if img: new_caption = captioner(img)[0]["generated_text"] captions = captions + [new_caption] return img, images, captions # ---------------- Attach Clicks ---------------- sd_btn.click(generate_sd, [caption_out, enhancer_box, images_state, captions_state], [sd_preview, images_state, captions_state]) ds_btn.click(generate_ds, [caption_out, enhancer_box, images_state, captions_state], [ds_preview, images_state, captions_state]) # ---------------- Step 3: Metrics ---------------- gr.Markdown("### Compute Pairwise Metrics", elem_classes="heading-orange") metrics_btn = gr.Button("Compute Metrics for All Pairs", elem_classes="teal-btn") with gr.Row(elem_classes="metrics-row"): metrics_A = gr.Markdown() metrics_B = gr.Markdown() metrics_C = gr.Markdown() def compute_metrics_all_pairs_ui(images, captions): yield ("
",) * 3 if len(images) < 3 or len(captions) < 3: msg = "⚠️ All three images and captions required." yield msg, msg, msg return pairs = [(0,1,"Reference ↔ SD-Turbo"), (0,2,"Reference ↔ DreamShaper"), (1,2,"SD-Turbo ↔ DreamShaper")] results = [] for i1, i2, label in pairs: clip_sim, lp, bert_f1 = compute_metrics(images, captions, i1, i2) results.append(f"**{label}**
CLIP similarity: {clip_sim:.3f}
LPIPS: {lp:.3f}
BERT F1: {bert_f1:.3f}") yield tuple(results) metrics_btn.click(compute_metrics_all_pairs_ui, [images_state, captions_state], [metrics_A, metrics_B, metrics_C]) # ---------------- Step 4: NLP ---------------- gr.Markdown("### NLP Analysis of Captions", elem_classes="heading-orange") nlp_btn = gr.Button("Analyze Captions", elem_classes="teal-btn") with gr.Row(elem_classes="metrics-row"): nlp_out_A = gr.HTML() nlp_out_B = gr.HTML() nlp_out_C = gr.HTML() def analyze_caption_pipeline_ui(captions): yield ("
",) * 3 if len(captions) < 3: yield "All three captions required.", "All three captions required.", "All three captions required." return labels = ["Reference Image","SD-Turbo","DreamShaper"] results = [] for label, caption in zip(labels, captions): sentiment = "
".join(f"{s['label']}: {s['score']:.2f}" for s in sentiment_model(caption)) ents = "
".join(f"{e['entity_group']}: {e['word']}" for e in ner_model(caption)) or "None" topics_data = topic_model(caption, candidate_labels=["people","animals","objects","food","nature"]) topics = "
".join(f"{l}: {sc:.2f}" for l, sc in zip(topics_data["labels"], topics_data["scores"])) results.append(f"{label}
Sentiment
{sentiment}
Entities
{ents}
Topics
{topics}") yield tuple(results) nlp_btn.click(analyze_caption_pipeline_ui, captions_state, [nlp_out_A, nlp_out_B, nlp_out_C]) # ---------------- Step 5: VQA ---------------- gr.Markdown("### Visual Question Answering (VQA)", elem_classes="heading-orange") with gr.Row(): # Left column: question input and button with gr.Column(scale=1): vqa_input = gr.Textbox(label="Enter a question about the reference image") vqa_btn = gr.Button("Get Answer", elem_classes="teal-btn") # Right column: VQA output with gr.Column(scale=1): vqa_out = gr.Markdown(label="VQA Output") def answer_vqa_ui(question, image): yield "
" if image is None or not question.strip(): yield "⚠️ Provide image + question." return try: # Prepare inputs inputs = vqa_processor(images=image, text=question, return_tensors="pt").to(device) # Use generate() for inference out_ids = vqa_model.generate(**inputs) answer = vqa_processor.decode(out_ids[0], skip_special_tokens=True) yield answer except Exception as e: yield f"⚠️ VQA failed: {str(e)}" vqa_btn.click(answer_vqa_ui, [vqa_input, upload_preview], vqa_out) return demo # ---------------- Launch ---------------- demo = build_ui_with_custom_ui() demo.launch() """ # **Purpose** # ===================================================== # Multimodal AI Image Studio # ===================================================== # Purpose: # This script provides a unified interface for generating, # comparing, and analyzing AI-generated images. # # Key Features: # 1. Upload a reference image and automatically generate captions. # 2. Enhance prompts to generate images using: # - SD-Turbo (Stability AI) # - DreamShaper (Artistic style model) # 3. Compute pairwise metrics between images: # - CLIP similarity # - LPIPS perceptual similarity # - BERTScore textual similarity # 4. NLP analysis of captions: # - Sentiment analysis # - Named entity recognition # - Topic classification # 5. Visual Question Answering (VQA) on the reference image. # # Requirements: # - Python >= 3.9 # - GPU recommended for faster image generation # # Usage: # 1. Install dependencies (see requirements.txt) # 2. Run this script # 3. Access the Gradio web interface for interactive exploration # Section One # ---------------- Install Libraries ---------------- # Libraries import torch import gradio as gr from PIL import Image from diffusers import DiffusionPipeline from transformers import pipeline, BlipProcessor, BlipForQuestionAnswering import lpips import clip from bert_score import score import torchvision.transforms as T import requests from io import BytesIO device = "cuda" if torch.cuda.is_available() else "cpu" def free_gpu_cache(): if device == "cuda": torch.cuda.empty_cache() # ============================== # MODELS # ============================== gen_pipe = DiffusionPipeline.from_pretrained( "stabilityai/sdxl-turbo", torch_dtype=torch.float16 if device=="cuda" else torch.float32 ).to(device) dreamshaper_pipe = DiffusionPipeline.from_pretrained( "Lykon/dreamshaper-7", torch_dtype=torch.float16 if device=="cuda" else torch.float32 ).to(device) captioner = pipeline( "image-to-text", model="Salesforce/blip-image-captioning-large", device=0 if device=="cuda" else -1 ) sentiment_model = pipeline( "sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english", device=-1 ) ner_model = pipeline( "ner", model="dbmdz/bert-large-cased-finetuned-conll03-english", aggregation_strategy="simple", device=-1 ) topic_model = pipeline( "zero-shot-classification", model="facebook/bart-large-mnli", device=-1 ) vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") vqa_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to(device) clip_model, clip_preprocess = clip.load("ViT-B/32", device=device) lpips_model = lpips.LPIPS(net='alex').to(device) lpips_transform = T.Compose([T.ToTensor(), T.Resize((256,256))]) style_map = { "Photorealistic": "photorealistic, ultra-detailed, 8k, cinematic lighting", "Real Life": "natural lighting, true-to-life colors, DSLR", "Documentary": "documentary handheld muted colors", "iPhone Camera": "iPhone photo natural HDR", "Street Photography": "candid street ambient shadows", "Cinematic": "cinematic lighting dramatic depth", "Anime": "anime cel shaded vibrant", "Watercolor": "watercolor soft wash art", "Macro": "macro lens shallow DOF", "Cyberpunk": "neon cyberpunk futuristic", } # Section Two # SEction Two # ============================== # FUNCTIONS # ============================== def generate_image_with_enhancer(base_caption, enhancer, negative, seed, style, images, pipe=gen_pipe): images = images or [] base_caption = base_caption or "" enhancer = enhancer or "" final_prompt = f"{base_caption}, {enhancer}".strip(", ") final_prompt = f"{final_prompt}, {style_map.get(style,'')}".strip(", ") try: seed = int(seed) except: seed = 42 generator = torch.Generator(device=device).manual_seed(seed) try: with torch.no_grad(): out = pipe(prompt=final_prompt, negative_prompt=negative, generator=generator) img = out.images[0] except Exception as e: print(f"{pipe} failed:", e) img = None if img: images.append(img) free_gpu_cache() return img, images generate_dreamshaper_with_enhancer = lambda base_caption, enhancer, negative, seed, style, images: \ generate_image_with_enhancer(base_caption, enhancer, negative, seed, style, images, pipe=dreamshaper_pipe) def caption_for_image(img): try: out = captioner(img) return out[0]["generated_text"] except: return "Caption failed." def answer_vqa(question, image): if not image or not question.strip(): return "Provide image + question." try: inputs_raw = vqa_processor(images=image, text=question, return_tensors="pt") inputs = {k:v.to(device) for k,v in inputs_raw.items()} with torch.no_grad(): out = vqa_model(**inputs) ans_id = out.logits.argmax(-1) return vqa_processor.decode(ans_id[0], skip_special_tokens=True) except: return "VQA failed." def compute_metrics(images, captions, i1, i2): img1, img2 = images[i1], images[i2] cap1, cap2 = captions[i1], captions[i2] t1 = clip_preprocess(img1).unsqueeze(0).to(device) t2 = clip_preprocess(img2).unsqueeze(0).to(device) with torch.no_grad(): f1 = clip_model.encode_image(t1) f2 = clip_model.encode_image(t2) clip_sim = float(torch.cosine_similarity(f1, f2)) L1 = (lpips_transform(img1).unsqueeze(0)*2 - 1).to(device) L2 = (lpips_transform(img2).unsqueeze(0)*2 - 1).to(device) with torch.no_grad(): lp = float(lpips_model(L1, L2)) if cap1 and cap2: _, _, F = score([cap1],[cap2], lang="en", verbose=False) bert_f1 = float(F.mean()) else: bert_f1 = 0.0 return clip_sim, lp, bert_f1 def caption_and_store(img, images, captions): if img is None: return None, "", images, captions try: caption = captioner(img)[0]["generated_text"] except Exception as e: print("Captioning failed:", e) caption = "Caption failed." images = images + [img] captions = captions + [caption] return img, caption, images, captions def fetch_and_caption(url, images, captions): if not url: return None, "", images, captions try: response = requests.get(url) img = Image.open(BytesIO(response.content)).convert("RGB") except Exception as e: print("Failed to fetch image from URL:", e) return None, "Failed to fetch image", images, captions return caption_and_store(img, images, captions) # Section Three # ---------------- Section Three: UI ---------------- def build_ui_with_custom_ui(): with gr.Blocks(title="Multimodal AI Image Studio") as demo: # ---------------- CSS Styling ---------------- gr.HTML(""" """) # ---------------- Heading ---------------- gr.Markdown("## Multimodal AI Image Studio: An Integrated Comparative Perspective", elem_classes="heading-orange") images_state = gr.State([]) captions_state = gr.State([]) # ---------------- Step 1: Upload Image ---------------- gr.Markdown("### Upload Reference Image", elem_classes="heading-orange") with gr.Tabs(): with gr.Tab("📁 Upload Image"): with gr.Row(elem_classes="equal-height-row"): with gr.Column(scale=1): upload_input = gr.Image(label="Drag & Drop Image", type="pil") upload_btn = gr.Button("Upload Image & Generate Caption", elem_classes="orange-btn") with gr.Column(scale=1): upload_preview = gr.Image(label="Uploaded Image", interactive=False, elem_classes="stretch-img") enhancer_box = gr.Textbox(label="Add Prompt Enhancer (Optional)", elem_classes="enhancer-box") caption_out = gr.Markdown(label="Generated Caption") with gr.Tab("📷 Webcam"): with gr.Row(elem_classes="equal-height-row"): with gr.Column(scale=1): webcam_input = gr.Image(label="Webcam Live", type="pil", sources=["webcam"], elem_classes="stretch-img") webcam_btn = gr.Button("Capture & Generate Caption", elem_classes="orange-btn") with gr.Column(scale=1): webcam_preview = gr.Image(label="Captured Image", interactive=False, elem_classes="stretch-img") enhancer_box_webcam = gr.Textbox(label="Add Prompt Enhancer (Optional)", elem_classes="enhancer-box") caption_out_webcam = gr.Markdown(label="Generated Caption") with gr.Tab("🔗 From URL"): url_input = gr.Textbox(label="Paste Image URL") url_btn = gr.Button("Fetch & Generate Caption", elem_classes="orange-btn") # ---------------- Caption Buttons ---------------- upload_btn.click(caption_and_store, [upload_input, images_state, captions_state], [upload_preview, caption_out, images_state, captions_state]) webcam_btn.click(caption_and_store, [webcam_input, images_state, captions_state], [webcam_preview, caption_out_webcam, images_state, captions_state]) url_btn.click(fetch_and_caption, [url_input, images_state, captions_state], [upload_preview, caption_out, images_state, captions_state]) # ---------------- Step 2: Generate Images ---------------- gr.Markdown("### Generate Images from Caption", elem_classes="heading-orange") with gr.Row(): with gr.Column(): sd_btn = gr.Button("Generate SD-Turbo Image", elem_classes="orange-btn") sd_preview = gr.Image(label="SD-Turbo Image") with gr.Column(): ds_btn = gr.Button("Generate DreamShaper Image", elem_classes="orange-btn") ds_preview = gr.Image(label="DreamShaper Image") # ---------------- Image Generation Functions ---------------- def generate_sd(_, enhancer, images, captions): if not captions: return None, images, captions, gr.update(interactive=False), gr.update(interactive=False) base_caption = captions[-1] img, images = generate_image_with_enhancer(base_caption, enhancer or "", negative="", seed=42, style="Photorealistic", images=images) if img: captions = captions + [captioner(img)[0]["generated_text"]] ready = len(images) >= 1 and len(captions) >= 1 return img, images, captions #,gr.update(interactive=ready), gr.update(interactive=ready) def generate_ds(_, enhancer, images, captions): if not captions: return None, images, captions, gr.update(interactive=False), gr.update(interactive=False) base_caption = captions[-1] img, images = generate_dreamshaper_with_enhancer(base_caption, enhancer or "", negative="", seed=123, style="Photorealistic", images=images) if img: captions = captions + [captioner(img)[0]["generated_text"]] ready = len(images) >= 1 and len(captions) >= 1 return img, images, captions #, gr.update(interactive=ready), gr.update(interactive=ready) # ---------------- Step 3: Metrics ---------------- gr.Markdown("### Compute Pairwise Metrics", elem_classes="heading-orange") metrics_btn = gr.Button("Compute Metrics for All Pairs", elem_classes="teal-btn", interactive=False) with gr.Row(elem_classes="metrics-row"): metrics_A = gr.Markdown() metrics_B = gr.Markdown() metrics_C = gr.Markdown() def compute_metrics_all_pairs_ui(images, captions): yield ("
",) * 3 pairs = [(0,1,"Reference ↔ SD-Turbo"), (0,2,"Reference ↔ DreamShaper"), (1,2,"SD-Turbo ↔ DreamShaper")] results = [] if len(images) < 3 or len(captions) < 3: msg = "⚠️ All three images and captions required." yield msg, msg, msg return for i1, i2, label in pairs: clip_sim, lp, bert_f1 = compute_metrics(images, captions, i1, i2) results.append(f"**{label}**
CLIP similarity: {clip_sim:.3f}
LPIPS: {lp:.3f}
BERT F1: {bert_f1:.3f}") yield tuple(results) metrics_btn.click(compute_metrics_all_pairs_ui, [images_state, captions_state], [metrics_A, metrics_B, metrics_C]) # ---------------- Step 4: NLP ---------------- gr.Markdown("### NLP Analysis of Captions", elem_classes="heading-orange") nlp_btn = gr.Button("Analyze Captions", elem_classes="teal-btn", interactive=False) with gr.Row(elem_classes="metrics-row"): nlp_out_A = gr.HTML() nlp_out_B = gr.HTML() nlp_out_C = gr.HTML() def analyze_caption_pipeline_ui(captions): yield ("
",) * 3 if len(captions) < 3: yield "All three captions required.", "All three captions required.", "All three captions required." return labels = ["Reference Image","SD-Turbo","DreamShaper"] results = [] for label, caption in zip(labels, captions): sentiment = "
".join(f"{s['label']}: {s['score']:.2f}" for s in sentiment_model(caption)) ents = "
".join(f"{e['entity_group']}: {e['word']}" for e in ner_model(caption)) or "None" topics_data = topic_model(caption, candidate_labels=["people","animals","objects","food","nature"]) topics = "
".join(f"{l}: {sc:.2f}" for l, sc in zip(topics_data["labels"], topics_data["scores"])) results.append(f"{label}
Sentiment
{sentiment}
Entities
{ents}
Topics
{topics}") yield tuple(results) nlp_btn.click(analyze_caption_pipeline_ui, captions_state, [nlp_out_A, nlp_out_B, nlp_out_C]) # =============================== # Wire SD / DS buttons (AFTER metrics_btn & nlp_btn exist) # =============================== sd_btn.click(generate_sd, [caption_out, enhancer_box, images_state, captions_state], [sd_preview, images_state, captions_state, metrics_btn, nlp_btn]) ds_btn.click(generate_ds, [caption_out, enhancer_box, images_state, captions_state], [ds_preview, images_state, captions_state, metrics_btn, nlp_btn]) # ---------------- Enable Metrics/NLP only when ready ---------------- """ def enable_metrics_nlp(images, captions): ready = len(images) >= 3 and len(captions) >= 3 return ( gr.update(interactive=ready), gr.update(interactive=ready) )""" def enable_metrics_nlp(images, captions): ready = ( len(images) == 3 and len(captions) == 3 and all(c and c != "Caption failed." for c in captions) ) return gr.update(interactive=ready), gr.update(interactive=ready) images_state.change(enable_metrics_nlp, [images_state, captions_state], [metrics_btn, nlp_btn]) # ---------------- Step 5: VQA ---------------- gr.Markdown("### Visual Question Answering (VQA)", elem_classes="heading-orange") with gr.Row(): with gr.Column(scale=1): vqa_input = gr.Textbox(label="Enter a question about the reference image") vqa_btn = gr.Button("Get Answer", elem_classes="teal-btn") with gr.Column(scale=1): vqa_out = gr.Markdown(label="VQA Output") def answer_vqa_ui(question, image): yield "
" if image is None or not question.strip(): yield "⚠️ Provide image + question." return try: # Prepare inputs inputs = vqa_processor(images=image, text=question, return_tensors="pt").to(device) # Use generate() for inference out_ids = vqa_model.generate(**inputs) answer = vqa_processor.decode(out_ids[0], skip_special_tokens=True) yield answer except Exception as e: yield f"⚠️ VQA failed: {str(e)}" vqa_btn.click(answer_vqa_ui, [vqa_input, upload_preview], vqa_out) return demo # ---------------- Launch ---------------- demo = build_ui_with_custom_ui() demo.launch()