Spaces:
Sleeping
Sleeping
| import argparse | |
| from pathlib import Path | |
| import torch | |
| from datasets import load_dataset | |
| from torch import nn | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from calority_scratch_model import CalorityFoodCNN, image_to_tensor, save_checkpoint | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Train Calority's food model from scratch on a Hugging Face dataset.") | |
| parser.add_argument("--dataset", default="food101", help="Hugging Face dataset name, for example food101") | |
| parser.add_argument("--image-column", default="image") | |
| parser.add_argument("--label-column", default="label") | |
| parser.add_argument("--train-split", default="train") | |
| parser.add_argument("--eval-split", default="validation") | |
| parser.add_argument("--output-dir", default="./calority-scratch-model") | |
| parser.add_argument("--epochs", type=int, default=12) | |
| parser.add_argument("--batch-size", type=int, default=32) | |
| parser.add_argument("--learning-rate", type=float, default=3e-4) | |
| parser.add_argument("--num-workers", type=int, default=0) | |
| parser.add_argument("--limit-train", type=int, default=0, help="Optional small limit for quick smoke tests") | |
| parser.add_argument("--limit-eval", type=int, default=0, help="Optional small limit for quick smoke tests") | |
| return parser.parse_args() | |
| def get_labels(dataset, split: str, label_column: str) -> list[str]: | |
| feature = dataset[split].features[label_column] | |
| if hasattr(feature, "names") and feature.names: | |
| return list(feature.names) | |
| values = sorted(set(dataset[split][label_column])) | |
| return [str(value) for value in values] | |
| def make_collate_fn(image_column: str, label_column: str): | |
| def collate(batch): | |
| images = torch.stack([image_to_tensor(item[image_column]) for item in batch]) | |
| labels = torch.tensor([int(item[label_column]) for item in batch], dtype=torch.long) | |
| return images, labels | |
| return collate | |
| def evaluate(model, loader, loss_fn, device): | |
| model.eval() | |
| total_loss = 0.0 | |
| total_correct = 0 | |
| total_seen = 0 | |
| with torch.no_grad(): | |
| for images, labels in loader: | |
| images = images.to(device) | |
| labels = labels.to(device) | |
| logits = model(images) | |
| loss = loss_fn(logits, labels) | |
| total_loss += loss.item() * labels.size(0) | |
| total_correct += (logits.argmax(dim=1) == labels).sum().item() | |
| total_seen += labels.size(0) | |
| return total_loss / max(total_seen, 1), total_correct / max(total_seen, 1) | |
| def main(): | |
| args = parse_args() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| dataset = load_dataset(args.dataset) | |
| if args.limit_train: | |
| dataset[args.train_split] = dataset[args.train_split].shuffle(seed=42).select(range(args.limit_train)) | |
| if args.limit_eval: | |
| dataset[args.eval_split] = dataset[args.eval_split].shuffle(seed=42).select(range(args.limit_eval)) | |
| labels = get_labels(dataset, args.train_split, args.label_column) | |
| model = CalorityFoodCNN(num_labels=len(labels)).to(device) | |
| collate_fn = make_collate_fn(args.image_column, args.label_column) | |
| train_loader = DataLoader( | |
| dataset[args.train_split], | |
| batch_size=args.batch_size, | |
| shuffle=True, | |
| num_workers=args.num_workers, | |
| collate_fn=collate_fn, | |
| ) | |
| eval_loader = DataLoader( | |
| dataset[args.eval_split], | |
| batch_size=args.batch_size, | |
| shuffle=False, | |
| num_workers=args.num_workers, | |
| collate_fn=collate_fn, | |
| ) | |
| loss_fn = nn.CrossEntropyLoss() | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=1e-4) | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) | |
| best_acc = 0.0 | |
| output_dir = Path(args.output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| for epoch in range(1, args.epochs + 1): | |
| model.train() | |
| running_loss = 0.0 | |
| total_seen = 0 | |
| total_correct = 0 | |
| progress = tqdm(train_loader, desc=f"epoch {epoch}/{args.epochs}", leave=False) | |
| for images, labels_batch in progress: | |
| images = images.to(device) | |
| labels_batch = labels_batch.to(device) | |
| optimizer.zero_grad(set_to_none=True) | |
| logits = model(images) | |
| loss = loss_fn(logits, labels_batch) | |
| loss.backward() | |
| optimizer.step() | |
| running_loss += loss.item() * labels_batch.size(0) | |
| total_correct += (logits.argmax(dim=1) == labels_batch).sum().item() | |
| total_seen += labels_batch.size(0) | |
| progress.set_postfix( | |
| loss=round(running_loss / max(total_seen, 1), 4), | |
| acc=round(total_correct / max(total_seen, 1), 4), | |
| ) | |
| scheduler.step() | |
| eval_loss, eval_acc = evaluate(model, eval_loader, loss_fn, device) | |
| print(f"epoch={epoch} eval_loss={eval_loss:.4f} eval_acc={eval_acc:.4f}") | |
| if eval_acc >= best_acc: | |
| best_acc = eval_acc | |
| save_checkpoint(model, labels, output_dir) | |
| print(f"saved best model to {output_dir} with eval_acc={best_acc:.4f}") | |
| print(f"done. best_eval_acc={best_acc:.4f}") | |
| if __name__ == "__main__": | |
| main() | |