|
|
import argparse |
|
|
import json |
|
|
import os |
|
|
import random |
|
|
import shutil |
|
|
import re |
|
|
import traceback |
|
|
from pathlib import Path |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from accelerate import Accelerator |
|
|
from cleanfid import fid |
|
|
from diffusers import (AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, |
|
|
UNet2DConditionModel) |
|
|
|
|
|
from safetensors.torch import load_file as load_safetensors |
|
|
from PIL import Image, UnidentifiedImageError |
|
|
from torchmetrics.multimodal import CLIPScore |
|
|
from tqdm.auto import tqdm |
|
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
|
|
|
|
|
|
|
try: |
|
|
from msd_utils import (MambaSequentialBlock, |
|
|
replace_unet_self_attention_with_mamba) |
|
|
print("Successfully imported Mamba utils from msd_utils.py") |
|
|
except ImportError as e: |
|
|
print(f"ERROR: Failed to import from msd_utils.py: {e}") |
|
|
print("Ensure 'msd_utils.py' is in the current directory or your PYTHONPATH.") |
|
|
exit(1) |
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser(description="Evaluate Mamba-SD model with FID and CLIP-T on COCO val2014.") |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--model_checkpoint_path", type=str, default="/root/mamba/sd-mamba-mscoco-urltext-10k-run3/checkpoint-31000", |
|
|
help="Path to the trained Mamba-SD checkpoint directory (e.g., /root/mamba/.../checkpoint-31000)." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--unet_subfolder", type=str, default="unet_mamba", |
|
|
help="Name of the subfolder within the checkpoint containing the trained UNet weights (e.g., 'unet_mamba', 'unet_mamba_final')." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--base_model_name_or_path", type=str, default="runwayml/stable-diffusion-v1-5", |
|
|
help="Path or Hub ID of the base Stable Diffusion model (used for VAE, text encoder, etc.)." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--coco_val_images_path", type=str, default="/root/mamba/val2014", |
|
|
help="Path to the COCO val2014 image directory (e.g., /root/mamba/val2014)." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--coco_annotations_path", type=str, default="/root/mamba/annotations", |
|
|
help="Path to the COCO annotations directory containing 'captions_val2014.json'." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output_dir", type=str, default="./mamba_sd_eval_output", |
|
|
help="Directory to save generated images and evaluation results." |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--num_samples", type=int, default=50, |
|
|
help="Number of validation samples to generate/evaluate. Set to -1 to use all. Must match existing samples if skipping generation." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--batch_size", type=int, default=16, |
|
|
help="Batch size for image generation (if generating)." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--guidance_scale", type=float, default=7.5, |
|
|
help="Guidance scale for generation (if generating)." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--num_inference_steps", type=int, default=50, |
|
|
help="Number of DDIM inference steps (if generating)." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--seed", type=int, default=42, |
|
|
help="Random seed for generation (if generating) and sampling." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--fid_clip_model_name", type=str, default="ViT-L/14", |
|
|
help="CLIP model variant to use for FID computation with clean-fid." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--clip_score_model_name", type=str, default="openai/clip-vit-large-patch14", |
|
|
help="CLIP model variant to use for CLIPScore computation with torchmetrics." |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--skip_generation", action="store_true", |
|
|
help="If set, skip image generation and attempt to load existing images from output_dir for metric calculation." |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument("--mamba_d_state", type=int, default=16, help="Mamba ssm state dimension used during training.") |
|
|
parser.add_argument("--mamba_d_conv", type=int, default=4, help="Mamba ssm convolution dimension used during training.") |
|
|
parser.add_argument("--mamba_expand", type=int, default=2, help="Mamba ssm expansion factor used during training.") |
|
|
|
|
|
|
|
|
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["no", "fp16", "bf16"], help="Whether to use mixed precision during generation.") |
|
|
|
|
|
args = parser.parse_args() |
|
|
return args |
|
|
|
|
|
def load_coco_data(ann_file_path, img_dir_path, num_samples=-1, seed=42): |
|
|
"""Loads COCO val2014 captions and maps them to image file paths.""" |
|
|
print(f"Loading COCO annotations from: {ann_file_path}") |
|
|
if not os.path.exists(ann_file_path): |
|
|
raise FileNotFoundError(f"Annotation file not found: {ann_file_path}") |
|
|
if not os.path.isdir(img_dir_path): |
|
|
raise NotADirectoryError(f"Image directory not found: {img_dir_path}") |
|
|
|
|
|
with open(ann_file_path, 'r') as f: |
|
|
data = json.load(f) |
|
|
|
|
|
annotations = data['annotations'] |
|
|
images_info = {img['id']: img['file_name'] for img in data['images']} |
|
|
|
|
|
captions_by_image = {} |
|
|
for ann in annotations: |
|
|
img_id = ann['image_id'] |
|
|
if img_id in images_info: |
|
|
if img_id not in captions_by_image: |
|
|
captions_by_image[img_id] = [] |
|
|
captions_by_image[img_id].append(ann['caption']) |
|
|
|
|
|
evaluation_pairs = [] |
|
|
for img_id, captions in captions_by_image.items(): |
|
|
img_filename = images_info[img_id] |
|
|
img_path = os.path.join(img_dir_path, img_filename) |
|
|
if os.path.exists(img_path): |
|
|
evaluation_pairs.append({"image_path": img_path, "caption": captions[0], "image_id": img_id}) |
|
|
|
|
|
|
|
|
|
|
|
print(f"Found {len(evaluation_pairs)} unique images with captions in source.") |
|
|
|
|
|
original_num_pairs = len(evaluation_pairs) |
|
|
if num_samples > 0 and num_samples < len(evaluation_pairs): |
|
|
print(f"Selecting {num_samples} samples using seed {seed}...") |
|
|
random.seed(seed) |
|
|
evaluation_pairs = random.sample(evaluation_pairs, num_samples) |
|
|
print(f"Selected {len(evaluation_pairs)} samples for evaluation.") |
|
|
elif num_samples == -1: |
|
|
print(f"Using all {len(evaluation_pairs)} available samples.") |
|
|
else: |
|
|
print(f"Number of samples ({num_samples}) is invalid or >= total. Using all {len(evaluation_pairs)} samples.") |
|
|
|
|
|
if not evaluation_pairs: |
|
|
raise ValueError("No valid image-caption pairs selected or found. Check paths, annotation file format, and --num_samples.") |
|
|
|
|
|
return evaluation_pairs, len(evaluation_pairs) |
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
|
|
|
|
|
|
accelerator = Accelerator(mixed_precision=args.mixed_precision if not args.skip_generation else "no") |
|
|
device = accelerator.device |
|
|
print(f"Using device: {device}, Mixed Precision: {accelerator.mixed_precision}") |
|
|
|
|
|
|
|
|
output_dir = Path(args.output_dir) |
|
|
generated_images_dir = output_dir / "generated_images" |
|
|
ground_truth_dir = output_dir / "ground_truth_images" |
|
|
results_file = output_dir / "results.json" |
|
|
|
|
|
if accelerator.is_main_process: |
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
generated_images_dir.mkdir(exist_ok=True) |
|
|
ground_truth_dir.mkdir(exist_ok=True) |
|
|
print(f"Output directory: {output_dir}") |
|
|
print(f"Generated images dir: {generated_images_dir}") |
|
|
print(f"Ground truth images dir: {ground_truth_dir}") |
|
|
|
|
|
|
|
|
|
|
|
all_evaluation_pairs = [] |
|
|
num_selected_samples = 0 |
|
|
if accelerator.is_main_process: |
|
|
all_evaluation_pairs, num_selected_samples = load_coco_data( |
|
|
os.path.join(args.coco_annotations_path, "captions_val2014.json"), |
|
|
args.coco_val_images_path, |
|
|
args.num_samples, |
|
|
args.seed |
|
|
) |
|
|
print(f"Target number of samples for evaluation: {num_selected_samples}") |
|
|
|
|
|
data_lookup = {item['image_id']: item for item in all_evaluation_pairs} |
|
|
|
|
|
accelerator.wait_for_everyone() |
|
|
|
|
|
|
|
|
generated_image_paths = [] |
|
|
ground_truth_image_paths = [] |
|
|
captions_used = [] |
|
|
|
|
|
|
|
|
if not args.skip_generation: |
|
|
|
|
|
print("Loading models for generation...") |
|
|
print("Loading base models (Tokenizer, Text Encoder, VAE)...") |
|
|
tokenizer = CLIPTokenizer.from_pretrained(args.base_model_name_or_path, subfolder="tokenizer") |
|
|
text_encoder = CLIPTextModel.from_pretrained(args.base_model_name_or_path, subfolder="text_encoder") |
|
|
vae = AutoencoderKL.from_pretrained(args.base_model_name_or_path, subfolder="vae") |
|
|
scheduler = DDPMScheduler.from_pretrained(args.base_model_name_or_path, subfolder="scheduler") |
|
|
|
|
|
print("Loading base U-Net config and creating new Mamba-U-Net structure...") |
|
|
unet_config = UNet2DConditionModel.load_config(args.base_model_name_or_path, subfolder="unet") |
|
|
unet = UNet2DConditionModel.from_config(unet_config) |
|
|
|
|
|
print("Replacing Self-Attention with Mamba blocks...") |
|
|
mamba_kwargs = {'d_state': args.mamba_d_state, 'd_conv': args.mamba_d_conv, 'expand': args.mamba_expand} |
|
|
try: |
|
|
unet = replace_unet_self_attention_with_mamba(unet, mamba_kwargs) |
|
|
print("Mamba replacement successful.") |
|
|
except Exception as e: |
|
|
print(f"ERROR during Mamba replacement: {e}") |
|
|
exit(1) |
|
|
|
|
|
|
|
|
unet_weights_path = Path(args.model_checkpoint_path) / args.unet_subfolder |
|
|
print(f"Loading trained Mamba U-Net weights from: {unet_weights_path}") |
|
|
if not unet_weights_path.exists(): |
|
|
print(f"ERROR: Trained UNet subfolder not found: {unet_weights_path}") |
|
|
exit(1) |
|
|
|
|
|
try: |
|
|
|
|
|
unet_state_dict_path_safetensors_specific = unet_weights_path / "diffusion_pytorch_model.safetensors" |
|
|
unet_state_dict_path_bin = unet_weights_path / "diffusion_pytorch_model.bin" |
|
|
unet_state_dict_path_safetensors_generic = unet_weights_path / "model.safetensors" |
|
|
|
|
|
unet_state_dict_path = None |
|
|
state_dict = None |
|
|
|
|
|
if unet_state_dict_path_safetensors_specific.exists(): |
|
|
print(f"Found specific safetensors file: {unet_state_dict_path_safetensors_specific}") |
|
|
unet_state_dict_path = unet_state_dict_path_safetensors_specific |
|
|
state_dict = load_safetensors(unet_state_dict_path, device="cpu") |
|
|
elif unet_state_dict_path_bin.exists(): |
|
|
print(f"Found bin file: {unet_state_dict_path_bin}") |
|
|
unet_state_dict_path = unet_state_dict_path_bin |
|
|
state_dict = torch.load(unet_state_dict_path, map_location="cpu") |
|
|
elif unet_state_dict_path_safetensors_generic.exists(): |
|
|
print(f"Found generic safetensors file: {unet_state_dict_path_safetensors_generic}") |
|
|
unet_state_dict_path = unet_state_dict_path_safetensors_generic |
|
|
state_dict = load_safetensors(unet_state_dict_path, device="cpu") |
|
|
else: |
|
|
raise FileNotFoundError(f"Could not find 'diffusion_pytorch_model.safetensors', 'diffusion_pytorch_model.bin', or 'model.safetensors' in UNet subfolder: {unet_weights_path}") |
|
|
|
|
|
|
|
|
|
|
|
unet.load_state_dict(state_dict) |
|
|
print(f"Successfully loaded trained U-Net weights from {unet_state_dict_path}.") |
|
|
del state_dict |
|
|
|
|
|
except FileNotFoundError as fnf_error: |
|
|
print(f"ERROR: {fnf_error}") |
|
|
exit(1) |
|
|
except Exception as e: |
|
|
print(f"ERROR loading U-Net weights from {unet_weights_path}: {e}") |
|
|
print(traceback.format_exc()) |
|
|
exit(1) |
|
|
|
|
|
|
|
|
print("Creating Stable Diffusion Pipeline...") |
|
|
weight_dtype = torch.float32 |
|
|
if accelerator.mixed_precision == "fp16": |
|
|
weight_dtype = torch.float16 |
|
|
elif accelerator.mixed_precision == "bf16": |
|
|
weight_dtype = torch.bfloat16 |
|
|
print(f"Setting pipeline dtype to: {weight_dtype}") |
|
|
|
|
|
vae.to(device=device, dtype=weight_dtype) |
|
|
text_encoder.to(device=device, dtype=weight_dtype) |
|
|
unet.to(device=device, dtype=weight_dtype) |
|
|
|
|
|
pipeline = StableDiffusionPipeline( |
|
|
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, |
|
|
safety_checker=None, feature_extractor=None, requires_safety_checker=False, |
|
|
) |
|
|
|
|
|
try: |
|
|
import xformers |
|
|
pipeline.enable_xformers_memory_efficient_attention() |
|
|
print("Enabled xformers memory efficient attention.") |
|
|
except ImportError: |
|
|
print("xformers not installed. Running without it.") |
|
|
|
|
|
generator = torch.Generator(device=device).manual_seed(args.seed) |
|
|
|
|
|
|
|
|
if accelerator.is_main_process: |
|
|
print(f"Generating {num_selected_samples} images...") |
|
|
pipeline.set_progress_bar_config(disable=False) |
|
|
|
|
|
for i in tqdm(range(0, num_selected_samples, args.batch_size), desc="Generating Batches"): |
|
|
batch_data = all_evaluation_pairs[i : i + args.batch_size] |
|
|
if not batch_data: continue |
|
|
|
|
|
prompts = [item["caption"] for item in batch_data] |
|
|
gt_paths = [item["image_path"] for item in batch_data] |
|
|
image_ids = [item["image_id"] for item in batch_data] |
|
|
|
|
|
with torch.no_grad(), torch.autocast(device_type=accelerator.device.type, dtype=weight_dtype if weight_dtype != torch.float32 else None, enabled=accelerator.mixed_precision != "no"): |
|
|
images = pipeline( |
|
|
prompt=prompts, guidance_scale=args.guidance_scale, |
|
|
num_inference_steps=args.num_inference_steps, height=768,width=768,generator=generator |
|
|
).images |
|
|
|
|
|
for idx, (pil_image, gt_path, img_id, prompt) in enumerate(zip(images, gt_paths, image_ids, prompts)): |
|
|
generated_filename = f"gen_{img_id}.png" |
|
|
gt_filename = f"gt_{img_id}.png" |
|
|
gen_save_path = generated_images_dir / generated_filename |
|
|
gt_save_path = ground_truth_dir / gt_filename |
|
|
|
|
|
try: |
|
|
pil_image.save(gen_save_path) |
|
|
generated_image_paths.append(str(gen_save_path)) |
|
|
|
|
|
if not gt_save_path.exists(): |
|
|
shutil.copyfile(gt_path, gt_save_path) |
|
|
ground_truth_image_paths.append(str(gt_save_path)) |
|
|
captions_used.append(prompt) |
|
|
except Exception as e: |
|
|
print(f"Warning: Could not save generated image or copy GT for image_id {img_id}: {e}") |
|
|
|
|
|
print(f"Finished generation. Generated {len(generated_image_paths)} images.") |
|
|
|
|
|
if len(generated_image_paths) != num_selected_samples: |
|
|
print(f"Warning: Number of generated images ({len(generated_image_paths)}) does not match target ({num_selected_samples}). Check for errors during saving.") |
|
|
|
|
|
|
|
|
else: |
|
|
if accelerator.is_main_process: |
|
|
print(f"Skipping generation. Loading existing images from {generated_images_dir} and {ground_truth_dir}") |
|
|
|
|
|
if not generated_images_dir.exists() or not ground_truth_dir.exists(): |
|
|
print(f"Error: Cannot skip generation. Directory not found: {generated_images_dir} or {ground_truth_dir}") |
|
|
exit(1) |
|
|
|
|
|
|
|
|
gen_pattern = re.compile(r"gen_(\d+)\.(png|jpg|jpeg|webp)$", re.IGNORECASE) |
|
|
|
|
|
found_gen_files = list(generated_images_dir.glob("gen_*.*")) |
|
|
print(f"Found {len(found_gen_files)} potential generated image files.") |
|
|
|
|
|
for gen_file_path in tqdm(found_gen_files, desc="Scanning existing images"): |
|
|
match = gen_pattern.match(gen_file_path.name) |
|
|
if match: |
|
|
try: |
|
|
image_id = int(match.group(1)) |
|
|
if image_id in data_lookup: |
|
|
original_data = data_lookup[image_id] |
|
|
gt_filename = f"gt_{image_id}.png" |
|
|
gt_path_expected = ground_truth_dir / gt_filename |
|
|
gt_path_original = original_data["image_path"] |
|
|
caption = original_data["caption"] |
|
|
|
|
|
|
|
|
if not gt_path_expected.exists(): |
|
|
if os.path.exists(gt_path_original): |
|
|
print(f"Copying missing GT image: {gt_filename}") |
|
|
shutil.copyfile(gt_path_original, gt_path_expected) |
|
|
else: |
|
|
print(f"Warning: Cannot find original GT image to copy: {gt_path_original}") |
|
|
continue |
|
|
|
|
|
|
|
|
if gt_path_expected.exists(): |
|
|
generated_image_paths.append(str(gen_file_path)) |
|
|
ground_truth_image_paths.append(str(gt_path_expected)) |
|
|
captions_used.append(caption) |
|
|
else: |
|
|
print(f"Warning: Skipping image_id {image_id} because ground truth image could not be found/copied to {gt_path_expected}") |
|
|
|
|
|
else: |
|
|
print(f"Warning: Found generated image {gen_file_path.name} but its ID {image_id} is not in the selected COCO samples list. Skipping.") |
|
|
except ValueError: |
|
|
print(f"Warning: Could not parse image ID from filename {gen_file_path.name}. Skipping.") |
|
|
except Exception as e: |
|
|
print(f"Warning: Error processing existing file {gen_file_path}: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"Loaded {len(generated_image_paths)} existing generated images and corresponding GT paths/captions.") |
|
|
|
|
|
if len(generated_image_paths) == 0: |
|
|
print("Error: No generated images found in the specified directory matching the expected format (gen_ID.png/jpg...). Cannot calculate metrics.") |
|
|
exit(1) |
|
|
elif len(generated_image_paths) != num_selected_samples: |
|
|
print(f"Warning: Number of loaded images ({len(generated_image_paths)}) does not match the expected number of samples ({num_selected_samples}). Metrics will be calculated on the loaded images.") |
|
|
print("This might happen if generation was interrupted or if --num_samples differs from the initial generation.") |
|
|
|
|
|
|
|
|
|
|
|
accelerator.wait_for_everyone() |
|
|
|
|
|
|
|
|
fid_score = None |
|
|
clip_t_score = None |
|
|
|
|
|
|
|
|
if accelerator.is_main_process and generated_image_paths and ground_truth_image_paths: |
|
|
print(f"\nProceeding to calculate metrics using {len(generated_image_paths)} image pairs.") |
|
|
|
|
|
print("\n--- Calculating FID Score ---") |
|
|
try: |
|
|
|
|
|
if not any(ground_truth_dir.iterdir()): |
|
|
print(f"Error: Ground truth directory '{ground_truth_dir}' is empty. Cannot calculate FID.") |
|
|
fid_score = "Error - GT dir empty" |
|
|
elif not any(generated_images_dir.iterdir()): |
|
|
print(f"Error: Generated images directory '{generated_images_dir}' is empty. Cannot calculate FID.") |
|
|
fid_score = "Error - Gen dir empty" |
|
|
else: |
|
|
fid_score = fid.compute_fid( |
|
|
str(generated_images_dir), |
|
|
str(ground_truth_dir), |
|
|
mode="clean", |
|
|
num_workers=min(os.cpu_count(), 8) |
|
|
) |
|
|
|
|
|
print(f"FID Score: {fid_score:.2f}") |
|
|
except Exception as e: |
|
|
print(f"Error calculating FID: {e}") |
|
|
fid_score = "Error" |
|
|
|
|
|
print("\n--- Calculating CLIP-T Score ---") |
|
|
try: |
|
|
clip_scorer = CLIPScore(model_name_or_path=args.clip_score_model_name).to(device) |
|
|
|
|
|
clip_batch_size = 64 |
|
|
|
|
|
for i in tqdm(range(0, len(generated_image_paths), clip_batch_size), desc="Calculating CLIP Scores"): |
|
|
gen_paths_batch = generated_image_paths[i : i + clip_batch_size] |
|
|
captions_batch = captions_used[i : i + clip_batch_size] |
|
|
if not gen_paths_batch: continue |
|
|
|
|
|
images_batch = [] |
|
|
valid_captions_batch = [] |
|
|
for img_path, caption in zip(gen_paths_batch, captions_batch): |
|
|
try: |
|
|
|
|
|
if not os.path.exists(img_path): |
|
|
print(f"Warning: CLIP - Image file not found: {img_path}. Skipping.") |
|
|
continue |
|
|
img = Image.open(img_path).convert("RGB") |
|
|
images_batch.append(img) |
|
|
valid_captions_batch.append(caption) |
|
|
except UnidentifiedImageError: |
|
|
print(f"Warning: CLIP - Cannot identify image file (corrupted?): {img_path}. Skipping.") |
|
|
continue |
|
|
except Exception as img_err: |
|
|
print(f"Warning: CLIP - Skipping image due to load error: {img_path} - {img_err}") |
|
|
continue |
|
|
|
|
|
if not images_batch: continue |
|
|
|
|
|
image_tensors = [torch.tensor(np.array(img)).permute(2, 0, 1) for img in images_batch] |
|
|
|
|
|
|
|
|
image_tensors_dev = [t.to(device) for t in image_tensors] |
|
|
|
|
|
|
|
|
clip_scorer.update(image_tensors_dev, valid_captions_batch) |
|
|
|
|
|
|
|
|
del image_tensors_dev, image_tensors, images_batch |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
final_clip_score = clip_scorer.compute().item() |
|
|
clip_t_score = final_clip_score/100 |
|
|
|
|
|
print(f"CLIP-T Score : {clip_t_score:.3f}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error calculating CLIP-T score: {e}") |
|
|
print(traceback.format_exc()) |
|
|
clip_t_score = "Error" |
|
|
|
|
|
|
|
|
results = { |
|
|
"model_checkpoint": args.model_checkpoint_path, |
|
|
"unet_subfolder": args.unet_subfolder, |
|
|
"num_samples_target": num_selected_samples, |
|
|
"num_samples_evaluated": len(generated_image_paths), |
|
|
"coco_val_images_path": args.coco_val_images_path, |
|
|
"generation_skipped": args.skip_generation, |
|
|
"guidance_scale": args.guidance_scale if not args.skip_generation else "N/A (skipped generation)", |
|
|
"num_inference_steps": args.num_inference_steps if not args.skip_generation else "N/A (skipped generation)", |
|
|
"seed": args.seed, |
|
|
"mixed_precision": args.mixed_precision if not args.skip_generation else "N/A (skipped generation)", |
|
|
"fid_score": fid_score, |
|
|
"clip_t_score": clip_t_score, |
|
|
"fid_args": { "gen_dir": str(generated_images_dir), "gt_dir": str(ground_truth_dir) }, |
|
|
"clip_score_args": { "model_name": args.clip_score_model_name } |
|
|
} |
|
|
results_file_path = output_dir / "results.json" |
|
|
with open(results_file_path, 'w') as f: |
|
|
json.dump(results, f, indent=4) |
|
|
print(f"\nResults saved to: {results_file_path}") |
|
|
|
|
|
elif accelerator.is_main_process: |
|
|
print("\nSkipping metric calculation because image lists are empty (check generation/loading steps).") |
|
|
|
|
|
|
|
|
print("\nEvaluation finished.") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |