Spaces:
Sleeping
Sleeping
| import argparse | |
| from pathlib import Path | |
| import torch | |
| from datasets import load_dataset | |
| from torch import nn | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from calority_nutrition_model import ( | |
| TARGET_COLUMNS, | |
| CalorityNutritionCNN, | |
| save_nutrition_checkpoint, | |
| ) | |
| from calority_scratch_model import image_to_tensor | |
| def parse_args(): | |
| parser = argparse.ArgumentParser( | |
| description="Train Calority's calorie and macro predictor from scratch on mmathys/food-nutrients." | |
| ) | |
| parser.add_argument("--dataset", default="mmathys/food-nutrients") | |
| parser.add_argument("--source-split", default="test", help="This dataset currently ships with only a test split.") | |
| parser.add_argument("--image-column", default="image") | |
| parser.add_argument("--output-dir", default="./calority-nutrition-model") | |
| parser.add_argument("--epochs", type=int, default=40) | |
| parser.add_argument("--batch-size", type=int, default=16) | |
| parser.add_argument("--learning-rate", type=float, default=3e-4) | |
| parser.add_argument("--validation-size", type=float, default=0.15) | |
| parser.add_argument("--num-workers", type=int, default=0) | |
| parser.add_argument("--limit", type=int, default=0, help="Optional small limit for quick smoke tests") | |
| return parser.parse_args() | |
| def make_targets(dataset_split) -> torch.Tensor: | |
| rows = [[float(item[column]) for column in TARGET_COLUMNS] for item in dataset_split] | |
| return torch.tensor(rows, dtype=torch.float32) | |
| def make_collate_fn(image_column: str, target_mean: torch.Tensor, target_std: torch.Tensor): | |
| def collate(batch): | |
| images = torch.stack([image_to_tensor(item[image_column]) for item in batch]) | |
| targets = torch.tensor( | |
| [[float(item[column]) for column in TARGET_COLUMNS] for item in batch], | |
| dtype=torch.float32, | |
| ) | |
| normalized_targets = (targets - target_mean) / target_std | |
| return images, normalized_targets, targets | |
| return collate | |
| def evaluate(model, loader, loss_fn, target_mean, target_std, device): | |
| model.eval() | |
| total_loss = 0.0 | |
| total_mae = torch.zeros(len(TARGET_COLUMNS)) | |
| total_seen = 0 | |
| with torch.no_grad(): | |
| for images, normalized_targets, raw_targets in loader: | |
| images = images.to(device) | |
| normalized_targets = normalized_targets.to(device) | |
| predictions = model(images) | |
| loss = loss_fn(predictions, normalized_targets) | |
| raw_predictions = torch.clamp( | |
| (predictions.cpu() * target_std) + target_mean, | |
| min=0, | |
| ) | |
| total_loss += loss.item() * images.size(0) | |
| total_mae += torch.abs(raw_predictions - raw_targets).sum(dim=0) | |
| total_seen += images.size(0) | |
| mae = total_mae / max(total_seen, 1) | |
| return total_loss / max(total_seen, 1), mae | |
| def main(): | |
| args = parse_args() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| dataset = load_dataset(args.dataset) | |
| source = dataset[args.source_split].shuffle(seed=42) | |
| if args.limit: | |
| source = source.select(range(min(args.limit, len(source)))) | |
| split = source.train_test_split(test_size=args.validation_size, seed=42) | |
| train_ds = split["train"] | |
| eval_ds = split["test"] | |
| train_targets = make_targets(train_ds) | |
| target_mean = train_targets.mean(dim=0) | |
| target_std = torch.clamp(train_targets.std(dim=0), min=1.0) | |
| model = CalorityNutritionCNN(output_size=len(TARGET_COLUMNS)).to(device) | |
| collate_fn = make_collate_fn(args.image_column, target_mean, target_std) | |
| train_loader = DataLoader( | |
| train_ds, | |
| batch_size=args.batch_size, | |
| shuffle=True, | |
| num_workers=args.num_workers, | |
| collate_fn=collate_fn, | |
| ) | |
| eval_loader = DataLoader( | |
| eval_ds, | |
| batch_size=args.batch_size, | |
| shuffle=False, | |
| num_workers=args.num_workers, | |
| collate_fn=collate_fn, | |
| ) | |
| loss_fn = nn.SmoothL1Loss() | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=1e-4) | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) | |
| output_dir = Path(args.output_dir) | |
| best_calorie_mae = float("inf") | |
| for epoch in range(1, args.epochs + 1): | |
| model.train() | |
| running_loss = 0.0 | |
| total_seen = 0 | |
| progress = tqdm(train_loader, desc=f"epoch {epoch}/{args.epochs}", leave=False) | |
| for images, normalized_targets, _ in progress: | |
| images = images.to(device) | |
| normalized_targets = normalized_targets.to(device) | |
| optimizer.zero_grad(set_to_none=True) | |
| predictions = model(images) | |
| loss = loss_fn(predictions, normalized_targets) | |
| loss.backward() | |
| optimizer.step() | |
| running_loss += loss.item() * images.size(0) | |
| total_seen += images.size(0) | |
| progress.set_postfix(loss=round(running_loss / max(total_seen, 1), 4)) | |
| scheduler.step() | |
| eval_loss, mae = evaluate(model, eval_loader, loss_fn, target_mean, target_std, device) | |
| metric_line = ", ".join( | |
| f"{column}_mae={mae[index]:.2f}" for index, column in enumerate(TARGET_COLUMNS) | |
| ) | |
| print(f"epoch={epoch} eval_loss={eval_loss:.4f} {metric_line}") | |
| if mae[0].item() <= best_calorie_mae: | |
| best_calorie_mae = mae[0].item() | |
| save_nutrition_checkpoint(model, target_mean, target_std, output_dir) | |
| print(f"saved best nutrition model to {output_dir} with calorie_mae={best_calorie_mae:.2f}") | |
| print(f"done. best_calorie_mae={best_calorie_mae:.2f}") | |
| if __name__ == "__main__": | |
| main() | |