Sri Lankan Cooked Vegetable Recognition Model
A custom Prototypical Network trained for few-shot recognition of Sri Lankan vegetables across cooking transformation states. The model identifies the same vegetable in its raw form and after traditional Sri Lankan cooking methods — red curry, white curry, and tempering — where colour, texture, and shape change dramatically.
This is the model behind the FoodShot Android application, with on-device TFLite inference under 100 ms.
Model Details
| Property | Value |
|---|---|
| Architecture | Custom Prototypical Network |
| Embedding dimension | 128 |
| Input size | 224 × 224 RGB |
| Training regime | 4-way 2-shot episodic |
| Number of classes | 8 |
| Training images per class | 27 – 33 |
| Framework | PyTorch → TFLite (mobile) |
| TFLite model size | 492.9 KB |
Performance
| Metric | Value |
|---|---|
| Validation accuracy | 90.25% |
| Few-shot test accuracy | 87.75% |
| Full test accuracy | 84.91% |
| Mean per-class accuracy | 84.15% |
| Random Forest baseline | 70.83% |
The model outperforms all classical baselines (k-NN, SVM with colour histogram and HOG features, Random Forest) on the same test set.
Classes
The model recognises 8 vegetable–state combinations across 3 vegetables:
| # | Class | Vegetable | Cooking state |
|---|---|---|---|
| 1 | carrot_raw |
Carrot | Raw |
| 2 | carrot_white_curry |
Carrot | White curry (coconut milk) |
| 3 | greenbeans_raw |
Green beans | Raw |
| 4 | greenbeans_tempered |
Green beans | Tempered (high-heat stir-fry) |
| 5 | greenbeans_white_curry |
Green beans | White curry (coconut milk) |
| 6 | pumpkin_raw |
Pumpkin | Raw |
| 7 | pumpkin_red_curry |
Pumpkin | Red curry (turmeric-based) |
| 8 | pumpkin_white_curry |
Pumpkin | White curry (coconut milk) |
Files
| File | Description |
|---|---|
best_model.pth |
PyTorch checkpoint (model weights + metadata) |
model.tflite |
Quantised TFLite model for mobile/on-device inference |
prototypes.json |
Per-class prototype embeddings for nearest-prototype classification |
Python Usage
Installation
pip install torch torchvision pillow
Load and run inference
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import json
import numpy as np
# --- Define model architecture ---
class PrototypicalNetwork(nn.Module):
def __init__(self, embedding_dim=128):
super().__init__()
self.encoder = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(),
nn.AdaptiveAvgPool2d((4, 4)),
)
self.fc = nn.Linear(256 * 4 * 4, embedding_dim)
def forward(self, x):
x = self.encoder(x)
x = x.view(x.size(0), -1)
return self.fc(x)
# --- Load model ---
checkpoint = torch.load('best_model.pth', map_location='cpu')
model = PrototypicalNetwork(embedding_dim=128)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# --- Load class prototypes ---
with open('prototypes.json', 'r') as f:
prototypes = json.load(f)
class_names = list(prototypes.keys())
prototype_tensors = torch.tensor([prototypes[c] for c in class_names])
# --- Preprocessing ---
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# --- Predict ---
def predict(image_path):
img = Image.open(image_path).convert('RGB')
img_tensor = transform(img).unsqueeze(0)
with torch.no_grad():
embedding = model(img_tensor)
# Nearest prototype (Euclidean distance)
distances = torch.cdist(embedding, prototype_tensors)
predicted_idx = distances.argmin().item()
predicted_class = class_names[predicted_idx]
confidence = torch.softmax(-distances, dim=1)[0][predicted_idx].item()
return predicted_class, confidence
label, conf = predict('your_image.jpg')
print(f"Predicted: {label} (confidence: {conf:.2%})")
Mobile Usage (TFLite)
Use model.tflite with the Flutter-based FoodShot Android application. Class prototypes are stored in prototypes.json for on-device nearest-prototype classification. Inference runs in under 100 ms on a standard Android device.
// Flutter / TFLite example
final interpreter = await Interpreter.fromAsset('model.tflite');
// Pre-process image to [1, 224, 224, 3] float32
// Run interpreter
// Compare output embedding against prototypes.json using Euclidean distance
Extending to New Vegetables
This model is designed for extensibility. You can add a new vegetable class with 20–30 images without retraining the full network, using the companion Python library:
pip install srilankan-food-trainer
from srilankan_food_trainer import extend_model
extend_model(
model_path='best_model.pth',
new_class_name='beetroot_white_curry',
image_folder='./beetroot_images/', # 20–30 images minimum
output_path='extended_model.pth'
)
Dataset
This model was trained on the Sri Lankan Cooked Vegetable Dataset — 331 images across 8 vegetable–state combinations, collected via controlled cooking sessions (40%) and web scraping (60%).
- Dataset: ranasinghehashini/srilankan-food-recognition
- Train / val / test split: 70% / 15% / 15%
Citation
If you use this model or dataset in your research, please cite:
@misc{ranasinghe2026foodshot,
author = {Ranasinghe, G. H. C.},
title = {A Transformation-Aware Deep Learning Approach for Recognizing
Cooked Vegetable Ingredients in Sri Lankan Cuisine},
year = {2026},
note = {BSc Hons in Computing, Coventry University / NIBM},
}
License
MIT License — free to use, modify, and distribute with attribution.