Sri Lankan Cooked Vegetable Recognition Model

A custom Prototypical Network trained for few-shot recognition of Sri Lankan vegetables across cooking transformation states. The model identifies the same vegetable in its raw form and after traditional Sri Lankan cooking methods — red curry, white curry, and tempering — where colour, texture, and shape change dramatically.

This is the model behind the FoodShot Android application, with on-device TFLite inference under 100 ms.


Model Details

Property Value
Architecture Custom Prototypical Network
Embedding dimension 128
Input size 224 × 224 RGB
Training regime 4-way 2-shot episodic
Number of classes 8
Training images per class 27 – 33
Framework PyTorch → TFLite (mobile)
TFLite model size 492.9 KB

Performance

Metric Value
Validation accuracy 90.25%
Few-shot test accuracy 87.75%
Full test accuracy 84.91%
Mean per-class accuracy 84.15%
Random Forest baseline 70.83%

The model outperforms all classical baselines (k-NN, SVM with colour histogram and HOG features, Random Forest) on the same test set.


Classes

The model recognises 8 vegetable–state combinations across 3 vegetables:

# Class Vegetable Cooking state
1 carrot_raw Carrot Raw
2 carrot_white_curry Carrot White curry (coconut milk)
3 greenbeans_raw Green beans Raw
4 greenbeans_tempered Green beans Tempered (high-heat stir-fry)
5 greenbeans_white_curry Green beans White curry (coconut milk)
6 pumpkin_raw Pumpkin Raw
7 pumpkin_red_curry Pumpkin Red curry (turmeric-based)
8 pumpkin_white_curry Pumpkin White curry (coconut milk)

Files

File Description
best_model.pth PyTorch checkpoint (model weights + metadata)
model.tflite Quantised TFLite model for mobile/on-device inference
prototypes.json Per-class prototype embeddings for nearest-prototype classification

Python Usage

Installation

pip install torch torchvision pillow

Load and run inference

import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import json
import numpy as np

# --- Define model architecture ---
class PrototypicalNetwork(nn.Module):
    def __init__(self, embedding_dim=128):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(),
            nn.AdaptiveAvgPool2d((4, 4)),
        )
        self.fc = nn.Linear(256 * 4 * 4, embedding_dim)

    def forward(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

# --- Load model ---
checkpoint = torch.load('best_model.pth', map_location='cpu')
model = PrototypicalNetwork(embedding_dim=128)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# --- Load class prototypes ---
with open('prototypes.json', 'r') as f:
    prototypes = json.load(f)

class_names = list(prototypes.keys())
prototype_tensors = torch.tensor([prototypes[c] for c in class_names])

# --- Preprocessing ---
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# --- Predict ---
def predict(image_path):
    img = Image.open(image_path).convert('RGB')
    img_tensor = transform(img).unsqueeze(0)

    with torch.no_grad():
        embedding = model(img_tensor)

    # Nearest prototype (Euclidean distance)
    distances = torch.cdist(embedding, prototype_tensors)
    predicted_idx = distances.argmin().item()
    predicted_class = class_names[predicted_idx]
    confidence = torch.softmax(-distances, dim=1)[0][predicted_idx].item()

    return predicted_class, confidence

label, conf = predict('your_image.jpg')
print(f"Predicted: {label}  (confidence: {conf:.2%})")

Mobile Usage (TFLite)

Use model.tflite with the Flutter-based FoodShot Android application. Class prototypes are stored in prototypes.json for on-device nearest-prototype classification. Inference runs in under 100 ms on a standard Android device.

// Flutter / TFLite example
final interpreter = await Interpreter.fromAsset('model.tflite');
// Pre-process image to [1, 224, 224, 3] float32
// Run interpreter
// Compare output embedding against prototypes.json using Euclidean distance

Extending to New Vegetables

This model is designed for extensibility. You can add a new vegetable class with 20–30 images without retraining the full network, using the companion Python library:

pip install srilankan-food-trainer
from srilankan_food_trainer import extend_model

extend_model(
    model_path='best_model.pth',
    new_class_name='beetroot_white_curry',
    image_folder='./beetroot_images/',   # 20–30 images minimum
    output_path='extended_model.pth'
)

Dataset

This model was trained on the Sri Lankan Cooked Vegetable Dataset — 331 images across 8 vegetable–state combinations, collected via controlled cooking sessions (40%) and web scraping (60%).


Citation

If you use this model or dataset in your research, please cite:

@misc{ranasinghe2026foodshot,
  author    = {Ranasinghe, G. H. C.},
  title     = {A Transformation-Aware Deep Learning Approach for Recognizing
               Cooked Vegetable Ingredients in Sri Lankan Cuisine},
  year      = {2026},
  note      = {BSc Hons in Computing, Coventry University / NIBM},
}

License

MIT License — free to use, modify, and distribute with attribution.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support