import argparse from pathlib import Path import torch from datasets import load_dataset from torch import nn from torch.utils.data import DataLoader from tqdm import tqdm from calority_scratch_model import CalorityFoodCNN, image_to_tensor, save_checkpoint def parse_args(): parser = argparse.ArgumentParser(description="Train Calority's food model from scratch on a Hugging Face dataset.") parser.add_argument("--dataset", default="food101", help="Hugging Face dataset name, for example food101") parser.add_argument("--image-column", default="image") parser.add_argument("--label-column", default="label") parser.add_argument("--train-split", default="train") parser.add_argument("--eval-split", default="validation") parser.add_argument("--output-dir", default="./calority-scratch-model") parser.add_argument("--epochs", type=int, default=12) parser.add_argument("--batch-size", type=int, default=32) parser.add_argument("--learning-rate", type=float, default=3e-4) parser.add_argument("--num-workers", type=int, default=0) parser.add_argument("--limit-train", type=int, default=0, help="Optional small limit for quick smoke tests") parser.add_argument("--limit-eval", type=int, default=0, help="Optional small limit for quick smoke tests") return parser.parse_args() def get_labels(dataset, split: str, label_column: str) -> list[str]: feature = dataset[split].features[label_column] if hasattr(feature, "names") and feature.names: return list(feature.names) values = sorted(set(dataset[split][label_column])) return [str(value) for value in values] def make_collate_fn(image_column: str, label_column: str): def collate(batch): images = torch.stack([image_to_tensor(item[image_column]) for item in batch]) labels = torch.tensor([int(item[label_column]) for item in batch], dtype=torch.long) return images, labels return collate def evaluate(model, loader, loss_fn, device): model.eval() total_loss = 0.0 total_correct = 0 total_seen = 0 with torch.no_grad(): for images, labels in loader: images = images.to(device) labels = labels.to(device) logits = model(images) loss = loss_fn(logits, labels) total_loss += loss.item() * labels.size(0) total_correct += (logits.argmax(dim=1) == labels).sum().item() total_seen += labels.size(0) return total_loss / max(total_seen, 1), total_correct / max(total_seen, 1) def main(): args = parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dataset = load_dataset(args.dataset) if args.limit_train: dataset[args.train_split] = dataset[args.train_split].shuffle(seed=42).select(range(args.limit_train)) if args.limit_eval: dataset[args.eval_split] = dataset[args.eval_split].shuffle(seed=42).select(range(args.limit_eval)) labels = get_labels(dataset, args.train_split, args.label_column) model = CalorityFoodCNN(num_labels=len(labels)).to(device) collate_fn = make_collate_fn(args.image_column, args.label_column) train_loader = DataLoader( dataset[args.train_split], batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn, ) eval_loader = DataLoader( dataset[args.eval_split], batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn, ) loss_fn = nn.CrossEntropyLoss() optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=1e-4) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) best_acc = 0.0 output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) for epoch in range(1, args.epochs + 1): model.train() running_loss = 0.0 total_seen = 0 total_correct = 0 progress = tqdm(train_loader, desc=f"epoch {epoch}/{args.epochs}", leave=False) for images, labels_batch in progress: images = images.to(device) labels_batch = labels_batch.to(device) optimizer.zero_grad(set_to_none=True) logits = model(images) loss = loss_fn(logits, labels_batch) loss.backward() optimizer.step() running_loss += loss.item() * labels_batch.size(0) total_correct += (logits.argmax(dim=1) == labels_batch).sum().item() total_seen += labels_batch.size(0) progress.set_postfix( loss=round(running_loss / max(total_seen, 1), 4), acc=round(total_correct / max(total_seen, 1), 4), ) scheduler.step() eval_loss, eval_acc = evaluate(model, eval_loader, loss_fn, device) print(f"epoch={epoch} eval_loss={eval_loss:.4f} eval_acc={eval_acc:.4f}") if eval_acc >= best_acc: best_acc = eval_acc save_checkpoint(model, labels, output_dir) print(f"saved best model to {output_dir} with eval_acc={best_acc:.4f}") print(f"done. best_eval_acc={best_acc:.4f}") if __name__ == "__main__": main()