| | """ |
| | Evaluate checkpoints using run_evaluation_heavy. |
| | Thin wrapper that discovers checkpoints and calls run_evaluation_heavy. |
| | """ |
| |
|
| | import argparse |
| | import json |
| | import os |
| | from glob import glob |
| |
|
| | from .pipeline import run_evaluation_heavy |
| |
|
| |
|
| | def discover_checkpoints(checkpoint_pattern: str) -> list: |
| | """Discover checkpoint files matching pattern.""" |
| | checkpoints = sorted(glob(checkpoint_pattern)) |
| | return checkpoints |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="Evaluate checkpoints") |
| | parser.add_argument( |
| | "--checkpoint_pattern", |
| | type=str, |
| | required=True, |
| | help="Glob pattern for checkpoints, e.g., 'output/baseline/checkpoint_step_*/ema_weights.pt'", |
| | ) |
| | parser.add_argument( |
| | "--model_config", |
| | type=str, |
| | required=True, |
| | help='Model config as JSON string, e.g., \'{"in_channels":16,"hidden_size":768,...}\'', |
| | ) |
| | parser.add_argument( |
| | "--vae_path", |
| | type=str, |
| | default="REPA-E/e2e-qwenimage-vae", |
| | help="Path to VAE model", |
| | ) |
| | parser.add_argument( |
| | "--text_encoder_path", |
| | type=str, |
| | default="Qwen/Qwen3-0.6B", |
| | help="Path to text encoder", |
| | ) |
| | parser.add_argument( |
| | "--pooling", |
| | action="store_true", |
| | default=False, |
| | help="Use pooled text embeddings", |
| | ) |
| | parser.add_argument( |
| | "--dataset_path", |
| | type=str, |
| | default="./precomputed_dataset/heavy-eval@256p", |
| | help="Path to evaluation dataset", |
| | ) |
| | parser.add_argument( |
| | "--num_samples", |
| | type=int, |
| | default=2000, |
| | help="Number of samples for CLIP", |
| | ) |
| | parser.add_argument( |
| | "--batch_size", |
| | type=int, |
| | default=32, |
| | help="Batch size", |
| | ) |
| | parser.add_argument( |
| | "--device", |
| | type=str, |
| | default="cuda:0", |
| | help="Device", |
| | ) |
| | parser.add_argument( |
| | "--skip_existing", |
| | action="store_true", |
| | default=False, |
| | help="Skip checkpoints with existing results", |
| | ) |
| | args = parser.parse_args() |
| |
|
| | |
| | model_config = json.loads(args.model_config) |
| |
|
| | |
| | checkpoints = discover_checkpoints(args.checkpoint_pattern) |
| | print(f"Found {len(checkpoints)} checkpoints") |
| |
|
| | for checkpoint_path in checkpoints: |
| | checkpoint_dir = os.path.dirname(checkpoint_path) |
| | results_file = os.path.join(checkpoint_dir, "evaluation_results.json") |
| |
|
| | |
| | if args.skip_existing and os.path.exists(results_file): |
| | print(f"Skipping {checkpoint_path} (results exist)") |
| | continue |
| |
|
| | print(f"\nEvaluating {checkpoint_path}") |
| | os.makedirs(checkpoint_dir, exist_ok=True) |
| |
|
| | try: |
| | results = run_evaluation_heavy( |
| | checkpoint_path=checkpoint_path, |
| | model_config=model_config, |
| | vae_path=args.vae_path, |
| | text_encoder_path=args.text_encoder_path, |
| | pooling=args.pooling, |
| | save_path=checkpoint_dir, |
| | dataset_path=args.dataset_path, |
| | num_samples=args.num_samples, |
| | batch_size=args.batch_size, |
| | device=args.device, |
| | ) |
| |
|
| | |
| | with open(results_file, "w") as f: |
| | json.dump(results, f, indent=2) |
| | print(f"Results: {results}") |
| |
|
| | except Exception as e: |
| | print(f"Error: {e}") |
| | import traceback |
| |
|
| | traceback.print_exc() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|
| |
|