import json from pathlib import Path import torch from PIL import Image from torch import nn IMAGE_SIZE = 224 MEAN = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) STD = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) class CalorityFoodCNN(nn.Module): def __init__(self, num_labels: int): super().__init__() self.features = nn.Sequential( self._block(3, 32), self._block(32, 64), self._block(64, 128), self._block(128, 256), self._block(256, 384), ) self.pool = nn.AdaptiveAvgPool2d((1, 1)) self.classifier = nn.Sequential( nn.Flatten(), nn.Dropout(0.35), nn.Linear(384, 256), nn.ReLU(inplace=True), nn.Dropout(0.2), nn.Linear(256, num_labels), ) @staticmethod def _block(in_channels: int, out_channels: int) -> nn.Sequential: return nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.MaxPool2d(2), ) def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: x = self.features(pixel_values) x = self.pool(x) return self.classifier(x) def image_to_tensor(image: Image.Image, image_size: int = IMAGE_SIZE) -> torch.Tensor: resized = image.convert("RGB").resize((image_size, image_size), Image.Resampling.BILINEAR) raw = torch.ByteTensor(torch.ByteStorage.from_buffer(resized.tobytes())) tensor = raw.view(image_size, image_size, 3).permute(2, 0, 1).float() / 255.0 return (tensor - MEAN) / STD def save_checkpoint(model: nn.Module, labels: list[str], output_dir: str | Path) -> None: output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) torch.save(model.state_dict(), output_path / "model.pt") (output_path / "labels.json").write_text(json.dumps(labels, indent=2), encoding="utf-8") (output_path / "config.json").write_text( json.dumps({"architecture": "CalorityFoodCNN", "image_size": IMAGE_SIZE}, indent=2), encoding="utf-8", ) def load_checkpoint(model_dir: str | Path, device: str | torch.device = "cpu") -> tuple[CalorityFoodCNN, list[str]]: model_path = Path(model_dir) labels = json.loads((model_path / "labels.json").read_text(encoding="utf-8")) model = CalorityFoodCNN(num_labels=len(labels)) state = torch.load(model_path / "model.pt", map_location=device) model.load_state_dict(state) model.to(device) model.eval() return model, labels