calority-model-api / train_nutrients_from_scratch.py
okd06's picture
Deploy Calority model API
cecd1f0 verified
import argparse
from pathlib import Path
import torch
from datasets import load_dataset
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from calority_nutrition_model import (
TARGET_COLUMNS,
CalorityNutritionCNN,
save_nutrition_checkpoint,
)
from calority_scratch_model import image_to_tensor
def parse_args():
parser = argparse.ArgumentParser(
description="Train Calority's calorie and macro predictor from scratch on mmathys/food-nutrients."
)
parser.add_argument("--dataset", default="mmathys/food-nutrients")
parser.add_argument("--source-split", default="test", help="This dataset currently ships with only a test split.")
parser.add_argument("--image-column", default="image")
parser.add_argument("--output-dir", default="./calority-nutrition-model")
parser.add_argument("--epochs", type=int, default=40)
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--learning-rate", type=float, default=3e-4)
parser.add_argument("--validation-size", type=float, default=0.15)
parser.add_argument("--num-workers", type=int, default=0)
parser.add_argument("--limit", type=int, default=0, help="Optional small limit for quick smoke tests")
return parser.parse_args()
def make_targets(dataset_split) -> torch.Tensor:
rows = [[float(item[column]) for column in TARGET_COLUMNS] for item in dataset_split]
return torch.tensor(rows, dtype=torch.float32)
def make_collate_fn(image_column: str, target_mean: torch.Tensor, target_std: torch.Tensor):
def collate(batch):
images = torch.stack([image_to_tensor(item[image_column]) for item in batch])
targets = torch.tensor(
[[float(item[column]) for column in TARGET_COLUMNS] for item in batch],
dtype=torch.float32,
)
normalized_targets = (targets - target_mean) / target_std
return images, normalized_targets, targets
return collate
def evaluate(model, loader, loss_fn, target_mean, target_std, device):
model.eval()
total_loss = 0.0
total_mae = torch.zeros(len(TARGET_COLUMNS))
total_seen = 0
with torch.no_grad():
for images, normalized_targets, raw_targets in loader:
images = images.to(device)
normalized_targets = normalized_targets.to(device)
predictions = model(images)
loss = loss_fn(predictions, normalized_targets)
raw_predictions = torch.clamp(
(predictions.cpu() * target_std) + target_mean,
min=0,
)
total_loss += loss.item() * images.size(0)
total_mae += torch.abs(raw_predictions - raw_targets).sum(dim=0)
total_seen += images.size(0)
mae = total_mae / max(total_seen, 1)
return total_loss / max(total_seen, 1), mae
def main():
args = parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = load_dataset(args.dataset)
source = dataset[args.source_split].shuffle(seed=42)
if args.limit:
source = source.select(range(min(args.limit, len(source))))
split = source.train_test_split(test_size=args.validation_size, seed=42)
train_ds = split["train"]
eval_ds = split["test"]
train_targets = make_targets(train_ds)
target_mean = train_targets.mean(dim=0)
target_std = torch.clamp(train_targets.std(dim=0), min=1.0)
model = CalorityNutritionCNN(output_size=len(TARGET_COLUMNS)).to(device)
collate_fn = make_collate_fn(args.image_column, target_mean, target_std)
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
collate_fn=collate_fn,
)
eval_loader = DataLoader(
eval_ds,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=collate_fn,
)
loss_fn = nn.SmoothL1Loss()
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
output_dir = Path(args.output_dir)
best_calorie_mae = float("inf")
for epoch in range(1, args.epochs + 1):
model.train()
running_loss = 0.0
total_seen = 0
progress = tqdm(train_loader, desc=f"epoch {epoch}/{args.epochs}", leave=False)
for images, normalized_targets, _ in progress:
images = images.to(device)
normalized_targets = normalized_targets.to(device)
optimizer.zero_grad(set_to_none=True)
predictions = model(images)
loss = loss_fn(predictions, normalized_targets)
loss.backward()
optimizer.step()
running_loss += loss.item() * images.size(0)
total_seen += images.size(0)
progress.set_postfix(loss=round(running_loss / max(total_seen, 1), 4))
scheduler.step()
eval_loss, mae = evaluate(model, eval_loader, loss_fn, target_mean, target_std, device)
metric_line = ", ".join(
f"{column}_mae={mae[index]:.2f}" for index, column in enumerate(TARGET_COLUMNS)
)
print(f"epoch={epoch} eval_loss={eval_loss:.4f} {metric_line}")
if mae[0].item() <= best_calorie_mae:
best_calorie_mae = mae[0].item()
save_nutrition_checkpoint(model, target_mean, target_std, output_dir)
print(f"saved best nutrition model to {output_dir} with calorie_mae={best_calorie_mae:.2f}")
print(f"done. best_calorie_mae={best_calorie_mae:.2f}")
if __name__ == "__main__":
main()