| | """ |
| | Main entry point for CASWiT training and evaluation. |
| | |
| | This script provides a unified interface for training, evaluation, and inference. |
| | """ |
| |
|
| | import argparse |
| | import sys |
| | import logging |
| | from pathlib import Path |
| |
|
| | |
| | sys.path.insert(0, str(Path(__file__).parent)) |
| |
|
| | from train.train import main as train_main, load_config |
| | from train.eval import evaluate_model |
| | from train.inference import inference_single_image |
| | from model.build_model import build_model |
| | from dataset.definition_dataset import build_transforms |
| |
|
| |
|
| | def main(): |
| | """Main entry point.""" |
| | parser = argparse.ArgumentParser(description="CASWiT: Context-Aware Swin Transformer") |
| | parser.add_argument("mode", choices=["train", "eval", "inference"], |
| | help="Mode: train, eval, or inference") |
| | parser.add_argument("--config", type=str, required=True, |
| | help="Path to config YAML file") |
| | parser.add_argument("--checkpoint", type=str, default="", |
| | help="Path to model checkpoint (for eval/inference)") |
| | parser.add_argument("--image", type=str, default="", |
| | help="Path to input image (for inference)") |
| | parser.add_argument("--output", type=str, default="prediction.png", |
| | help="Path to save output (for inference)") |
| | parser.add_argument("--split", type=str, default="test", choices=["test", "val"], |
| | help="Dataset split for evaluation") |
| | |
| | args = parser.parse_args() |
| | |
| | if args.mode == "train": |
| | train_main(args.config) |
| | elif args.mode == "eval": |
| | if not args.checkpoint: |
| | print("Error: --checkpoint required for evaluation") |
| | sys.exit(1) |
| | cfg = load_config(args.config) |
| | evaluate_model(cfg, args.checkpoint, args.split) |
| | elif args.mode == "inference": |
| | if not args.checkpoint or not args.image: |
| | print("Error: --checkpoint and --image required for inference") |
| | sys.exit(1) |
| | import torch |
| | |
| | cfg = load_config(args.config) |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | |
| | |
| | checkpoint_path_obj = Path(args.checkpoint) |
| | if not checkpoint_path_obj.exists() or not checkpoint_path_obj.is_file(): |
| | print(f"Error: Checkpoint file not found: {args.checkpoint}") |
| | sys.exit(1) |
| | |
| | model = build_model(cfg).to(device) |
| | |
| | print(f"Loading checkpoint from: {args.checkpoint}") |
| | state_dict = torch.load(args.checkpoint, map_location=device) |
| | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} |
| | missing, unexpected = model.load_state_dict(state_dict, strict=False) |
| | print(f"Successfully loaded checkpoint from: {args.checkpoint}") |
| | if len(missing) > 0: |
| | print(f" Missing keys: {len(missing)}") |
| | if len(unexpected) > 0: |
| | print(f" Unexpected keys: {len(unexpected)}") |
| | if len(missing) == 0 and len(unexpected) == 0: |
| | print(f" Perfect match! All weights loaded successfully.") |
| | |
| | transform = build_transforms() |
| | inference_single_image(model, args.image, device, transform, cfg, args.output) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|
| |
|