Spaces:
Sleeping
Sleeping
| import argparse, os, pandas as pd | |
| import torch, torch.nn as nn | |
| import torchvision.transforms as T | |
| from torch.utils.data import DataLoader, random_split, Dataset | |
| from PIL import Image | |
| import timm | |
| class CarsDataset(Dataset): | |
| def __init__(self, csv_path, img_root): | |
| self.df = pd.read_csv(csv_path) | |
| self.img_root = img_root | |
| self.labels = sorted(self.df['label'].unique().tolist()) | |
| self.transform = T.Compose([T.Resize((224,224)), T.ToTensor()]) | |
| self.label_to_idx = {l:i for i,l in enumerate(self.labels)} | |
| def __len__(self): return len(self.df) | |
| def __getitem__(self, i): | |
| row = self.df.iloc[i] | |
| p = row['image_path'] | |
| if not os.path.isabs(p): | |
| p = os.path.join(self.img_root, p) | |
| img = Image.open(p).convert("RGB") | |
| x = self.transform(img) | |
| y = self.label_to_idx[row['label']] | |
| return x, y | |
| def main(args): | |
| ds = CarsDataset(args.annotations, os.path.dirname(args.annotations)) | |
| n = len(ds); n_val = max(1, int(0.2*n)) | |
| tr, va = random_split(ds, [n-n_val, n_val]) | |
| tl = DataLoader(tr, batch_size=32, shuffle=True) | |
| vl = DataLoader(va, batch_size=32) | |
| model = timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=len(ds.labels)) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| opt = torch.optim.AdamW(model.parameters(), lr=2e-4) | |
| crit = nn.CrossEntropyLoss() | |
| best = 0.0 | |
| os.makedirs(args.out_dir, exist_ok=True) | |
| for epoch in range(args.epochs): | |
| model.train() | |
| for xb, yb in tl: | |
| xb = xb.to(device); yb = yb.to(device) | |
| opt.zero_grad(); out = model(xb); loss = crit(out, yb) | |
| loss.backward(); opt.step() | |
| # val | |
| model.eval(); corr=0; tot=0 | |
| with torch.no_grad(): | |
| for xb, yb in vl: | |
| xb = xb.to(device); yb = yb.to(device) | |
| pred = model(xb).argmax(1) | |
| corr += (pred==yb).sum().item(); tot += yb.numel() | |
| acc = corr/tot if tot else 0 | |
| print(f"Epoch {epoch+1}: val_acc={acc:.3f}") | |
| if acc > best: | |
| best = acc | |
| torch.save({"model": model.state_dict(), "labels": ds.labels}, os.path.join(args.out_dir, "best.pt")) | |
| print("Done. Best acc:", best) | |
| if __name__ == "__main__": | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--data_root", required=True) | |
| ap.add_argument("--annotations", required=True) | |
| ap.add_argument("--out_dir", default="checkpoints/vision") | |
| ap.add_argument("--epochs", type=int, default=10) | |
| args = ap.parse_args() | |
| main(args) | |