Spaces:
Sleeping
Sleeping
| import json | |
| from pathlib import Path | |
| import torch | |
| from PIL import Image | |
| from torch import nn | |
| IMAGE_SIZE = 224 | |
| MEAN = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) | |
| STD = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) | |
| class CalorityFoodCNN(nn.Module): | |
| def __init__(self, num_labels: int): | |
| super().__init__() | |
| self.features = nn.Sequential( | |
| self._block(3, 32), | |
| self._block(32, 64), | |
| self._block(64, 128), | |
| self._block(128, 256), | |
| self._block(256, 384), | |
| ) | |
| self.pool = nn.AdaptiveAvgPool2d((1, 1)) | |
| self.classifier = nn.Sequential( | |
| nn.Flatten(), | |
| nn.Dropout(0.35), | |
| nn.Linear(384, 256), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(0.2), | |
| nn.Linear(256, num_labels), | |
| ) | |
| def _block(in_channels: int, out_channels: int) -> nn.Sequential: | |
| return nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(inplace=True), | |
| nn.MaxPool2d(2), | |
| ) | |
| def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: | |
| x = self.features(pixel_values) | |
| x = self.pool(x) | |
| return self.classifier(x) | |
| def image_to_tensor(image: Image.Image, image_size: int = IMAGE_SIZE) -> torch.Tensor: | |
| resized = image.convert("RGB").resize((image_size, image_size), Image.Resampling.BILINEAR) | |
| raw = torch.ByteTensor(torch.ByteStorage.from_buffer(resized.tobytes())) | |
| tensor = raw.view(image_size, image_size, 3).permute(2, 0, 1).float() / 255.0 | |
| return (tensor - MEAN) / STD | |
| def save_checkpoint(model: nn.Module, labels: list[str], output_dir: str | Path) -> None: | |
| output_path = Path(output_dir) | |
| output_path.mkdir(parents=True, exist_ok=True) | |
| torch.save(model.state_dict(), output_path / "model.pt") | |
| (output_path / "labels.json").write_text(json.dumps(labels, indent=2), encoding="utf-8") | |
| (output_path / "config.json").write_text( | |
| json.dumps({"architecture": "CalorityFoodCNN", "image_size": IMAGE_SIZE}, indent=2), | |
| encoding="utf-8", | |
| ) | |
| def load_checkpoint(model_dir: str | Path, device: str | torch.device = "cpu") -> tuple[CalorityFoodCNN, list[str]]: | |
| model_path = Path(model_dir) | |
| labels = json.loads((model_path / "labels.json").read_text(encoding="utf-8")) | |
| model = CalorityFoodCNN(num_labels=len(labels)) | |
| state = torch.load(model_path / "model.pt", map_location=device) | |
| model.load_state_dict(state) | |
| model.to(device) | |
| model.eval() | |
| return model, labels | |