salu_Image_Editter / scripts /evaluate.py
Raghava Pulugu
Clean deployment
cad10d9
Raw
History Blame Contribute Delete
1.95 kB
"""
evaluate.py — Model evaluation metrics.
Computes FID score and CLIP similarity to evaluate model quality.
"""
import argparse
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
import numpy as np
from PIL import Image
from pathlib import Path
from tqdm import tqdm
from torchvision import transforms
def compute_clip_similarity(image1: Image.Image, image2: Image.Image, prompt: str):
"""Compute CLIP-based similarity between edited image and prompt."""
try:
from transformers import CLIPProcessor, CLIPModel
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
inputs = processor(text=[prompt], images=[image2], return_tensors="pt", padding=True)
outputs = model(**inputs)
similarity = outputs.logits_per_image.item() / 100.0
return similarity
except Exception as e:
print(f"CLIP evaluation failed: {e}")
return None
def evaluate_samples(sample_dir: str, prompts_file: str = None):
"""Evaluate all generated samples in a directory."""
sample_dir = Path(sample_dir)
samples = sorted(sample_dir.glob("*.png"))
if not samples:
print(f"No samples found in {sample_dir}")
return
print(f"Evaluating {len(samples)} samples...")
# Basic statistics
sizes = []
for img_path in samples:
img = Image.open(img_path)
sizes.append(img.size)
print(f" Image sizes: {set(sizes)}")
print(f" Total samples: {len(samples)}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Evaluate model outputs")
parser.add_argument("--sample-dir", type=str, required=True)
parser.add_argument("--prompts", type=str, default=None)
args = parser.parse_args()
evaluate_samples(args.sample_dir, args.prompts)