MSD / eval.py
szxllm's picture
Update eval.py
c14bad3 verified
import argparse
import json
import os
import random
import shutil
import re # Import regex for parsing filenames
import traceback # For potentially more detailed error logging
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from accelerate import Accelerator
from cleanfid import fid
from diffusers import (AutoencoderKL, DDPMScheduler, StableDiffusionPipeline,
UNet2DConditionModel)
# Make sure safetensors is installed: pip install safetensors
from safetensors.torch import load_file as load_safetensors # Use safetensors loading for .safetensors
from PIL import Image, UnidentifiedImageError
from torchmetrics.multimodal import CLIPScore
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
# --- Import Mamba Utilities ---
try:
from msd_utils import (MambaSequentialBlock,
replace_unet_self_attention_with_mamba)
print("Successfully imported Mamba utils from msd_utils.py")
except ImportError as e:
print(f"ERROR: Failed to import from msd_utils.py: {e}")
print("Ensure 'msd_utils.py' is in the current directory or your PYTHONPATH.")
exit(1)
def parse_args():
parser = argparse.ArgumentParser(description="Evaluate Mamba-SD model with FID and CLIP-T on COCO val2014.")
# --- Paths ---
parser.add_argument(
"--model_checkpoint_path", type=str, required=True,
help="Path to the trained Mamba-SD checkpoint directory (e.g., /root/mamba/.../checkpoint-31000)."
)
parser.add_argument(
"--unet_subfolder", type=str, default="unet_mamba",
help="Name of the subfolder within the checkpoint containing the trained UNet weights (e.g., 'unet_mamba', 'unet_mamba_final')."
)
parser.add_argument(
"--base_model_name_or_path", type=str, default="runwayml/stable-diffusion-v1-5",
help="Path or Hub ID of the base Stable Diffusion model (used for VAE, text encoder, etc.)."
)
parser.add_argument(
"--coco_val_images_path", type=str, required=True,
help="Path to the COCO val2014 image directory (e.g., /root/mamba/val2014)."
)
parser.add_argument(
"--coco_annotations_path", type=str, required=True,
help="Path to the COCO annotations directory containing 'captions_val2014.json'."
)
parser.add_argument(
"--output_dir", type=str, default="./mamba_sd_eval_output",
help="Directory to save generated images and evaluation results."
)
# --- Evaluation Parameters ---
parser.add_argument(
"--num_samples", type=int, default=5000,
help="Number of validation samples to generate/evaluate. Set to -1 to use all. Must match existing samples if skipping generation."
)
parser.add_argument(
"--batch_size", type=int, default=16,
help="Batch size for image generation (if generating)."
)
parser.add_argument(
"--guidance_scale", type=float, default=7.5,
help="Guidance scale for generation (if generating)."
)
parser.add_argument(
"--num_inference_steps", type=int, default=50,
help="Number of DDIM inference steps (if generating)."
)
parser.add_argument(
"--seed", type=int, default=42,
help="Random seed for generation (if generating) and sampling."
)
parser.add_argument(
"--fid_clip_model_name", type=str, default="ViT-L/14",
help="CLIP model variant to use for FID computation with clean-fid."
)
parser.add_argument(
"--clip_score_model_name", type=str, default="openai/clip-vit-large-patch14",
help="CLIP model variant to use for CLIPScore computation with torchmetrics."
)
# --- Control Flags ---
parser.add_argument(
"--skip_generation", action="store_true",
help="If set, skip image generation and attempt to load existing images from output_dir for metric calculation."
)
# --- Mamba Parameters (MUST match training) ---
parser.add_argument("--mamba_d_state", type=int, default=16, help="Mamba ssm state dimension used during training.")
parser.add_argument("--mamba_d_conv", type=int, default=4, help="Mamba ssm convolution dimension used during training.")
parser.add_argument("--mamba_expand", type=int, default=2, help="Mamba ssm expansion factor used during training.")
# --- Performance ---
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["no", "fp16", "bf16"], help="Whether to use mixed precision during generation.")
args = parser.parse_args()
return args
def load_coco_data(ann_file_path, img_dir_path, num_samples=-1, seed=42):
"""Loads COCO val2014 captions and maps them to image file paths."""
print(f"Loading COCO annotations from: {ann_file_path}")
if not os.path.exists(ann_file_path):
raise FileNotFoundError(f"Annotation file not found: {ann_file_path}")
if not os.path.isdir(img_dir_path):
raise NotADirectoryError(f"Image directory not found: {img_dir_path}")
with open(ann_file_path, 'r') as f:
data = json.load(f)
annotations = data['annotations']
images_info = {img['id']: img['file_name'] for img in data['images']}
captions_by_image = {}
for ann in annotations:
img_id = ann['image_id']
if img_id in images_info:
if img_id not in captions_by_image:
captions_by_image[img_id] = []
captions_by_image[img_id].append(ann['caption'])
evaluation_pairs = []
for img_id, captions in captions_by_image.items():
img_filename = images_info[img_id]
img_path = os.path.join(img_dir_path, img_filename)
if os.path.exists(img_path):
evaluation_pairs.append({"image_path": img_path, "caption": captions[0], "image_id": img_id})
# else:
# print(f"Warning: Image file not found for image_id {img_id}: {img_path}") # Can be noisy
print(f"Found {len(evaluation_pairs)} unique images with captions in source.")
original_num_pairs = len(evaluation_pairs)
if num_samples > 0 and num_samples < len(evaluation_pairs):
print(f"Selecting {num_samples} samples using seed {seed}...")
random.seed(seed)
evaluation_pairs = random.sample(evaluation_pairs, num_samples)
print(f"Selected {len(evaluation_pairs)} samples for evaluation.")
elif num_samples == -1:
print(f"Using all {len(evaluation_pairs)} available samples.")
else:
print(f"Number of samples ({num_samples}) is invalid or >= total. Using all {len(evaluation_pairs)} samples.")
if not evaluation_pairs:
raise ValueError("No valid image-caption pairs selected or found. Check paths, annotation file format, and --num_samples.")
return evaluation_pairs, len(evaluation_pairs) # Return pairs and the count
def main():
args = parse_args()
# --- Accelerator Setup ---
accelerator = Accelerator(mixed_precision=args.mixed_precision if not args.skip_generation else "no") # Don't need mixed precision if only calculating metrics
device = accelerator.device
print(f"Using device: {device}, Mixed Precision: {accelerator.mixed_precision}")
# --- Prepare Output Directory ---
output_dir = Path(args.output_dir)
generated_images_dir = output_dir / "generated_images"
ground_truth_dir = output_dir / "ground_truth_images"
results_file = output_dir / "results.json"
if accelerator.is_main_process:
output_dir.mkdir(parents=True, exist_ok=True)
generated_images_dir.mkdir(exist_ok=True)
ground_truth_dir.mkdir(exist_ok=True)
print(f"Output directory: {output_dir}")
print(f"Generated images dir: {generated_images_dir}")
print(f"Ground truth images dir: {ground_truth_dir}")
# --- Load COCO Data (Needed in both generation and skip scenarios) ---
# Only the main process needs to load the full list initially
all_evaluation_pairs = []
num_selected_samples = 0
if accelerator.is_main_process:
all_evaluation_pairs, num_selected_samples = load_coco_data(
os.path.join(args.coco_annotations_path, "captions_val2014.json"),
args.coco_val_images_path,
args.num_samples,
args.seed
)
print(f"Target number of samples for evaluation: {num_selected_samples}")
# Create a lookup for finding data by image_id, useful if skipping generation
data_lookup = {item['image_id']: item for item in all_evaluation_pairs}
accelerator.wait_for_everyone() # Ensure all processes know the dirs exist
# --- Initialize lists ---
generated_image_paths = []
ground_truth_image_paths = []
captions_used = []
# --- Generation or Loading ---
if not args.skip_generation:
# --- Load Models (only needed for generation) ---
print("Loading models for generation...")
print("Loading base models (Tokenizer, Text Encoder, VAE)...")
tokenizer = CLIPTokenizer.from_pretrained(args.base_model_name_or_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(args.base_model_name_or_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(args.base_model_name_or_path, subfolder="vae")
scheduler = DDPMScheduler.from_pretrained(args.base_model_name_or_path, subfolder="scheduler")
print("Loading base U-Net config and creating new Mamba-U-Net structure...")
unet_config = UNet2DConditionModel.load_config(args.base_model_name_or_path, subfolder="unet")
unet = UNet2DConditionModel.from_config(unet_config)
print("Replacing Self-Attention with Mamba blocks...")
mamba_kwargs = {'d_state': args.mamba_d_state, 'd_conv': args.mamba_d_conv, 'expand': args.mamba_expand}
try:
unet = replace_unet_self_attention_with_mamba(unet, mamba_kwargs)
print("Mamba replacement successful.")
except Exception as e:
print(f"ERROR during Mamba replacement: {e}")
exit(1)
# --- Load Trained Mamba U-Net Weights ---
unet_weights_path = Path(args.model_checkpoint_path) / args.unet_subfolder
print(f"Loading trained Mamba U-Net weights from: {unet_weights_path}")
if not unet_weights_path.exists():
print(f"ERROR: Trained UNet subfolder not found: {unet_weights_path}")
exit(1)
try:
# --- CORRECTED FILE CHECKING LOGIC ---
unet_state_dict_path_safetensors_specific = unet_weights_path / "diffusion_pytorch_model.safetensors"
unet_state_dict_path_bin = unet_weights_path / "diffusion_pytorch_model.bin"
unet_state_dict_path_safetensors_generic = unet_weights_path / "model.safetensors"
unet_state_dict_path = None
state_dict = None
if unet_state_dict_path_safetensors_specific.exists():
print(f"Found specific safetensors file: {unet_state_dict_path_safetensors_specific}")
unet_state_dict_path = unet_state_dict_path_safetensors_specific
state_dict = load_safetensors(unet_state_dict_path, device="cpu") # Load using safetensors library
elif unet_state_dict_path_bin.exists():
print(f"Found bin file: {unet_state_dict_path_bin}")
unet_state_dict_path = unet_state_dict_path_bin
state_dict = torch.load(unet_state_dict_path, map_location="cpu") # Load using torch
elif unet_state_dict_path_safetensors_generic.exists():
print(f"Found generic safetensors file: {unet_state_dict_path_safetensors_generic}")
unet_state_dict_path = unet_state_dict_path_safetensors_generic
state_dict = load_safetensors(unet_state_dict_path, device="cpu") # Load using safetensors library
else:
raise FileNotFoundError(f"Could not find 'diffusion_pytorch_model.safetensors', 'diffusion_pytorch_model.bin', or 'model.safetensors' in UNet subfolder: {unet_weights_path}")
# --- END OF CORRECTION ---
# Load the state dict into the model
unet.load_state_dict(state_dict)
print(f"Successfully loaded trained U-Net weights from {unet_state_dict_path}.")
del state_dict # Free memory
except FileNotFoundError as fnf_error:
print(f"ERROR: {fnf_error}")
exit(1)
except Exception as e:
print(f"ERROR loading U-Net weights from {unet_weights_path}: {e}")
print(traceback.format_exc()) # Print full traceback for debugging other errors
exit(1)
print("Creating Stable Diffusion Pipeline...")
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
print(f"Setting pipeline dtype to: {weight_dtype}")
vae.to(device=device, dtype=weight_dtype)
text_encoder.to(device=device, dtype=weight_dtype)
unet.to(device=device, dtype=weight_dtype)
pipeline = StableDiffusionPipeline(
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler,
safety_checker=None, feature_extractor=None, requires_safety_checker=False,
)
try:
import xformers
pipeline.enable_xformers_memory_efficient_attention()
print("Enabled xformers memory efficient attention.")
except ImportError:
print("xformers not installed. Running without it.")
generator = torch.Generator(device=device).manual_seed(args.seed)
# --- Generate Images (Main Process Only) ---
if accelerator.is_main_process:
print(f"Generating {num_selected_samples} images...")
pipeline.set_progress_bar_config(disable=False)
for i in tqdm(range(0, num_selected_samples, args.batch_size), desc="Generating Batches"):
batch_data = all_evaluation_pairs[i : i + args.batch_size]
if not batch_data: continue
prompts = [item["caption"] for item in batch_data]
gt_paths = [item["image_path"] for item in batch_data]
image_ids = [item["image_id"] for item in batch_data]
with torch.no_grad(), torch.autocast(device_type=accelerator.device.type, dtype=weight_dtype if weight_dtype != torch.float32 else None, enabled=accelerator.mixed_precision != "no"):
images = pipeline(
prompt=prompts, guidance_scale=args.guidance_scale,
num_inference_steps=args.num_inference_steps, generator=generator
).images
for idx, (pil_image, gt_path, img_id, prompt) in enumerate(zip(images, gt_paths, image_ids, prompts)):
generated_filename = f"gen_{img_id}.png"
gt_filename = f"gt_{img_id}.png"
gen_save_path = generated_images_dir / generated_filename
gt_save_path = ground_truth_dir / gt_filename
try:
pil_image.save(gen_save_path)
generated_image_paths.append(str(gen_save_path))
# Copy ground truth image
if not gt_save_path.exists(): # Avoid re-copying if already there
shutil.copyfile(gt_path, gt_save_path)
ground_truth_image_paths.append(str(gt_save_path)) # Add even if it existed
captions_used.append(prompt)
except Exception as e:
print(f"Warning: Could not save generated image or copy GT for image_id {img_id}: {e}")
print(f"Finished generation. Generated {len(generated_image_paths)} images.")
# Final check after loop
if len(generated_image_paths) != num_selected_samples:
print(f"Warning: Number of generated images ({len(generated_image_paths)}) does not match target ({num_selected_samples}). Check for errors during saving.")
else: # --- Skip Generation: Load Existing Images ---
if accelerator.is_main_process:
print(f"Skipping generation. Loading existing images from {generated_images_dir} and {ground_truth_dir}")
if not generated_images_dir.exists() or not ground_truth_dir.exists():
print(f"Error: Cannot skip generation. Directory not found: {generated_images_dir} or {ground_truth_dir}")
exit(1)
# Regex to extract image ID from filename
gen_pattern = re.compile(r"gen_(\d+)\.(png|jpg|jpeg|webp)$", re.IGNORECASE)
found_gen_files = list(generated_images_dir.glob("gen_*.*"))
print(f"Found {len(found_gen_files)} potential generated image files.")
for gen_file_path in tqdm(found_gen_files, desc="Scanning existing images"):
match = gen_pattern.match(gen_file_path.name)
if match:
try:
image_id = int(match.group(1))
if image_id in data_lookup:
original_data = data_lookup[image_id]
gt_filename = f"gt_{image_id}.png" # Assume GT is png for consistency
gt_path_expected = ground_truth_dir / gt_filename
gt_path_original = original_data["image_path"]
caption = original_data["caption"]
# Check if GT image exists in target dir, copy if not
if not gt_path_expected.exists():
if os.path.exists(gt_path_original):
print(f"Copying missing GT image: {gt_filename}")
shutil.copyfile(gt_path_original, gt_path_expected)
else:
print(f"Warning: Cannot find original GT image to copy: {gt_path_original}")
continue # Skip if GT is missing
# Only add if both gen and GT exist (or GT was copied)
if gt_path_expected.exists():
generated_image_paths.append(str(gen_file_path))
ground_truth_image_paths.append(str(gt_path_expected))
captions_used.append(caption)
else:
print(f"Warning: Skipping image_id {image_id} because ground truth image could not be found/copied to {gt_path_expected}")
else:
print(f"Warning: Found generated image {gen_file_path.name} but its ID {image_id} is not in the selected COCO samples list. Skipping.")
except ValueError:
print(f"Warning: Could not parse image ID from filename {gen_file_path.name}. Skipping.")
except Exception as e:
print(f"Warning: Error processing existing file {gen_file_path}: {e}")
# else:
# print(f"Debug: Filename {gen_file_path.name} did not match pattern.")
print(f"Loaded {len(generated_image_paths)} existing generated images and corresponding GT paths/captions.")
if len(generated_image_paths) == 0:
print("Error: No generated images found in the specified directory matching the expected format (gen_ID.png/jpg...). Cannot calculate metrics.")
exit(1)
elif len(generated_image_paths) != num_selected_samples:
print(f"Warning: Number of loaded images ({len(generated_image_paths)}) does not match the expected number of samples ({num_selected_samples}). Metrics will be calculated on the loaded images.")
print("This might happen if generation was interrupted or if --num_samples differs from the initial generation.")
# --- Wait for main process to finish generation OR loading ---
accelerator.wait_for_everyone() # Important barrier
# --- Calculate Metrics (Main Process Only) ---
fid_score = None
clip_t_score = None
# Ensure lists are populated (either by generation or loading) before metrics
if accelerator.is_main_process and generated_image_paths and ground_truth_image_paths:
print(f"\nProceeding to calculate metrics using {len(generated_image_paths)} image pairs.")
print("\n--- Calculating FID Score ---")
try:
# Ensure both directories contain images before calculating
if not any(ground_truth_dir.iterdir()):
print(f"Error: Ground truth directory '{ground_truth_dir}' is empty. Cannot calculate FID.")
fid_score = "Error - GT dir empty"
elif not any(generated_images_dir.iterdir()):
print(f"Error: Generated images directory '{generated_images_dir}' is empty. Cannot calculate FID.")
fid_score = "Error - Gen dir empty"
else:
fid_score = fid.compute_fid(
str(generated_images_dir),
str(ground_truth_dir),
mode="clean",
num_workers=min(os.cpu_count(), 8) # Use reasonable number of workers
)
fid_score=fid_score
print(f"FID Score: {fid_score:.2f}")
except Exception as e:
print(f"Error calculating FID: {e}")
fid_score = "Error"
print("\n--- Calculating CLIP-T Score ---")
try:
clip_scorer = CLIPScore(model_name_or_path=args.clip_score_model_name).to(device)
# clip_scores = [] # Not needed
clip_batch_size = 64 # Adjust based on GPU memory
for i in tqdm(range(0, len(generated_image_paths), clip_batch_size), desc="Calculating CLIP Scores"):
gen_paths_batch = generated_image_paths[i : i + clip_batch_size]
captions_batch = captions_used[i : i + clip_batch_size]
if not gen_paths_batch: continue
images_batch = []
valid_captions_batch = []
for img_path, caption in zip(gen_paths_batch, captions_batch):
try:
# Add check for file existence and basic PIL load
if not os.path.exists(img_path):
print(f"Warning: CLIP - Image file not found: {img_path}. Skipping.")
continue
img = Image.open(img_path).convert("RGB")
images_batch.append(img)
valid_captions_batch.append(caption)
except UnidentifiedImageError:
print(f"Warning: CLIP - Cannot identify image file (corrupted?): {img_path}. Skipping.")
continue
except Exception as img_err:
print(f"Warning: CLIP - Skipping image due to load error: {img_path} - {img_err}")
continue
if not images_batch: continue
image_tensors = [torch.tensor(np.array(img)).permute(2, 0, 1) for img in images_batch]
# Move tensors to the correct device for the metric
image_tensors_dev = [t.to(device) for t in image_tensors]
# Update metric - ensure inputs are on the same device as the metric module
clip_scorer.update(image_tensors_dev, valid_captions_batch)
# Clear tensor list to potentially free memory
del image_tensors_dev, image_tensors, images_batch # Explicitly delete
if torch.cuda.is_available():
torch.cuda.empty_cache() # Be cautious with explicit cache clearing
final_clip_score = clip_scorer.compute().item()
clip_t_score = final_clip_score/100
print(f"CLIP-T Score : {clip_t_score:.3f}")
except Exception as e:
print(f"Error calculating CLIP-T score: {e}")
print(traceback.format_exc()) # Print traceback for CLIP errors too
clip_t_score = "Error"
# --- Save Results ---
results = {
"model_checkpoint": args.model_checkpoint_path,
"unet_subfolder": args.unet_subfolder,
"num_samples_target": num_selected_samples,
"num_samples_evaluated": len(generated_image_paths), # Actual number used
"coco_val_images_path": args.coco_val_images_path,
"generation_skipped": args.skip_generation,
"guidance_scale": args.guidance_scale if not args.skip_generation else "N/A (skipped generation)",
"num_inference_steps": args.num_inference_steps if not args.skip_generation else "N/A (skipped generation)",
"seed": args.seed,
"mixed_precision": args.mixed_precision if not args.skip_generation else "N/A (skipped generation)",
"fid_score": fid_score,
"clip_t_score": clip_t_score,
"fid_args": { "gen_dir": str(generated_images_dir), "gt_dir": str(ground_truth_dir) },
"clip_score_args": { "model_name": args.clip_score_model_name }
}
results_file_path = output_dir / "results.json"
with open(results_file_path, 'w') as f:
json.dump(results, f, indent=4)
print(f"\nResults saved to: {results_file_path}")
elif accelerator.is_main_process:
print("\nSkipping metric calculation because image lists are empty (check generation/loading steps).")
print("\nEvaluation finished.")
if __name__ == "__main__":
main()