Spaces:
Sleeping
Sleeping
| """ | |
| evaluate.py — Model evaluation metrics. | |
| Computes FID score and CLIP similarity to evaluate model quality. | |
| """ | |
| import argparse | |
| import sys | |
| import os | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from pathlib import Path | |
| from tqdm import tqdm | |
| from torchvision import transforms | |
| def compute_clip_similarity(image1: Image.Image, image2: Image.Image, prompt: str): | |
| """Compute CLIP-based similarity between edited image and prompt.""" | |
| try: | |
| from transformers import CLIPProcessor, CLIPModel | |
| model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
| processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| inputs = processor(text=[prompt], images=[image2], return_tensors="pt", padding=True) | |
| outputs = model(**inputs) | |
| similarity = outputs.logits_per_image.item() / 100.0 | |
| return similarity | |
| except Exception as e: | |
| print(f"CLIP evaluation failed: {e}") | |
| return None | |
| def evaluate_samples(sample_dir: str, prompts_file: str = None): | |
| """Evaluate all generated samples in a directory.""" | |
| sample_dir = Path(sample_dir) | |
| samples = sorted(sample_dir.glob("*.png")) | |
| if not samples: | |
| print(f"No samples found in {sample_dir}") | |
| return | |
| print(f"Evaluating {len(samples)} samples...") | |
| # Basic statistics | |
| sizes = [] | |
| for img_path in samples: | |
| img = Image.open(img_path) | |
| sizes.append(img.size) | |
| print(f" Image sizes: {set(sizes)}") | |
| print(f" Total samples: {len(samples)}") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Evaluate model outputs") | |
| parser.add_argument("--sample-dir", type=str, required=True) | |
| parser.add_argument("--prompts", type=str, default=None) | |
| args = parser.parse_args() | |
| evaluate_samples(args.sample_dir, args.prompts) | |