""" Main entry point for CASWiT training and evaluation. This script provides a unified interface for training, evaluation, and inference. """ import argparse import sys import logging from pathlib import Path # Add project root to path sys.path.insert(0, str(Path(__file__).parent)) from train.train import main as train_main, load_config from train.eval import evaluate_model from train.inference import inference_single_image from model.build_model import build_model from dataset.definition_dataset import build_transforms def main(): """Main entry point.""" parser = argparse.ArgumentParser(description="CASWiT: Context-Aware Swin Transformer") parser.add_argument("mode", choices=["train", "eval", "inference"], help="Mode: train, eval, or inference") parser.add_argument("--config", type=str, required=True, help="Path to config YAML file") parser.add_argument("--checkpoint", type=str, default="", help="Path to model checkpoint (for eval/inference)") parser.add_argument("--image", type=str, default="", help="Path to input image (for inference)") parser.add_argument("--output", type=str, default="prediction.png", help="Path to save output (for inference)") parser.add_argument("--split", type=str, default="test", choices=["test", "val"], help="Dataset split for evaluation") args = parser.parse_args() if args.mode == "train": train_main(args.config) elif args.mode == "eval": if not args.checkpoint: print("Error: --checkpoint required for evaluation") sys.exit(1) cfg = load_config(args.config) evaluate_model(cfg, args.checkpoint, args.split) elif args.mode == "inference": if not args.checkpoint or not args.image: print("Error: --checkpoint and --image required for inference") sys.exit(1) import torch cfg = load_config(args.config) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Validate checkpoint path checkpoint_path_obj = Path(args.checkpoint) if not checkpoint_path_obj.exists() or not checkpoint_path_obj.is_file(): print(f"Error: Checkpoint file not found: {args.checkpoint}") sys.exit(1) model = build_model(cfg).to(device) print(f"Loading checkpoint from: {args.checkpoint}") state_dict = torch.load(args.checkpoint, map_location=device) state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} missing, unexpected = model.load_state_dict(state_dict, strict=False) print(f"Successfully loaded checkpoint from: {args.checkpoint}") if len(missing) > 0: print(f" Missing keys: {len(missing)}") if len(unexpected) > 0: print(f" Unexpected keys: {len(unexpected)}") if len(missing) == 0 and len(unexpected) == 0: print(f" Perfect match! All weights loaded successfully.") transform = build_transforms() inference_single_image(model, args.image, device, transform, cfg, args.output) if __name__ == "__main__": main()