Image Segmentation
English
antoine.carreaud67 commited on
Commit
ca50374
·
1 Parent(s): f228688

update main.py

Browse files
Files changed (1) hide show
  1. main.py +3 -9
main.py CHANGED
@@ -15,7 +15,7 @@ sys.path.insert(0, str(Path(__file__).parent))
15
  from train.train import main as train_main, load_config
16
  from train.eval import evaluate_model
17
  from train.inference import inference_single_image
18
- from model.CASWiT import CASWiT
19
  from dataset.definition_dataset import build_transforms
20
 
21
 
@@ -60,13 +60,7 @@ def main():
60
  print(f"Error: Checkpoint file not found: {args.checkpoint}")
61
  sys.exit(1)
62
 
63
- model = CASWiT(
64
- num_head_xa=cfg.cross_attention_heads,
65
- num_classes=cfg.num_classes,
66
- model_name=cfg.model_name,
67
- mlp_ratio=cfg.fusion_mlp_ratio,
68
- drop_path=cfg.fusion_drop_path
69
- ).to(device)
70
 
71
  print(f"Loading checkpoint from: {args.checkpoint}")
72
  state_dict = torch.load(args.checkpoint, map_location=device)
@@ -81,7 +75,7 @@ def main():
81
  print(f" Perfect match! All weights loaded successfully.")
82
 
83
  transform = build_transforms()
84
- inference_single_image(model, args.image, device, transform, args.output)
85
 
86
 
87
  if __name__ == "__main__":
 
15
  from train.train import main as train_main, load_config
16
  from train.eval import evaluate_model
17
  from train.inference import inference_single_image
18
+ from model.build_model import build_model
19
  from dataset.definition_dataset import build_transforms
20
 
21
 
 
60
  print(f"Error: Checkpoint file not found: {args.checkpoint}")
61
  sys.exit(1)
62
 
63
+ model = build_model(cfg).to(device)
 
 
 
 
 
 
64
 
65
  print(f"Loading checkpoint from: {args.checkpoint}")
66
  state_dict = torch.load(args.checkpoint, map_location=device)
 
75
  print(f" Perfect match! All weights loaded successfully.")
76
 
77
  transform = build_transforms()
78
+ inference_single_image(model, args.image, device, transform, cfg, args.output)
79
 
80
 
81
  if __name__ == "__main__":