serviceadvisor / training /train_classifier.py
viswanani's picture
Upload 13 files
b9985cf verified
import argparse, os, pandas as pd
import torch, torch.nn as nn
import torchvision.transforms as T
from torch.utils.data import DataLoader, random_split, Dataset
from PIL import Image
import timm
class CarsDataset(Dataset):
def __init__(self, csv_path, img_root):
self.df = pd.read_csv(csv_path)
self.img_root = img_root
self.labels = sorted(self.df['label'].unique().tolist())
self.transform = T.Compose([T.Resize((224,224)), T.ToTensor()])
self.label_to_idx = {l:i for i,l in enumerate(self.labels)}
def __len__(self): return len(self.df)
def __getitem__(self, i):
row = self.df.iloc[i]
p = row['image_path']
if not os.path.isabs(p):
p = os.path.join(self.img_root, p)
img = Image.open(p).convert("RGB")
x = self.transform(img)
y = self.label_to_idx[row['label']]
return x, y
def main(args):
ds = CarsDataset(args.annotations, os.path.dirname(args.annotations))
n = len(ds); n_val = max(1, int(0.2*n))
tr, va = random_split(ds, [n-n_val, n_val])
tl = DataLoader(tr, batch_size=32, shuffle=True)
vl = DataLoader(va, batch_size=32)
model = timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=len(ds.labels))
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
opt = torch.optim.AdamW(model.parameters(), lr=2e-4)
crit = nn.CrossEntropyLoss()
best = 0.0
os.makedirs(args.out_dir, exist_ok=True)
for epoch in range(args.epochs):
model.train()
for xb, yb in tl:
xb = xb.to(device); yb = yb.to(device)
opt.zero_grad(); out = model(xb); loss = crit(out, yb)
loss.backward(); opt.step()
# val
model.eval(); corr=0; tot=0
with torch.no_grad():
for xb, yb in vl:
xb = xb.to(device); yb = yb.to(device)
pred = model(xb).argmax(1)
corr += (pred==yb).sum().item(); tot += yb.numel()
acc = corr/tot if tot else 0
print(f"Epoch {epoch+1}: val_acc={acc:.3f}")
if acc > best:
best = acc
torch.save({"model": model.state_dict(), "labels": ds.labels}, os.path.join(args.out_dir, "best.pt"))
print("Done. Best acc:", best)
if __name__ == "__main__":
ap = argparse.ArgumentParser()
ap.add_argument("--data_root", required=True)
ap.add_argument("--annotations", required=True)
ap.add_argument("--out_dir", default="checkpoints/vision")
ap.add_argument("--epochs", type=int, default=10)
args = ap.parse_args()
main(args)