Chyd19's picture
Update app.py
a8f971e verified
# **Purpose**
# =====================================================
# Multimodal AI Image Studio
# =====================================================
# Purpose:
# This script provides a unified interface for generating,
# comparing, and analyzing AI-generated images.
#
# Key Features:
# 1. Upload a reference image and automatically generate captions.
# 2. Enhance prompts to generate images using:
# - SD-Turbo (Stability AI)
# - DreamShaper (Artistic style model)
# 3. Compute pairwise metrics between images:
# - CLIP similarity
# - LPIPS perceptual similarity
# - BERTScore textual similarity
# 4. NLP analysis of captions:
# - Sentiment analysis
# - Named entity recognition
# - Topic classification
# 5. Visual Question Answering (VQA) on the reference image.
#
# Requirements:
# - Python >= 3.9
# - GPU recommended for faster image generation
#
# Usage:
# 1. Install dependencies (see requirements.txt)
# 2. Run this script
# 3. Access the Gradio web interface for interactive exploration
"""
# **Section One**
# ==============================
# SECTION 1
# ==============================
# Install
# Section One
# ---------------- Install Libraries ----------------
# Libraries
import torch
import gradio as gr
from PIL import Image
from diffusers import DiffusionPipeline
from transformers import pipeline, BlipProcessor, BlipForQuestionAnswering
import lpips
import clip
from bert_score import score
import torchvision.transforms as T
import requests
from io import BytesIO
device = "cuda" if torch.cuda.is_available() else "cpu"
def free_gpu_cache():
if device == "cuda":
torch.cuda.empty_cache()
# ==============================
# MODELS
# ==============================
gen_pipe = DiffusionPipeline.from_pretrained(
"stabilityai/sdxl-turbo",
torch_dtype=torch.float16 if device=="cuda" else torch.float32
).to(device)
dreamshaper_pipe = DiffusionPipeline.from_pretrained(
"Lykon/dreamshaper-7",
torch_dtype=torch.float16 if device=="cuda" else torch.float32
).to(device)
captioner = pipeline(
"image-to-text",
model="Salesforce/blip-image-captioning-large",
device=0 if device=="cuda" else -1
)
sentiment_model = pipeline(
"sentiment-analysis",
model="distilbert-base-uncased-finetuned-sst-2-english",
device=-1
)
ner_model = pipeline(
"ner",
model="dbmdz/bert-large-cased-finetuned-conll03-english",
aggregation_strategy="simple",
device=-1
)
topic_model = pipeline(
"zero-shot-classification",
model="facebook/bart-large-mnli",
device=-1
)
vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
vqa_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to(device)
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
lpips_model = lpips.LPIPS(net='alex').to(device)
lpips_transform = T.Compose([T.ToTensor(), T.Resize((256,256))])
style_map = {
"Photorealistic": "photorealistic, ultra-detailed, 8k, cinematic lighting",
"Real Life": "natural lighting, true-to-life colors, DSLR",
"Documentary": "documentary handheld muted colors",
"iPhone Camera": "iPhone photo natural HDR",
"Street Photography": "candid street ambient shadows",
"Cinematic": "cinematic lighting dramatic depth",
"Anime": "anime cel shaded vibrant",
"Watercolor": "watercolor soft wash art",
"Macro": "macro lens shallow DOF",
"Cyberpunk": "neon cyberpunk futuristic",
}
# SEction Two
# ==============================
# FUNCTIONS
# ==============================
def generate_image_with_enhancer(base_caption, enhancer, negative, seed, style, images, pipe=gen_pipe):
images = images or []
base_caption = base_caption or ""
enhancer = enhancer or ""
final_prompt = f"{base_caption}, {enhancer}".strip(", ")
final_prompt = f"{final_prompt}, {style_map.get(style,'')}".strip(", ")
try:
seed = int(seed)
except:
seed = 42
generator = torch.Generator(device=device).manual_seed(seed)
try:
with torch.no_grad():
out = pipe(prompt=final_prompt, negative_prompt=negative, generator=generator)
img = out.images[0]
except Exception as e:
print(f"{pipe} failed:", e)
img = None
if img:
images.append(img)
free_gpu_cache()
return img, images
generate_dreamshaper_with_enhancer = lambda base_caption, enhancer, negative, seed, style, images: \
generate_image_with_enhancer(base_caption, enhancer, negative, seed, style, images, pipe=dreamshaper_pipe)
def caption_for_image(img):
try:
out = captioner(img)
return out[0]["generated_text"]
except:
return "Caption failed."
def answer_vqa(question, image):
if not image or not question.strip():
return "Provide image + question."
try:
inputs_raw = vqa_processor(images=image, text=question, return_tensors="pt")
inputs = {k:v.to(device) for k,v in inputs_raw.items()}
with torch.no_grad():
out = vqa_model(**inputs)
ans_id = out.logits.argmax(-1)
return vqa_processor.decode(ans_id[0], skip_special_tokens=True)
except:
return "VQA failed."
def compute_metrics(images, captions, i1, i2):
img1, img2 = images[i1], images[i2]
cap1, cap2 = captions[i1], captions[i2]
t1 = clip_preprocess(img1).unsqueeze(0).to(device)
t2 = clip_preprocess(img2).unsqueeze(0).to(device)
with torch.no_grad():
f1 = clip_model.encode_image(t1)
f2 = clip_model.encode_image(t2)
clip_sim = float(torch.cosine_similarity(f1, f2))
L1 = (lpips_transform(img1).unsqueeze(0)*2 - 1).to(device)
L2 = (lpips_transform(img2).unsqueeze(0)*2 - 1).to(device)
with torch.no_grad():
lp = float(lpips_model(L1, L2))
if cap1 and cap2:
_, _, F = score([cap1],[cap2], lang="en", verbose=False)
bert_f1 = float(F.mean())
else:
bert_f1 = 0.0
return clip_sim, lp, bert_f1
def caption_and_store(img, images, captions):
if img is None:
return None, "", images, captions
try:
caption = captioner(img)[0]["generated_text"]
except Exception as e:
print("Captioning failed:", e)
caption = "Caption failed."
images = images + [img]
captions = captions + [caption]
return img, caption, images, captions
def fetch_and_caption(url, images, captions):
if not url:
return None, "", images, captions
try:
response = requests.get(url)
img = Image.open(BytesIO(response.content)).convert("RGB")
except Exception as e:
print("Failed to fetch image from URL:", e)
return None, "Failed to fetch image", images, captions
return caption_and_store(img, images, captions)
# SECTION THREE
# ---------------- Section Three: UI ----------------
def build_ui_with_custom_ui():
with gr.Blocks(title="Multimodal AI Image Studio") as demo:
# ---------------- CSS Styling ----------------
gr.HTML(""
<style>
.heading-orange h2, .heading-orange h3 { color: #ff5500 !important; }
.orange-btn button { background-color: #ff5500 !important; color: white !important; border-radius: 6px !important; height: 36px !important; font-weight: bold; }
.teal-btn button { background-color: #008080 !important; color: white !important; border-radius: 6px !important; height: 40px !important; font-weight: bold; }
.loading-line { height: 4px; background: linear-gradient(90deg, #008080 0%, #00cccc 50%, #008080 100%); background-size: 200% 100%; animation: loading 1s linear infinite; margin-bottom:4px; }
@keyframes loading { 0% { background-position: 200% 0; } 100% { background-position: -200% 0; } }
.enhancer-box textarea { width: 100% !important; height: 36px !important; font-size: 14px; }
.equal-height-row { display: flex; align-items: stretch; }
.equal-height-row > .gr-column { display: flex; flex-direction: column; }
.stretch-img .gr-image-container { flex-grow: 1; display: flex; }
.stretch-img img { width: 100% !important; height: 100% !important; object-fit: contain; }
.metrics-row { display: flex; gap: 20px; }
.metrics-row > div { flex: 1; }
.gradio-tabs button.selected { background-color: #ff5500 !important; color: white !important; font-weight: bold; }
</style>
"")
# ---------------- Heading ----------------
gr.Markdown("## Multimodal AI Image Studio: An Integrated Comparative Perspective",
elem_classes="heading-orange")
images_state = gr.State([])
captions_state = gr.State([])
# ---------------- Step 1: Upload Image ----------------
gr.Markdown("### Upload Reference Image", elem_classes="heading-orange")
with gr.Tabs():
with gr.Tab("📁 Upload Image"):
with gr.Row(elem_classes="equal-height-row"):
with gr.Column(scale=1):
upload_input = gr.Image(label="Drag & Drop Image", type="pil")
upload_btn = gr.Button("Upload Image & Generate Caption", elem_classes="orange-btn")
with gr.Column(scale=1):
upload_preview = gr.Image(label="Uploaded Image", interactive=False, elem_classes="stretch-img")
enhancer_box = gr.Textbox(label="Add Prompt Enhancer (Optional)", elem_classes="enhancer-box")
caption_out = gr.Markdown(label="Generated Caption")
with gr.Tab("📷 Webcam"):
with gr.Row(elem_classes="equal-height-row"):
with gr.Column(scale=1):
webcam_input = gr.Image(label="Webcam Live", type="pil", sources=["webcam"], elem_classes="stretch-img")
webcam_btn = gr.Button("Capture & Generate Caption", elem_classes="orange-btn")
with gr.Column(scale=1):
webcam_preview = gr.Image(label="Captured Image", interactive=False, elem_classes="stretch-img")
enhancer_box_webcam = gr.Textbox(label="Add Prompt Enhancer (Optional)", elem_classes="enhancer-box")
caption_out_webcam = gr.Markdown(label="Generated Caption")
with gr.Tab("🔗 From URL"):
url_input = gr.Textbox(label="Paste Image URL")
url_btn = gr.Button("Fetch & Generate Caption", elem_classes="orange-btn")
# ---------------- Caption Buttons ----------------
upload_btn.click(caption_and_store, [upload_input, images_state, captions_state],
[upload_preview, caption_out, images_state, captions_state])
webcam_btn.click(caption_and_store, [webcam_input, images_state, captions_state],
[webcam_preview, caption_out_webcam, images_state, captions_state])
url_btn.click(fetch_and_caption, [url_input, images_state, captions_state],
[upload_preview, caption_out, images_state, captions_state])
# ---------------- Step 2: Generate Images ----------------
gr.Markdown("### Generate Images from Caption", elem_classes="heading-orange")
with gr.Row():
with gr.Column():
sd_btn = gr.Button("Generate SD-Turbo Image", elem_classes="orange-btn")
sd_preview = gr.Image(label="SD-Turbo Image")
with gr.Column():
ds_btn = gr.Button("Generate DreamShaper Image", elem_classes="orange-btn")
ds_preview = gr.Image(label="DreamShaper Image")
# ---------------- Image Generation Functions ----------------
def generate_sd(_, enhancer, images, captions):
if not captions:
return None, images, captions
base_caption = captions[-1]
img, images = generate_image_with_enhancer(base_caption, enhancer or "", negative="", seed=42, style="Photorealistic", images=images)
if img:
new_caption = captioner(img)[0]["generated_text"]
captions = captions + [new_caption]
return img, images, captions
def generate_ds(_, enhancer, images, captions):
if not captions:
return None, images, captions
base_caption = captions[-1]
img, images = generate_dreamshaper_with_enhancer(base_caption, enhancer or "", negative="", seed=123, style="Photorealistic", images=images)
if img:
new_caption = captioner(img)[0]["generated_text"]
captions = captions + [new_caption]
return img, images, captions
# ---------------- Attach Clicks ----------------
sd_btn.click(generate_sd, [caption_out, enhancer_box, images_state, captions_state],
[sd_preview, images_state, captions_state])
ds_btn.click(generate_ds, [caption_out, enhancer_box, images_state, captions_state],
[ds_preview, images_state, captions_state])
# ---------------- Step 3: Metrics ----------------
gr.Markdown("### Compute Pairwise Metrics", elem_classes="heading-orange")
metrics_btn = gr.Button("Compute Metrics for All Pairs", elem_classes="teal-btn")
with gr.Row(elem_classes="metrics-row"):
metrics_A = gr.Markdown()
metrics_B = gr.Markdown()
metrics_C = gr.Markdown()
def compute_metrics_all_pairs_ui(images, captions):
yield ("<div class='loading-line'></div>",) * 3
if len(images) < 3 or len(captions) < 3:
msg = "⚠️ All three images and captions required."
yield msg, msg, msg
return
pairs = [(0,1,"Reference ↔ SD-Turbo"), (0,2,"Reference ↔ DreamShaper"), (1,2,"SD-Turbo ↔ DreamShaper")]
results = []
for i1, i2, label in pairs:
clip_sim, lp, bert_f1 = compute_metrics(images, captions, i1, i2)
results.append(f"**{label}**<br>CLIP similarity: {clip_sim:.3f}<br>LPIPS: {lp:.3f}<br>BERT F1: {bert_f1:.3f}")
yield tuple(results)
metrics_btn.click(compute_metrics_all_pairs_ui, [images_state, captions_state],
[metrics_A, metrics_B, metrics_C])
# ---------------- Step 4: NLP ----------------
gr.Markdown("### NLP Analysis of Captions", elem_classes="heading-orange")
nlp_btn = gr.Button("Analyze Captions", elem_classes="teal-btn")
with gr.Row(elem_classes="metrics-row"):
nlp_out_A = gr.HTML()
nlp_out_B = gr.HTML()
nlp_out_C = gr.HTML()
def analyze_caption_pipeline_ui(captions):
yield ("<div class='loading-line'></div>",) * 3
if len(captions) < 3:
yield "<b>All three captions required.</b>", "<b>All three captions required.</b>", "<b>All three captions required.</b>"
return
labels = ["Reference Image","SD-Turbo","DreamShaper"]
results = []
for label, caption in zip(labels, captions):
sentiment = "<br>".join(f"{s['label']}: {s['score']:.2f}" for s in sentiment_model(caption))
ents = "<br>".join(f"{e['entity_group']}: {e['word']}" for e in ner_model(caption)) or "None"
topics_data = topic_model(caption, candidate_labels=["people","animals","objects","food","nature"])
topics = "<br>".join(f"{l}: {sc:.2f}" for l, sc in zip(topics_data["labels"], topics_data["scores"]))
results.append(f"<b>{label}</b><br><b>Sentiment</b><br>{sentiment}<br><b>Entities</b><br>{ents}<br><b>Topics</b><br>{topics}")
yield tuple(results)
nlp_btn.click(analyze_caption_pipeline_ui, captions_state,
[nlp_out_A, nlp_out_B, nlp_out_C])
# ---------------- Step 5: VQA ----------------
gr.Markdown("### Visual Question Answering (VQA)", elem_classes="heading-orange")
with gr.Row():
# Left column: question input and button
with gr.Column(scale=1):
vqa_input = gr.Textbox(label="Enter a question about the reference image")
vqa_btn = gr.Button("Get Answer", elem_classes="teal-btn")
# Right column: VQA output
with gr.Column(scale=1):
vqa_out = gr.Markdown(label="VQA Output")
def answer_vqa_ui(question, image):
yield "<div class='loading-line'></div>"
if image is None or not question.strip():
yield "⚠️ Provide image + question."
return
try:
# Prepare inputs
inputs = vqa_processor(images=image, text=question, return_tensors="pt").to(device)
# Use generate() for inference
out_ids = vqa_model.generate(**inputs)
answer = vqa_processor.decode(out_ids[0], skip_special_tokens=True)
yield answer
except Exception as e:
yield f"⚠️ VQA failed: {str(e)}"
vqa_btn.click(answer_vqa_ui, [vqa_input, upload_preview], vqa_out)
return demo
# ---------------- Launch ----------------
demo = build_ui_with_custom_ui()
demo.launch()
"""
# **Purpose**
# =====================================================
# Multimodal AI Image Studio
# =====================================================
# Purpose:
# This script provides a unified interface for generating,
# comparing, and analyzing AI-generated images.
#
# Key Features:
# 1. Upload a reference image and automatically generate captions.
# 2. Enhance prompts to generate images using:
# - SD-Turbo (Stability AI)
# - DreamShaper (Artistic style model)
# 3. Compute pairwise metrics between images:
# - CLIP similarity
# - LPIPS perceptual similarity
# - BERTScore textual similarity
# 4. NLP analysis of captions:
# - Sentiment analysis
# - Named entity recognition
# - Topic classification
# 5. Visual Question Answering (VQA) on the reference image.
#
# Requirements:
# - Python >= 3.9
# - GPU recommended for faster image generation
#
# Usage:
# 1. Install dependencies (see requirements.txt)
# 2. Run this script
# 3. Access the Gradio web interface for interactive exploration
# Section One
# ---------------- Install Libraries ----------------
# Libraries
import torch
import gradio as gr
from PIL import Image
from diffusers import DiffusionPipeline
from transformers import pipeline, BlipProcessor, BlipForQuestionAnswering
import lpips
import clip
from bert_score import score
import torchvision.transforms as T
import requests
from io import BytesIO
device = "cuda" if torch.cuda.is_available() else "cpu"
def free_gpu_cache():
if device == "cuda":
torch.cuda.empty_cache()
# ==============================
# MODELS
# ==============================
gen_pipe = DiffusionPipeline.from_pretrained(
"stabilityai/sdxl-turbo",
torch_dtype=torch.float16 if device=="cuda" else torch.float32
).to(device)
dreamshaper_pipe = DiffusionPipeline.from_pretrained(
"Lykon/dreamshaper-7",
torch_dtype=torch.float16 if device=="cuda" else torch.float32
).to(device)
captioner = pipeline(
"image-to-text",
model="Salesforce/blip-image-captioning-large",
device=0 if device=="cuda" else -1
)
sentiment_model = pipeline(
"sentiment-analysis",
model="distilbert-base-uncased-finetuned-sst-2-english",
device=-1
)
ner_model = pipeline(
"ner",
model="dbmdz/bert-large-cased-finetuned-conll03-english",
aggregation_strategy="simple",
device=-1
)
topic_model = pipeline(
"zero-shot-classification",
model="facebook/bart-large-mnli",
device=-1
)
vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
vqa_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to(device)
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
lpips_model = lpips.LPIPS(net='alex').to(device)
lpips_transform = T.Compose([T.ToTensor(), T.Resize((256,256))])
style_map = {
"Photorealistic": "photorealistic, ultra-detailed, 8k, cinematic lighting",
"Real Life": "natural lighting, true-to-life colors, DSLR",
"Documentary": "documentary handheld muted colors",
"iPhone Camera": "iPhone photo natural HDR",
"Street Photography": "candid street ambient shadows",
"Cinematic": "cinematic lighting dramatic depth",
"Anime": "anime cel shaded vibrant",
"Watercolor": "watercolor soft wash art",
"Macro": "macro lens shallow DOF",
"Cyberpunk": "neon cyberpunk futuristic",
}
# Section Two
# SEction Two
# ==============================
# FUNCTIONS
# ==============================
def generate_image_with_enhancer(base_caption, enhancer, negative, seed, style, images, pipe=gen_pipe):
images = images or []
base_caption = base_caption or ""
enhancer = enhancer or ""
final_prompt = f"{base_caption}, {enhancer}".strip(", ")
final_prompt = f"{final_prompt}, {style_map.get(style,'')}".strip(", ")
try:
seed = int(seed)
except:
seed = 42
generator = torch.Generator(device=device).manual_seed(seed)
try:
with torch.no_grad():
out = pipe(prompt=final_prompt, negative_prompt=negative, generator=generator)
img = out.images[0]
except Exception as e:
print(f"{pipe} failed:", e)
img = None
if img:
images.append(img)
free_gpu_cache()
return img, images
generate_dreamshaper_with_enhancer = lambda base_caption, enhancer, negative, seed, style, images: \
generate_image_with_enhancer(base_caption, enhancer, negative, seed, style, images, pipe=dreamshaper_pipe)
def caption_for_image(img):
try:
out = captioner(img)
return out[0]["generated_text"]
except:
return "Caption failed."
def answer_vqa(question, image):
if not image or not question.strip():
return "Provide image + question."
try:
inputs_raw = vqa_processor(images=image, text=question, return_tensors="pt")
inputs = {k:v.to(device) for k,v in inputs_raw.items()}
with torch.no_grad():
out = vqa_model(**inputs)
ans_id = out.logits.argmax(-1)
return vqa_processor.decode(ans_id[0], skip_special_tokens=True)
except:
return "VQA failed."
def compute_metrics(images, captions, i1, i2):
img1, img2 = images[i1], images[i2]
cap1, cap2 = captions[i1], captions[i2]
t1 = clip_preprocess(img1).unsqueeze(0).to(device)
t2 = clip_preprocess(img2).unsqueeze(0).to(device)
with torch.no_grad():
f1 = clip_model.encode_image(t1)
f2 = clip_model.encode_image(t2)
clip_sim = float(torch.cosine_similarity(f1, f2))
L1 = (lpips_transform(img1).unsqueeze(0)*2 - 1).to(device)
L2 = (lpips_transform(img2).unsqueeze(0)*2 - 1).to(device)
with torch.no_grad():
lp = float(lpips_model(L1, L2))
if cap1 and cap2:
_, _, F = score([cap1],[cap2], lang="en", verbose=False)
bert_f1 = float(F.mean())
else:
bert_f1 = 0.0
return clip_sim, lp, bert_f1
def caption_and_store(img, images, captions):
if img is None:
return None, "", images, captions
try:
caption = captioner(img)[0]["generated_text"]
except Exception as e:
print("Captioning failed:", e)
caption = "Caption failed."
images = images + [img]
captions = captions + [caption]
return img, caption, images, captions
def fetch_and_caption(url, images, captions):
if not url:
return None, "", images, captions
try:
response = requests.get(url)
img = Image.open(BytesIO(response.content)).convert("RGB")
except Exception as e:
print("Failed to fetch image from URL:", e)
return None, "Failed to fetch image", images, captions
return caption_and_store(img, images, captions)
# Section Three
# ---------------- Section Three: UI ----------------
def build_ui_with_custom_ui():
with gr.Blocks(title="Multimodal AI Image Studio") as demo:
# ---------------- CSS Styling ----------------
gr.HTML("""
<style>
.heading-orange h2, .heading-orange h3 { color: #ff5500 !important; }
.orange-btn button { background-color: #ff5500 !important; color: white !important; border-radius: 6px !important; height: 36px !important; font-weight: bold; }
.teal-btn button { background-color: #008080 !important; color: white !important; border-radius: 6px !important; height: 40px !important; font-weight: bold; }
.loading-line { height: 4px; background: linear-gradient(90deg, #008080 0%, #00cccc 50%, #008080 100%); background-size: 200% 100%; animation: loading 1s linear infinite; margin-bottom:4px; }
@keyframes loading { 0% { background-position: 200% 0; } 100% { background-position: -200% 0; } }
.enhancer-box textarea { width: 100% !important; height: 36px !important; font-size: 14px; }
.equal-height-row { display: flex; align-items: stretch; }
.equal-height-row > .gr-column { display: flex; flex-direction: column; }
.stretch-img .gr-image-container { flex-grow: 1; display: flex; }
.stretch-img img { width: 100% !important; height: 100% !important; object-fit: contain; }
.metrics-row { display: flex; gap: 20px; }
.metrics-row > div { flex: 1; }
.gradio-tabs button.selected { background-color: #ff5500 !important; color: white !important; font-weight: bold; }
</style>
""")
# ---------------- Heading ----------------
gr.Markdown("## Multimodal AI Image Studio: An Integrated Comparative Perspective",
elem_classes="heading-orange")
images_state = gr.State([])
captions_state = gr.State([])
# ---------------- Step 1: Upload Image ----------------
gr.Markdown("### Upload Reference Image", elem_classes="heading-orange")
with gr.Tabs():
with gr.Tab("📁 Upload Image"):
with gr.Row(elem_classes="equal-height-row"):
with gr.Column(scale=1):
upload_input = gr.Image(label="Drag & Drop Image", type="pil")
upload_btn = gr.Button("Upload Image & Generate Caption", elem_classes="orange-btn")
with gr.Column(scale=1):
upload_preview = gr.Image(label="Uploaded Image", interactive=False, elem_classes="stretch-img")
enhancer_box = gr.Textbox(label="Add Prompt Enhancer (Optional)", elem_classes="enhancer-box")
caption_out = gr.Markdown(label="Generated Caption")
with gr.Tab("📷 Webcam"):
with gr.Row(elem_classes="equal-height-row"):
with gr.Column(scale=1):
webcam_input = gr.Image(label="Webcam Live", type="pil", sources=["webcam"], elem_classes="stretch-img")
webcam_btn = gr.Button("Capture & Generate Caption", elem_classes="orange-btn")
with gr.Column(scale=1):
webcam_preview = gr.Image(label="Captured Image", interactive=False, elem_classes="stretch-img")
enhancer_box_webcam = gr.Textbox(label="Add Prompt Enhancer (Optional)", elem_classes="enhancer-box")
caption_out_webcam = gr.Markdown(label="Generated Caption")
with gr.Tab("🔗 From URL"):
url_input = gr.Textbox(label="Paste Image URL")
url_btn = gr.Button("Fetch & Generate Caption", elem_classes="orange-btn")
# ---------------- Caption Buttons ----------------
upload_btn.click(caption_and_store, [upload_input, images_state, captions_state],
[upload_preview, caption_out, images_state, captions_state])
webcam_btn.click(caption_and_store, [webcam_input, images_state, captions_state],
[webcam_preview, caption_out_webcam, images_state, captions_state])
url_btn.click(fetch_and_caption, [url_input, images_state, captions_state],
[upload_preview, caption_out, images_state, captions_state])
# ---------------- Step 2: Generate Images ----------------
gr.Markdown("### Generate Images from Caption", elem_classes="heading-orange")
with gr.Row():
with gr.Column():
sd_btn = gr.Button("Generate SD-Turbo Image", elem_classes="orange-btn")
sd_preview = gr.Image(label="SD-Turbo Image")
with gr.Column():
ds_btn = gr.Button("Generate DreamShaper Image", elem_classes="orange-btn")
ds_preview = gr.Image(label="DreamShaper Image")
# ---------------- Image Generation Functions ----------------
def generate_sd(_, enhancer, images, captions):
if not captions:
return None, images, captions, gr.update(interactive=False), gr.update(interactive=False)
base_caption = captions[-1]
img, images = generate_image_with_enhancer(base_caption, enhancer or "", negative="", seed=42, style="Photorealistic", images=images)
if img:
captions = captions + [captioner(img)[0]["generated_text"]]
ready = len(images) >= 1 and len(captions) >= 1
return img, images, captions #,gr.update(interactive=ready), gr.update(interactive=ready)
def generate_ds(_, enhancer, images, captions):
if not captions:
return None, images, captions, gr.update(interactive=False), gr.update(interactive=False)
base_caption = captions[-1]
img, images = generate_dreamshaper_with_enhancer(base_caption, enhancer or "", negative="", seed=123, style="Photorealistic", images=images)
if img:
captions = captions + [captioner(img)[0]["generated_text"]]
ready = len(images) >= 1 and len(captions) >= 1
return img, images, captions #, gr.update(interactive=ready), gr.update(interactive=ready)
# ---------------- Step 3: Metrics ----------------
gr.Markdown("### Compute Pairwise Metrics", elem_classes="heading-orange")
metrics_btn = gr.Button("Compute Metrics for All Pairs", elem_classes="teal-btn", interactive=False)
with gr.Row(elem_classes="metrics-row"):
metrics_A = gr.Markdown()
metrics_B = gr.Markdown()
metrics_C = gr.Markdown()
def compute_metrics_all_pairs_ui(images, captions):
yield ("<div class='loading-line'></div>",) * 3
pairs = [(0,1,"Reference ↔ SD-Turbo"), (0,2,"Reference ↔ DreamShaper"), (1,2,"SD-Turbo ↔ DreamShaper")]
results = []
if len(images) < 3 or len(captions) < 3:
msg = "⚠️ All three images and captions required."
yield msg, msg, msg
return
for i1, i2, label in pairs:
clip_sim, lp, bert_f1 = compute_metrics(images, captions, i1, i2)
results.append(f"**{label}**<br>CLIP similarity: {clip_sim:.3f}<br>LPIPS: {lp:.3f}<br>BERT F1: {bert_f1:.3f}")
yield tuple(results)
metrics_btn.click(compute_metrics_all_pairs_ui, [images_state, captions_state],
[metrics_A, metrics_B, metrics_C])
# ---------------- Step 4: NLP ----------------
gr.Markdown("### NLP Analysis of Captions", elem_classes="heading-orange")
nlp_btn = gr.Button("Analyze Captions", elem_classes="teal-btn", interactive=False)
with gr.Row(elem_classes="metrics-row"):
nlp_out_A = gr.HTML()
nlp_out_B = gr.HTML()
nlp_out_C = gr.HTML()
def analyze_caption_pipeline_ui(captions):
yield ("<div class='loading-line'></div>",) * 3
if len(captions) < 3:
yield "<b>All three captions required.</b>", "<b>All three captions required.</b>", "<b>All three captions required.</b>"
return
labels = ["Reference Image","SD-Turbo","DreamShaper"]
results = []
for label, caption in zip(labels, captions):
sentiment = "<br>".join(f"{s['label']}: {s['score']:.2f}" for s in sentiment_model(caption))
ents = "<br>".join(f"{e['entity_group']}: {e['word']}" for e in ner_model(caption)) or "None"
topics_data = topic_model(caption, candidate_labels=["people","animals","objects","food","nature"])
topics = "<br>".join(f"{l}: {sc:.2f}" for l, sc in zip(topics_data["labels"], topics_data["scores"]))
results.append(f"<b>{label}</b><br><b>Sentiment</b><br>{sentiment}<br><b>Entities</b><br>{ents}<br><b>Topics</b><br>{topics}")
yield tuple(results)
nlp_btn.click(analyze_caption_pipeline_ui, captions_state,
[nlp_out_A, nlp_out_B, nlp_out_C])
# ===============================
# Wire SD / DS buttons (AFTER metrics_btn & nlp_btn exist)
# ===============================
sd_btn.click(generate_sd, [caption_out, enhancer_box, images_state, captions_state],
[sd_preview, images_state, captions_state, metrics_btn, nlp_btn])
ds_btn.click(generate_ds, [caption_out, enhancer_box, images_state, captions_state],
[ds_preview, images_state, captions_state, metrics_btn, nlp_btn])
# ---------------- Enable Metrics/NLP only when ready ----------------
"""
def enable_metrics_nlp(images, captions):
ready = len(images) >= 3 and len(captions) >= 3
return (
gr.update(interactive=ready),
gr.update(interactive=ready)
)"""
def enable_metrics_nlp(images, captions):
ready = (
len(images) == 3 and
len(captions) == 3 and
all(c and c != "Caption failed." for c in captions)
)
return gr.update(interactive=ready), gr.update(interactive=ready)
images_state.change(enable_metrics_nlp, [images_state, captions_state], [metrics_btn, nlp_btn])
# ---------------- Step 5: VQA ----------------
gr.Markdown("### Visual Question Answering (VQA)", elem_classes="heading-orange")
with gr.Row():
with gr.Column(scale=1):
vqa_input = gr.Textbox(label="Enter a question about the reference image")
vqa_btn = gr.Button("Get Answer", elem_classes="teal-btn")
with gr.Column(scale=1):
vqa_out = gr.Markdown(label="VQA Output")
def answer_vqa_ui(question, image):
yield "<div class='loading-line'></div>"
if image is None or not question.strip():
yield "⚠️ Provide image + question."
return
try:
# Prepare inputs
inputs = vqa_processor(images=image, text=question, return_tensors="pt").to(device)
# Use generate() for inference
out_ids = vqa_model.generate(**inputs)
answer = vqa_processor.decode(out_ids[0], skip_special_tokens=True)
yield answer
except Exception as e:
yield f"⚠️ VQA failed: {str(e)}"
vqa_btn.click(answer_vqa_ui, [vqa_input, upload_preview], vqa_out)
return demo
# ---------------- Launch ----------------
demo = build_ui_with_custom_ui()
demo.launch()