calority-model-api / train_from_scratch.py
okd06's picture
Deploy Calority model API
cecd1f0 verified
import argparse
from pathlib import Path
import torch
from datasets import load_dataset
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from calority_scratch_model import CalorityFoodCNN, image_to_tensor, save_checkpoint
def parse_args():
parser = argparse.ArgumentParser(description="Train Calority's food model from scratch on a Hugging Face dataset.")
parser.add_argument("--dataset", default="food101", help="Hugging Face dataset name, for example food101")
parser.add_argument("--image-column", default="image")
parser.add_argument("--label-column", default="label")
parser.add_argument("--train-split", default="train")
parser.add_argument("--eval-split", default="validation")
parser.add_argument("--output-dir", default="./calority-scratch-model")
parser.add_argument("--epochs", type=int, default=12)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--learning-rate", type=float, default=3e-4)
parser.add_argument("--num-workers", type=int, default=0)
parser.add_argument("--limit-train", type=int, default=0, help="Optional small limit for quick smoke tests")
parser.add_argument("--limit-eval", type=int, default=0, help="Optional small limit for quick smoke tests")
return parser.parse_args()
def get_labels(dataset, split: str, label_column: str) -> list[str]:
feature = dataset[split].features[label_column]
if hasattr(feature, "names") and feature.names:
return list(feature.names)
values = sorted(set(dataset[split][label_column]))
return [str(value) for value in values]
def make_collate_fn(image_column: str, label_column: str):
def collate(batch):
images = torch.stack([image_to_tensor(item[image_column]) for item in batch])
labels = torch.tensor([int(item[label_column]) for item in batch], dtype=torch.long)
return images, labels
return collate
def evaluate(model, loader, loss_fn, device):
model.eval()
total_loss = 0.0
total_correct = 0
total_seen = 0
with torch.no_grad():
for images, labels in loader:
images = images.to(device)
labels = labels.to(device)
logits = model(images)
loss = loss_fn(logits, labels)
total_loss += loss.item() * labels.size(0)
total_correct += (logits.argmax(dim=1) == labels).sum().item()
total_seen += labels.size(0)
return total_loss / max(total_seen, 1), total_correct / max(total_seen, 1)
def main():
args = parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = load_dataset(args.dataset)
if args.limit_train:
dataset[args.train_split] = dataset[args.train_split].shuffle(seed=42).select(range(args.limit_train))
if args.limit_eval:
dataset[args.eval_split] = dataset[args.eval_split].shuffle(seed=42).select(range(args.limit_eval))
labels = get_labels(dataset, args.train_split, args.label_column)
model = CalorityFoodCNN(num_labels=len(labels)).to(device)
collate_fn = make_collate_fn(args.image_column, args.label_column)
train_loader = DataLoader(
dataset[args.train_split],
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
collate_fn=collate_fn,
)
eval_loader = DataLoader(
dataset[args.eval_split],
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=collate_fn,
)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
best_acc = 0.0
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
for epoch in range(1, args.epochs + 1):
model.train()
running_loss = 0.0
total_seen = 0
total_correct = 0
progress = tqdm(train_loader, desc=f"epoch {epoch}/{args.epochs}", leave=False)
for images, labels_batch in progress:
images = images.to(device)
labels_batch = labels_batch.to(device)
optimizer.zero_grad(set_to_none=True)
logits = model(images)
loss = loss_fn(logits, labels_batch)
loss.backward()
optimizer.step()
running_loss += loss.item() * labels_batch.size(0)
total_correct += (logits.argmax(dim=1) == labels_batch).sum().item()
total_seen += labels_batch.size(0)
progress.set_postfix(
loss=round(running_loss / max(total_seen, 1), 4),
acc=round(total_correct / max(total_seen, 1), 4),
)
scheduler.step()
eval_loss, eval_acc = evaluate(model, eval_loader, loss_fn, device)
print(f"epoch={epoch} eval_loss={eval_loss:.4f} eval_acc={eval_acc:.4f}")
if eval_acc >= best_acc:
best_acc = eval_acc
save_checkpoint(model, labels, output_dir)
print(f"saved best model to {output_dir} with eval_acc={best_acc:.4f}")
print(f"done. best_eval_acc={best_acc:.4f}")
if __name__ == "__main__":
main()