MapleGel_T2I / app.py
Chyd19's picture
Update app.py
3272049 verified
# ---------------- Libraries ----------------
import torch
import gradio as gr
from diffusers import DiffusionPipeline
from transformers import pipeline, BlipProcessor, BlipForQuestionAnswering
import lpips
import clip
from bert_score import score
import torchvision.transforms as T
# ---------------- Device Setup ----------------
device = "cuda" if torch.cuda.is_available() else "cpu"
# ---------------- GPU Cache Free ----------------
def free_gpu_cache():
torch.cuda.empty_cache()
# ---------------- Load SD Turbo & DreamShaper ----------------
gen_pipe = DiffusionPipeline.from_pretrained(
"stabilityai/sd-turbo",
torch_dtype=torch.float16 if device=="cuda" else torch.float32
).to(device)
dreamshaper_pipe = DiffusionPipeline.from_pretrained(
"Lykon/dreamshaper-7",
torch_dtype=torch.float16 if device=="cuda" else torch.float32
).to(device)
# ---------------- Load NLP Models ----------------
captioner = pipeline(
"image-to-text",
model="Salesforce/blip-image-captioning-large",
device=0 if device=="cuda" else -1
)
sentiment_model = pipeline(
"sentiment-analysis",
model="distilbert-base-uncased-finetuned-sst-2-english",
device=0 if device=="cuda" else -1
)
ner_model = pipeline(
"ner",
model="dbmdz/bert-large-cased-finetuned-conll03-english",
aggregation_strategy="simple",
device=0 if device=="cuda" else -1
)
topic_model = pipeline(
"zero-shot-classification",
model="facebook/bart-large-mnli",
device=0 if device=="cuda" else -1
)
# ---------------- Load VQA Model ----------------
vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
vqa_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base").to(device)
# ---------------- Load CLIP & LPIPS ----------------
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
lpips_model = lpips.LPIPS(net='alex').to(device)
# ---------------- Style Map ----------------
style_map = {
"Photorealistic": "photorealistic, ultra-detailed, 8k, cinematic lighting",
"Real Life": "natural lighting, true-to-life colors, DSLR",
"Documentary": "documentary handheld muted colors",
"iPhone Camera": "iPhone photo natural HDR",
"Street Photography": "candid street ambient shadows",
"Cinematic": "cinematic lighting dramatic depth",
"Anime": "anime cel shaded vibrant",
"Watercolor": "watercolor soft wash art",
"Macro": "macro lens shallow DOF",
"Cyberpunk": "neon cyberpunk futuristic",
}
lpips_transform = T.Compose([T.ToTensor(), T.Resize((256,256))])
# ---------------- Image Generation Functions ----------------
def generate_image_and_store(prompt, negative, seed, style, images):
images = images or []
enhanced_prompt = f"{prompt}, {style_map.get(style,'')}"
generator = torch.Generator(device=device).manual_seed(int(seed))
ctx = torch.autocast("cuda") if device=="cuda" else torch.no_grad()
with ctx:
img = gen_pipe(prompt=enhanced_prompt, negative_prompt=negative, generator=generator).images[0]
images.append(img)
free_gpu_cache()
return img, images
def generate_dreamshaper_image(prompt, negative, seed, style, images):
images = images or []
enhanced_prompt = f"{prompt}, {style_map.get(style,'')}"
generator = torch.Generator(device=device).manual_seed(int(seed))
ctx = torch.autocast("cuda") if device=="cuda" else torch.no_grad()
with ctx:
img = dreamshaper_pipe(prompt=enhanced_prompt, negative_prompt=negative, generator=generator).images[0]
images.append(img)
free_gpu_cache()
return img, images
# ---------------- VQA ----------------
def answer_vqa(question, image):
if image is None or question.strip() == "":
return "Upload an image and enter a question."
inputs = vqa_processor(images=image, text=question, return_tensors="pt").to(device)
with torch.no_grad():
generated_ids = vqa_model.generate(**inputs)
answer = vqa_processor.decode(generated_ids[0], skip_special_tokens=True)
return answer
# ---------------- Metrics Computation ----------------
def compute_metrics_button(images, captions, idx1, idx2):
img1_clip = clip_preprocess(images[idx1]).unsqueeze(0).to(device)
img2_clip = clip_preprocess(images[idx2]).unsqueeze(0).to(device)
with torch.no_grad():
feat1 = clip_model.encode_image(img1_clip)
feat2 = clip_model.encode_image(img2_clip)
clip_sim = float(torch.cosine_similarity(feat1, feat2).item())
img1_lp = lpips_transform(images[idx1]).unsqueeze(0).to(device) * 2 - 1
img2_lp = lpips_transform(images[idx2]).unsqueeze(0).to(device) * 2 - 1
with torch.no_grad():
lpips_score = float(lpips_model(img1_lp, img2_lp).item())
_, _, F1 = score([captions[idx1]], [captions[idx2]], lang="en", verbose=False)
bert_f1 = float(F1.mean().item())
return f"""
**Metrics Comparison**
- CLIP Similarity: {clip_sim:.4f}
- LPIPS Score: {lpips_score:.4f}
- BERTScore F1: {bert_f1:.4f}
"""
# ---------------- Build Gradio UI ----------------
# ---------------- Build Gradio UI with Original Look ----------------
def build_ui_with_custom_ui():
with gr.Blocks(title="Multimodal AI Image Studio") as demo:
# ---------------- CSS Styling ----------------
gr.HTML("""
<style>
.heading-orange h2, .heading-orange h3 { color: #ff5500 !important; }
.orange-btn button { background-color: #ff5500 !important; color: white !important; border-radius: 6px !important; height: 36px !important; font-weight: bold; }
.teal-btn button { background-color: #008080 !important; color: white !important; border-radius: 6px !important; height: 40px !important; font-weight: bold; }
/* Horizontal thin spinner */
.loading-line {
height: 4px;
background: linear-gradient(90deg, #008080 0%, #00cccc 50%, #008080 100%);
background-size: 200% 100%;
animation: loading 1s linear infinite;
}
@keyframes loading {
0% { background-position: 200% 0; }
100% { background-position: -200% 0; }
}
</style>
""")
# ---------------- Heading ----------------
gr.Markdown("## Multimodal AI Image Studio: An Integrated Comparative Perspective", elem_classes="heading-orange")
# ---------------- States ----------------
images_state = gr.State([])
captions_state = gr.State([])
# ---------------- Step 1: Upload Reference Image ----------------
gr.Markdown("### Upload Reference Image", elem_classes="heading-orange")
with gr.Row():
with gr.Column(scale=1):
upload_input = gr.Image(label="Drag & Drop Image", type="pil")
upload_btn = gr.Button("Upload Image & Generate Caption", elem_classes="orange-btn")
with gr.Column(scale=1):
upload_preview = gr.Image(label="Uploaded Image", interactive=False)
caption_out = gr.Markdown(label="Generated Caption")
def upload_and_generate_caption_ui(img, images_state, captions_state):
images = [img]
caption = captioner(img)[0]["generated_text"]
captions = [caption]
return img, caption, images, captions
upload_btn.click(
upload_and_generate_caption_ui,
inputs=[upload_input, images_state, captions_state],
outputs=[upload_preview, caption_out, images_state, captions_state]
)
# ---------------- Step 2: Generate SD-Turbo & DreamShaper ----------------
gr.Markdown("### Generate Images from Caption", elem_classes="heading-orange")
with gr.Row():
with gr.Column(scale=1, min_width=300):
sd_btn = gr.Button("Generate SD-Turbo Image", elem_classes="orange-btn")
sd_preview = gr.Image(label="SD-Turbo Image", interactive=False)
with gr.Column(scale=1, min_width=300):
ds_btn = gr.Button("Generate DreamShaper Image", elem_classes="orange-btn")
ds_preview = gr.Image(label="DreamShaper Image", interactive=False)
def generate_sd_from_caption_ui(caption, images_state, captions_state):
img, images = generate_image_and_store(caption, negative="", seed=42, style="Photorealistic", images=images_state)
captions_state[1:2] = [captioner(img)[0]["generated_text"]]
return img, images, captions_state
def generate_ds_from_caption_ui(caption, images_state, captions_state):
img, images = generate_dreamshaper_image(caption, negative="", seed=123, style="Photorealistic", images=images_state)
captions_state[2:3] = [captioner(img)[0]["generated_text"]]
return img, images, captions_state
sd_btn.click(generate_sd_from_caption_ui, inputs=[caption_out, images_state, captions_state],
outputs=[sd_preview, images_state, captions_state])
ds_btn.click(generate_ds_from_caption_ui, inputs=[caption_out, images_state, captions_state],
outputs=[ds_preview, images_state, captions_state])
# ---------------- Step 3: Compute Pairwise Metrics ----------------
gr.Markdown("### Compute Pairwise Metrics", elem_classes="heading-orange")
metrics_btn = gr.Button("Compute Metrics for All Pairs", elem_classes="teal-btn")
with gr.Row():
metrics_A = gr.Markdown()
metrics_B = gr.Markdown()
metrics_C = gr.Markdown()
def compute_metrics_all_pairs_ui(images, captions):
yield ("<div class='loading-line'></div>", "<div class='loading-line'></div>", "<div class='loading-line'></div>")
if len(images) < 3:
msg = "All three images and captions are required to compute metrics."
yield msg, msg, msg
else:
A = compute_metrics_button(images, captions, 0, 1)
B = compute_metrics_button(images, captions, 0, 2)
C = compute_metrics_button(images, captions, 1, 2)
yield (f"**Reference ↔ SD-Turbo**\n{A}",
f"**Reference ↔ DreamShaper**\n{B}",
f"**SD-Turbo ↔ DreamShaper**\n{C}")
metrics_btn.click(compute_metrics_all_pairs_ui, inputs=[images_state, captions_state],
outputs=[metrics_A, metrics_B, metrics_C])
# ---------------- Step 4: NLP Analysis ----------------
gr.Markdown("### NLP Analysis of Captions", elem_classes="heading-orange")
nlp_btn = gr.Button("Analyze Captions", elem_classes="teal-btn")
nlp_out = gr.HTML()
def analyze_caption_pipeline_ui(captions):
yield "<div class='loading-line'></div>"
if len(captions) < 3:
yield "<b>All three captions are required for NLP analysis.</b>"
else:
labels = ["Reference Image", "SD-Turbo", "DreamShaper"]
blocks = []
for label, caption in zip(labels, captions):
sentiment = "<br>".join([f"{s['label']}: {s['score']:.2f}" for s in sentiment_model(caption)])
ents = "<br>".join([f"{e['entity_group']}: {e['word']}" for e in ner_model(caption)]) or "None"
topics_data = topic_model(caption, candidate_labels=['people','animals','objects','food','nature'])
topics = "<br>".join([f"{l}: {sc:.2f}" for l, sc in zip(topics_data['labels'], topics_data['scores'])])
block = f"<div style='flex:1;padding:10px;min-width:250px;'><h3><u>{label}</u></h3><b>Sentiment</b><br>{sentiment}<br><br><b>Entities</b><br>{ents}<br><br><b>Topics</b><br>{topics}</div>"
blocks.append(block)
yield f"<div style='display:flex; gap:20px; justify-content:space-between;'>{''.join(blocks)}</div>"
nlp_btn.click(analyze_caption_pipeline_ui, inputs=[captions_state], outputs=[nlp_out])
# ---------------- Step 5: Visual Question Answering ----------------
gr.Markdown("### Visual Question Answering (VQA)", elem_classes="heading-orange")
with gr.Row():
with gr.Column(scale=1):
vqa_input = gr.Textbox(label="Enter a question about the reference image")
vqa_btn = gr.Button("Get Answer", elem_classes="teal-btn")
with gr.Column(scale=1):
vqa_out = gr.Markdown(label="VQA Output")
def answer_vqa_ui(question, image):
yield "<div class='loading-line'></div>"
ans = answer_vqa(question, image)
yield ans
vqa_btn.click(answer_vqa_ui, inputs=[vqa_input, upload_preview], outputs=[vqa_out])
return demo
# Launch the interface
demo = build_ui_with_custom_ui()
demo.launch()