Image Segmentation
English
CASWiT / main.py
antoine.carreaud67
update main.py
ca50374
"""
Main entry point for CASWiT training and evaluation.
This script provides a unified interface for training, evaluation, and inference.
"""
import argparse
import sys
import logging
from pathlib import Path
# Add project root to path
sys.path.insert(0, str(Path(__file__).parent))
from train.train import main as train_main, load_config
from train.eval import evaluate_model
from train.inference import inference_single_image
from model.build_model import build_model
from dataset.definition_dataset import build_transforms
def main():
"""Main entry point."""
parser = argparse.ArgumentParser(description="CASWiT: Context-Aware Swin Transformer")
parser.add_argument("mode", choices=["train", "eval", "inference"],
help="Mode: train, eval, or inference")
parser.add_argument("--config", type=str, required=True,
help="Path to config YAML file")
parser.add_argument("--checkpoint", type=str, default="",
help="Path to model checkpoint (for eval/inference)")
parser.add_argument("--image", type=str, default="",
help="Path to input image (for inference)")
parser.add_argument("--output", type=str, default="prediction.png",
help="Path to save output (for inference)")
parser.add_argument("--split", type=str, default="test", choices=["test", "val"],
help="Dataset split for evaluation")
args = parser.parse_args()
if args.mode == "train":
train_main(args.config)
elif args.mode == "eval":
if not args.checkpoint:
print("Error: --checkpoint required for evaluation")
sys.exit(1)
cfg = load_config(args.config)
evaluate_model(cfg, args.checkpoint, args.split)
elif args.mode == "inference":
if not args.checkpoint or not args.image:
print("Error: --checkpoint and --image required for inference")
sys.exit(1)
import torch
cfg = load_config(args.config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Validate checkpoint path
checkpoint_path_obj = Path(args.checkpoint)
if not checkpoint_path_obj.exists() or not checkpoint_path_obj.is_file():
print(f"Error: Checkpoint file not found: {args.checkpoint}")
sys.exit(1)
model = build_model(cfg).to(device)
print(f"Loading checkpoint from: {args.checkpoint}")
state_dict = torch.load(args.checkpoint, map_location=device)
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
missing, unexpected = model.load_state_dict(state_dict, strict=False)
print(f"Successfully loaded checkpoint from: {args.checkpoint}")
if len(missing) > 0:
print(f" Missing keys: {len(missing)}")
if len(unexpected) > 0:
print(f" Unexpected keys: {len(unexpected)}")
if len(missing) == 0 and len(unexpected) == 0:
print(f" Perfect match! All weights loaded successfully.")
transform = build_transforms()
inference_single_image(model, args.image, device, transform, cfg, args.output)
if __name__ == "__main__":
main()