| """ | |
| Inference helper for T2.6 Regional Thali Identifier. | |
| Loads model + metadata, predicts dish from a PIL image. | |
| """ | |
| import torch | |
| import timm | |
| from PIL import Image | |
| from .data_utils import ( | |
| load_metadata, load_nutrition, get_class_mappings, | |
| get_inference_transform, MODEL_DIR, NUM_CLASSES, | |
| ) | |
| def load_model(model_path=None, device=None): | |
| """Load the best EfficientNet-B0 model.""" | |
| if device is None: | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if model_path is None: | |
| model_path = MODEL_DIR / "efficientnet_b0_best.pth" | |
| model = timm.create_model("efficientnet_b0", pretrained=False, num_classes=NUM_CLASSES) | |
| state_dict = torch.load(str(model_path), map_location=device, weights_only=True) | |
| model.load_state_dict(state_dict) | |
| model = model.to(device).eval() | |
| return model, device | |
| def predict_image(image: Image.Image, model, device, top_k=3): | |
| """ | |
| Predict dish from a PIL image. | |
| Returns list of top_k dicts: {dish, display_name, region, confidence, nutrition} | |
| """ | |
| df = load_metadata() | |
| nutrition = load_nutrition() | |
| mappings = get_class_mappings(df) | |
| transform = get_inference_transform() | |
| img_tensor = transform(image.convert("RGB")).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| logits = model(img_tensor) | |
| probs = torch.softmax(logits, dim=1).squeeze(0) | |
| top_probs, top_indices = probs.topk(top_k) | |
| results = [] | |
| for prob, idx in zip(top_probs.tolist(), top_indices.tolist()): | |
| folder_name = mappings["idx_to_class"][idx] | |
| region = mappings["dish_to_region"][folder_name] | |
| info = nutrition.get(folder_name, {}) | |
| row = df[df["folder_name"] == folder_name].iloc[0] | |
| results.append({ | |
| "folder_name": folder_name, | |
| "display_name": info.get("display_name", row["display_name"]), | |
| "confidence": prob, | |
| "region": region, | |
| "state": info.get("state", row["state"]), | |
| "diet": info.get("diet", row["diet"]), | |
| "course": info.get("course", row["course"]), | |
| "flavor_profile": info.get("flavor_profile", row["flavor_profile"]), | |
| "ingredients": info.get("ingredients", row["ingredients"]), | |
| "prep_time": info.get("prep_time_min", str(row["prep_time"])), | |
| "cook_time": info.get("cook_time_min", str(row["cook_time"])), | |
| "allergens": info.get("allergens", {}), | |
| "calories": info.get("approx_calories_kcal", "N/A"), | |
| }) | |
| return results | |