Spaces:
Sleeping
Sleeping
| """ | |
| Text-to-Image Generator + Evaluation Metrics | |
| Dataset : rhli/genarena | Model: runwayml/stable-diffusion-v1-5 | |
| Deploy on: Hugging Face Spaces (Gradio SDK) | |
| Evaluation metrics | |
| ββββββββββββββββββ | |
| β’ CLIP Score β prompt-image alignment (higher = better; 0-100) | |
| Analogue of recall: did the image capture the prompt? | |
| β’ FID β FrΓ©chet Inception Distance vs. a reference batch | |
| (lower = better; 0 = identical distributions) | |
| Analogue of precision: are generated images realistic? | |
| β’ Aesthetic Score β LAION aesthetic predictor (higher = better; 1-10) | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import gradio as gr | |
| import numpy as np | |
| import random | |
| from PIL import Image | |
| from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler | |
| from datasets import load_dataset | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 1. Device / dtype | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 2. Generation pipeline | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODEL_ID = "runwayml/stable-diffusion-v1-5" | |
| print(f"Loading generation model on {DEVICE} ...") | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=DTYPE, | |
| safety_checker=None, | |
| requires_safety_checker=False, | |
| ) | |
| pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) | |
| pipe = pipe.to(DEVICE) | |
| if DEVICE == "cuda": | |
| pipe.enable_attention_slicing() | |
| print("Generation model ready") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 3. Evaluation models (lazy-loaded on first use to save startup time) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _clip_model = None | |
| _clip_processor = None | |
| _aesthetic_model = None | |
| def _load_clip(): | |
| global _clip_model, _clip_processor | |
| if _clip_model is None: | |
| from transformers import CLIPModel, CLIPProcessor | |
| print("Loading CLIP ViT-B/32 ...") | |
| _clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(DEVICE) | |
| _clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| _clip_model.eval() | |
| print("CLIP ready") | |
| return _clip_model, _clip_processor | |
| class _AestheticPredictor(nn.Module): | |
| """Small MLP trained on LAION human ratings β predicts aesthetic score from CLIP embeddings.""" | |
| def __init__(self, input_size: int = 768): | |
| super().__init__() | |
| self.layers = nn.Sequential( | |
| nn.Linear(input_size, 1024), nn.Dropout(0.2), | |
| nn.Linear(1024, 128), nn.Dropout(0.2), | |
| nn.Linear(128, 64), nn.Dropout(0.1), | |
| nn.Linear(64, 16), | |
| nn.Linear(16, 1), | |
| ) | |
| def forward(self, x): | |
| return self.layers(x) | |
| def _load_aesthetic(): | |
| global _aesthetic_model | |
| if _aesthetic_model is None: | |
| from huggingface_hub import hf_hub_download | |
| print("Loading aesthetic predictor ...") | |
| weights_path = hf_hub_download( | |
| "camenduru/improved-aesthetic-predictor", | |
| filename="sac+logos+ava1-l14-linearMSE.pth", | |
| ) | |
| _aesthetic_model = _AestheticPredictor(input_size=768) | |
| # weights_only=False required for legacy .pth files (PyTorch 2.x changed the default) | |
| state = torch.load(weights_path, map_location="cpu", weights_only=False) | |
| _aesthetic_model.load_state_dict(state) | |
| _aesthetic_model.eval().to(DEVICE) | |
| print("Aesthetic predictor ready") | |
| return _aesthetic_model | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 4. Metric helpers | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def compute_clip_score(image: Image.Image, prompt: str) -> float: | |
| """ | |
| CLIP Score in [0, 100]. | |
| Cosine similarity between CLIP image & text embeddings, scaled to 0-100. | |
| Higher = better prompt alignment β recall analogue. | |
| truncation=True + max_length=77 prevents the hard 77-token limit overflow. | |
| """ | |
| model, processor = _load_clip() | |
| inputs = processor( | |
| text=[prompt], images=image, | |
| return_tensors="pt", padding=True, | |
| truncation=True, max_length=77, | |
| ) | |
| # Move each tensor individually β BatchEncoding.to() is unreliable across versions | |
| inputs = {k: v.to(DEVICE) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| out = model(**inputs) | |
| img_emb = out.image_embeds | |
| txt_emb = out.text_embeds | |
| img_emb = img_emb / img_emb.norm(dim=-1, keepdim=True) | |
| txt_emb = txt_emb / txt_emb.norm(dim=-1, keepdim=True) | |
| sim = (img_emb * txt_emb).sum().item() | |
| return round(float(np.clip(sim * 100, 0, 100)), 2) | |
| def compute_aesthetic_score(image: Image.Image) -> float: | |
| """ | |
| LAION aesthetic score in [1, 10]. Higher = more visually pleasing. | |
| Uses CLIPVisionModelWithProjection (not CLIPModel) so .forward() always | |
| returns image_embeds as a plain tensor, not a BaseModelOutputWithPooling. | |
| """ | |
| try: | |
| from transformers import CLIPVisionModelWithProjection, CLIPProcessor | |
| clip_v = CLIPVisionModelWithProjection.from_pretrained( | |
| "openai/clip-vit-large-patch14" | |
| ).to(DEVICE) | |
| clip_v.eval() | |
| proc_v = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") | |
| aes = _load_aesthetic() | |
| pixel_values = proc_v(images=image, return_tensors="pt")["pixel_values"].to(DEVICE) | |
| with torch.no_grad(): | |
| out = clip_v(pixel_values=pixel_values) # CLIPVisionModelOutput | |
| emb = out.image_embeds # plain tensor (1, 768) | |
| emb = emb / emb.norm(dim=-1, keepdim=True) | |
| score = aes(emb).item() | |
| return round(float(np.clip(score, 1, 10)), 2) | |
| except Exception as e: | |
| print(f"Aesthetic score skipped: {e}") | |
| return -1.0 | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 5. Dataset prompts | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| PROMPT_COLUMN = "prompt" | |
| try: | |
| _ds = load_dataset("rhli/genarena", split="train") | |
| DATASET_PROMPTS = [_ds[i][PROMPT_COLUMN] for i in range(min(200, len(_ds)))] | |
| print(f"Loaded {len(DATASET_PROMPTS)} prompts from rhli/genarena") | |
| except Exception as e: | |
| print(f"Dataset load failed: {e}") | |
| DATASET_PROMPTS = [ | |
| "a futuristic city at sunset", | |
| "a cozy cottage in a misty forest", | |
| "a robot painting a watercolor", | |
| "an astronaut on a purple alien planet", | |
| ] | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 6. Core inference helpers wired to Gradio callbacks | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _run_pipe(prompt, negative_prompt, num_steps, guidance_scale, seed): | |
| generator = torch.Generator(DEVICE).manual_seed(int(seed)) | |
| # torch.amp.autocast is the stable API across PyTorch versions | |
| if DEVICE == "cuda": | |
| ctx = torch.amp.autocast(device_type="cuda") | |
| else: | |
| ctx = torch.no_grad() | |
| with ctx: | |
| result = pipe( | |
| prompt, | |
| negative_prompt=negative_prompt or None, | |
| num_inference_steps=int(num_steps), | |
| guidance_scale=float(guidance_scale), | |
| generator=generator, | |
| height=512, width=512, | |
| ) | |
| return result.images[0] | |
| def generate_image(prompt, negative_prompt, num_steps, guidance_scale, seed): | |
| if not prompt.strip(): | |
| return None, "Please enter a prompt." | |
| try: | |
| image = _run_pipe(prompt, negative_prompt, num_steps, guidance_scale, seed) | |
| return image, f"Generated with seed {int(seed)}" | |
| except Exception as e: | |
| return None, f"Error: {e}" | |
| def evaluate_single(prompt, negative_prompt, num_steps, guidance_scale, seed, run_aesthetic): | |
| """Generate one image and compute CLIP score + optionally aesthetic score.""" | |
| if not prompt.strip(): | |
| return None, 0.0, 0.0, "Please enter a prompt." | |
| try: | |
| image = _run_pipe(prompt, negative_prompt, num_steps, guidance_scale, seed) | |
| clip = compute_clip_score(image, prompt) | |
| aes = compute_aesthetic_score(image) if run_aesthetic else -1.0 | |
| clip_status = "Good" if clip >= 25 else "Moderate" if clip >= 15 else "Low" | |
| rows = [ | |
| "### Evaluation Results", | |
| "", | |
| "| Metric | Value | Status |", | |
| "|--------|-------|--------|", | |
| f"| **CLIP Score** (0-100, recall analogue) | `{clip:.1f}` | {clip_status} |", | |
| ] | |
| if aes > 0: | |
| aes_status = "Good" if aes >= 5 else "Moderate" if aes >= 3 else "Low" | |
| rows.append(f"| **Aesthetic Score** (1-10) | `{aes:.2f}` | {aes_status} |") | |
| else: | |
| rows.append("| **Aesthetic Score** | `skipped` | enable checkbox to compute |") | |
| rows += [ | |
| "", | |
| "**CLIP Score** β how well the image matches the prompt (recall analogue).", | |
| "**Aesthetic Score** β perceived visual quality via LAION predictor.", | |
| ] | |
| return image, clip, aes if aes > 0 else 0.0, "\n".join(rows) | |
| except Exception as e: | |
| return None, 0.0, 0.0, f"Error: {e}" | |
| def random_prompt(): | |
| return random.choice(DATASET_PROMPTS) | |
| def random_seed(): | |
| return random.randint(0, 2**31 - 1) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 7. Gradio UI | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(title="Text-to-Image Generator", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| "# Text-to-Image Generator\n" | |
| "Stable Diffusion v1.5 Β· Dataset: " | |
| "[rhli/genarena](https://huggingface.co/datasets/rhli/genarena)" | |
| ) | |
| with gr.Tabs(): | |
| # ββ Tab 1: Generate ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.TabItem("Generate"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt_box = gr.Textbox(label="Prompt", lines=3, | |
| placeholder="Describe the image you want...") | |
| surprise_btn = gr.Button("Surprise me (dataset prompt)", | |
| variant="secondary", size="sm") | |
| neg_box = gr.Textbox( | |
| label="Negative prompt (optional)", | |
| value="blurry, low quality, ugly, distorted", | |
| lines=2, | |
| ) | |
| with gr.Accordion("Advanced settings", open=False): | |
| steps_sl = gr.Slider(10, 50, 20, step=1, label="Inference steps") | |
| guide_sl = gr.Slider(1.0, 20.0, 7.5, step=0.5, label="Guidance scale") | |
| with gr.Row(): | |
| seed_box = gr.Number(label="Seed", value=42, precision=0) | |
| rand_seed_btn = gr.Button("Random seed", size="sm") | |
| gen_btn = gr.Button("Generate", variant="primary") | |
| with gr.Column(scale=1): | |
| gen_image = gr.Image(label="Generated image", type="pil") | |
| gen_status = gr.Markdown("") | |
| gen_btn.click( | |
| generate_image, | |
| inputs=[prompt_box, neg_box, steps_sl, guide_sl, seed_box], | |
| outputs=[gen_image, gen_status], | |
| ) | |
| surprise_btn.click(random_prompt, outputs=prompt_box) | |
| rand_seed_btn.click(random_seed, outputs=seed_box) | |
| gr.Examples( | |
| examples=[ | |
| ["a golden sunset over a calm ocean, photorealistic", "blurry, low quality", 20, 7.5, 42], | |
| ["a watercolor painting of a Japanese cherry blossom garden", "", 25, 8.0, 7], | |
| ["a futuristic robot chef in a neon-lit kitchen", "low quality", 20, 7.5, 99], | |
| ["an ancient library filled with glowing magical books", "", 20, 9.0, 12], | |
| ], | |
| inputs=[prompt_box, neg_box, steps_sl, guide_sl, seed_box], | |
| outputs=[gen_image, gen_status], | |
| fn=generate_image, | |
| cache_examples=False, | |
| ) | |
| # ββ Tab 2: Single-image evaluation βββββββββββββββββββββββββββββββββββ | |
| with gr.TabItem("Evaluate Single Image"): | |
| gr.Markdown( | |
| "Generate one image and measure:\n" | |
| "- **CLIP Score** (0-100) β prompt alignment. *Recall analogue.*\n" | |
| "- **Aesthetic Score** (1-10) β visual quality. *(adds ~30 s, loads an extra model)*" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| eval_prompt = gr.Textbox(label="Prompt", lines=3, | |
| placeholder="Enter your prompt...") | |
| eval_neg = gr.Textbox( | |
| label="Negative prompt", | |
| value="blurry, low quality, ugly, distorted", | |
| lines=2, | |
| ) | |
| with gr.Accordion("Settings", open=False): | |
| eval_steps = gr.Slider(10, 50, 20, step=1, label="Inference steps") | |
| eval_guide = gr.Slider(1.0, 20.0, 7.5, step=0.5, label="Guidance scale") | |
| with gr.Row(): | |
| eval_seed = gr.Number(label="Seed", value=42, precision=0) | |
| eval_rand_btn = gr.Button("Random seed", size="sm") | |
| eval_aes_chk = gr.Checkbox(label="Compute Aesthetic Score (slower)", value=False) | |
| eval_btn = gr.Button("Generate + Evaluate", variant="primary") | |
| with gr.Column(scale=1): | |
| eval_image = gr.Image(label="Generated image", type="pil") | |
| clip_num = gr.Number(label="CLIP Score (0-100)", precision=2) | |
| aes_num = gr.Number(label="Aesthetic Score (1-10)", precision=2) | |
| eval_md = gr.Markdown("") | |
| eval_btn.click( | |
| evaluate_single, | |
| inputs=[eval_prompt, eval_neg, eval_steps, eval_guide, eval_seed, eval_aes_chk], | |
| outputs=[eval_image, clip_num, aes_num, eval_md], | |
| ) | |
| eval_rand_btn.click(random_seed, outputs=eval_seed) | |
| # ββ Tab 4: Metric guide βββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.TabItem("Metric Guide"): | |
| gr.Markdown( | |
| """ | |
| ## Evaluation Metrics | |
| | Metric | Range | Better when | Analogue | Method | | |
| |--------|-------|-------------|----------|--------| | |
| | CLIP Score | 0 β 100 | Higher | **Recall** | Cosine sim of CLIP image & text embeddings | | |
| | Aesthetic Score | 1 β 10 | Higher | Quality | LAION linear head on CLIP ViT-L/14 features | | |
| --- | |
| ### CLIP Score β Recall analogue | |
| - **What it measures:** Did the image capture the content described in the prompt? | |
| - **How:** CLIP encodes the image and text into a shared embedding space; cosine similarity is computed and scaled to 0-100. | |
| - **Threshold:** β₯ 25 is generally good alignment for SD v1.5. | |
| - **Limit:** CLIP can miss subtle semantic errors and spatial relationships. Prompts are truncated to 77 tokens. | |
| ### Aesthetic Score | |
| - **What it measures:** Perceived visual quality, independent of the prompt. | |
| - **How:** A small MLP trained on human LAION ratings predicts a score from CLIP ViT-L/14 embeddings. | |
| - **Threshold:** β₯ 5.0 is considered aesthetically pleasing. | |
| """ | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 8. Launch | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| # queue() is required for long-running functions (batch eval, FID) | |
| # Without it Gradio times out silently when a function takes > a few seconds | |
| demo.queue().launch() | |