SMAI_A3 / utils /predict.py
Caerus256's picture
Upload 16 files
3156443 verified
"""
Inference helper for T2.6 Regional Thali Identifier.
Loads model + metadata, predicts dish from a PIL image.
"""
import torch
import timm
from PIL import Image
from .data_utils import (
load_metadata, load_nutrition, get_class_mappings,
get_inference_transform, MODEL_DIR, NUM_CLASSES,
)
def load_model(model_path=None, device=None):
"""Load the best EfficientNet-B0 model."""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if model_path is None:
model_path = MODEL_DIR / "efficientnet_b0_best.pth"
model = timm.create_model("efficientnet_b0", pretrained=False, num_classes=NUM_CLASSES)
state_dict = torch.load(str(model_path), map_location=device, weights_only=True)
model.load_state_dict(state_dict)
model = model.to(device).eval()
return model, device
def predict_image(image: Image.Image, model, device, top_k=3):
"""
Predict dish from a PIL image.
Returns list of top_k dicts: {dish, display_name, region, confidence, nutrition}
"""
df = load_metadata()
nutrition = load_nutrition()
mappings = get_class_mappings(df)
transform = get_inference_transform()
img_tensor = transform(image.convert("RGB")).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(img_tensor)
probs = torch.softmax(logits, dim=1).squeeze(0)
top_probs, top_indices = probs.topk(top_k)
results = []
for prob, idx in zip(top_probs.tolist(), top_indices.tolist()):
folder_name = mappings["idx_to_class"][idx]
region = mappings["dish_to_region"][folder_name]
info = nutrition.get(folder_name, {})
row = df[df["folder_name"] == folder_name].iloc[0]
results.append({
"folder_name": folder_name,
"display_name": info.get("display_name", row["display_name"]),
"confidence": prob,
"region": region,
"state": info.get("state", row["state"]),
"diet": info.get("diet", row["diet"]),
"course": info.get("course", row["course"]),
"flavor_profile": info.get("flavor_profile", row["flavor_profile"]),
"ingredients": info.get("ingredients", row["ingredients"]),
"prep_time": info.get("prep_time_min", str(row["prep_time"])),
"cook_time": info.get("cook_time_min", str(row["cook_time"])),
"allergens": info.get("allergens", {}),
"calories": info.get("approx_calories_kcal", "N/A"),
})
return results