Spaces:
Sleeping
Sleeping
Deploy Calority model API
Browse files- .dockerignore +8 -0
- .env.example +6 -0
- Dockerfile +21 -9
- README.md +18 -5
- calority_nutrition_model.py +112 -0
- calority_scratch_model.py +78 -0
- main.py +297 -0
- requirements.txt +12 -2
- train_food_model.py +78 -0
- train_from_scratch.py +142 -0
- train_nutrients_from_scratch.py +154 -0
- upload_model_to_hf.py +40 -0
- upload_space_to_hf.py +47 -0
.dockerignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv
|
| 2 |
+
__pycache__
|
| 3 |
+
*.pyc
|
| 4 |
+
.env
|
| 5 |
+
calority-nutrition-model
|
| 6 |
+
calority-scratch-model
|
| 7 |
+
calority-food-model
|
| 8 |
+
*.log
|
.env.example
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MODEL_ID=nateraw/food
|
| 2 |
+
MODEL_TASK=nutrition-regression
|
| 3 |
+
MODEL_DIR=./calority-nutrition-model
|
| 4 |
+
HF_MODEL_REPO_ID=
|
| 5 |
+
MODEL_API_KEY=change-me-before-deploy
|
| 6 |
+
PORT=8000
|
Dockerfile
CHANGED
|
@@ -1,16 +1,28 @@
|
|
| 1 |
-
|
| 2 |
-
# you will also find guides on how best to write your Dockerfile
|
| 3 |
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
RUN useradd -m -u 1000 user
|
| 7 |
USER user
|
| 8 |
-
ENV PATH="/home/user/.local/bin:$PATH"
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
|
| 13 |
-
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
| 14 |
|
| 15 |
-
|
| 16 |
-
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
|
|
|
| 2 |
|
| 3 |
+
ENV PYTHONDONTWRITEBYTECODE=1
|
| 4 |
+
ENV PYTHONUNBUFFERED=1
|
| 5 |
+
ENV PORT=7860
|
| 6 |
+
ENV MODEL_TASK=nutrition-regression
|
| 7 |
+
|
| 8 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 9 |
+
libgl1 \
|
| 10 |
+
libglib2.0-0 \
|
| 11 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 12 |
|
| 13 |
RUN useradd -m -u 1000 user
|
| 14 |
USER user
|
|
|
|
| 15 |
|
| 16 |
+
ENV HOME=/home/user
|
| 17 |
+
ENV PATH="/home/user/.local/bin:${PATH}"
|
| 18 |
+
|
| 19 |
+
WORKDIR /home/user/app
|
| 20 |
+
|
| 21 |
+
COPY requirements.txt .
|
| 22 |
+
RUN pip install --no-cache-dir --upgrade pip && pip install --no-cache-dir -r requirements.txt
|
| 23 |
+
|
| 24 |
+
COPY --chown=user . .
|
| 25 |
|
| 26 |
+
EXPOSE 7860
|
|
|
|
| 27 |
|
| 28 |
+
CMD ["sh", "-c", "uvicorn main:app --host 0.0.0.0 --port ${PORT}"]
|
|
|
README.md
CHANGED
|
@@ -1,10 +1,23 @@
|
|
| 1 |
---
|
| 2 |
-
title: Calority Model
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Calority Model API
|
| 3 |
+
emoji: 🥗
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: yellow
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
---
|
| 9 |
|
| 10 |
+
# Calority Model API
|
| 11 |
+
|
| 12 |
+
FastAPI service for Calority meal nutrition analysis.
|
| 13 |
+
|
| 14 |
+
Set these Space secrets:
|
| 15 |
+
|
| 16 |
+
- `MODEL_TASK=nutrition-regression`
|
| 17 |
+
- `HF_MODEL_REPO_ID=<your-hf-username>/<your-model-repo>`
|
| 18 |
+
- `MODEL_API_KEY=<same-value-you-store-in-supabase>`
|
| 19 |
+
|
| 20 |
+
The API exposes:
|
| 21 |
+
|
| 22 |
+
- `GET /health`
|
| 23 |
+
- `POST /analyze-meal`
|
calority_nutrition_model.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from torch import nn
|
| 7 |
+
|
| 8 |
+
from calority_scratch_model import IMAGE_SIZE, image_to_tensor
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
TARGET_COLUMNS = ["total_calories", "total_mass", "total_fat", "total_carb", "total_protein"]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class CalorityNutritionCNN(nn.Module):
|
| 15 |
+
def __init__(self, output_size: int = len(TARGET_COLUMNS)):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.features = nn.Sequential(
|
| 18 |
+
self._block(3, 32),
|
| 19 |
+
self._block(32, 64),
|
| 20 |
+
self._block(64, 128),
|
| 21 |
+
self._block(128, 256),
|
| 22 |
+
self._block(256, 384),
|
| 23 |
+
)
|
| 24 |
+
self.pool = nn.AdaptiveAvgPool2d((1, 1))
|
| 25 |
+
self.regressor = nn.Sequential(
|
| 26 |
+
nn.Flatten(),
|
| 27 |
+
nn.Dropout(0.35),
|
| 28 |
+
nn.Linear(384, 256),
|
| 29 |
+
nn.ReLU(inplace=True),
|
| 30 |
+
nn.Dropout(0.2),
|
| 31 |
+
nn.Linear(256, output_size),
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
@staticmethod
|
| 35 |
+
def _block(in_channels: int, out_channels: int) -> nn.Sequential:
|
| 36 |
+
return nn.Sequential(
|
| 37 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
|
| 38 |
+
nn.BatchNorm2d(out_channels),
|
| 39 |
+
nn.ReLU(inplace=True),
|
| 40 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
|
| 41 |
+
nn.BatchNorm2d(out_channels),
|
| 42 |
+
nn.ReLU(inplace=True),
|
| 43 |
+
nn.MaxPool2d(2),
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
| 47 |
+
x = self.features(pixel_values)
|
| 48 |
+
x = self.pool(x)
|
| 49 |
+
return self.regressor(x)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def predict_nutrients(
|
| 53 |
+
model: CalorityNutritionCNN,
|
| 54 |
+
image: Image.Image,
|
| 55 |
+
target_mean: torch.Tensor,
|
| 56 |
+
target_std: torch.Tensor,
|
| 57 |
+
device: torch.device,
|
| 58 |
+
) -> dict:
|
| 59 |
+
tensor = image_to_tensor(image, IMAGE_SIZE).unsqueeze(0).to(device)
|
| 60 |
+
with torch.no_grad():
|
| 61 |
+
normalized = model(tensor)[0].cpu()
|
| 62 |
+
values = torch.clamp((normalized * target_std) + target_mean, min=0)
|
| 63 |
+
return {column: round(float(value), 2) for column, value in zip(TARGET_COLUMNS, values)}
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def save_nutrition_checkpoint(
|
| 67 |
+
model: nn.Module,
|
| 68 |
+
target_mean: torch.Tensor,
|
| 69 |
+
target_std: torch.Tensor,
|
| 70 |
+
output_dir: str | Path,
|
| 71 |
+
) -> None:
|
| 72 |
+
output_path = Path(output_dir)
|
| 73 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 74 |
+
torch.save(model.state_dict(), output_path / "model.pt")
|
| 75 |
+
(output_path / "target_stats.json").write_text(
|
| 76 |
+
json.dumps(
|
| 77 |
+
{
|
| 78 |
+
"target_columns": TARGET_COLUMNS,
|
| 79 |
+
"target_mean": [float(value) for value in target_mean],
|
| 80 |
+
"target_std": [float(value) for value in target_std],
|
| 81 |
+
},
|
| 82 |
+
indent=2,
|
| 83 |
+
),
|
| 84 |
+
encoding="utf-8",
|
| 85 |
+
)
|
| 86 |
+
(output_path / "config.json").write_text(
|
| 87 |
+
json.dumps(
|
| 88 |
+
{
|
| 89 |
+
"architecture": "CalorityNutritionCNN",
|
| 90 |
+
"task": "nutrition-regression",
|
| 91 |
+
"image_size": IMAGE_SIZE,
|
| 92 |
+
},
|
| 93 |
+
indent=2,
|
| 94 |
+
),
|
| 95 |
+
encoding="utf-8",
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def load_nutrition_checkpoint(
|
| 100 |
+
model_dir: str | Path,
|
| 101 |
+
device: str | torch.device = "cpu",
|
| 102 |
+
) -> tuple[CalorityNutritionCNN, torch.Tensor, torch.Tensor]:
|
| 103 |
+
model_path = Path(model_dir)
|
| 104 |
+
stats = json.loads((model_path / "target_stats.json").read_text(encoding="utf-8"))
|
| 105 |
+
model = CalorityNutritionCNN(output_size=len(stats["target_columns"]))
|
| 106 |
+
state = torch.load(model_path / "model.pt", map_location=device)
|
| 107 |
+
model.load_state_dict(state)
|
| 108 |
+
model.to(device)
|
| 109 |
+
model.eval()
|
| 110 |
+
target_mean = torch.tensor(stats["target_mean"], dtype=torch.float32)
|
| 111 |
+
target_std = torch.tensor(stats["target_std"], dtype=torch.float32)
|
| 112 |
+
return model, target_mean, target_std
|
calority_scratch_model.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from torch import nn
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
IMAGE_SIZE = 224
|
| 10 |
+
MEAN = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
|
| 11 |
+
STD = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class CalorityFoodCNN(nn.Module):
|
| 15 |
+
def __init__(self, num_labels: int):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.features = nn.Sequential(
|
| 18 |
+
self._block(3, 32),
|
| 19 |
+
self._block(32, 64),
|
| 20 |
+
self._block(64, 128),
|
| 21 |
+
self._block(128, 256),
|
| 22 |
+
self._block(256, 384),
|
| 23 |
+
)
|
| 24 |
+
self.pool = nn.AdaptiveAvgPool2d((1, 1))
|
| 25 |
+
self.classifier = nn.Sequential(
|
| 26 |
+
nn.Flatten(),
|
| 27 |
+
nn.Dropout(0.35),
|
| 28 |
+
nn.Linear(384, 256),
|
| 29 |
+
nn.ReLU(inplace=True),
|
| 30 |
+
nn.Dropout(0.2),
|
| 31 |
+
nn.Linear(256, num_labels),
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
@staticmethod
|
| 35 |
+
def _block(in_channels: int, out_channels: int) -> nn.Sequential:
|
| 36 |
+
return nn.Sequential(
|
| 37 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
|
| 38 |
+
nn.BatchNorm2d(out_channels),
|
| 39 |
+
nn.ReLU(inplace=True),
|
| 40 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
|
| 41 |
+
nn.BatchNorm2d(out_channels),
|
| 42 |
+
nn.ReLU(inplace=True),
|
| 43 |
+
nn.MaxPool2d(2),
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
| 47 |
+
x = self.features(pixel_values)
|
| 48 |
+
x = self.pool(x)
|
| 49 |
+
return self.classifier(x)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def image_to_tensor(image: Image.Image, image_size: int = IMAGE_SIZE) -> torch.Tensor:
|
| 53 |
+
resized = image.convert("RGB").resize((image_size, image_size), Image.Resampling.BILINEAR)
|
| 54 |
+
raw = torch.ByteTensor(torch.ByteStorage.from_buffer(resized.tobytes()))
|
| 55 |
+
tensor = raw.view(image_size, image_size, 3).permute(2, 0, 1).float() / 255.0
|
| 56 |
+
return (tensor - MEAN) / STD
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def save_checkpoint(model: nn.Module, labels: list[str], output_dir: str | Path) -> None:
|
| 60 |
+
output_path = Path(output_dir)
|
| 61 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 62 |
+
torch.save(model.state_dict(), output_path / "model.pt")
|
| 63 |
+
(output_path / "labels.json").write_text(json.dumps(labels, indent=2), encoding="utf-8")
|
| 64 |
+
(output_path / "config.json").write_text(
|
| 65 |
+
json.dumps({"architecture": "CalorityFoodCNN", "image_size": IMAGE_SIZE}, indent=2),
|
| 66 |
+
encoding="utf-8",
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def load_checkpoint(model_dir: str | Path, device: str | torch.device = "cpu") -> tuple[CalorityFoodCNN, list[str]]:
|
| 71 |
+
model_path = Path(model_dir)
|
| 72 |
+
labels = json.loads((model_path / "labels.json").read_text(encoding="utf-8"))
|
| 73 |
+
model = CalorityFoodCNN(num_labels=len(labels))
|
| 74 |
+
state = torch.load(model_path / "model.pt", map_location=device)
|
| 75 |
+
model.load_state_dict(state)
|
| 76 |
+
model.to(device)
|
| 77 |
+
model.eval()
|
| 78 |
+
return model, labels
|
main.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import io
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from functools import lru_cache
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from fastapi import FastAPI, Header, HTTPException
|
| 11 |
+
from huggingface_hub import snapshot_download
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from pydantic import BaseModel, Field
|
| 14 |
+
from transformers import pipeline
|
| 15 |
+
|
| 16 |
+
from calority_nutrition_model import load_nutrition_checkpoint, predict_nutrients
|
| 17 |
+
from calority_scratch_model import image_to_tensor, load_checkpoint
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
MODEL_ID = os.getenv("MODEL_ID", "nateraw/food")
|
| 21 |
+
MODEL_DIR = os.getenv("MODEL_DIR", "")
|
| 22 |
+
HF_MODEL_REPO_ID = os.getenv("HF_MODEL_REPO_ID", "")
|
| 23 |
+
MODEL_TASK = os.getenv("MODEL_TASK", "classification")
|
| 24 |
+
MODEL_API_KEY = os.getenv("MODEL_API_KEY", "")
|
| 25 |
+
|
| 26 |
+
app = FastAPI(title="Calority Meal Model", version="0.1.0")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class AnalyzeMealRequest(BaseModel):
|
| 30 |
+
imageBase64: str = Field(min_length=1)
|
| 31 |
+
mimeType: str = "image/jpeg"
|
| 32 |
+
portionContext: str = ""
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass(frozen=True)
|
| 36 |
+
class NutritionProfile:
|
| 37 |
+
serving_g: int
|
| 38 |
+
calories_100g: int
|
| 39 |
+
protein_100g: float
|
| 40 |
+
carbs_100g: float
|
| 41 |
+
fat_100g: float
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
NUTRITION = {
|
| 45 |
+
"apple pie": NutritionProfile(140, 237, 1.9, 34.0, 11.0),
|
| 46 |
+
"baby back ribs": NutritionProfile(220, 290, 20.0, 6.0, 21.0),
|
| 47 |
+
"baklava": NutritionProfile(80, 428, 6.0, 54.0, 21.0),
|
| 48 |
+
"beef carpaccio": NutritionProfile(120, 160, 22.0, 1.0, 7.0),
|
| 49 |
+
"beef tartare": NutritionProfile(150, 190, 20.0, 2.0, 12.0),
|
| 50 |
+
"beet salad": NutritionProfile(180, 90, 3.0, 12.0, 4.0),
|
| 51 |
+
"bibimbap": NutritionProfile(450, 145, 6.0, 20.0, 4.0),
|
| 52 |
+
"bread pudding": NutritionProfile(160, 220, 5.0, 32.0, 8.0),
|
| 53 |
+
"breakfast burrito": NutritionProfile(280, 210, 10.0, 23.0, 9.0),
|
| 54 |
+
"bruschetta": NutritionProfile(120, 190, 6.0, 25.0, 7.0),
|
| 55 |
+
"caesar salad": NutritionProfile(220, 170, 8.0, 8.0, 12.0),
|
| 56 |
+
"cannoli": NutritionProfile(90, 310, 7.0, 33.0, 16.0),
|
| 57 |
+
"caprese salad": NutritionProfile(180, 170, 9.0, 5.0, 13.0),
|
| 58 |
+
"carrot cake": NutritionProfile(120, 415, 4.0, 50.0, 22.0),
|
| 59 |
+
"cheesecake": NutritionProfile(125, 321, 6.0, 26.0, 22.0),
|
| 60 |
+
"chicken curry": NutritionProfile(300, 165, 13.0, 7.0, 9.0),
|
| 61 |
+
"chicken quesadilla": NutritionProfile(250, 260, 14.0, 22.0, 13.0),
|
| 62 |
+
"chicken wings": NutritionProfile(180, 290, 24.0, 1.0, 20.0),
|
| 63 |
+
"chocolate cake": NutritionProfile(120, 371, 5.0, 53.0, 16.0),
|
| 64 |
+
"club sandwich": NutritionProfile(260, 240, 13.0, 22.0, 12.0),
|
| 65 |
+
"cup cakes": NutritionProfile(80, 305, 4.0, 47.0, 12.0),
|
| 66 |
+
"donuts": NutritionProfile(80, 452, 5.0, 51.0, 25.0),
|
| 67 |
+
"dumplings": NutritionProfile(220, 190, 9.0, 26.0, 6.0),
|
| 68 |
+
"edamame": NutritionProfile(160, 121, 11.0, 9.0, 5.0),
|
| 69 |
+
"falafel": NutritionProfile(180, 333, 13.0, 32.0, 18.0),
|
| 70 |
+
"filet mignon": NutritionProfile(180, 250, 26.0, 0.0, 16.0),
|
| 71 |
+
"fish and chips": NutritionProfile(350, 230, 11.0, 24.0, 10.0),
|
| 72 |
+
"french fries": NutritionProfile(150, 312, 3.4, 41.0, 15.0),
|
| 73 |
+
"fried rice": NutritionProfile(300, 165, 5.0, 25.0, 5.0),
|
| 74 |
+
"greek salad": NutritionProfile(220, 110, 4.0, 7.0, 8.0),
|
| 75 |
+
"grilled cheese sandwich": NutritionProfile(180, 350, 12.0, 28.0, 21.0),
|
| 76 |
+
"hamburger": NutritionProfile(250, 295, 17.0, 24.0, 14.0),
|
| 77 |
+
"hot dog": NutritionProfile(150, 290, 11.0, 24.0, 17.0),
|
| 78 |
+
"hummus": NutritionProfile(120, 166, 8.0, 14.0, 10.0),
|
| 79 |
+
"lasagna": NutritionProfile(320, 170, 10.0, 16.0, 8.0),
|
| 80 |
+
"macaroni and cheese": NutritionProfile(250, 164, 7.0, 20.0, 6.0),
|
| 81 |
+
"omelette": NutritionProfile(180, 154, 11.0, 1.0, 12.0),
|
| 82 |
+
"pancakes": NutritionProfile(220, 227, 6.0, 28.0, 10.0),
|
| 83 |
+
"pizza": NutritionProfile(250, 266, 11.0, 33.0, 10.0),
|
| 84 |
+
"ramen": NutritionProfile(500, 90, 4.0, 12.0, 3.0),
|
| 85 |
+
"samosa": NutritionProfile(150, 260, 6.0, 30.0, 13.0),
|
| 86 |
+
"sashimi": NutritionProfile(160, 130, 22.0, 0.0, 4.0),
|
| 87 |
+
"spaghetti bolognese": NutritionProfile(350, 150, 8.0, 20.0, 5.0),
|
| 88 |
+
"steak": NutritionProfile(220, 250, 26.0, 0.0, 15.0),
|
| 89 |
+
"sushi": NutritionProfile(220, 145, 7.0, 24.0, 2.0),
|
| 90 |
+
"tacos": NutritionProfile(220, 210, 10.0, 21.0, 10.0),
|
| 91 |
+
"waffles": NutritionProfile(180, 291, 8.0, 33.0, 14.0),
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
DEFAULT_PROFILE = NutritionProfile(250, 180, 8.0, 20.0, 6.0)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
@lru_cache(maxsize=1)
|
| 98 |
+
def classifier():
|
| 99 |
+
return pipeline("image-classification", model=MODEL_ID)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@lru_cache(maxsize=1)
|
| 103 |
+
def resolved_model_dir() -> str:
|
| 104 |
+
if MODEL_DIR:
|
| 105 |
+
return MODEL_DIR
|
| 106 |
+
if HF_MODEL_REPO_ID:
|
| 107 |
+
return snapshot_download(repo_id=HF_MODEL_REPO_ID)
|
| 108 |
+
return ""
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
@lru_cache(maxsize=1)
|
| 112 |
+
def scratch_classifier():
|
| 113 |
+
model_dir = resolved_model_dir()
|
| 114 |
+
if not model_dir or MODEL_TASK != "classification":
|
| 115 |
+
return None
|
| 116 |
+
model_path = Path(model_dir)
|
| 117 |
+
if not (model_path / "model.pt").exists():
|
| 118 |
+
return None
|
| 119 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 120 |
+
model, labels = load_checkpoint(model_path, device=device)
|
| 121 |
+
return model, labels, device
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
@lru_cache(maxsize=1)
|
| 125 |
+
def nutrition_regressor():
|
| 126 |
+
model_dir = resolved_model_dir()
|
| 127 |
+
if not model_dir or MODEL_TASK != "nutrition-regression":
|
| 128 |
+
return None
|
| 129 |
+
model_path = Path(model_dir)
|
| 130 |
+
if not (model_path / "model.pt").exists() or not (model_path / "target_stats.json").exists():
|
| 131 |
+
return None
|
| 132 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 133 |
+
model, target_mean, target_std = load_nutrition_checkpoint(model_path, device=device)
|
| 134 |
+
return model, target_mean, target_std, device
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def classify_image(image: Image.Image) -> list[dict]:
|
| 138 |
+
scratch = scratch_classifier()
|
| 139 |
+
if scratch is None:
|
| 140 |
+
return classifier()(image, top_k=3)
|
| 141 |
+
|
| 142 |
+
model, labels, device = scratch
|
| 143 |
+
tensor = image_to_tensor(image).unsqueeze(0).to(device)
|
| 144 |
+
with torch.no_grad():
|
| 145 |
+
probabilities = torch.softmax(model(tensor), dim=1)[0]
|
| 146 |
+
top_scores, top_indices = torch.topk(probabilities, k=min(3, len(labels)))
|
| 147 |
+
return [
|
| 148 |
+
{"label": labels[index.item()], "score": score.item()}
|
| 149 |
+
for score, index in zip(top_scores, top_indices)
|
| 150 |
+
]
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def analyze_nutrients(image: Image.Image, portion_context: str) -> dict | None:
|
| 154 |
+
regressor = nutrition_regressor()
|
| 155 |
+
if regressor is None:
|
| 156 |
+
return None
|
| 157 |
+
|
| 158 |
+
model, target_mean, target_std, device = regressor
|
| 159 |
+
nutrients = predict_nutrients(model, image, target_mean, target_std, device)
|
| 160 |
+
|
| 161 |
+
calories = round(nutrients["total_calories"])
|
| 162 |
+
mass = round(nutrients["total_mass"])
|
| 163 |
+
fat = round(nutrients["total_fat"])
|
| 164 |
+
carbs = round(nutrients["total_carb"])
|
| 165 |
+
protein = round(nutrients["total_protein"])
|
| 166 |
+
macro_calories = (protein * 4) + (carbs * 4) + (fat * 9)
|
| 167 |
+
macro_gap = abs(macro_calories - calories)
|
| 168 |
+
confidence = "medium" if calories > 0 else "low"
|
| 169 |
+
confidence_note = (
|
| 170 |
+
f"Estimated from image using Calority nutrition regression. Macro calories differ by {macro_gap} kcal."
|
| 171 |
+
)
|
| 172 |
+
if portion_context:
|
| 173 |
+
confidence_note = f"{confidence_note} User context: {portion_context}."
|
| 174 |
+
|
| 175 |
+
return {
|
| 176 |
+
"name": "Food Plate",
|
| 177 |
+
"calories": calories,
|
| 178 |
+
"protein": protein,
|
| 179 |
+
"carbs": carbs,
|
| 180 |
+
"fat": fat,
|
| 181 |
+
"ingredients": [
|
| 182 |
+
f"Estimated total mass {mass}g",
|
| 183 |
+
f"Protein {protein}g - {protein * 4} kcal",
|
| 184 |
+
f"Carbs {carbs}g - {carbs * 4} kcal",
|
| 185 |
+
f"Fat {fat}g - {fat * 9} kcal",
|
| 186 |
+
],
|
| 187 |
+
"confidence": confidence,
|
| 188 |
+
"confidenceNote": confidence_note,
|
| 189 |
+
"nutritionDetails": {
|
| 190 |
+
"totalMass": mass,
|
| 191 |
+
"calories": calories,
|
| 192 |
+
"protein": protein,
|
| 193 |
+
"carbs": carbs,
|
| 194 |
+
"fat": fat,
|
| 195 |
+
"macroCalories": macro_calories,
|
| 196 |
+
},
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def require_auth(authorization: str | None) -> None:
|
| 201 |
+
if not MODEL_API_KEY:
|
| 202 |
+
return
|
| 203 |
+
expected = f"Bearer {MODEL_API_KEY}"
|
| 204 |
+
if authorization != expected:
|
| 205 |
+
raise HTTPException(status_code=401, detail="Invalid model service token")
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def decode_image(image_base64: str) -> Image.Image:
|
| 209 |
+
try:
|
| 210 |
+
raw = base64.b64decode(image_base64)
|
| 211 |
+
return Image.open(io.BytesIO(raw)).convert("RGB")
|
| 212 |
+
except Exception as exc:
|
| 213 |
+
raise HTTPException(status_code=400, detail="Invalid imageBase64") from exc
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def normalize_label(label: str) -> str:
|
| 217 |
+
return label.lower().replace("_", " ").replace("-", " ").strip()
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def grams_from_context(portion_context: str, fallback: int) -> int:
|
| 221 |
+
match = re.search(r"(\d{2,4})\s*(g|gram|grams)\b", portion_context.lower())
|
| 222 |
+
if match:
|
| 223 |
+
return max(30, min(1200, int(match.group(1))))
|
| 224 |
+
return fallback
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def nutrition_for(label: str, grams: int) -> dict:
|
| 228 |
+
profile = NUTRITION.get(label, DEFAULT_PROFILE)
|
| 229 |
+
factor = grams / 100
|
| 230 |
+
calories = round(profile.calories_100g * factor)
|
| 231 |
+
protein = round(profile.protein_100g * factor)
|
| 232 |
+
carbs = round(profile.carbs_100g * factor)
|
| 233 |
+
fat = round(profile.fat_100g * factor)
|
| 234 |
+
return {
|
| 235 |
+
"calories": calories,
|
| 236 |
+
"protein": protein,
|
| 237 |
+
"carbs": carbs,
|
| 238 |
+
"fat": fat,
|
| 239 |
+
"ingredient": f"{label.title()} estimated {grams}g - {calories} kcal",
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def confidence_from(score: float) -> tuple[str, str]:
|
| 244 |
+
if score >= 0.75:
|
| 245 |
+
return "high", ""
|
| 246 |
+
if score >= 0.45:
|
| 247 |
+
return "medium", "The food is visible, but the model is not fully certain."
|
| 248 |
+
return "low", "The model could not confidently identify the meal."
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
@app.get("/health")
|
| 252 |
+
def health() -> dict:
|
| 253 |
+
if nutrition_regressor():
|
| 254 |
+
model_source = f"nutrition-regression:{HF_MODEL_REPO_ID or MODEL_DIR}"
|
| 255 |
+
elif scratch_classifier():
|
| 256 |
+
model_source = f"classification:{HF_MODEL_REPO_ID or MODEL_DIR}"
|
| 257 |
+
else:
|
| 258 |
+
model_source = f"pipeline:{MODEL_ID}"
|
| 259 |
+
return {"status": "ok", "model": model_source}
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
@app.post("/analyze-meal")
|
| 263 |
+
def analyze_meal(payload: AnalyzeMealRequest, authorization: str | None = Header(default=None)) -> dict:
|
| 264 |
+
require_auth(authorization)
|
| 265 |
+
image = decode_image(payload.imageBase64)
|
| 266 |
+
nutrient_result = analyze_nutrients(image, payload.portionContext)
|
| 267 |
+
if nutrient_result:
|
| 268 |
+
return nutrient_result
|
| 269 |
+
|
| 270 |
+
predictions = classify_image(image)
|
| 271 |
+
best = predictions[0]
|
| 272 |
+
label = normalize_label(best["label"])
|
| 273 |
+
score = float(best["score"])
|
| 274 |
+
|
| 275 |
+
profile = NUTRITION.get(label, DEFAULT_PROFILE)
|
| 276 |
+
grams = grams_from_context(payload.portionContext, profile.serving_g)
|
| 277 |
+
macros = nutrition_for(label, grams)
|
| 278 |
+
confidence, confidence_note = confidence_from(score)
|
| 279 |
+
|
| 280 |
+
alternatives = [
|
| 281 |
+
f"{normalize_label(item['label']).title()} ({round(float(item['score']) * 100)}%)"
|
| 282 |
+
for item in predictions[1:]
|
| 283 |
+
]
|
| 284 |
+
|
| 285 |
+
if alternatives and confidence != "high":
|
| 286 |
+
confidence_note = f"{confidence_note} Alternatives: {', '.join(alternatives)}".strip()
|
| 287 |
+
|
| 288 |
+
return {
|
| 289 |
+
"name": label.title(),
|
| 290 |
+
"calories": macros["calories"],
|
| 291 |
+
"protein": macros["protein"],
|
| 292 |
+
"carbs": macros["carbs"],
|
| 293 |
+
"fat": macros["fat"],
|
| 294 |
+
"ingredients": [macros["ingredient"]],
|
| 295 |
+
"confidence": confidence,
|
| 296 |
+
"confidenceNote": confidence_note,
|
| 297 |
+
}
|
requirements.txt
CHANGED
|
@@ -1,2 +1,12 @@
|
|
| 1 |
-
fastapi
|
| 2 |
-
uvicorn[standard]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.115.6
|
| 2 |
+
uvicorn[standard]==0.32.1
|
| 3 |
+
pillow==11.0.0
|
| 4 |
+
pydantic==2.10.4
|
| 5 |
+
python-multipart==0.0.20
|
| 6 |
+
transformers==4.47.1
|
| 7 |
+
torch==2.5.1
|
| 8 |
+
accelerate==1.2.1
|
| 9 |
+
datasets==3.2.0
|
| 10 |
+
huggingface_hub==0.27.0
|
| 11 |
+
scikit-learn==1.6.0
|
| 12 |
+
tqdm==4.67.1
|
train_food_model.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
from datasets import load_dataset
|
| 5 |
+
from sklearn.metrics import accuracy_score
|
| 6 |
+
from transformers import (
|
| 7 |
+
AutoImageProcessor,
|
| 8 |
+
AutoModelForImageClassification,
|
| 9 |
+
Trainer,
|
| 10 |
+
TrainingArguments,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def parse_args():
|
| 15 |
+
parser = argparse.ArgumentParser(description="Fine-tune Calority's food image classifier.")
|
| 16 |
+
parser.add_argument("--base-model", default="google/vit-base-patch16-224-in21k")
|
| 17 |
+
parser.add_argument("--dataset", default="food101")
|
| 18 |
+
parser.add_argument("--output-dir", default="./calority-food-model")
|
| 19 |
+
parser.add_argument("--epochs", type=int, default=3)
|
| 20 |
+
parser.add_argument("--batch-size", type=int, default=16)
|
| 21 |
+
return parser.parse_args()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def main():
|
| 25 |
+
args = parse_args()
|
| 26 |
+
dataset = load_dataset(args.dataset)
|
| 27 |
+
labels = dataset["train"].features["label"].names
|
| 28 |
+
processor = AutoImageProcessor.from_pretrained(args.base_model)
|
| 29 |
+
|
| 30 |
+
def transform(batch):
|
| 31 |
+
images = [image.convert("RGB") for image in batch["image"]]
|
| 32 |
+
encoded = processor(images=images, return_tensors="pt")
|
| 33 |
+
encoded["labels"] = batch["label"]
|
| 34 |
+
return encoded
|
| 35 |
+
|
| 36 |
+
train_ds = dataset["train"].with_transform(transform)
|
| 37 |
+
eval_ds = dataset["validation"].with_transform(transform)
|
| 38 |
+
|
| 39 |
+
model = AutoModelForImageClassification.from_pretrained(
|
| 40 |
+
args.base_model,
|
| 41 |
+
num_labels=len(labels),
|
| 42 |
+
id2label={i: label for i, label in enumerate(labels)},
|
| 43 |
+
label2id={label: i for i, label in enumerate(labels)},
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def compute_metrics(eval_pred):
|
| 47 |
+
logits, labels_np = eval_pred
|
| 48 |
+
preds = np.argmax(logits, axis=-1)
|
| 49 |
+
return {"accuracy": accuracy_score(labels_np, preds)}
|
| 50 |
+
|
| 51 |
+
training_args = TrainingArguments(
|
| 52 |
+
output_dir=args.output_dir,
|
| 53 |
+
learning_rate=5e-5,
|
| 54 |
+
per_device_train_batch_size=args.batch_size,
|
| 55 |
+
per_device_eval_batch_size=args.batch_size,
|
| 56 |
+
num_train_epochs=args.epochs,
|
| 57 |
+
eval_strategy="epoch",
|
| 58 |
+
save_strategy="epoch",
|
| 59 |
+
load_best_model_at_end=True,
|
| 60 |
+
metric_for_best_model="accuracy",
|
| 61 |
+
remove_unused_columns=False,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
trainer = Trainer(
|
| 65 |
+
model=model,
|
| 66 |
+
args=training_args,
|
| 67 |
+
train_dataset=train_ds,
|
| 68 |
+
eval_dataset=eval_ds,
|
| 69 |
+
compute_metrics=compute_metrics,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
trainer.train()
|
| 73 |
+
trainer.save_model(args.output_dir)
|
| 74 |
+
processor.save_pretrained(args.output_dir)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
if __name__ == "__main__":
|
| 78 |
+
main()
|
train_from_scratch.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from datasets import load_dataset
|
| 6 |
+
from torch import nn
|
| 7 |
+
from torch.utils.data import DataLoader
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
from calority_scratch_model import CalorityFoodCNN, image_to_tensor, save_checkpoint
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def parse_args():
|
| 14 |
+
parser = argparse.ArgumentParser(description="Train Calority's food model from scratch on a Hugging Face dataset.")
|
| 15 |
+
parser.add_argument("--dataset", default="food101", help="Hugging Face dataset name, for example food101")
|
| 16 |
+
parser.add_argument("--image-column", default="image")
|
| 17 |
+
parser.add_argument("--label-column", default="label")
|
| 18 |
+
parser.add_argument("--train-split", default="train")
|
| 19 |
+
parser.add_argument("--eval-split", default="validation")
|
| 20 |
+
parser.add_argument("--output-dir", default="./calority-scratch-model")
|
| 21 |
+
parser.add_argument("--epochs", type=int, default=12)
|
| 22 |
+
parser.add_argument("--batch-size", type=int, default=32)
|
| 23 |
+
parser.add_argument("--learning-rate", type=float, default=3e-4)
|
| 24 |
+
parser.add_argument("--num-workers", type=int, default=0)
|
| 25 |
+
parser.add_argument("--limit-train", type=int, default=0, help="Optional small limit for quick smoke tests")
|
| 26 |
+
parser.add_argument("--limit-eval", type=int, default=0, help="Optional small limit for quick smoke tests")
|
| 27 |
+
return parser.parse_args()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_labels(dataset, split: str, label_column: str) -> list[str]:
|
| 31 |
+
feature = dataset[split].features[label_column]
|
| 32 |
+
if hasattr(feature, "names") and feature.names:
|
| 33 |
+
return list(feature.names)
|
| 34 |
+
|
| 35 |
+
values = sorted(set(dataset[split][label_column]))
|
| 36 |
+
return [str(value) for value in values]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def make_collate_fn(image_column: str, label_column: str):
|
| 40 |
+
def collate(batch):
|
| 41 |
+
images = torch.stack([image_to_tensor(item[image_column]) for item in batch])
|
| 42 |
+
labels = torch.tensor([int(item[label_column]) for item in batch], dtype=torch.long)
|
| 43 |
+
return images, labels
|
| 44 |
+
|
| 45 |
+
return collate
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def evaluate(model, loader, loss_fn, device):
|
| 49 |
+
model.eval()
|
| 50 |
+
total_loss = 0.0
|
| 51 |
+
total_correct = 0
|
| 52 |
+
total_seen = 0
|
| 53 |
+
|
| 54 |
+
with torch.no_grad():
|
| 55 |
+
for images, labels in loader:
|
| 56 |
+
images = images.to(device)
|
| 57 |
+
labels = labels.to(device)
|
| 58 |
+
logits = model(images)
|
| 59 |
+
loss = loss_fn(logits, labels)
|
| 60 |
+
total_loss += loss.item() * labels.size(0)
|
| 61 |
+
total_correct += (logits.argmax(dim=1) == labels).sum().item()
|
| 62 |
+
total_seen += labels.size(0)
|
| 63 |
+
|
| 64 |
+
return total_loss / max(total_seen, 1), total_correct / max(total_seen, 1)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def main():
|
| 68 |
+
args = parse_args()
|
| 69 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 70 |
+
dataset = load_dataset(args.dataset)
|
| 71 |
+
|
| 72 |
+
if args.limit_train:
|
| 73 |
+
dataset[args.train_split] = dataset[args.train_split].shuffle(seed=42).select(range(args.limit_train))
|
| 74 |
+
if args.limit_eval:
|
| 75 |
+
dataset[args.eval_split] = dataset[args.eval_split].shuffle(seed=42).select(range(args.limit_eval))
|
| 76 |
+
|
| 77 |
+
labels = get_labels(dataset, args.train_split, args.label_column)
|
| 78 |
+
model = CalorityFoodCNN(num_labels=len(labels)).to(device)
|
| 79 |
+
|
| 80 |
+
collate_fn = make_collate_fn(args.image_column, args.label_column)
|
| 81 |
+
train_loader = DataLoader(
|
| 82 |
+
dataset[args.train_split],
|
| 83 |
+
batch_size=args.batch_size,
|
| 84 |
+
shuffle=True,
|
| 85 |
+
num_workers=args.num_workers,
|
| 86 |
+
collate_fn=collate_fn,
|
| 87 |
+
)
|
| 88 |
+
eval_loader = DataLoader(
|
| 89 |
+
dataset[args.eval_split],
|
| 90 |
+
batch_size=args.batch_size,
|
| 91 |
+
shuffle=False,
|
| 92 |
+
num_workers=args.num_workers,
|
| 93 |
+
collate_fn=collate_fn,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
loss_fn = nn.CrossEntropyLoss()
|
| 97 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=1e-4)
|
| 98 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
|
| 99 |
+
|
| 100 |
+
best_acc = 0.0
|
| 101 |
+
output_dir = Path(args.output_dir)
|
| 102 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 103 |
+
|
| 104 |
+
for epoch in range(1, args.epochs + 1):
|
| 105 |
+
model.train()
|
| 106 |
+
running_loss = 0.0
|
| 107 |
+
total_seen = 0
|
| 108 |
+
total_correct = 0
|
| 109 |
+
|
| 110 |
+
progress = tqdm(train_loader, desc=f"epoch {epoch}/{args.epochs}", leave=False)
|
| 111 |
+
for images, labels_batch in progress:
|
| 112 |
+
images = images.to(device)
|
| 113 |
+
labels_batch = labels_batch.to(device)
|
| 114 |
+
|
| 115 |
+
optimizer.zero_grad(set_to_none=True)
|
| 116 |
+
logits = model(images)
|
| 117 |
+
loss = loss_fn(logits, labels_batch)
|
| 118 |
+
loss.backward()
|
| 119 |
+
optimizer.step()
|
| 120 |
+
|
| 121 |
+
running_loss += loss.item() * labels_batch.size(0)
|
| 122 |
+
total_correct += (logits.argmax(dim=1) == labels_batch).sum().item()
|
| 123 |
+
total_seen += labels_batch.size(0)
|
| 124 |
+
progress.set_postfix(
|
| 125 |
+
loss=round(running_loss / max(total_seen, 1), 4),
|
| 126 |
+
acc=round(total_correct / max(total_seen, 1), 4),
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
scheduler.step()
|
| 130 |
+
eval_loss, eval_acc = evaluate(model, eval_loader, loss_fn, device)
|
| 131 |
+
print(f"epoch={epoch} eval_loss={eval_loss:.4f} eval_acc={eval_acc:.4f}")
|
| 132 |
+
|
| 133 |
+
if eval_acc >= best_acc:
|
| 134 |
+
best_acc = eval_acc
|
| 135 |
+
save_checkpoint(model, labels, output_dir)
|
| 136 |
+
print(f"saved best model to {output_dir} with eval_acc={best_acc:.4f}")
|
| 137 |
+
|
| 138 |
+
print(f"done. best_eval_acc={best_acc:.4f}")
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
if __name__ == "__main__":
|
| 142 |
+
main()
|
train_nutrients_from_scratch.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from datasets import load_dataset
|
| 6 |
+
from torch import nn
|
| 7 |
+
from torch.utils.data import DataLoader
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
from calority_nutrition_model import (
|
| 11 |
+
TARGET_COLUMNS,
|
| 12 |
+
CalorityNutritionCNN,
|
| 13 |
+
save_nutrition_checkpoint,
|
| 14 |
+
)
|
| 15 |
+
from calority_scratch_model import image_to_tensor
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def parse_args():
|
| 19 |
+
parser = argparse.ArgumentParser(
|
| 20 |
+
description="Train Calority's calorie and macro predictor from scratch on mmathys/food-nutrients."
|
| 21 |
+
)
|
| 22 |
+
parser.add_argument("--dataset", default="mmathys/food-nutrients")
|
| 23 |
+
parser.add_argument("--source-split", default="test", help="This dataset currently ships with only a test split.")
|
| 24 |
+
parser.add_argument("--image-column", default="image")
|
| 25 |
+
parser.add_argument("--output-dir", default="./calority-nutrition-model")
|
| 26 |
+
parser.add_argument("--epochs", type=int, default=40)
|
| 27 |
+
parser.add_argument("--batch-size", type=int, default=16)
|
| 28 |
+
parser.add_argument("--learning-rate", type=float, default=3e-4)
|
| 29 |
+
parser.add_argument("--validation-size", type=float, default=0.15)
|
| 30 |
+
parser.add_argument("--num-workers", type=int, default=0)
|
| 31 |
+
parser.add_argument("--limit", type=int, default=0, help="Optional small limit for quick smoke tests")
|
| 32 |
+
return parser.parse_args()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def make_targets(dataset_split) -> torch.Tensor:
|
| 36 |
+
rows = [[float(item[column]) for column in TARGET_COLUMNS] for item in dataset_split]
|
| 37 |
+
return torch.tensor(rows, dtype=torch.float32)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def make_collate_fn(image_column: str, target_mean: torch.Tensor, target_std: torch.Tensor):
|
| 41 |
+
def collate(batch):
|
| 42 |
+
images = torch.stack([image_to_tensor(item[image_column]) for item in batch])
|
| 43 |
+
targets = torch.tensor(
|
| 44 |
+
[[float(item[column]) for column in TARGET_COLUMNS] for item in batch],
|
| 45 |
+
dtype=torch.float32,
|
| 46 |
+
)
|
| 47 |
+
normalized_targets = (targets - target_mean) / target_std
|
| 48 |
+
return images, normalized_targets, targets
|
| 49 |
+
|
| 50 |
+
return collate
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def evaluate(model, loader, loss_fn, target_mean, target_std, device):
|
| 54 |
+
model.eval()
|
| 55 |
+
total_loss = 0.0
|
| 56 |
+
total_mae = torch.zeros(len(TARGET_COLUMNS))
|
| 57 |
+
total_seen = 0
|
| 58 |
+
|
| 59 |
+
with torch.no_grad():
|
| 60 |
+
for images, normalized_targets, raw_targets in loader:
|
| 61 |
+
images = images.to(device)
|
| 62 |
+
normalized_targets = normalized_targets.to(device)
|
| 63 |
+
predictions = model(images)
|
| 64 |
+
loss = loss_fn(predictions, normalized_targets)
|
| 65 |
+
|
| 66 |
+
raw_predictions = torch.clamp(
|
| 67 |
+
(predictions.cpu() * target_std) + target_mean,
|
| 68 |
+
min=0,
|
| 69 |
+
)
|
| 70 |
+
total_loss += loss.item() * images.size(0)
|
| 71 |
+
total_mae += torch.abs(raw_predictions - raw_targets).sum(dim=0)
|
| 72 |
+
total_seen += images.size(0)
|
| 73 |
+
|
| 74 |
+
mae = total_mae / max(total_seen, 1)
|
| 75 |
+
return total_loss / max(total_seen, 1), mae
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def main():
|
| 79 |
+
args = parse_args()
|
| 80 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 81 |
+
dataset = load_dataset(args.dataset)
|
| 82 |
+
source = dataset[args.source_split].shuffle(seed=42)
|
| 83 |
+
|
| 84 |
+
if args.limit:
|
| 85 |
+
source = source.select(range(min(args.limit, len(source))))
|
| 86 |
+
|
| 87 |
+
split = source.train_test_split(test_size=args.validation_size, seed=42)
|
| 88 |
+
train_ds = split["train"]
|
| 89 |
+
eval_ds = split["test"]
|
| 90 |
+
|
| 91 |
+
train_targets = make_targets(train_ds)
|
| 92 |
+
target_mean = train_targets.mean(dim=0)
|
| 93 |
+
target_std = torch.clamp(train_targets.std(dim=0), min=1.0)
|
| 94 |
+
|
| 95 |
+
model = CalorityNutritionCNN(output_size=len(TARGET_COLUMNS)).to(device)
|
| 96 |
+
collate_fn = make_collate_fn(args.image_column, target_mean, target_std)
|
| 97 |
+
train_loader = DataLoader(
|
| 98 |
+
train_ds,
|
| 99 |
+
batch_size=args.batch_size,
|
| 100 |
+
shuffle=True,
|
| 101 |
+
num_workers=args.num_workers,
|
| 102 |
+
collate_fn=collate_fn,
|
| 103 |
+
)
|
| 104 |
+
eval_loader = DataLoader(
|
| 105 |
+
eval_ds,
|
| 106 |
+
batch_size=args.batch_size,
|
| 107 |
+
shuffle=False,
|
| 108 |
+
num_workers=args.num_workers,
|
| 109 |
+
collate_fn=collate_fn,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
loss_fn = nn.SmoothL1Loss()
|
| 113 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=1e-4)
|
| 114 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
|
| 115 |
+
output_dir = Path(args.output_dir)
|
| 116 |
+
best_calorie_mae = float("inf")
|
| 117 |
+
|
| 118 |
+
for epoch in range(1, args.epochs + 1):
|
| 119 |
+
model.train()
|
| 120 |
+
running_loss = 0.0
|
| 121 |
+
total_seen = 0
|
| 122 |
+
|
| 123 |
+
progress = tqdm(train_loader, desc=f"epoch {epoch}/{args.epochs}", leave=False)
|
| 124 |
+
for images, normalized_targets, _ in progress:
|
| 125 |
+
images = images.to(device)
|
| 126 |
+
normalized_targets = normalized_targets.to(device)
|
| 127 |
+
|
| 128 |
+
optimizer.zero_grad(set_to_none=True)
|
| 129 |
+
predictions = model(images)
|
| 130 |
+
loss = loss_fn(predictions, normalized_targets)
|
| 131 |
+
loss.backward()
|
| 132 |
+
optimizer.step()
|
| 133 |
+
|
| 134 |
+
running_loss += loss.item() * images.size(0)
|
| 135 |
+
total_seen += images.size(0)
|
| 136 |
+
progress.set_postfix(loss=round(running_loss / max(total_seen, 1), 4))
|
| 137 |
+
|
| 138 |
+
scheduler.step()
|
| 139 |
+
eval_loss, mae = evaluate(model, eval_loader, loss_fn, target_mean, target_std, device)
|
| 140 |
+
metric_line = ", ".join(
|
| 141 |
+
f"{column}_mae={mae[index]:.2f}" for index, column in enumerate(TARGET_COLUMNS)
|
| 142 |
+
)
|
| 143 |
+
print(f"epoch={epoch} eval_loss={eval_loss:.4f} {metric_line}")
|
| 144 |
+
|
| 145 |
+
if mae[0].item() <= best_calorie_mae:
|
| 146 |
+
best_calorie_mae = mae[0].item()
|
| 147 |
+
save_nutrition_checkpoint(model, target_mean, target_std, output_dir)
|
| 148 |
+
print(f"saved best nutrition model to {output_dir} with calorie_mae={best_calorie_mae:.2f}")
|
| 149 |
+
|
| 150 |
+
print(f"done. best_calorie_mae={best_calorie_mae:.2f}")
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
if __name__ == "__main__":
|
| 154 |
+
main()
|
upload_model_to_hf.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from huggingface_hub import HfApi, create_repo, upload_folder
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def parse_args():
|
| 9 |
+
parser = argparse.ArgumentParser(description="Upload a trained Calority checkpoint to Hugging Face Hub.")
|
| 10 |
+
parser.add_argument("--model-dir", default="./calority-nutrition-model")
|
| 11 |
+
parser.add_argument("--repo-id", required=True, help="Example: your-username/calority-nutrition-model")
|
| 12 |
+
parser.add_argument("--private", action="store_true")
|
| 13 |
+
return parser.parse_args()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def main():
|
| 17 |
+
args = parse_args()
|
| 18 |
+
model_dir = Path(args.model_dir)
|
| 19 |
+
if not (model_dir / "model.pt").exists():
|
| 20 |
+
raise SystemExit(f"Missing checkpoint: {model_dir / 'model.pt'}")
|
| 21 |
+
if not (model_dir / "target_stats.json").exists():
|
| 22 |
+
raise SystemExit(f"Missing target stats: {model_dir / 'target_stats.json'}")
|
| 23 |
+
|
| 24 |
+
token = os.getenv("HF_TOKEN")
|
| 25 |
+
api = HfApi(token=token)
|
| 26 |
+
create_repo(args.repo_id, repo_type="model", private=args.private, exist_ok=True, token=token)
|
| 27 |
+
upload_folder(
|
| 28 |
+
repo_id=args.repo_id,
|
| 29 |
+
repo_type="model",
|
| 30 |
+
folder_path=str(model_dir),
|
| 31 |
+
path_in_repo=".",
|
| 32 |
+
commit_message="Upload Calority nutrition model checkpoint",
|
| 33 |
+
token=token,
|
| 34 |
+
)
|
| 35 |
+
info = api.model_info(args.repo_id, token=token)
|
| 36 |
+
print(f"Uploaded checkpoint to https://huggingface.co/{info.modelId}")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
if __name__ == "__main__":
|
| 40 |
+
main()
|
upload_space_to_hf.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from huggingface_hub import create_repo, upload_folder
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def parse_args():
|
| 9 |
+
parser = argparse.ArgumentParser(description="Upload the Calority model API to a Hugging Face Docker Space.")
|
| 10 |
+
parser.add_argument("--space-id", required=True, help="Example: your-username/calority-model-api")
|
| 11 |
+
parser.add_argument("--private", action="store_true")
|
| 12 |
+
return parser.parse_args()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def main():
|
| 16 |
+
args = parse_args()
|
| 17 |
+
token = os.getenv("HF_TOKEN")
|
| 18 |
+
create_repo(
|
| 19 |
+
args.space_id,
|
| 20 |
+
repo_type="space",
|
| 21 |
+
space_sdk="docker",
|
| 22 |
+
private=args.private,
|
| 23 |
+
exist_ok=True,
|
| 24 |
+
token=token,
|
| 25 |
+
)
|
| 26 |
+
upload_folder(
|
| 27 |
+
repo_id=args.space_id,
|
| 28 |
+
repo_type="space",
|
| 29 |
+
folder_path=str(Path(__file__).parent),
|
| 30 |
+
path_in_repo=".",
|
| 31 |
+
commit_message="Deploy Calority model API",
|
| 32 |
+
ignore_patterns=[
|
| 33 |
+
".env",
|
| 34 |
+
"__pycache__/*",
|
| 35 |
+
"*.pyc",
|
| 36 |
+
"calority-model-api/*",
|
| 37 |
+
"calority-nutrition-model/*",
|
| 38 |
+
"calority-scratch-model/*",
|
| 39 |
+
"calority-food-model/*",
|
| 40 |
+
],
|
| 41 |
+
token=token,
|
| 42 |
+
)
|
| 43 |
+
print(f"Uploaded Space files to https://huggingface.co/spaces/{args.space_id}")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
if __name__ == "__main__":
|
| 47 |
+
main()
|