Spaces:
Sleeping
Sleeping
| # ---------------- Libraries ---------------- | |
| import torch | |
| import gradio as gr | |
| from diffusers import DiffusionPipeline | |
| from transformers import pipeline, BlipProcessor, BlipForQuestionAnswering | |
| import lpips | |
| import clip | |
| from bert_score import score | |
| import torchvision.transforms as T | |
| # ---------------- Device Setup ---------------- | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ---------------- GPU Cache Free ---------------- | |
| def free_gpu_cache(): | |
| torch.cuda.empty_cache() | |
| # ---------------- Load SD Turbo & DreamShaper ---------------- | |
| gen_pipe = DiffusionPipeline.from_pretrained( | |
| "stabilityai/sd-turbo", | |
| torch_dtype=torch.float16 if device=="cuda" else torch.float32 | |
| ).to(device) | |
| dreamshaper_pipe = DiffusionPipeline.from_pretrained( | |
| "Lykon/dreamshaper-7", | |
| torch_dtype=torch.float16 if device=="cuda" else torch.float32 | |
| ).to(device) | |
| # ---------------- Load NLP Models ---------------- | |
| captioner = pipeline( | |
| "image-to-text", | |
| model="Salesforce/blip-image-captioning-large", | |
| device=0 if device=="cuda" else -1 | |
| ) | |
| sentiment_model = pipeline( | |
| "sentiment-analysis", | |
| model="distilbert-base-uncased-finetuned-sst-2-english", | |
| device=0 if device=="cuda" else -1 | |
| ) | |
| ner_model = pipeline( | |
| "ner", | |
| model="dbmdz/bert-large-cased-finetuned-conll03-english", | |
| aggregation_strategy="simple", | |
| device=0 if device=="cuda" else -1 | |
| ) | |
| topic_model = pipeline( | |
| "zero-shot-classification", | |
| model="facebook/bart-large-mnli", | |
| device=0 if device=="cuda" else -1 | |
| ) | |
| # ---------------- Load VQA Model ---------------- | |
| vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") | |
| vqa_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to(device) | |
| # ---------------- Load CLIP & LPIPS ---------------- | |
| clip_model, clip_preprocess = clip.load("ViT-B/32", device=device) | |
| lpips_model = lpips.LPIPS(net='alex').to(device) | |
| # ---------------- Style Map ---------------- | |
| style_map = { | |
| "Photorealistic": "photorealistic, ultra-detailed, 8k, cinematic lighting", | |
| "Real Life": "natural lighting, true-to-life colors, DSLR", | |
| "Documentary": "documentary handheld muted colors", | |
| "iPhone Camera": "iPhone photo natural HDR", | |
| "Street Photography": "candid street ambient shadows", | |
| "Cinematic": "cinematic lighting dramatic depth", | |
| "Anime": "anime cel shaded vibrant", | |
| "Watercolor": "watercolor soft wash art", | |
| "Macro": "macro lens shallow DOF", | |
| "Cyberpunk": "neon cyberpunk futuristic", | |
| } | |
| lpips_transform = T.Compose([T.ToTensor(), T.Resize((256,256))]) | |
| # ---------------- Image Generation Functions ---------------- | |
| def generate_image_and_store(prompt, negative, seed, style, images): | |
| images = images or [] | |
| enhanced_prompt = f"{prompt}, {style_map.get(style,'')}" | |
| generator = torch.Generator(device=device).manual_seed(int(seed)) | |
| ctx = torch.autocast("cuda") if device=="cuda" else torch.no_grad() | |
| with ctx: | |
| img = gen_pipe(prompt=enhanced_prompt, negative_prompt=negative, generator=generator).images[0] | |
| images.append(img) | |
| free_gpu_cache() | |
| return img, images | |
| def generate_dreamshaper_image(prompt, negative, seed, style, images): | |
| images = images or [] | |
| enhanced_prompt = f"{prompt}, {style_map.get(style,'')}" | |
| generator = torch.Generator(device=device).manual_seed(int(seed)) | |
| ctx = torch.autocast("cuda") if device=="cuda" else torch.no_grad() | |
| with ctx: | |
| img = dreamshaper_pipe(prompt=enhanced_prompt, negative_prompt=negative, generator=generator).images[0] | |
| images.append(img) | |
| free_gpu_cache() | |
| return img, images | |
| # ---------------- VQA ---------------- | |
| def answer_vqa(question, image): | |
| if image is None or question.strip() == "": | |
| return "Upload an image and enter a question." | |
| inputs = vqa_processor(images=image, text=question, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| generated_ids = vqa_model.generate(**inputs) | |
| answer = vqa_processor.decode(generated_ids[0], skip_special_tokens=True) | |
| return answer | |
| # ---------------- Metrics Computation ---------------- | |
| def compute_metrics_button(images, captions, idx1, idx2): | |
| img1_clip = clip_preprocess(images[idx1]).unsqueeze(0).to(device) | |
| img2_clip = clip_preprocess(images[idx2]).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| feat1 = clip_model.encode_image(img1_clip) | |
| feat2 = clip_model.encode_image(img2_clip) | |
| clip_sim = float(torch.cosine_similarity(feat1, feat2).item()) | |
| img1_lp = lpips_transform(images[idx1]).unsqueeze(0).to(device) * 2 - 1 | |
| img2_lp = lpips_transform(images[idx2]).unsqueeze(0).to(device) * 2 - 1 | |
| with torch.no_grad(): | |
| lpips_score = float(lpips_model(img1_lp, img2_lp).item()) | |
| _, _, F1 = score([captions[idx1]], [captions[idx2]], lang="en", verbose=False) | |
| bert_f1 = float(F1.mean().item()) | |
| return f""" | |
| **Metrics Comparison** | |
| - CLIP Similarity: {clip_sim:.4f} | |
| - LPIPS Score: {lpips_score:.4f} | |
| - BERTScore F1: {bert_f1:.4f} | |
| """ | |
| # ---------------- Build Gradio UI ---------------- | |
| # ---------------- Build Gradio UI with Original Look ---------------- | |
| def build_ui_with_custom_ui(): | |
| with gr.Blocks(title="Multimodal AI Image Studio") as demo: | |
| # ---------------- CSS Styling ---------------- | |
| gr.HTML(""" | |
| <style> | |
| .heading-orange h2, .heading-orange h3 { color: #ff5500 !important; } | |
| .orange-btn button { background-color: #ff5500 !important; color: white !important; border-radius: 6px !important; height: 36px !important; font-weight: bold; } | |
| .teal-btn button { background-color: #008080 !important; color: white !important; border-radius: 6px !important; height: 40px !important; font-weight: bold; } | |
| /* Horizontal thin spinner */ | |
| .loading-line { | |
| height: 4px; | |
| background: linear-gradient(90deg, #008080 0%, #00cccc 50%, #008080 100%); | |
| background-size: 200% 100%; | |
| animation: loading 1s linear infinite; | |
| } | |
| @keyframes loading { | |
| 0% { background-position: 200% 0; } | |
| 100% { background-position: -200% 0; } | |
| } | |
| </style> | |
| """) | |
| # ---------------- Heading ---------------- | |
| gr.Markdown("## Multimodal AI Image Studio: An Integrated Comparative Perspective", elem_classes="heading-orange") | |
| # ---------------- States ---------------- | |
| images_state = gr.State([]) | |
| captions_state = gr.State([]) | |
| # ---------------- Step 1: Upload Reference Image ---------------- | |
| gr.Markdown("### Upload Reference Image", elem_classes="heading-orange") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| upload_input = gr.Image(label="Drag & Drop Image", type="pil") | |
| upload_btn = gr.Button("Upload Image & Generate Caption", elem_classes="orange-btn") | |
| with gr.Column(scale=1): | |
| upload_preview = gr.Image(label="Uploaded Image", interactive=False) | |
| caption_out = gr.Markdown(label="Generated Caption") | |
| def upload_and_generate_caption_ui(img, images_state, captions_state): | |
| images = [img] | |
| caption = captioner(img)[0]["generated_text"] | |
| captions = [caption] | |
| return img, caption, images, captions | |
| upload_btn.click( | |
| upload_and_generate_caption_ui, | |
| inputs=[upload_input, images_state, captions_state], | |
| outputs=[upload_preview, caption_out, images_state, captions_state] | |
| ) | |
| # ---------------- Step 2: Generate SD-Turbo & DreamShaper ---------------- | |
| gr.Markdown("### Generate Images from Caption", elem_classes="heading-orange") | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=300): | |
| sd_btn = gr.Button("Generate SD-Turbo Image", elem_classes="orange-btn") | |
| sd_preview = gr.Image(label="SD-Turbo Image", interactive=False) | |
| with gr.Column(scale=1, min_width=300): | |
| ds_btn = gr.Button("Generate DreamShaper Image", elem_classes="orange-btn") | |
| ds_preview = gr.Image(label="DreamShaper Image", interactive=False) | |
| def generate_sd_from_caption_ui(caption, images_state, captions_state): | |
| img, images = generate_image_and_store(caption, negative="", seed=42, style="Photorealistic", images=images_state) | |
| captions_state[1:2] = [captioner(img)[0]["generated_text"]] | |
| return img, images, captions_state | |
| def generate_ds_from_caption_ui(caption, images_state, captions_state): | |
| img, images = generate_dreamshaper_image(caption, negative="", seed=123, style="Photorealistic", images=images_state) | |
| captions_state[2:3] = [captioner(img)[0]["generated_text"]] | |
| return img, images, captions_state | |
| sd_btn.click(generate_sd_from_caption_ui, inputs=[caption_out, images_state, captions_state], | |
| outputs=[sd_preview, images_state, captions_state]) | |
| ds_btn.click(generate_ds_from_caption_ui, inputs=[caption_out, images_state, captions_state], | |
| outputs=[ds_preview, images_state, captions_state]) | |
| # ---------------- Step 3: Compute Pairwise Metrics ---------------- | |
| gr.Markdown("### Compute Pairwise Metrics", elem_classes="heading-orange") | |
| metrics_btn = gr.Button("Compute Metrics for All Pairs", elem_classes="teal-btn") | |
| with gr.Row(): | |
| metrics_A = gr.Markdown() | |
| metrics_B = gr.Markdown() | |
| metrics_C = gr.Markdown() | |
| def compute_metrics_all_pairs_ui(images, captions): | |
| yield ("<div class='loading-line'></div>", "<div class='loading-line'></div>", "<div class='loading-line'></div>") | |
| if len(images) < 3: | |
| msg = "All three images and captions are required to compute metrics." | |
| yield msg, msg, msg | |
| else: | |
| A = compute_metrics_button(images, captions, 0, 1) | |
| B = compute_metrics_button(images, captions, 0, 2) | |
| C = compute_metrics_button(images, captions, 1, 2) | |
| yield (f"**Reference ↔ SD-Turbo**\n{A}", | |
| f"**Reference ↔ DreamShaper**\n{B}", | |
| f"**SD-Turbo ↔ DreamShaper**\n{C}") | |
| metrics_btn.click(compute_metrics_all_pairs_ui, inputs=[images_state, captions_state], | |
| outputs=[metrics_A, metrics_B, metrics_C]) | |
| # ---------------- Step 4: NLP Analysis ---------------- | |
| gr.Markdown("### NLP Analysis of Captions", elem_classes="heading-orange") | |
| nlp_btn = gr.Button("Analyze Captions", elem_classes="teal-btn") | |
| nlp_out = gr.HTML() | |
| def analyze_caption_pipeline_ui(captions): | |
| yield "<div class='loading-line'></div>" | |
| if len(captions) < 3: | |
| yield "<b>All three captions are required for NLP analysis.</b>" | |
| else: | |
| labels = ["Reference Image", "SD-Turbo", "DreamShaper"] | |
| blocks = [] | |
| for label, caption in zip(labels, captions): | |
| sentiment = "<br>".join([f"{s['label']}: {s['score']:.2f}" for s in sentiment_model(caption)]) | |
| ents = "<br>".join([f"{e['entity_group']}: {e['word']}" for e in ner_model(caption)]) or "None" | |
| topics_data = topic_model(caption, candidate_labels=['people','animals','objects','food','nature']) | |
| topics = "<br>".join([f"{l}: {sc:.2f}" for l, sc in zip(topics_data['labels'], topics_data['scores'])]) | |
| block = f"<div style='flex:1;padding:10px;min-width:250px;'><h3><u>{label}</u></h3><b>Sentiment</b><br>{sentiment}<br><br><b>Entities</b><br>{ents}<br><br><b>Topics</b><br>{topics}</div>" | |
| blocks.append(block) | |
| yield f"<div style='display:flex; gap:20px; justify-content:space-between;'>{''.join(blocks)}</div>" | |
| nlp_btn.click(analyze_caption_pipeline_ui, inputs=[captions_state], outputs=[nlp_out]) | |
| # ---------------- Step 5: Visual Question Answering ---------------- | |
| gr.Markdown("### Visual Question Answering (VQA)", elem_classes="heading-orange") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| vqa_input = gr.Textbox(label="Enter a question about the reference image") | |
| vqa_btn = gr.Button("Get Answer", elem_classes="teal-btn") | |
| with gr.Column(scale=1): | |
| vqa_out = gr.Markdown(label="VQA Output") | |
| def answer_vqa_ui(question, image): | |
| yield "<div class='loading-line'></div>" | |
| ans = answer_vqa(question, image) | |
| yield ans | |
| vqa_btn.click(answer_vqa_ui, inputs=[vqa_input, upload_preview], outputs=[vqa_out]) | |
| return demo | |
| # Launch the interface | |
| demo = build_ui_with_custom_ui() | |
| demo.launch() | |