import argparse from pathlib import Path import torch from datasets import load_dataset from torch import nn from torch.utils.data import DataLoader from tqdm import tqdm from calority_nutrition_model import ( TARGET_COLUMNS, CalorityNutritionCNN, save_nutrition_checkpoint, ) from calority_scratch_model import image_to_tensor def parse_args(): parser = argparse.ArgumentParser( description="Train Calority's calorie and macro predictor from scratch on mmathys/food-nutrients." ) parser.add_argument("--dataset", default="mmathys/food-nutrients") parser.add_argument("--source-split", default="test", help="This dataset currently ships with only a test split.") parser.add_argument("--image-column", default="image") parser.add_argument("--output-dir", default="./calority-nutrition-model") parser.add_argument("--epochs", type=int, default=40) parser.add_argument("--batch-size", type=int, default=16) parser.add_argument("--learning-rate", type=float, default=3e-4) parser.add_argument("--validation-size", type=float, default=0.15) parser.add_argument("--num-workers", type=int, default=0) parser.add_argument("--limit", type=int, default=0, help="Optional small limit for quick smoke tests") return parser.parse_args() def make_targets(dataset_split) -> torch.Tensor: rows = [[float(item[column]) for column in TARGET_COLUMNS] for item in dataset_split] return torch.tensor(rows, dtype=torch.float32) def make_collate_fn(image_column: str, target_mean: torch.Tensor, target_std: torch.Tensor): def collate(batch): images = torch.stack([image_to_tensor(item[image_column]) for item in batch]) targets = torch.tensor( [[float(item[column]) for column in TARGET_COLUMNS] for item in batch], dtype=torch.float32, ) normalized_targets = (targets - target_mean) / target_std return images, normalized_targets, targets return collate def evaluate(model, loader, loss_fn, target_mean, target_std, device): model.eval() total_loss = 0.0 total_mae = torch.zeros(len(TARGET_COLUMNS)) total_seen = 0 with torch.no_grad(): for images, normalized_targets, raw_targets in loader: images = images.to(device) normalized_targets = normalized_targets.to(device) predictions = model(images) loss = loss_fn(predictions, normalized_targets) raw_predictions = torch.clamp( (predictions.cpu() * target_std) + target_mean, min=0, ) total_loss += loss.item() * images.size(0) total_mae += torch.abs(raw_predictions - raw_targets).sum(dim=0) total_seen += images.size(0) mae = total_mae / max(total_seen, 1) return total_loss / max(total_seen, 1), mae def main(): args = parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dataset = load_dataset(args.dataset) source = dataset[args.source_split].shuffle(seed=42) if args.limit: source = source.select(range(min(args.limit, len(source)))) split = source.train_test_split(test_size=args.validation_size, seed=42) train_ds = split["train"] eval_ds = split["test"] train_targets = make_targets(train_ds) target_mean = train_targets.mean(dim=0) target_std = torch.clamp(train_targets.std(dim=0), min=1.0) model = CalorityNutritionCNN(output_size=len(TARGET_COLUMNS)).to(device) collate_fn = make_collate_fn(args.image_column, target_mean, target_std) train_loader = DataLoader( train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn, ) eval_loader = DataLoader( eval_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn, ) loss_fn = nn.SmoothL1Loss() optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=1e-4) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) output_dir = Path(args.output_dir) best_calorie_mae = float("inf") for epoch in range(1, args.epochs + 1): model.train() running_loss = 0.0 total_seen = 0 progress = tqdm(train_loader, desc=f"epoch {epoch}/{args.epochs}", leave=False) for images, normalized_targets, _ in progress: images = images.to(device) normalized_targets = normalized_targets.to(device) optimizer.zero_grad(set_to_none=True) predictions = model(images) loss = loss_fn(predictions, normalized_targets) loss.backward() optimizer.step() running_loss += loss.item() * images.size(0) total_seen += images.size(0) progress.set_postfix(loss=round(running_loss / max(total_seen, 1), 4)) scheduler.step() eval_loss, mae = evaluate(model, eval_loader, loss_fn, target_mean, target_std, device) metric_line = ", ".join( f"{column}_mae={mae[index]:.2f}" for index, column in enumerate(TARGET_COLUMNS) ) print(f"epoch={epoch} eval_loss={eval_loss:.4f} {metric_line}") if mae[0].item() <= best_calorie_mae: best_calorie_mae = mae[0].item() save_nutrition_checkpoint(model, target_mean, target_std, output_dir) print(f"saved best nutrition model to {output_dir} with calorie_mae={best_calorie_mae:.2f}") print(f"done. best_calorie_mae={best_calorie_mae:.2f}") if __name__ == "__main__": main()