import argparse import json import os import random import shutil import re # Import regex for parsing filenames import traceback # For potentially more detailed error logging from pathlib import Path import numpy as np import torch import torch.nn.functional as F from accelerate import Accelerator from cleanfid import fid from diffusers import (AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel) # Make sure safetensors is installed: pip install safetensors from safetensors.torch import load_file as load_safetensors # Use safetensors loading for .safetensors from PIL import Image, UnidentifiedImageError from torchmetrics.multimodal import CLIPScore from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer # --- Import Mamba Utilities --- try: from msd_utils import (MambaSequentialBlock, replace_unet_self_attention_with_mamba) print("Successfully imported Mamba utils from msd_utils.py") except ImportError as e: print(f"ERROR: Failed to import from msd_utils.py: {e}") print("Ensure 'msd_utils.py' is in the current directory or your PYTHONPATH.") exit(1) def parse_args(): parser = argparse.ArgumentParser(description="Evaluate Mamba-SD model with FID and CLIP-T on COCO val2014.") # --- Paths --- parser.add_argument( "--model_checkpoint_path", type=str, default="/root/mamba/sd-mamba-mscoco-urltext-10k-run3/checkpoint-31000", help="Path to the trained Mamba-SD checkpoint directory (e.g., /root/mamba/.../checkpoint-31000)." ) parser.add_argument( "--unet_subfolder", type=str, default="unet_mamba", help="Name of the subfolder within the checkpoint containing the trained UNet weights (e.g., 'unet_mamba', 'unet_mamba_final')." ) parser.add_argument( "--base_model_name_or_path", type=str, default="runwayml/stable-diffusion-v1-5", help="Path or Hub ID of the base Stable Diffusion model (used for VAE, text encoder, etc.)." ) parser.add_argument( "--coco_val_images_path", type=str, default="/root/mamba/val2014", help="Path to the COCO val2014 image directory (e.g., /root/mamba/val2014)." ) parser.add_argument( "--coco_annotations_path", type=str, default="/root/mamba/annotations", help="Path to the COCO annotations directory containing 'captions_val2014.json'." ) parser.add_argument( "--output_dir", type=str, default="./mamba_sd_eval_output", help="Directory to save generated images and evaluation results." ) # --- Evaluation Parameters --- parser.add_argument( "--num_samples", type=int, default=50, help="Number of validation samples to generate/evaluate. Set to -1 to use all. Must match existing samples if skipping generation." ) parser.add_argument( "--batch_size", type=int, default=16, help="Batch size for image generation (if generating)." ) parser.add_argument( "--guidance_scale", type=float, default=7.5, help="Guidance scale for generation (if generating)." ) parser.add_argument( "--num_inference_steps", type=int, default=50, help="Number of DDIM inference steps (if generating)." ) parser.add_argument( "--seed", type=int, default=42, help="Random seed for generation (if generating) and sampling." ) parser.add_argument( "--fid_clip_model_name", type=str, default="ViT-L/14", help="CLIP model variant to use for FID computation with clean-fid." ) parser.add_argument( "--clip_score_model_name", type=str, default="openai/clip-vit-large-patch14", help="CLIP model variant to use for CLIPScore computation with torchmetrics." ) # --- Control Flags --- parser.add_argument( "--skip_generation", action="store_true", help="If set, skip image generation and attempt to load existing images from output_dir for metric calculation." ) # --- Mamba Parameters (MUST match training) --- parser.add_argument("--mamba_d_state", type=int, default=16, help="Mamba ssm state dimension used during training.") parser.add_argument("--mamba_d_conv", type=int, default=4, help="Mamba ssm convolution dimension used during training.") parser.add_argument("--mamba_expand", type=int, default=2, help="Mamba ssm expansion factor used during training.") # --- Performance --- parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["no", "fp16", "bf16"], help="Whether to use mixed precision during generation.") args = parser.parse_args() return args def load_coco_data(ann_file_path, img_dir_path, num_samples=-1, seed=42): """Loads COCO val2014 captions and maps them to image file paths.""" print(f"Loading COCO annotations from: {ann_file_path}") if not os.path.exists(ann_file_path): raise FileNotFoundError(f"Annotation file not found: {ann_file_path}") if not os.path.isdir(img_dir_path): raise NotADirectoryError(f"Image directory not found: {img_dir_path}") with open(ann_file_path, 'r') as f: data = json.load(f) annotations = data['annotations'] images_info = {img['id']: img['file_name'] for img in data['images']} captions_by_image = {} for ann in annotations: img_id = ann['image_id'] if img_id in images_info: if img_id not in captions_by_image: captions_by_image[img_id] = [] captions_by_image[img_id].append(ann['caption']) evaluation_pairs = [] for img_id, captions in captions_by_image.items(): img_filename = images_info[img_id] img_path = os.path.join(img_dir_path, img_filename) if os.path.exists(img_path): evaluation_pairs.append({"image_path": img_path, "caption": captions[0], "image_id": img_id}) # else: # print(f"Warning: Image file not found for image_id {img_id}: {img_path}") # Can be noisy print(f"Found {len(evaluation_pairs)} unique images with captions in source.") original_num_pairs = len(evaluation_pairs) if num_samples > 0 and num_samples < len(evaluation_pairs): print(f"Selecting {num_samples} samples using seed {seed}...") random.seed(seed) evaluation_pairs = random.sample(evaluation_pairs, num_samples) print(f"Selected {len(evaluation_pairs)} samples for evaluation.") elif num_samples == -1: print(f"Using all {len(evaluation_pairs)} available samples.") else: print(f"Number of samples ({num_samples}) is invalid or >= total. Using all {len(evaluation_pairs)} samples.") if not evaluation_pairs: raise ValueError("No valid image-caption pairs selected or found. Check paths, annotation file format, and --num_samples.") return evaluation_pairs, len(evaluation_pairs) # Return pairs and the count def main(): args = parse_args() # --- Accelerator Setup --- accelerator = Accelerator(mixed_precision=args.mixed_precision if not args.skip_generation else "no") # Don't need mixed precision if only calculating metrics device = accelerator.device print(f"Using device: {device}, Mixed Precision: {accelerator.mixed_precision}") # --- Prepare Output Directory --- output_dir = Path(args.output_dir) generated_images_dir = output_dir / "generated_images" ground_truth_dir = output_dir / "ground_truth_images" results_file = output_dir / "results.json" if accelerator.is_main_process: output_dir.mkdir(parents=True, exist_ok=True) generated_images_dir.mkdir(exist_ok=True) ground_truth_dir.mkdir(exist_ok=True) print(f"Output directory: {output_dir}") print(f"Generated images dir: {generated_images_dir}") print(f"Ground truth images dir: {ground_truth_dir}") # --- Load COCO Data (Needed in both generation and skip scenarios) --- # Only the main process needs to load the full list initially all_evaluation_pairs = [] num_selected_samples = 0 if accelerator.is_main_process: all_evaluation_pairs, num_selected_samples = load_coco_data( os.path.join(args.coco_annotations_path, "captions_val2014.json"), args.coco_val_images_path, args.num_samples, args.seed ) print(f"Target number of samples for evaluation: {num_selected_samples}") # Create a lookup for finding data by image_id, useful if skipping generation data_lookup = {item['image_id']: item for item in all_evaluation_pairs} accelerator.wait_for_everyone() # Ensure all processes know the dirs exist # --- Initialize lists --- generated_image_paths = [] ground_truth_image_paths = [] captions_used = [] # --- Generation or Loading --- if not args.skip_generation: # --- Load Models (only needed for generation) --- print("Loading models for generation...") print("Loading base models (Tokenizer, Text Encoder, VAE)...") tokenizer = CLIPTokenizer.from_pretrained(args.base_model_name_or_path, subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained(args.base_model_name_or_path, subfolder="text_encoder") vae = AutoencoderKL.from_pretrained(args.base_model_name_or_path, subfolder="vae") scheduler = DDPMScheduler.from_pretrained(args.base_model_name_or_path, subfolder="scheduler") print("Loading base U-Net config and creating new Mamba-U-Net structure...") unet_config = UNet2DConditionModel.load_config(args.base_model_name_or_path, subfolder="unet") unet = UNet2DConditionModel.from_config(unet_config) print("Replacing Self-Attention with Mamba blocks...") mamba_kwargs = {'d_state': args.mamba_d_state, 'd_conv': args.mamba_d_conv, 'expand': args.mamba_expand} try: unet = replace_unet_self_attention_with_mamba(unet, mamba_kwargs) print("Mamba replacement successful.") except Exception as e: print(f"ERROR during Mamba replacement: {e}") exit(1) # --- Load Trained Mamba U-Net Weights --- unet_weights_path = Path(args.model_checkpoint_path) / args.unet_subfolder print(f"Loading trained Mamba U-Net weights from: {unet_weights_path}") if not unet_weights_path.exists(): print(f"ERROR: Trained UNet subfolder not found: {unet_weights_path}") exit(1) try: # --- CORRECTED FILE CHECKING LOGIC --- unet_state_dict_path_safetensors_specific = unet_weights_path / "diffusion_pytorch_model.safetensors" unet_state_dict_path_bin = unet_weights_path / "diffusion_pytorch_model.bin" unet_state_dict_path_safetensors_generic = unet_weights_path / "model.safetensors" unet_state_dict_path = None state_dict = None if unet_state_dict_path_safetensors_specific.exists(): print(f"Found specific safetensors file: {unet_state_dict_path_safetensors_specific}") unet_state_dict_path = unet_state_dict_path_safetensors_specific state_dict = load_safetensors(unet_state_dict_path, device="cpu") # Load using safetensors library elif unet_state_dict_path_bin.exists(): print(f"Found bin file: {unet_state_dict_path_bin}") unet_state_dict_path = unet_state_dict_path_bin state_dict = torch.load(unet_state_dict_path, map_location="cpu") # Load using torch elif unet_state_dict_path_safetensors_generic.exists(): print(f"Found generic safetensors file: {unet_state_dict_path_safetensors_generic}") unet_state_dict_path = unet_state_dict_path_safetensors_generic state_dict = load_safetensors(unet_state_dict_path, device="cpu") # Load using safetensors library else: raise FileNotFoundError(f"Could not find 'diffusion_pytorch_model.safetensors', 'diffusion_pytorch_model.bin', or 'model.safetensors' in UNet subfolder: {unet_weights_path}") # --- END OF CORRECTION --- # Load the state dict into the model unet.load_state_dict(state_dict) print(f"Successfully loaded trained U-Net weights from {unet_state_dict_path}.") del state_dict # Free memory except FileNotFoundError as fnf_error: print(f"ERROR: {fnf_error}") exit(1) except Exception as e: print(f"ERROR loading U-Net weights from {unet_weights_path}: {e}") print(traceback.format_exc()) # Print full traceback for debugging other errors exit(1) print("Creating Stable Diffusion Pipeline...") weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 print(f"Setting pipeline dtype to: {weight_dtype}") vae.to(device=device, dtype=weight_dtype) text_encoder.to(device=device, dtype=weight_dtype) unet.to(device=device, dtype=weight_dtype) pipeline = StableDiffusionPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False, ) try: import xformers pipeline.enable_xformers_memory_efficient_attention() print("Enabled xformers memory efficient attention.") except ImportError: print("xformers not installed. Running without it.") generator = torch.Generator(device=device).manual_seed(args.seed) # --- Generate Images (Main Process Only) --- if accelerator.is_main_process: print(f"Generating {num_selected_samples} images...") pipeline.set_progress_bar_config(disable=False) for i in tqdm(range(0, num_selected_samples, args.batch_size), desc="Generating Batches"): batch_data = all_evaluation_pairs[i : i + args.batch_size] if not batch_data: continue prompts = [item["caption"] for item in batch_data] gt_paths = [item["image_path"] for item in batch_data] image_ids = [item["image_id"] for item in batch_data] with torch.no_grad(), torch.autocast(device_type=accelerator.device.type, dtype=weight_dtype if weight_dtype != torch.float32 else None, enabled=accelerator.mixed_precision != "no"): images = pipeline( prompt=prompts, guidance_scale=args.guidance_scale, num_inference_steps=args.num_inference_steps, height=768,width=768,generator=generator ).images for idx, (pil_image, gt_path, img_id, prompt) in enumerate(zip(images, gt_paths, image_ids, prompts)): generated_filename = f"gen_{img_id}.png" gt_filename = f"gt_{img_id}.png" gen_save_path = generated_images_dir / generated_filename gt_save_path = ground_truth_dir / gt_filename try: pil_image.save(gen_save_path) generated_image_paths.append(str(gen_save_path)) # Copy ground truth image if not gt_save_path.exists(): # Avoid re-copying if already there shutil.copyfile(gt_path, gt_save_path) ground_truth_image_paths.append(str(gt_save_path)) # Add even if it existed captions_used.append(prompt) except Exception as e: print(f"Warning: Could not save generated image or copy GT for image_id {img_id}: {e}") print(f"Finished generation. Generated {len(generated_image_paths)} images.") # Final check after loop if len(generated_image_paths) != num_selected_samples: print(f"Warning: Number of generated images ({len(generated_image_paths)}) does not match target ({num_selected_samples}). Check for errors during saving.") else: # --- Skip Generation: Load Existing Images --- if accelerator.is_main_process: print(f"Skipping generation. Loading existing images from {generated_images_dir} and {ground_truth_dir}") if not generated_images_dir.exists() or not ground_truth_dir.exists(): print(f"Error: Cannot skip generation. Directory not found: {generated_images_dir} or {ground_truth_dir}") exit(1) # Regex to extract image ID from filename gen_pattern = re.compile(r"gen_(\d+)\.(png|jpg|jpeg|webp)$", re.IGNORECASE) found_gen_files = list(generated_images_dir.glob("gen_*.*")) print(f"Found {len(found_gen_files)} potential generated image files.") for gen_file_path in tqdm(found_gen_files, desc="Scanning existing images"): match = gen_pattern.match(gen_file_path.name) if match: try: image_id = int(match.group(1)) if image_id in data_lookup: original_data = data_lookup[image_id] gt_filename = f"gt_{image_id}.png" # Assume GT is png for consistency gt_path_expected = ground_truth_dir / gt_filename gt_path_original = original_data["image_path"] caption = original_data["caption"] # Check if GT image exists in target dir, copy if not if not gt_path_expected.exists(): if os.path.exists(gt_path_original): print(f"Copying missing GT image: {gt_filename}") shutil.copyfile(gt_path_original, gt_path_expected) else: print(f"Warning: Cannot find original GT image to copy: {gt_path_original}") continue # Skip if GT is missing # Only add if both gen and GT exist (or GT was copied) if gt_path_expected.exists(): generated_image_paths.append(str(gen_file_path)) ground_truth_image_paths.append(str(gt_path_expected)) captions_used.append(caption) else: print(f"Warning: Skipping image_id {image_id} because ground truth image could not be found/copied to {gt_path_expected}") else: print(f"Warning: Found generated image {gen_file_path.name} but its ID {image_id} is not in the selected COCO samples list. Skipping.") except ValueError: print(f"Warning: Could not parse image ID from filename {gen_file_path.name}. Skipping.") except Exception as e: print(f"Warning: Error processing existing file {gen_file_path}: {e}") # else: # print(f"Debug: Filename {gen_file_path.name} did not match pattern.") print(f"Loaded {len(generated_image_paths)} existing generated images and corresponding GT paths/captions.") if len(generated_image_paths) == 0: print("Error: No generated images found in the specified directory matching the expected format (gen_ID.png/jpg...). Cannot calculate metrics.") exit(1) elif len(generated_image_paths) != num_selected_samples: print(f"Warning: Number of loaded images ({len(generated_image_paths)}) does not match the expected number of samples ({num_selected_samples}). Metrics will be calculated on the loaded images.") print("This might happen if generation was interrupted or if --num_samples differs from the initial generation.") # --- Wait for main process to finish generation OR loading --- accelerator.wait_for_everyone() # Important barrier # --- Calculate Metrics (Main Process Only) --- fid_score = None clip_t_score = None # Ensure lists are populated (either by generation or loading) before metrics if accelerator.is_main_process and generated_image_paths and ground_truth_image_paths: print(f"\nProceeding to calculate metrics using {len(generated_image_paths)} image pairs.") print("\n--- Calculating FID Score ---") try: # Ensure both directories contain images before calculating if not any(ground_truth_dir.iterdir()): print(f"Error: Ground truth directory '{ground_truth_dir}' is empty. Cannot calculate FID.") fid_score = "Error - GT dir empty" elif not any(generated_images_dir.iterdir()): print(f"Error: Generated images directory '{generated_images_dir}' is empty. Cannot calculate FID.") fid_score = "Error - Gen dir empty" else: fid_score = fid.compute_fid( str(generated_images_dir), str(ground_truth_dir), mode="clean", num_workers=min(os.cpu_count(), 8) # Use reasonable number of workers ) print(f"FID Score: {fid_score:.2f}") except Exception as e: print(f"Error calculating FID: {e}") fid_score = "Error" print("\n--- Calculating CLIP-T Score ---") try: clip_scorer = CLIPScore(model_name_or_path=args.clip_score_model_name).to(device) # clip_scores = [] # Not needed clip_batch_size = 64 # Adjust based on GPU memory for i in tqdm(range(0, len(generated_image_paths), clip_batch_size), desc="Calculating CLIP Scores"): gen_paths_batch = generated_image_paths[i : i + clip_batch_size] captions_batch = captions_used[i : i + clip_batch_size] if not gen_paths_batch: continue images_batch = [] valid_captions_batch = [] for img_path, caption in zip(gen_paths_batch, captions_batch): try: # Add check for file existence and basic PIL load if not os.path.exists(img_path): print(f"Warning: CLIP - Image file not found: {img_path}. Skipping.") continue img = Image.open(img_path).convert("RGB") images_batch.append(img) valid_captions_batch.append(caption) except UnidentifiedImageError: print(f"Warning: CLIP - Cannot identify image file (corrupted?): {img_path}. Skipping.") continue except Exception as img_err: print(f"Warning: CLIP - Skipping image due to load error: {img_path} - {img_err}") continue if not images_batch: continue image_tensors = [torch.tensor(np.array(img)).permute(2, 0, 1) for img in images_batch] # Move tensors to the correct device for the metric image_tensors_dev = [t.to(device) for t in image_tensors] # Update metric - ensure inputs are on the same device as the metric module clip_scorer.update(image_tensors_dev, valid_captions_batch) # Clear tensor list to potentially free memory del image_tensors_dev, image_tensors, images_batch # Explicitly delete if torch.cuda.is_available(): torch.cuda.empty_cache() # Be cautious with explicit cache clearing final_clip_score = clip_scorer.compute().item() clip_t_score = final_clip_score/100 print(f"CLIP-T Score : {clip_t_score:.3f}") except Exception as e: print(f"Error calculating CLIP-T score: {e}") print(traceback.format_exc()) # Print traceback for CLIP errors too clip_t_score = "Error" # --- Save Results --- results = { "model_checkpoint": args.model_checkpoint_path, "unet_subfolder": args.unet_subfolder, "num_samples_target": num_selected_samples, "num_samples_evaluated": len(generated_image_paths), # Actual number used "coco_val_images_path": args.coco_val_images_path, "generation_skipped": args.skip_generation, "guidance_scale": args.guidance_scale if not args.skip_generation else "N/A (skipped generation)", "num_inference_steps": args.num_inference_steps if not args.skip_generation else "N/A (skipped generation)", "seed": args.seed, "mixed_precision": args.mixed_precision if not args.skip_generation else "N/A (skipped generation)", "fid_score": fid_score, "clip_t_score": clip_t_score, "fid_args": { "gen_dir": str(generated_images_dir), "gt_dir": str(ground_truth_dir) }, "clip_score_args": { "model_name": args.clip_score_model_name } } results_file_path = output_dir / "results.json" with open(results_file_path, 'w') as f: json.dump(results, f, indent=4) print(f"\nResults saved to: {results_file_path}") elif accelerator.is_main_process: print("\nSkipping metric calculation because image lists are empty (check generation/loading steps).") print("\nEvaluation finished.") if __name__ == "__main__": main()