ARIN460-FINALv2 / app.py
etchisone's picture
Upload 2 files
cfef6ce verified
"""
Text-to-Image Generator + Evaluation Metrics
Dataset : rhli/genarena | Model: runwayml/stable-diffusion-v1-5
Deploy on: Hugging Face Spaces (Gradio SDK)
Evaluation metrics
──────────────────
β€’ CLIP Score – prompt-image alignment (higher = better; 0-100)
Analogue of recall: did the image capture the prompt?
β€’ FID – FrΓ©chet Inception Distance vs. a reference batch
(lower = better; 0 = identical distributions)
Analogue of precision: are generated images realistic?
β€’ Aesthetic Score – LAION aesthetic predictor (higher = better; 1-10)
"""
import torch
import torch.nn as nn
import gradio as gr
import numpy as np
import random
from PIL import Image
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
from datasets import load_dataset
# ─────────────────────────────────────────────────────────────────────────────
# 1. Device / dtype
# ─────────────────────────────────────────────────────────────────────────────
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
# ─────────────────────────────────────────────────────────────────────────────
# 2. Generation pipeline
# ─────────────────────────────────────────────────────────────────────────────
MODEL_ID = "runwayml/stable-diffusion-v1-5"
print(f"Loading generation model on {DEVICE} ...")
pipe = StableDiffusionPipeline.from_pretrained(
MODEL_ID,
torch_dtype=DTYPE,
safety_checker=None,
requires_safety_checker=False,
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(DEVICE)
if DEVICE == "cuda":
pipe.enable_attention_slicing()
print("Generation model ready")
# ─────────────────────────────────────────────────────────────────────────────
# 3. Evaluation models (lazy-loaded on first use to save startup time)
# ─────────────────────────────────────────────────────────────────────────────
_clip_model = None
_clip_processor = None
_aesthetic_model = None
def _load_clip():
global _clip_model, _clip_processor
if _clip_model is None:
from transformers import CLIPModel, CLIPProcessor
print("Loading CLIP ViT-B/32 ...")
_clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(DEVICE)
_clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
_clip_model.eval()
print("CLIP ready")
return _clip_model, _clip_processor
class _AestheticPredictor(nn.Module):
"""Small MLP trained on LAION human ratings β€” predicts aesthetic score from CLIP embeddings."""
def __init__(self, input_size: int = 768):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(input_size, 1024), nn.Dropout(0.2),
nn.Linear(1024, 128), nn.Dropout(0.2),
nn.Linear(128, 64), nn.Dropout(0.1),
nn.Linear(64, 16),
nn.Linear(16, 1),
)
def forward(self, x):
return self.layers(x)
def _load_aesthetic():
global _aesthetic_model
if _aesthetic_model is None:
from huggingface_hub import hf_hub_download
print("Loading aesthetic predictor ...")
weights_path = hf_hub_download(
"camenduru/improved-aesthetic-predictor",
filename="sac+logos+ava1-l14-linearMSE.pth",
)
_aesthetic_model = _AestheticPredictor(input_size=768)
# weights_only=False required for legacy .pth files (PyTorch 2.x changed the default)
state = torch.load(weights_path, map_location="cpu", weights_only=False)
_aesthetic_model.load_state_dict(state)
_aesthetic_model.eval().to(DEVICE)
print("Aesthetic predictor ready")
return _aesthetic_model
# ─────────────────────────────────────────────────────────────────────────────
# 4. Metric helpers
# ─────────────────────────────────────────────────────────────────────────────
def compute_clip_score(image: Image.Image, prompt: str) -> float:
"""
CLIP Score in [0, 100].
Cosine similarity between CLIP image & text embeddings, scaled to 0-100.
Higher = better prompt alignment β€” recall analogue.
truncation=True + max_length=77 prevents the hard 77-token limit overflow.
"""
model, processor = _load_clip()
inputs = processor(
text=[prompt], images=image,
return_tensors="pt", padding=True,
truncation=True, max_length=77,
)
# Move each tensor individually β€” BatchEncoding.to() is unreliable across versions
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
with torch.no_grad():
out = model(**inputs)
img_emb = out.image_embeds
txt_emb = out.text_embeds
img_emb = img_emb / img_emb.norm(dim=-1, keepdim=True)
txt_emb = txt_emb / txt_emb.norm(dim=-1, keepdim=True)
sim = (img_emb * txt_emb).sum().item()
return round(float(np.clip(sim * 100, 0, 100)), 2)
def compute_aesthetic_score(image: Image.Image) -> float:
"""
LAION aesthetic score in [1, 10]. Higher = more visually pleasing.
Uses CLIPVisionModelWithProjection (not CLIPModel) so .forward() always
returns image_embeds as a plain tensor, not a BaseModelOutputWithPooling.
"""
try:
from transformers import CLIPVisionModelWithProjection, CLIPProcessor
clip_v = CLIPVisionModelWithProjection.from_pretrained(
"openai/clip-vit-large-patch14"
).to(DEVICE)
clip_v.eval()
proc_v = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
aes = _load_aesthetic()
pixel_values = proc_v(images=image, return_tensors="pt")["pixel_values"].to(DEVICE)
with torch.no_grad():
out = clip_v(pixel_values=pixel_values) # CLIPVisionModelOutput
emb = out.image_embeds # plain tensor (1, 768)
emb = emb / emb.norm(dim=-1, keepdim=True)
score = aes(emb).item()
return round(float(np.clip(score, 1, 10)), 2)
except Exception as e:
print(f"Aesthetic score skipped: {e}")
return -1.0
# ─────────────────────────────────────────────────────────────────────────────
# 5. Dataset prompts
# ─────────────────────────────────────────────────────────────────────────────
PROMPT_COLUMN = "prompt"
try:
_ds = load_dataset("rhli/genarena", split="train")
DATASET_PROMPTS = [_ds[i][PROMPT_COLUMN] for i in range(min(200, len(_ds)))]
print(f"Loaded {len(DATASET_PROMPTS)} prompts from rhli/genarena")
except Exception as e:
print(f"Dataset load failed: {e}")
DATASET_PROMPTS = [
"a futuristic city at sunset",
"a cozy cottage in a misty forest",
"a robot painting a watercolor",
"an astronaut on a purple alien planet",
]
# ─────────────────────────────────────────────────────────────────────────────
# 6. Core inference helpers wired to Gradio callbacks
# ─────────────────────────────────────────────────────────────────────────────
def _run_pipe(prompt, negative_prompt, num_steps, guidance_scale, seed):
generator = torch.Generator(DEVICE).manual_seed(int(seed))
# torch.amp.autocast is the stable API across PyTorch versions
if DEVICE == "cuda":
ctx = torch.amp.autocast(device_type="cuda")
else:
ctx = torch.no_grad()
with ctx:
result = pipe(
prompt,
negative_prompt=negative_prompt or None,
num_inference_steps=int(num_steps),
guidance_scale=float(guidance_scale),
generator=generator,
height=512, width=512,
)
return result.images[0]
def generate_image(prompt, negative_prompt, num_steps, guidance_scale, seed):
if not prompt.strip():
return None, "Please enter a prompt."
try:
image = _run_pipe(prompt, negative_prompt, num_steps, guidance_scale, seed)
return image, f"Generated with seed {int(seed)}"
except Exception as e:
return None, f"Error: {e}"
def evaluate_single(prompt, negative_prompt, num_steps, guidance_scale, seed, run_aesthetic):
"""Generate one image and compute CLIP score + optionally aesthetic score."""
if not prompt.strip():
return None, 0.0, 0.0, "Please enter a prompt."
try:
image = _run_pipe(prompt, negative_prompt, num_steps, guidance_scale, seed)
clip = compute_clip_score(image, prompt)
aes = compute_aesthetic_score(image) if run_aesthetic else -1.0
clip_status = "Good" if clip >= 25 else "Moderate" if clip >= 15 else "Low"
rows = [
"### Evaluation Results",
"",
"| Metric | Value | Status |",
"|--------|-------|--------|",
f"| **CLIP Score** (0-100, recall analogue) | `{clip:.1f}` | {clip_status} |",
]
if aes > 0:
aes_status = "Good" if aes >= 5 else "Moderate" if aes >= 3 else "Low"
rows.append(f"| **Aesthetic Score** (1-10) | `{aes:.2f}` | {aes_status} |")
else:
rows.append("| **Aesthetic Score** | `skipped` | enable checkbox to compute |")
rows += [
"",
"**CLIP Score** β€” how well the image matches the prompt (recall analogue).",
"**Aesthetic Score** β€” perceived visual quality via LAION predictor.",
]
return image, clip, aes if aes > 0 else 0.0, "\n".join(rows)
except Exception as e:
return None, 0.0, 0.0, f"Error: {e}"
def random_prompt():
return random.choice(DATASET_PROMPTS)
def random_seed():
return random.randint(0, 2**31 - 1)
# ─────────────────────────────────────────────────────────────────────────────
# 7. Gradio UI
# ─────────────────────────────────────────────────────────────────────────────
with gr.Blocks(title="Text-to-Image Generator", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"# Text-to-Image Generator\n"
"Stable Diffusion v1.5 Β· Dataset: "
"[rhli/genarena](https://huggingface.co/datasets/rhli/genarena)"
)
with gr.Tabs():
# ── Tab 1: Generate ──────────────────────────────────────────────────
with gr.TabItem("Generate"):
with gr.Row():
with gr.Column(scale=1):
prompt_box = gr.Textbox(label="Prompt", lines=3,
placeholder="Describe the image you want...")
surprise_btn = gr.Button("Surprise me (dataset prompt)",
variant="secondary", size="sm")
neg_box = gr.Textbox(
label="Negative prompt (optional)",
value="blurry, low quality, ugly, distorted",
lines=2,
)
with gr.Accordion("Advanced settings", open=False):
steps_sl = gr.Slider(10, 50, 20, step=1, label="Inference steps")
guide_sl = gr.Slider(1.0, 20.0, 7.5, step=0.5, label="Guidance scale")
with gr.Row():
seed_box = gr.Number(label="Seed", value=42, precision=0)
rand_seed_btn = gr.Button("Random seed", size="sm")
gen_btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=1):
gen_image = gr.Image(label="Generated image", type="pil")
gen_status = gr.Markdown("")
gen_btn.click(
generate_image,
inputs=[prompt_box, neg_box, steps_sl, guide_sl, seed_box],
outputs=[gen_image, gen_status],
)
surprise_btn.click(random_prompt, outputs=prompt_box)
rand_seed_btn.click(random_seed, outputs=seed_box)
gr.Examples(
examples=[
["a golden sunset over a calm ocean, photorealistic", "blurry, low quality", 20, 7.5, 42],
["a watercolor painting of a Japanese cherry blossom garden", "", 25, 8.0, 7],
["a futuristic robot chef in a neon-lit kitchen", "low quality", 20, 7.5, 99],
["an ancient library filled with glowing magical books", "", 20, 9.0, 12],
],
inputs=[prompt_box, neg_box, steps_sl, guide_sl, seed_box],
outputs=[gen_image, gen_status],
fn=generate_image,
cache_examples=False,
)
# ── Tab 2: Single-image evaluation ───────────────────────────────────
with gr.TabItem("Evaluate Single Image"):
gr.Markdown(
"Generate one image and measure:\n"
"- **CLIP Score** (0-100) β€” prompt alignment. *Recall analogue.*\n"
"- **Aesthetic Score** (1-10) β€” visual quality. *(adds ~30 s, loads an extra model)*"
)
with gr.Row():
with gr.Column(scale=1):
eval_prompt = gr.Textbox(label="Prompt", lines=3,
placeholder="Enter your prompt...")
eval_neg = gr.Textbox(
label="Negative prompt",
value="blurry, low quality, ugly, distorted",
lines=2,
)
with gr.Accordion("Settings", open=False):
eval_steps = gr.Slider(10, 50, 20, step=1, label="Inference steps")
eval_guide = gr.Slider(1.0, 20.0, 7.5, step=0.5, label="Guidance scale")
with gr.Row():
eval_seed = gr.Number(label="Seed", value=42, precision=0)
eval_rand_btn = gr.Button("Random seed", size="sm")
eval_aes_chk = gr.Checkbox(label="Compute Aesthetic Score (slower)", value=False)
eval_btn = gr.Button("Generate + Evaluate", variant="primary")
with gr.Column(scale=1):
eval_image = gr.Image(label="Generated image", type="pil")
clip_num = gr.Number(label="CLIP Score (0-100)", precision=2)
aes_num = gr.Number(label="Aesthetic Score (1-10)", precision=2)
eval_md = gr.Markdown("")
eval_btn.click(
evaluate_single,
inputs=[eval_prompt, eval_neg, eval_steps, eval_guide, eval_seed, eval_aes_chk],
outputs=[eval_image, clip_num, aes_num, eval_md],
)
eval_rand_btn.click(random_seed, outputs=eval_seed)
# ── Tab 4: Metric guide ───────────────────────────────────────────────
with gr.TabItem("Metric Guide"):
gr.Markdown(
"""
## Evaluation Metrics
| Metric | Range | Better when | Analogue | Method |
|--------|-------|-------------|----------|--------|
| CLIP Score | 0 – 100 | Higher | **Recall** | Cosine sim of CLIP image & text embeddings |
| Aesthetic Score | 1 – 10 | Higher | Quality | LAION linear head on CLIP ViT-L/14 features |
---
### CLIP Score β€” Recall analogue
- **What it measures:** Did the image capture the content described in the prompt?
- **How:** CLIP encodes the image and text into a shared embedding space; cosine similarity is computed and scaled to 0-100.
- **Threshold:** β‰₯ 25 is generally good alignment for SD v1.5.
- **Limit:** CLIP can miss subtle semantic errors and spatial relationships. Prompts are truncated to 77 tokens.
### Aesthetic Score
- **What it measures:** Perceived visual quality, independent of the prompt.
- **How:** A small MLP trained on human LAION ratings predicts a score from CLIP ViT-L/14 embeddings.
- **Threshold:** β‰₯ 5.0 is considered aesthetically pleasing.
"""
)
# ─────────────────────────────────────────────────────────────────────────────
# 8. Launch
# ─────────────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
# queue() is required for long-running functions (batch eval, FID)
# Without it Gradio times out silently when a function takes > a few seconds
demo.queue().launch()