|
|
|
|
|
running_loss += loss.item() * imgs.size(0) |
|
|
|
|
|
|
|
|
|
|
|
model.eval() |
|
|
correct = 0 |
|
|
total = 0 |
|
|
with torch.no_grad(): |
|
|
for imgs, labels in val_loader: |
|
|
imgs, labels = imgs.to(device), labels.to(device) |
|
|
outputs = model(imgs) |
|
|
_, preds = outputs.max(1) |
|
|
correct += (preds == labels).sum().item() |
|
|
total += labels.size(0) |
|
|
val_acc = correct / total if total else 0 |
|
|
avg_loss = running_loss / (len(train_loader.dataset) if len(train_loader.dataset) else 1) |
|
|
print(f"Epoch {epoch+1}/{args.epochs} val_acc={val_acc:.4f} train_loss={avg_loss:.4f}") |
|
|
|
|
|
|
|
|
if val_acc > best_acc: |
|
|
best_acc = val_acc |
|
|
torch.save({'model_state': model.state_dict(), 'classes': classes, 'base': args.base}, args.ckpt) |
|
|
print("Saved best checkpoint ->", args.ckpt) |
|
|
|
|
|
|
|
|
print("Training finished. Best val acc:", best_acc) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict(args): |
|
|
if not os.path.exists(args.ckpt): |
|
|
raise SystemExit(f"Checkpoint not found: {args.ckpt}") |
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
ck = torch.load(args.ckpt, map_location=device) |
|
|
classes = ck['classes'] |
|
|
base = ck.get('base', 'resnet18') |
|
|
model = build_model(len(classes), base=base, pretrained=False) |
|
|
model.load_state_dict(ck['model_state']) |
|
|
model.to(device).eval() |
|
|
|
|
|
|
|
|
tf = transforms.Compose([ |
|
|
transforms.Resize(256), |
|
|
transforms.CenterCrop(args.img), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]) |
|
|
]) |
|
|
img = Image.open(args.image).convert('RGB') |
|
|
x = tf(img).unsqueeze(0).to(device) |
|
|
with torch.no_grad(): |
|
|
out = model(x) |
|
|
pred = out.argmax(1).item() |
|
|
print('Predicted:', classes[pred]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
p = argparse.ArgumentParser() |
|
|
p.add_argument('--mode', choices=['train', 'predict'], required=True) |
|
|
p.add_argument('--train_dir', default='data/train') |
|
|
p.add_argument('--val_dir', default='data/val') |
|
|
p.add_argument('--image', help='Image path for prediction') |
|
|
p.add_argument('--ckpt', default='ckpt.pth') |
|
|
p.add_argument('--epochs', type=int, default=3) |
|
|
p.add_argument('--batch', type=int, default=16) |
|
|
p.add_argument('--lr', type=float, default=1e-4) |
|
|
p.add_argument('--img', type=int, default=224) |
|
|
p.add_argument('--base', default='resnet18') |
|
|
args = p.parse_args() |
|
|
|
|
|
|
|
|
if args.mode == 'train': |
|
|
train(args) |
|
|
else: |
|
|
if not args.image: |
|
|
raise SystemExit('Provide --image for predict mode') |
|
|
predict(args) |