#!/usr/bin/env python3 running_loss += loss.item() * imgs.size(0) # Validation model.eval() correct = 0 total = 0 with torch.no_grad(): for imgs, labels in val_loader: imgs, labels = imgs.to(device), labels.to(device) outputs = model(imgs) _, preds = outputs.max(1) correct += (preds == labels).sum().item() total += labels.size(0) val_acc = correct / total if total else 0 avg_loss = running_loss / (len(train_loader.dataset) if len(train_loader.dataset) else 1) print(f"Epoch {epoch+1}/{args.epochs} val_acc={val_acc:.4f} train_loss={avg_loss:.4f}") if val_acc > best_acc: best_acc = val_acc torch.save({'model_state': model.state_dict(), 'classes': classes, 'base': args.base}, args.ckpt) print("Saved best checkpoint ->", args.ckpt) print("Training finished. Best val acc:", best_acc) def predict(args): if not os.path.exists(args.ckpt): raise SystemExit(f"Checkpoint not found: {args.ckpt}") device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') ck = torch.load(args.ckpt, map_location=device) classes = ck['classes'] base = ck.get('base', 'resnet18') model = build_model(len(classes), base=base, pretrained=False) model.load_state_dict(ck['model_state']) model.to(device).eval() tf = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(args.img), transforms.ToTensor(), transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]) ]) img = Image.open(args.image).convert('RGB') x = tf(img).unsqueeze(0).to(device) with torch.no_grad(): out = model(x) pred = out.argmax(1).item() print('Predicted:', classes[pred]) if __name__ == '__main__': p = argparse.ArgumentParser() p.add_argument('--mode', choices=['train', 'predict'], required=True) p.add_argument('--train_dir', default='data/train') p.add_argument('--val_dir', default='data/val') p.add_argument('--image', help='Image path for prediction') p.add_argument('--ckpt', default='ckpt.pth') p.add_argument('--epochs', type=int, default=3) p.add_argument('--batch', type=int, default=16) p.add_argument('--lr', type=float, default=1e-4) p.add_argument('--img', type=int, default=224) p.add_argument('--base', default='resnet18') args = p.parse_args() if args.mode == 'train': train(args) else: if not args.image: raise SystemExit('Provide --image for predict mode') predict(args)