calority-model-api / calority_scratch_model.py
okd06's picture
Deploy Calority model API
cecd1f0 verified
import json
from pathlib import Path
import torch
from PIL import Image
from torch import nn
IMAGE_SIZE = 224
MEAN = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
STD = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
class CalorityFoodCNN(nn.Module):
def __init__(self, num_labels: int):
super().__init__()
self.features = nn.Sequential(
self._block(3, 32),
self._block(32, 64),
self._block(64, 128),
self._block(128, 256),
self._block(256, 384),
)
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Dropout(0.35),
nn.Linear(384, 256),
nn.ReLU(inplace=True),
nn.Dropout(0.2),
nn.Linear(256, num_labels),
)
@staticmethod
def _block(in_channels: int, out_channels: int) -> nn.Sequential:
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
x = self.features(pixel_values)
x = self.pool(x)
return self.classifier(x)
def image_to_tensor(image: Image.Image, image_size: int = IMAGE_SIZE) -> torch.Tensor:
resized = image.convert("RGB").resize((image_size, image_size), Image.Resampling.BILINEAR)
raw = torch.ByteTensor(torch.ByteStorage.from_buffer(resized.tobytes()))
tensor = raw.view(image_size, image_size, 3).permute(2, 0, 1).float() / 255.0
return (tensor - MEAN) / STD
def save_checkpoint(model: nn.Module, labels: list[str], output_dir: str | Path) -> None:
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
torch.save(model.state_dict(), output_path / "model.pt")
(output_path / "labels.json").write_text(json.dumps(labels, indent=2), encoding="utf-8")
(output_path / "config.json").write_text(
json.dumps({"architecture": "CalorityFoodCNN", "image_size": IMAGE_SIZE}, indent=2),
encoding="utf-8",
)
def load_checkpoint(model_dir: str | Path, device: str | torch.device = "cpu") -> tuple[CalorityFoodCNN, list[str]]:
model_path = Path(model_dir)
labels = json.loads((model_path / "labels.json").read_text(encoding="utf-8"))
model = CalorityFoodCNN(num_labels=len(labels))
state = torch.load(model_path / "model.pt", map_location=device)
model.load_state_dict(state)
model.to(device)
model.eval()
return model, labels