ArtFlow / artflow /evaluation /evaluate_checkpoints.py
kaupane's picture
Upload folder using huggingface_hub
eb52c18 verified
"""
Evaluate checkpoints using run_evaluation_heavy.
Thin wrapper that discovers checkpoints and calls run_evaluation_heavy.
"""
import argparse
import json
import os
from glob import glob
from .pipeline import run_evaluation_heavy
def discover_checkpoints(checkpoint_pattern: str) -> list:
"""Discover checkpoint files matching pattern."""
checkpoints = sorted(glob(checkpoint_pattern))
return checkpoints
def main():
parser = argparse.ArgumentParser(description="Evaluate checkpoints")
parser.add_argument(
"--checkpoint_pattern",
type=str,
required=True,
help="Glob pattern for checkpoints, e.g., 'output/baseline/checkpoint_step_*/ema_weights.pt'",
)
parser.add_argument(
"--model_config",
type=str,
required=True,
help='Model config as JSON string, e.g., \'{"in_channels":16,"hidden_size":768,...}\'',
)
parser.add_argument(
"--vae_path",
type=str,
default="REPA-E/e2e-qwenimage-vae",
help="Path to VAE model",
)
parser.add_argument(
"--text_encoder_path",
type=str,
default="Qwen/Qwen3-0.6B",
help="Path to text encoder",
)
parser.add_argument(
"--pooling",
action="store_true",
default=False,
help="Use pooled text embeddings",
)
parser.add_argument(
"--dataset_path",
type=str,
default="./precomputed_dataset/heavy-eval@256p",
help="Path to evaluation dataset",
)
parser.add_argument(
"--num_samples",
type=int,
default=2000,
help="Number of samples for CLIP",
)
parser.add_argument(
"--batch_size",
type=int,
default=32,
help="Batch size",
)
parser.add_argument(
"--device",
type=str,
default="cuda:0",
help="Device",
)
parser.add_argument(
"--skip_existing",
action="store_true",
default=False,
help="Skip checkpoints with existing results",
)
args = parser.parse_args()
# Parse model config
model_config = json.loads(args.model_config)
# Discover checkpoints
checkpoints = discover_checkpoints(args.checkpoint_pattern)
print(f"Found {len(checkpoints)} checkpoints")
for checkpoint_path in checkpoints:
checkpoint_dir = os.path.dirname(checkpoint_path)
results_file = os.path.join(checkpoint_dir, "evaluation_results.json")
# Skip if results exist
if args.skip_existing and os.path.exists(results_file):
print(f"Skipping {checkpoint_path} (results exist)")
continue
print(f"\nEvaluating {checkpoint_path}")
os.makedirs(checkpoint_dir, exist_ok=True)
try:
results = run_evaluation_heavy(
checkpoint_path=checkpoint_path,
model_config=model_config,
vae_path=args.vae_path,
text_encoder_path=args.text_encoder_path,
pooling=args.pooling,
save_path=checkpoint_dir,
dataset_path=args.dataset_path,
num_samples=args.num_samples,
batch_size=args.batch_size,
device=args.device,
)
# Save results
with open(results_file, "w") as f:
json.dump(results, f, indent=2)
print(f"Results: {results}")
except Exception as e:
print(f"Error: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()