Tasfiya025 commited on
Commit
79393ab
·
verified ·
1 Parent(s): 01effba

Create train_infer.py

Browse files
Files changed (1) hide show
  1. train_infer.py +80 -0
train_infer.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ running_loss += loss.item() * imgs.size(0)
3
+
4
+
5
+ # Validation
6
+ model.eval()
7
+ correct = 0
8
+ total = 0
9
+ with torch.no_grad():
10
+ for imgs, labels in val_loader:
11
+ imgs, labels = imgs.to(device), labels.to(device)
12
+ outputs = model(imgs)
13
+ _, preds = outputs.max(1)
14
+ correct += (preds == labels).sum().item()
15
+ total += labels.size(0)
16
+ val_acc = correct / total if total else 0
17
+ avg_loss = running_loss / (len(train_loader.dataset) if len(train_loader.dataset) else 1)
18
+ print(f"Epoch {epoch+1}/{args.epochs} val_acc={val_acc:.4f} train_loss={avg_loss:.4f}")
19
+
20
+
21
+ if val_acc > best_acc:
22
+ best_acc = val_acc
23
+ torch.save({'model_state': model.state_dict(), 'classes': classes, 'base': args.base}, args.ckpt)
24
+ print("Saved best checkpoint ->", args.ckpt)
25
+
26
+
27
+ print("Training finished. Best val acc:", best_acc)
28
+
29
+
30
+
31
+
32
+ def predict(args):
33
+ if not os.path.exists(args.ckpt):
34
+ raise SystemExit(f"Checkpoint not found: {args.ckpt}")
35
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
36
+ ck = torch.load(args.ckpt, map_location=device)
37
+ classes = ck['classes']
38
+ base = ck.get('base', 'resnet18')
39
+ model = build_model(len(classes), base=base, pretrained=False)
40
+ model.load_state_dict(ck['model_state'])
41
+ model.to(device).eval()
42
+
43
+
44
+ tf = transforms.Compose([
45
+ transforms.Resize(256),
46
+ transforms.CenterCrop(args.img),
47
+ transforms.ToTensor(),
48
+ transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
49
+ ])
50
+ img = Image.open(args.image).convert('RGB')
51
+ x = tf(img).unsqueeze(0).to(device)
52
+ with torch.no_grad():
53
+ out = model(x)
54
+ pred = out.argmax(1).item()
55
+ print('Predicted:', classes[pred])
56
+
57
+
58
+
59
+
60
+ if __name__ == '__main__':
61
+ p = argparse.ArgumentParser()
62
+ p.add_argument('--mode', choices=['train', 'predict'], required=True)
63
+ p.add_argument('--train_dir', default='data/train')
64
+ p.add_argument('--val_dir', default='data/val')
65
+ p.add_argument('--image', help='Image path for prediction')
66
+ p.add_argument('--ckpt', default='ckpt.pth')
67
+ p.add_argument('--epochs', type=int, default=3)
68
+ p.add_argument('--batch', type=int, default=16)
69
+ p.add_argument('--lr', type=float, default=1e-4)
70
+ p.add_argument('--img', type=int, default=224)
71
+ p.add_argument('--base', default='resnet18')
72
+ args = p.parse_args()
73
+
74
+
75
+ if args.mode == 'train':
76
+ train(args)
77
+ else:
78
+ if not args.image:
79
+ raise SystemExit('Provide --image for predict mode')
80
+ predict(args)