import argparse, os, pandas as pd import torch, torch.nn as nn import torchvision.transforms as T from torch.utils.data import DataLoader, random_split, Dataset from PIL import Image import timm class CarsDataset(Dataset): def __init__(self, csv_path, img_root): self.df = pd.read_csv(csv_path) self.img_root = img_root self.labels = sorted(self.df['label'].unique().tolist()) self.transform = T.Compose([T.Resize((224,224)), T.ToTensor()]) self.label_to_idx = {l:i for i,l in enumerate(self.labels)} def __len__(self): return len(self.df) def __getitem__(self, i): row = self.df.iloc[i] p = row['image_path'] if not os.path.isabs(p): p = os.path.join(self.img_root, p) img = Image.open(p).convert("RGB") x = self.transform(img) y = self.label_to_idx[row['label']] return x, y def main(args): ds = CarsDataset(args.annotations, os.path.dirname(args.annotations)) n = len(ds); n_val = max(1, int(0.2*n)) tr, va = random_split(ds, [n-n_val, n_val]) tl = DataLoader(tr, batch_size=32, shuffle=True) vl = DataLoader(va, batch_size=32) model = timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=len(ds.labels)) device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) opt = torch.optim.AdamW(model.parameters(), lr=2e-4) crit = nn.CrossEntropyLoss() best = 0.0 os.makedirs(args.out_dir, exist_ok=True) for epoch in range(args.epochs): model.train() for xb, yb in tl: xb = xb.to(device); yb = yb.to(device) opt.zero_grad(); out = model(xb); loss = crit(out, yb) loss.backward(); opt.step() # val model.eval(); corr=0; tot=0 with torch.no_grad(): for xb, yb in vl: xb = xb.to(device); yb = yb.to(device) pred = model(xb).argmax(1) corr += (pred==yb).sum().item(); tot += yb.numel() acc = corr/tot if tot else 0 print(f"Epoch {epoch+1}: val_acc={acc:.3f}") if acc > best: best = acc torch.save({"model": model.state_dict(), "labels": ds.labels}, os.path.join(args.out_dir, "best.pt")) print("Done. Best acc:", best) if __name__ == "__main__": ap = argparse.ArgumentParser() ap.add_argument("--data_root", required=True) ap.add_argument("--annotations", required=True) ap.add_argument("--out_dir", default="checkpoints/vision") ap.add_argument("--epochs", type=int, default=10) args = ap.parse_args() main(args)