""" Walnut Rancidity Predictor — Inference Script Usage: from model.predict import predict_storage_risk result = predict_storage_risk(sequence) """ import sys, os from pathlib import Path import numpy as np import torch import torch.nn as nn import joblib # ── Adjust sys.path so this works when called from repo root ────────────────── ROOT = Path(__file__).resolve().parent.parent sys.path.insert(0, str(ROOT)) MODEL_PATH = ROOT / "models" / "walnut_rancidity_lstm_attention.pt" SCALER_PATH = ROOT / "models" / "feature_scaler.pkl" FEATURE_COLS = [ "temperature", "humidity", "moisture", "oxygen", "peroxide_value", "free_fatty_acids", "hexanal_level", "oxidation_index", ] SEQ_LEN = 30 # ── Model (must match train.py) ─────────────────────────────────────────────── class Attention(nn.Module): def __init__(self, hidden_size: int): super().__init__() self.attn = nn.Linear(hidden_size, 1) def forward(self, lstm_out: torch.Tensor) -> torch.Tensor: scores = self.attn(lstm_out).squeeze(-1) weights = torch.softmax(scores, dim=-1) context = (weights.unsqueeze(-1) * lstm_out).sum(dim=1) return context class WalnutLSTMAttention(nn.Module): def __init__(self, n_features: int, hidden: int, n_layers: int, dropout: float): super().__init__() self.lstm = nn.LSTM( input_size=n_features, hidden_size=hidden, num_layers=n_layers, dropout=dropout if n_layers > 1 else 0.0, batch_first=True, ) self.attn = Attention(hidden) self.dropout = nn.Dropout(dropout) self.head_rancidity = nn.Sequential( nn.Linear(hidden, 32), nn.ReLU(), nn.Linear(32, 1), nn.Sigmoid(), ) self.head_shelf_life = nn.Sequential( nn.Linear(hidden, 32), nn.ReLU(), nn.Linear(32, 1), ) self.head_decay = nn.Sequential( nn.Linear(hidden, 32), nn.ReLU(), nn.Linear(32, 1), nn.Sigmoid(), ) def forward(self, x: torch.Tensor): lstm_out, _ = self.lstm(x) context = self.attn(lstm_out) context = self.dropout(context) rp = self.head_rancidity(context).squeeze(-1) sl = self.head_shelf_life(context).squeeze(-1) dc = self.head_decay(context).squeeze(-1) return rp, sl, dc # ── Lazy-loaded globals ─────────────────────────────────────────────────────── _model = None _scaler = None def _load_artifacts(): global _model, _scaler if _model is not None: return ckpt = torch.load(MODEL_PATH, map_location="cpu") cfg = ckpt["config"] _model = WalnutLSTMAttention( n_features=cfg["n_features"], hidden=cfg["hidden"], n_layers=cfg["n_layers"], dropout=cfg["dropout"], ) _model.load_state_dict(ckpt["model_state"]) _model.eval() _scaler = joblib.load(SCALER_PATH) # ── Public API ───────────────────────────────────────────────────────────────── def predict_storage_risk(sequence: list | np.ndarray) -> dict: """ Predict walnut storage risk from a time-series sequence. Parameters ---------- sequence : array-like of shape (SEQ_LEN, 8) or (N, 8) Each row contains the 8 features in order: [temperature, humidity, moisture, oxygen, peroxide_value, free_fatty_acids, hexanal_level, oxidation_index] If more than SEQ_LEN rows are provided, the last SEQ_LEN rows are used. If fewer rows are provided, the sequence is zero-padded at the front. Returns ------- dict with keys: rancidity_probability : float [0, 1] shelf_life_remaining_days : float (days) risk_level : "LOW" | "MEDIUM" | "HIGH" """ _load_artifacts() seq = np.array(sequence, dtype=np.float32) if seq.ndim == 1: seq = seq.reshape(1, -1) # Pad or truncate to SEQ_LEN if len(seq) > SEQ_LEN: seq = seq[-SEQ_LEN:] elif len(seq) < SEQ_LEN: pad = np.zeros((SEQ_LEN - len(seq), seq.shape[1]), dtype=np.float32) seq = np.vstack([pad, seq]) # Scale seq_scaled = _scaler.transform(seq) # (SEQ_LEN, 8) x = torch.tensor(seq_scaled[np.newaxis], dtype=torch.float32) # (1, SEQ_LEN, 8) with torch.no_grad(): rp_pred, sl_pred, dc_pred = _model(x) rancidity_prob = float(rp_pred.item()) shelf_life = float(sl_pred.item()) * 180.0 # denormalise if rancidity_prob < 0.3: risk_level = "LOW" elif rancidity_prob <= 0.7: risk_level = "MEDIUM" else: risk_level = "HIGH" return { "rancidity_probability": round(rancidity_prob, 4), "shelf_life_remaining_days": round(max(shelf_life, 0.0), 2), "risk_level": risk_level, } # ── CLI demo ────────────────────────────────────────────────────────────────── if __name__ == "__main__": print("Running demo inference …") # Cold-storage scenario (low risk) cold_seq = np.column_stack([ np.full(30, 5.0), # temperature np.full(30, 50.0), # humidity np.full(30, 4.0), # moisture np.full(30, 0.20), # oxygen np.linspace(0.5, 1.2, 30), # peroxide_value np.linspace(0.05, 0.10, 30), # free_fatty_acids np.linspace(0.1, 0.3, 30), # hexanal_level np.linspace(0.2, 0.5, 30), # oxidation_index ]) result = predict_storage_risk(cold_seq) print(f"Cold storage → {result}") # Hot transport scenario (high risk) hot_seq = np.column_stack([ np.full(30, 38.0), np.full(30, 80.0), np.full(30, 7.5), np.full(30, 0.22), np.linspace(2.0, 8.0, 30), np.linspace(0.2, 0.6, 30), np.linspace(0.8, 2.5, 30), np.linspace(1.0, 3.5, 30), ]) result = predict_storage_risk(hot_seq) print(f"Hot transport → {result}")