Spaces:
Running
Running
| # **Purpose** | |
| # ===================================================== | |
| # Multimodal AI Image Studio | |
| # ===================================================== | |
| # Purpose: | |
| # This script provides a unified interface for generating, | |
| # comparing, and analyzing AI-generated images. | |
| # | |
| # Key Features: | |
| # 1. Upload a reference image and automatically generate captions. | |
| # 2. Enhance prompts to generate images using: | |
| # - SD-Turbo (Stability AI) | |
| # - DreamShaper (Artistic style model) | |
| # 3. Compute pairwise metrics between images: | |
| # - CLIP similarity | |
| # - LPIPS perceptual similarity | |
| # - BERTScore textual similarity | |
| # 4. NLP analysis of captions: | |
| # - Sentiment analysis | |
| # - Named entity recognition | |
| # - Topic classification | |
| # 5. Visual Question Answering (VQA) on the reference image. | |
| # | |
| # Requirements: | |
| # - Python >= 3.9 | |
| # - GPU recommended for faster image generation | |
| # | |
| # Usage: | |
| # 1. Install dependencies (see requirements.txt) | |
| # 2. Run this script | |
| # 3. Access the Gradio web interface for interactive exploration | |
| """ | |
| # **Section One** | |
| # ============================== | |
| # SECTION 1 | |
| # ============================== | |
| # Install | |
| # Section One | |
| # ---------------- Install Libraries ---------------- | |
| # Libraries | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from diffusers import DiffusionPipeline | |
| from transformers import pipeline, BlipProcessor, BlipForQuestionAnswering | |
| import lpips | |
| import clip | |
| from bert_score import score | |
| import torchvision.transforms as T | |
| import requests | |
| from io import BytesIO | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def free_gpu_cache(): | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| # ============================== | |
| # MODELS | |
| # ============================== | |
| gen_pipe = DiffusionPipeline.from_pretrained( | |
| "stabilityai/sdxl-turbo", | |
| torch_dtype=torch.float16 if device=="cuda" else torch.float32 | |
| ).to(device) | |
| dreamshaper_pipe = DiffusionPipeline.from_pretrained( | |
| "Lykon/dreamshaper-7", | |
| torch_dtype=torch.float16 if device=="cuda" else torch.float32 | |
| ).to(device) | |
| captioner = pipeline( | |
| "image-to-text", | |
| model="Salesforce/blip-image-captioning-large", | |
| device=0 if device=="cuda" else -1 | |
| ) | |
| sentiment_model = pipeline( | |
| "sentiment-analysis", | |
| model="distilbert-base-uncased-finetuned-sst-2-english", | |
| device=-1 | |
| ) | |
| ner_model = pipeline( | |
| "ner", | |
| model="dbmdz/bert-large-cased-finetuned-conll03-english", | |
| aggregation_strategy="simple", | |
| device=-1 | |
| ) | |
| topic_model = pipeline( | |
| "zero-shot-classification", | |
| model="facebook/bart-large-mnli", | |
| device=-1 | |
| ) | |
| vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") | |
| vqa_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to(device) | |
| clip_model, clip_preprocess = clip.load("ViT-B/32", device=device) | |
| lpips_model = lpips.LPIPS(net='alex').to(device) | |
| lpips_transform = T.Compose([T.ToTensor(), T.Resize((256,256))]) | |
| style_map = { | |
| "Photorealistic": "photorealistic, ultra-detailed, 8k, cinematic lighting", | |
| "Real Life": "natural lighting, true-to-life colors, DSLR", | |
| "Documentary": "documentary handheld muted colors", | |
| "iPhone Camera": "iPhone photo natural HDR", | |
| "Street Photography": "candid street ambient shadows", | |
| "Cinematic": "cinematic lighting dramatic depth", | |
| "Anime": "anime cel shaded vibrant", | |
| "Watercolor": "watercolor soft wash art", | |
| "Macro": "macro lens shallow DOF", | |
| "Cyberpunk": "neon cyberpunk futuristic", | |
| } | |
| # SEction Two | |
| # ============================== | |
| # FUNCTIONS | |
| # ============================== | |
| def generate_image_with_enhancer(base_caption, enhancer, negative, seed, style, images, pipe=gen_pipe): | |
| images = images or [] | |
| base_caption = base_caption or "" | |
| enhancer = enhancer or "" | |
| final_prompt = f"{base_caption}, {enhancer}".strip(", ") | |
| final_prompt = f"{final_prompt}, {style_map.get(style,'')}".strip(", ") | |
| try: | |
| seed = int(seed) | |
| except: | |
| seed = 42 | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| try: | |
| with torch.no_grad(): | |
| out = pipe(prompt=final_prompt, negative_prompt=negative, generator=generator) | |
| img = out.images[0] | |
| except Exception as e: | |
| print(f"{pipe} failed:", e) | |
| img = None | |
| if img: | |
| images.append(img) | |
| free_gpu_cache() | |
| return img, images | |
| generate_dreamshaper_with_enhancer = lambda base_caption, enhancer, negative, seed, style, images: \ | |
| generate_image_with_enhancer(base_caption, enhancer, negative, seed, style, images, pipe=dreamshaper_pipe) | |
| def caption_for_image(img): | |
| try: | |
| out = captioner(img) | |
| return out[0]["generated_text"] | |
| except: | |
| return "Caption failed." | |
| def answer_vqa(question, image): | |
| if not image or not question.strip(): | |
| return "Provide image + question." | |
| try: | |
| inputs_raw = vqa_processor(images=image, text=question, return_tensors="pt") | |
| inputs = {k:v.to(device) for k,v in inputs_raw.items()} | |
| with torch.no_grad(): | |
| out = vqa_model(**inputs) | |
| ans_id = out.logits.argmax(-1) | |
| return vqa_processor.decode(ans_id[0], skip_special_tokens=True) | |
| except: | |
| return "VQA failed." | |
| def compute_metrics(images, captions, i1, i2): | |
| img1, img2 = images[i1], images[i2] | |
| cap1, cap2 = captions[i1], captions[i2] | |
| t1 = clip_preprocess(img1).unsqueeze(0).to(device) | |
| t2 = clip_preprocess(img2).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| f1 = clip_model.encode_image(t1) | |
| f2 = clip_model.encode_image(t2) | |
| clip_sim = float(torch.cosine_similarity(f1, f2)) | |
| L1 = (lpips_transform(img1).unsqueeze(0)*2 - 1).to(device) | |
| L2 = (lpips_transform(img2).unsqueeze(0)*2 - 1).to(device) | |
| with torch.no_grad(): | |
| lp = float(lpips_model(L1, L2)) | |
| if cap1 and cap2: | |
| _, _, F = score([cap1],[cap2], lang="en", verbose=False) | |
| bert_f1 = float(F.mean()) | |
| else: | |
| bert_f1 = 0.0 | |
| return clip_sim, lp, bert_f1 | |
| def caption_and_store(img, images, captions): | |
| if img is None: | |
| return None, "", images, captions | |
| try: | |
| caption = captioner(img)[0]["generated_text"] | |
| except Exception as e: | |
| print("Captioning failed:", e) | |
| caption = "Caption failed." | |
| images = images + [img] | |
| captions = captions + [caption] | |
| return img, caption, images, captions | |
| def fetch_and_caption(url, images, captions): | |
| if not url: | |
| return None, "", images, captions | |
| try: | |
| response = requests.get(url) | |
| img = Image.open(BytesIO(response.content)).convert("RGB") | |
| except Exception as e: | |
| print("Failed to fetch image from URL:", e) | |
| return None, "Failed to fetch image", images, captions | |
| return caption_and_store(img, images, captions) | |
| # SECTION THREE | |
| # ---------------- Section Three: UI ---------------- | |
| def build_ui_with_custom_ui(): | |
| with gr.Blocks(title="Multimodal AI Image Studio") as demo: | |
| # ---------------- CSS Styling ---------------- | |
| gr.HTML("" | |
| <style> | |
| .heading-orange h2, .heading-orange h3 { color: #ff5500 !important; } | |
| .orange-btn button { background-color: #ff5500 !important; color: white !important; border-radius: 6px !important; height: 36px !important; font-weight: bold; } | |
| .teal-btn button { background-color: #008080 !important; color: white !important; border-radius: 6px !important; height: 40px !important; font-weight: bold; } | |
| .loading-line { height: 4px; background: linear-gradient(90deg, #008080 0%, #00cccc 50%, #008080 100%); background-size: 200% 100%; animation: loading 1s linear infinite; margin-bottom:4px; } | |
| @keyframes loading { 0% { background-position: 200% 0; } 100% { background-position: -200% 0; } } | |
| .enhancer-box textarea { width: 100% !important; height: 36px !important; font-size: 14px; } | |
| .equal-height-row { display: flex; align-items: stretch; } | |
| .equal-height-row > .gr-column { display: flex; flex-direction: column; } | |
| .stretch-img .gr-image-container { flex-grow: 1; display: flex; } | |
| .stretch-img img { width: 100% !important; height: 100% !important; object-fit: contain; } | |
| .metrics-row { display: flex; gap: 20px; } | |
| .metrics-row > div { flex: 1; } | |
| .gradio-tabs button.selected { background-color: #ff5500 !important; color: white !important; font-weight: bold; } | |
| </style> | |
| "") | |
| # ---------------- Heading ---------------- | |
| gr.Markdown("## Multimodal AI Image Studio: An Integrated Comparative Perspective", | |
| elem_classes="heading-orange") | |
| images_state = gr.State([]) | |
| captions_state = gr.State([]) | |
| # ---------------- Step 1: Upload Image ---------------- | |
| gr.Markdown("### Upload Reference Image", elem_classes="heading-orange") | |
| with gr.Tabs(): | |
| with gr.Tab("📁 Upload Image"): | |
| with gr.Row(elem_classes="equal-height-row"): | |
| with gr.Column(scale=1): | |
| upload_input = gr.Image(label="Drag & Drop Image", type="pil") | |
| upload_btn = gr.Button("Upload Image & Generate Caption", elem_classes="orange-btn") | |
| with gr.Column(scale=1): | |
| upload_preview = gr.Image(label="Uploaded Image", interactive=False, elem_classes="stretch-img") | |
| enhancer_box = gr.Textbox(label="Add Prompt Enhancer (Optional)", elem_classes="enhancer-box") | |
| caption_out = gr.Markdown(label="Generated Caption") | |
| with gr.Tab("📷 Webcam"): | |
| with gr.Row(elem_classes="equal-height-row"): | |
| with gr.Column(scale=1): | |
| webcam_input = gr.Image(label="Webcam Live", type="pil", sources=["webcam"], elem_classes="stretch-img") | |
| webcam_btn = gr.Button("Capture & Generate Caption", elem_classes="orange-btn") | |
| with gr.Column(scale=1): | |
| webcam_preview = gr.Image(label="Captured Image", interactive=False, elem_classes="stretch-img") | |
| enhancer_box_webcam = gr.Textbox(label="Add Prompt Enhancer (Optional)", elem_classes="enhancer-box") | |
| caption_out_webcam = gr.Markdown(label="Generated Caption") | |
| with gr.Tab("🔗 From URL"): | |
| url_input = gr.Textbox(label="Paste Image URL") | |
| url_btn = gr.Button("Fetch & Generate Caption", elem_classes="orange-btn") | |
| # ---------------- Caption Buttons ---------------- | |
| upload_btn.click(caption_and_store, [upload_input, images_state, captions_state], | |
| [upload_preview, caption_out, images_state, captions_state]) | |
| webcam_btn.click(caption_and_store, [webcam_input, images_state, captions_state], | |
| [webcam_preview, caption_out_webcam, images_state, captions_state]) | |
| url_btn.click(fetch_and_caption, [url_input, images_state, captions_state], | |
| [upload_preview, caption_out, images_state, captions_state]) | |
| # ---------------- Step 2: Generate Images ---------------- | |
| gr.Markdown("### Generate Images from Caption", elem_classes="heading-orange") | |
| with gr.Row(): | |
| with gr.Column(): | |
| sd_btn = gr.Button("Generate SD-Turbo Image", elem_classes="orange-btn") | |
| sd_preview = gr.Image(label="SD-Turbo Image") | |
| with gr.Column(): | |
| ds_btn = gr.Button("Generate DreamShaper Image", elem_classes="orange-btn") | |
| ds_preview = gr.Image(label="DreamShaper Image") | |
| # ---------------- Image Generation Functions ---------------- | |
| def generate_sd(_, enhancer, images, captions): | |
| if not captions: | |
| return None, images, captions | |
| base_caption = captions[-1] | |
| img, images = generate_image_with_enhancer(base_caption, enhancer or "", negative="", seed=42, style="Photorealistic", images=images) | |
| if img: | |
| new_caption = captioner(img)[0]["generated_text"] | |
| captions = captions + [new_caption] | |
| return img, images, captions | |
| def generate_ds(_, enhancer, images, captions): | |
| if not captions: | |
| return None, images, captions | |
| base_caption = captions[-1] | |
| img, images = generate_dreamshaper_with_enhancer(base_caption, enhancer or "", negative="", seed=123, style="Photorealistic", images=images) | |
| if img: | |
| new_caption = captioner(img)[0]["generated_text"] | |
| captions = captions + [new_caption] | |
| return img, images, captions | |
| # ---------------- Attach Clicks ---------------- | |
| sd_btn.click(generate_sd, [caption_out, enhancer_box, images_state, captions_state], | |
| [sd_preview, images_state, captions_state]) | |
| ds_btn.click(generate_ds, [caption_out, enhancer_box, images_state, captions_state], | |
| [ds_preview, images_state, captions_state]) | |
| # ---------------- Step 3: Metrics ---------------- | |
| gr.Markdown("### Compute Pairwise Metrics", elem_classes="heading-orange") | |
| metrics_btn = gr.Button("Compute Metrics for All Pairs", elem_classes="teal-btn") | |
| with gr.Row(elem_classes="metrics-row"): | |
| metrics_A = gr.Markdown() | |
| metrics_B = gr.Markdown() | |
| metrics_C = gr.Markdown() | |
| def compute_metrics_all_pairs_ui(images, captions): | |
| yield ("<div class='loading-line'></div>",) * 3 | |
| if len(images) < 3 or len(captions) < 3: | |
| msg = "⚠️ All three images and captions required." | |
| yield msg, msg, msg | |
| return | |
| pairs = [(0,1,"Reference ↔ SD-Turbo"), (0,2,"Reference ↔ DreamShaper"), (1,2,"SD-Turbo ↔ DreamShaper")] | |
| results = [] | |
| for i1, i2, label in pairs: | |
| clip_sim, lp, bert_f1 = compute_metrics(images, captions, i1, i2) | |
| results.append(f"**{label}**<br>CLIP similarity: {clip_sim:.3f}<br>LPIPS: {lp:.3f}<br>BERT F1: {bert_f1:.3f}") | |
| yield tuple(results) | |
| metrics_btn.click(compute_metrics_all_pairs_ui, [images_state, captions_state], | |
| [metrics_A, metrics_B, metrics_C]) | |
| # ---------------- Step 4: NLP ---------------- | |
| gr.Markdown("### NLP Analysis of Captions", elem_classes="heading-orange") | |
| nlp_btn = gr.Button("Analyze Captions", elem_classes="teal-btn") | |
| with gr.Row(elem_classes="metrics-row"): | |
| nlp_out_A = gr.HTML() | |
| nlp_out_B = gr.HTML() | |
| nlp_out_C = gr.HTML() | |
| def analyze_caption_pipeline_ui(captions): | |
| yield ("<div class='loading-line'></div>",) * 3 | |
| if len(captions) < 3: | |
| yield "<b>All three captions required.</b>", "<b>All three captions required.</b>", "<b>All three captions required.</b>" | |
| return | |
| labels = ["Reference Image","SD-Turbo","DreamShaper"] | |
| results = [] | |
| for label, caption in zip(labels, captions): | |
| sentiment = "<br>".join(f"{s['label']}: {s['score']:.2f}" for s in sentiment_model(caption)) | |
| ents = "<br>".join(f"{e['entity_group']}: {e['word']}" for e in ner_model(caption)) or "None" | |
| topics_data = topic_model(caption, candidate_labels=["people","animals","objects","food","nature"]) | |
| topics = "<br>".join(f"{l}: {sc:.2f}" for l, sc in zip(topics_data["labels"], topics_data["scores"])) | |
| results.append(f"<b>{label}</b><br><b>Sentiment</b><br>{sentiment}<br><b>Entities</b><br>{ents}<br><b>Topics</b><br>{topics}") | |
| yield tuple(results) | |
| nlp_btn.click(analyze_caption_pipeline_ui, captions_state, | |
| [nlp_out_A, nlp_out_B, nlp_out_C]) | |
| # ---------------- Step 5: VQA ---------------- | |
| gr.Markdown("### Visual Question Answering (VQA)", elem_classes="heading-orange") | |
| with gr.Row(): | |
| # Left column: question input and button | |
| with gr.Column(scale=1): | |
| vqa_input = gr.Textbox(label="Enter a question about the reference image") | |
| vqa_btn = gr.Button("Get Answer", elem_classes="teal-btn") | |
| # Right column: VQA output | |
| with gr.Column(scale=1): | |
| vqa_out = gr.Markdown(label="VQA Output") | |
| def answer_vqa_ui(question, image): | |
| yield "<div class='loading-line'></div>" | |
| if image is None or not question.strip(): | |
| yield "⚠️ Provide image + question." | |
| return | |
| try: | |
| # Prepare inputs | |
| inputs = vqa_processor(images=image, text=question, return_tensors="pt").to(device) | |
| # Use generate() for inference | |
| out_ids = vqa_model.generate(**inputs) | |
| answer = vqa_processor.decode(out_ids[0], skip_special_tokens=True) | |
| yield answer | |
| except Exception as e: | |
| yield f"⚠️ VQA failed: {str(e)}" | |
| vqa_btn.click(answer_vqa_ui, [vqa_input, upload_preview], vqa_out) | |
| return demo | |
| # ---------------- Launch ---------------- | |
| demo = build_ui_with_custom_ui() | |
| demo.launch() | |
| """ | |
| # **Purpose** | |
| # ===================================================== | |
| # Multimodal AI Image Studio | |
| # ===================================================== | |
| # Purpose: | |
| # This script provides a unified interface for generating, | |
| # comparing, and analyzing AI-generated images. | |
| # | |
| # Key Features: | |
| # 1. Upload a reference image and automatically generate captions. | |
| # 2. Enhance prompts to generate images using: | |
| # - SD-Turbo (Stability AI) | |
| # - DreamShaper (Artistic style model) | |
| # 3. Compute pairwise metrics between images: | |
| # - CLIP similarity | |
| # - LPIPS perceptual similarity | |
| # - BERTScore textual similarity | |
| # 4. NLP analysis of captions: | |
| # - Sentiment analysis | |
| # - Named entity recognition | |
| # - Topic classification | |
| # 5. Visual Question Answering (VQA) on the reference image. | |
| # | |
| # Requirements: | |
| # - Python >= 3.9 | |
| # - GPU recommended for faster image generation | |
| # | |
| # Usage: | |
| # 1. Install dependencies (see requirements.txt) | |
| # 2. Run this script | |
| # 3. Access the Gradio web interface for interactive exploration | |
| # Section One | |
| # ---------------- Install Libraries ---------------- | |
| # Libraries | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from diffusers import DiffusionPipeline | |
| from transformers import pipeline, BlipProcessor, BlipForQuestionAnswering | |
| import lpips | |
| import clip | |
| from bert_score import score | |
| import torchvision.transforms as T | |
| import requests | |
| from io import BytesIO | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def free_gpu_cache(): | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| # ============================== | |
| # MODELS | |
| # ============================== | |
| gen_pipe = DiffusionPipeline.from_pretrained( | |
| "stabilityai/sdxl-turbo", | |
| torch_dtype=torch.float16 if device=="cuda" else torch.float32 | |
| ).to(device) | |
| dreamshaper_pipe = DiffusionPipeline.from_pretrained( | |
| "Lykon/dreamshaper-7", | |
| torch_dtype=torch.float16 if device=="cuda" else torch.float32 | |
| ).to(device) | |
| captioner = pipeline( | |
| "image-to-text", | |
| model="Salesforce/blip-image-captioning-large", | |
| device=0 if device=="cuda" else -1 | |
| ) | |
| sentiment_model = pipeline( | |
| "sentiment-analysis", | |
| model="distilbert-base-uncased-finetuned-sst-2-english", | |
| device=-1 | |
| ) | |
| ner_model = pipeline( | |
| "ner", | |
| model="dbmdz/bert-large-cased-finetuned-conll03-english", | |
| aggregation_strategy="simple", | |
| device=-1 | |
| ) | |
| topic_model = pipeline( | |
| "zero-shot-classification", | |
| model="facebook/bart-large-mnli", | |
| device=-1 | |
| ) | |
| vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") | |
| vqa_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to(device) | |
| clip_model, clip_preprocess = clip.load("ViT-B/32", device=device) | |
| lpips_model = lpips.LPIPS(net='alex').to(device) | |
| lpips_transform = T.Compose([T.ToTensor(), T.Resize((256,256))]) | |
| style_map = { | |
| "Photorealistic": "photorealistic, ultra-detailed, 8k, cinematic lighting", | |
| "Real Life": "natural lighting, true-to-life colors, DSLR", | |
| "Documentary": "documentary handheld muted colors", | |
| "iPhone Camera": "iPhone photo natural HDR", | |
| "Street Photography": "candid street ambient shadows", | |
| "Cinematic": "cinematic lighting dramatic depth", | |
| "Anime": "anime cel shaded vibrant", | |
| "Watercolor": "watercolor soft wash art", | |
| "Macro": "macro lens shallow DOF", | |
| "Cyberpunk": "neon cyberpunk futuristic", | |
| } | |
| # Section Two | |
| # SEction Two | |
| # ============================== | |
| # FUNCTIONS | |
| # ============================== | |
| def generate_image_with_enhancer(base_caption, enhancer, negative, seed, style, images, pipe=gen_pipe): | |
| images = images or [] | |
| base_caption = base_caption or "" | |
| enhancer = enhancer or "" | |
| final_prompt = f"{base_caption}, {enhancer}".strip(", ") | |
| final_prompt = f"{final_prompt}, {style_map.get(style,'')}".strip(", ") | |
| try: | |
| seed = int(seed) | |
| except: | |
| seed = 42 | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| try: | |
| with torch.no_grad(): | |
| out = pipe(prompt=final_prompt, negative_prompt=negative, generator=generator) | |
| img = out.images[0] | |
| except Exception as e: | |
| print(f"{pipe} failed:", e) | |
| img = None | |
| if img: | |
| images.append(img) | |
| free_gpu_cache() | |
| return img, images | |
| generate_dreamshaper_with_enhancer = lambda base_caption, enhancer, negative, seed, style, images: \ | |
| generate_image_with_enhancer(base_caption, enhancer, negative, seed, style, images, pipe=dreamshaper_pipe) | |
| def caption_for_image(img): | |
| try: | |
| out = captioner(img) | |
| return out[0]["generated_text"] | |
| except: | |
| return "Caption failed." | |
| def answer_vqa(question, image): | |
| if not image or not question.strip(): | |
| return "Provide image + question." | |
| try: | |
| inputs_raw = vqa_processor(images=image, text=question, return_tensors="pt") | |
| inputs = {k:v.to(device) for k,v in inputs_raw.items()} | |
| with torch.no_grad(): | |
| out = vqa_model(**inputs) | |
| ans_id = out.logits.argmax(-1) | |
| return vqa_processor.decode(ans_id[0], skip_special_tokens=True) | |
| except: | |
| return "VQA failed." | |
| def compute_metrics(images, captions, i1, i2): | |
| img1, img2 = images[i1], images[i2] | |
| cap1, cap2 = captions[i1], captions[i2] | |
| t1 = clip_preprocess(img1).unsqueeze(0).to(device) | |
| t2 = clip_preprocess(img2).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| f1 = clip_model.encode_image(t1) | |
| f2 = clip_model.encode_image(t2) | |
| clip_sim = float(torch.cosine_similarity(f1, f2)) | |
| L1 = (lpips_transform(img1).unsqueeze(0)*2 - 1).to(device) | |
| L2 = (lpips_transform(img2).unsqueeze(0)*2 - 1).to(device) | |
| with torch.no_grad(): | |
| lp = float(lpips_model(L1, L2)) | |
| if cap1 and cap2: | |
| _, _, F = score([cap1],[cap2], lang="en", verbose=False) | |
| bert_f1 = float(F.mean()) | |
| else: | |
| bert_f1 = 0.0 | |
| return clip_sim, lp, bert_f1 | |
| def caption_and_store(img, images, captions): | |
| if img is None: | |
| return None, "", images, captions | |
| try: | |
| caption = captioner(img)[0]["generated_text"] | |
| except Exception as e: | |
| print("Captioning failed:", e) | |
| caption = "Caption failed." | |
| images = images + [img] | |
| captions = captions + [caption] | |
| return img, caption, images, captions | |
| def fetch_and_caption(url, images, captions): | |
| if not url: | |
| return None, "", images, captions | |
| try: | |
| response = requests.get(url) | |
| img = Image.open(BytesIO(response.content)).convert("RGB") | |
| except Exception as e: | |
| print("Failed to fetch image from URL:", e) | |
| return None, "Failed to fetch image", images, captions | |
| return caption_and_store(img, images, captions) | |
| # Section Three | |
| # ---------------- Section Three: UI ---------------- | |
| def build_ui_with_custom_ui(): | |
| with gr.Blocks(title="Multimodal AI Image Studio") as demo: | |
| # ---------------- CSS Styling ---------------- | |
| gr.HTML(""" | |
| <style> | |
| .heading-orange h2, .heading-orange h3 { color: #ff5500 !important; } | |
| .orange-btn button { background-color: #ff5500 !important; color: white !important; border-radius: 6px !important; height: 36px !important; font-weight: bold; } | |
| .teal-btn button { background-color: #008080 !important; color: white !important; border-radius: 6px !important; height: 40px !important; font-weight: bold; } | |
| .loading-line { height: 4px; background: linear-gradient(90deg, #008080 0%, #00cccc 50%, #008080 100%); background-size: 200% 100%; animation: loading 1s linear infinite; margin-bottom:4px; } | |
| @keyframes loading { 0% { background-position: 200% 0; } 100% { background-position: -200% 0; } } | |
| .enhancer-box textarea { width: 100% !important; height: 36px !important; font-size: 14px; } | |
| .equal-height-row { display: flex; align-items: stretch; } | |
| .equal-height-row > .gr-column { display: flex; flex-direction: column; } | |
| .stretch-img .gr-image-container { flex-grow: 1; display: flex; } | |
| .stretch-img img { width: 100% !important; height: 100% !important; object-fit: contain; } | |
| .metrics-row { display: flex; gap: 20px; } | |
| .metrics-row > div { flex: 1; } | |
| .gradio-tabs button.selected { background-color: #ff5500 !important; color: white !important; font-weight: bold; } | |
| </style> | |
| """) | |
| # ---------------- Heading ---------------- | |
| gr.Markdown("## Multimodal AI Image Studio: An Integrated Comparative Perspective", | |
| elem_classes="heading-orange") | |
| images_state = gr.State([]) | |
| captions_state = gr.State([]) | |
| # ---------------- Step 1: Upload Image ---------------- | |
| gr.Markdown("### Upload Reference Image", elem_classes="heading-orange") | |
| with gr.Tabs(): | |
| with gr.Tab("📁 Upload Image"): | |
| with gr.Row(elem_classes="equal-height-row"): | |
| with gr.Column(scale=1): | |
| upload_input = gr.Image(label="Drag & Drop Image", type="pil") | |
| upload_btn = gr.Button("Upload Image & Generate Caption", elem_classes="orange-btn") | |
| with gr.Column(scale=1): | |
| upload_preview = gr.Image(label="Uploaded Image", interactive=False, elem_classes="stretch-img") | |
| enhancer_box = gr.Textbox(label="Add Prompt Enhancer (Optional)", elem_classes="enhancer-box") | |
| caption_out = gr.Markdown(label="Generated Caption") | |
| with gr.Tab("📷 Webcam"): | |
| with gr.Row(elem_classes="equal-height-row"): | |
| with gr.Column(scale=1): | |
| webcam_input = gr.Image(label="Webcam Live", type="pil", sources=["webcam"], elem_classes="stretch-img") | |
| webcam_btn = gr.Button("Capture & Generate Caption", elem_classes="orange-btn") | |
| with gr.Column(scale=1): | |
| webcam_preview = gr.Image(label="Captured Image", interactive=False, elem_classes="stretch-img") | |
| enhancer_box_webcam = gr.Textbox(label="Add Prompt Enhancer (Optional)", elem_classes="enhancer-box") | |
| caption_out_webcam = gr.Markdown(label="Generated Caption") | |
| with gr.Tab("🔗 From URL"): | |
| url_input = gr.Textbox(label="Paste Image URL") | |
| url_btn = gr.Button("Fetch & Generate Caption", elem_classes="orange-btn") | |
| # ---------------- Caption Buttons ---------------- | |
| upload_btn.click(caption_and_store, [upload_input, images_state, captions_state], | |
| [upload_preview, caption_out, images_state, captions_state]) | |
| webcam_btn.click(caption_and_store, [webcam_input, images_state, captions_state], | |
| [webcam_preview, caption_out_webcam, images_state, captions_state]) | |
| url_btn.click(fetch_and_caption, [url_input, images_state, captions_state], | |
| [upload_preview, caption_out, images_state, captions_state]) | |
| # ---------------- Step 2: Generate Images ---------------- | |
| gr.Markdown("### Generate Images from Caption", elem_classes="heading-orange") | |
| with gr.Row(): | |
| with gr.Column(): | |
| sd_btn = gr.Button("Generate SD-Turbo Image", elem_classes="orange-btn") | |
| sd_preview = gr.Image(label="SD-Turbo Image") | |
| with gr.Column(): | |
| ds_btn = gr.Button("Generate DreamShaper Image", elem_classes="orange-btn") | |
| ds_preview = gr.Image(label="DreamShaper Image") | |
| # ---------------- Image Generation Functions ---------------- | |
| def generate_sd(_, enhancer, images, captions): | |
| if not captions: | |
| return None, images, captions, gr.update(interactive=False), gr.update(interactive=False) | |
| base_caption = captions[-1] | |
| img, images = generate_image_with_enhancer(base_caption, enhancer or "", negative="", seed=42, style="Photorealistic", images=images) | |
| if img: | |
| captions = captions + [captioner(img)[0]["generated_text"]] | |
| ready = len(images) >= 1 and len(captions) >= 1 | |
| return img, images, captions #,gr.update(interactive=ready), gr.update(interactive=ready) | |
| def generate_ds(_, enhancer, images, captions): | |
| if not captions: | |
| return None, images, captions, gr.update(interactive=False), gr.update(interactive=False) | |
| base_caption = captions[-1] | |
| img, images = generate_dreamshaper_with_enhancer(base_caption, enhancer or "", negative="", seed=123, style="Photorealistic", images=images) | |
| if img: | |
| captions = captions + [captioner(img)[0]["generated_text"]] | |
| ready = len(images) >= 1 and len(captions) >= 1 | |
| return img, images, captions #, gr.update(interactive=ready), gr.update(interactive=ready) | |
| # ---------------- Step 3: Metrics ---------------- | |
| gr.Markdown("### Compute Pairwise Metrics", elem_classes="heading-orange") | |
| metrics_btn = gr.Button("Compute Metrics for All Pairs", elem_classes="teal-btn", interactive=False) | |
| with gr.Row(elem_classes="metrics-row"): | |
| metrics_A = gr.Markdown() | |
| metrics_B = gr.Markdown() | |
| metrics_C = gr.Markdown() | |
| def compute_metrics_all_pairs_ui(images, captions): | |
| yield ("<div class='loading-line'></div>",) * 3 | |
| pairs = [(0,1,"Reference ↔ SD-Turbo"), (0,2,"Reference ↔ DreamShaper"), (1,2,"SD-Turbo ↔ DreamShaper")] | |
| results = [] | |
| if len(images) < 3 or len(captions) < 3: | |
| msg = "⚠️ All three images and captions required." | |
| yield msg, msg, msg | |
| return | |
| for i1, i2, label in pairs: | |
| clip_sim, lp, bert_f1 = compute_metrics(images, captions, i1, i2) | |
| results.append(f"**{label}**<br>CLIP similarity: {clip_sim:.3f}<br>LPIPS: {lp:.3f}<br>BERT F1: {bert_f1:.3f}") | |
| yield tuple(results) | |
| metrics_btn.click(compute_metrics_all_pairs_ui, [images_state, captions_state], | |
| [metrics_A, metrics_B, metrics_C]) | |
| # ---------------- Step 4: NLP ---------------- | |
| gr.Markdown("### NLP Analysis of Captions", elem_classes="heading-orange") | |
| nlp_btn = gr.Button("Analyze Captions", elem_classes="teal-btn", interactive=False) | |
| with gr.Row(elem_classes="metrics-row"): | |
| nlp_out_A = gr.HTML() | |
| nlp_out_B = gr.HTML() | |
| nlp_out_C = gr.HTML() | |
| def analyze_caption_pipeline_ui(captions): | |
| yield ("<div class='loading-line'></div>",) * 3 | |
| if len(captions) < 3: | |
| yield "<b>All three captions required.</b>", "<b>All three captions required.</b>", "<b>All three captions required.</b>" | |
| return | |
| labels = ["Reference Image","SD-Turbo","DreamShaper"] | |
| results = [] | |
| for label, caption in zip(labels, captions): | |
| sentiment = "<br>".join(f"{s['label']}: {s['score']:.2f}" for s in sentiment_model(caption)) | |
| ents = "<br>".join(f"{e['entity_group']}: {e['word']}" for e in ner_model(caption)) or "None" | |
| topics_data = topic_model(caption, candidate_labels=["people","animals","objects","food","nature"]) | |
| topics = "<br>".join(f"{l}: {sc:.2f}" for l, sc in zip(topics_data["labels"], topics_data["scores"])) | |
| results.append(f"<b>{label}</b><br><b>Sentiment</b><br>{sentiment}<br><b>Entities</b><br>{ents}<br><b>Topics</b><br>{topics}") | |
| yield tuple(results) | |
| nlp_btn.click(analyze_caption_pipeline_ui, captions_state, | |
| [nlp_out_A, nlp_out_B, nlp_out_C]) | |
| # =============================== | |
| # Wire SD / DS buttons (AFTER metrics_btn & nlp_btn exist) | |
| # =============================== | |
| sd_btn.click(generate_sd, [caption_out, enhancer_box, images_state, captions_state], | |
| [sd_preview, images_state, captions_state, metrics_btn, nlp_btn]) | |
| ds_btn.click(generate_ds, [caption_out, enhancer_box, images_state, captions_state], | |
| [ds_preview, images_state, captions_state, metrics_btn, nlp_btn]) | |
| # ---------------- Enable Metrics/NLP only when ready ---------------- | |
| """ | |
| def enable_metrics_nlp(images, captions): | |
| ready = len(images) >= 3 and len(captions) >= 3 | |
| return ( | |
| gr.update(interactive=ready), | |
| gr.update(interactive=ready) | |
| )""" | |
| def enable_metrics_nlp(images, captions): | |
| ready = ( | |
| len(images) == 3 and | |
| len(captions) == 3 and | |
| all(c and c != "Caption failed." for c in captions) | |
| ) | |
| return gr.update(interactive=ready), gr.update(interactive=ready) | |
| images_state.change(enable_metrics_nlp, [images_state, captions_state], [metrics_btn, nlp_btn]) | |
| # ---------------- Step 5: VQA ---------------- | |
| gr.Markdown("### Visual Question Answering (VQA)", elem_classes="heading-orange") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| vqa_input = gr.Textbox(label="Enter a question about the reference image") | |
| vqa_btn = gr.Button("Get Answer", elem_classes="teal-btn") | |
| with gr.Column(scale=1): | |
| vqa_out = gr.Markdown(label="VQA Output") | |
| def answer_vqa_ui(question, image): | |
| yield "<div class='loading-line'></div>" | |
| if image is None or not question.strip(): | |
| yield "⚠️ Provide image + question." | |
| return | |
| try: | |
| # Prepare inputs | |
| inputs = vqa_processor(images=image, text=question, return_tensors="pt").to(device) | |
| # Use generate() for inference | |
| out_ids = vqa_model.generate(**inputs) | |
| answer = vqa_processor.decode(out_ids[0], skip_special_tokens=True) | |
| yield answer | |
| except Exception as e: | |
| yield f"⚠️ VQA failed: {str(e)}" | |
| vqa_btn.click(answer_vqa_ui, [vqa_input, upload_preview], vqa_out) | |
| return demo | |
| # ---------------- Launch ---------------- | |
| demo = build_ui_with_custom_ui() | |
| demo.launch() | |