Arko007's picture
Upload model/predict.py with huggingface_hub
8beb241 verified
"""
Walnut Rancidity Predictor β€” Inference Script
Usage:
from model.predict import predict_storage_risk
result = predict_storage_risk(sequence)
"""
import sys, os
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import joblib
# ── Adjust sys.path so this works when called from repo root ──────────────────
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT))
MODEL_PATH = ROOT / "models" / "walnut_rancidity_lstm_attention.pt"
SCALER_PATH = ROOT / "models" / "feature_scaler.pkl"
FEATURE_COLS = [
"temperature", "humidity", "moisture", "oxygen",
"peroxide_value", "free_fatty_acids", "hexanal_level", "oxidation_index",
]
SEQ_LEN = 30
# ── Model (must match train.py) ───────────────────────────────────────────────
class Attention(nn.Module):
def __init__(self, hidden_size: int):
super().__init__()
self.attn = nn.Linear(hidden_size, 1)
def forward(self, lstm_out: torch.Tensor) -> torch.Tensor:
scores = self.attn(lstm_out).squeeze(-1)
weights = torch.softmax(scores, dim=-1)
context = (weights.unsqueeze(-1) * lstm_out).sum(dim=1)
return context
class WalnutLSTMAttention(nn.Module):
def __init__(self, n_features: int, hidden: int, n_layers: int, dropout: float):
super().__init__()
self.lstm = nn.LSTM(
input_size=n_features,
hidden_size=hidden,
num_layers=n_layers,
dropout=dropout if n_layers > 1 else 0.0,
batch_first=True,
)
self.attn = Attention(hidden)
self.dropout = nn.Dropout(dropout)
self.head_rancidity = nn.Sequential(
nn.Linear(hidden, 32), nn.ReLU(),
nn.Linear(32, 1), nn.Sigmoid(),
)
self.head_shelf_life = nn.Sequential(
nn.Linear(hidden, 32), nn.ReLU(),
nn.Linear(32, 1),
)
self.head_decay = nn.Sequential(
nn.Linear(hidden, 32), nn.ReLU(),
nn.Linear(32, 1), nn.Sigmoid(),
)
def forward(self, x: torch.Tensor):
lstm_out, _ = self.lstm(x)
context = self.attn(lstm_out)
context = self.dropout(context)
rp = self.head_rancidity(context).squeeze(-1)
sl = self.head_shelf_life(context).squeeze(-1)
dc = self.head_decay(context).squeeze(-1)
return rp, sl, dc
# ── Lazy-loaded globals ───────────────────────────────────────────────────────
_model = None
_scaler = None
def _load_artifacts():
global _model, _scaler
if _model is not None:
return
ckpt = torch.load(MODEL_PATH, map_location="cpu")
cfg = ckpt["config"]
_model = WalnutLSTMAttention(
n_features=cfg["n_features"],
hidden=cfg["hidden"],
n_layers=cfg["n_layers"],
dropout=cfg["dropout"],
)
_model.load_state_dict(ckpt["model_state"])
_model.eval()
_scaler = joblib.load(SCALER_PATH)
# ── Public API ─────────────────────────────────────────────────────────────────
def predict_storage_risk(sequence: list | np.ndarray) -> dict:
"""
Predict walnut storage risk from a time-series sequence.
Parameters
----------
sequence : array-like of shape (SEQ_LEN, 8) or (N, 8)
Each row contains the 8 features in order:
[temperature, humidity, moisture, oxygen,
peroxide_value, free_fatty_acids, hexanal_level, oxidation_index]
If more than SEQ_LEN rows are provided, the last SEQ_LEN rows are used.
If fewer rows are provided, the sequence is zero-padded at the front.
Returns
-------
dict with keys:
rancidity_probability : float [0, 1]
shelf_life_remaining_days : float (days)
risk_level : "LOW" | "MEDIUM" | "HIGH"
"""
_load_artifacts()
seq = np.array(sequence, dtype=np.float32)
if seq.ndim == 1:
seq = seq.reshape(1, -1)
# Pad or truncate to SEQ_LEN
if len(seq) > SEQ_LEN:
seq = seq[-SEQ_LEN:]
elif len(seq) < SEQ_LEN:
pad = np.zeros((SEQ_LEN - len(seq), seq.shape[1]), dtype=np.float32)
seq = np.vstack([pad, seq])
# Scale
seq_scaled = _scaler.transform(seq) # (SEQ_LEN, 8)
x = torch.tensor(seq_scaled[np.newaxis], dtype=torch.float32) # (1, SEQ_LEN, 8)
with torch.no_grad():
rp_pred, sl_pred, dc_pred = _model(x)
rancidity_prob = float(rp_pred.item())
shelf_life = float(sl_pred.item()) * 180.0 # denormalise
if rancidity_prob < 0.3:
risk_level = "LOW"
elif rancidity_prob <= 0.7:
risk_level = "MEDIUM"
else:
risk_level = "HIGH"
return {
"rancidity_probability": round(rancidity_prob, 4),
"shelf_life_remaining_days": round(max(shelf_life, 0.0), 2),
"risk_level": risk_level,
}
# ── CLI demo ──────────────────────────────────────────────────────────────────
if __name__ == "__main__":
print("Running demo inference …")
# Cold-storage scenario (low risk)
cold_seq = np.column_stack([
np.full(30, 5.0), # temperature
np.full(30, 50.0), # humidity
np.full(30, 4.0), # moisture
np.full(30, 0.20), # oxygen
np.linspace(0.5, 1.2, 30), # peroxide_value
np.linspace(0.05, 0.10, 30), # free_fatty_acids
np.linspace(0.1, 0.3, 30), # hexanal_level
np.linspace(0.2, 0.5, 30), # oxidation_index
])
result = predict_storage_risk(cold_seq)
print(f"Cold storage β†’ {result}")
# Hot transport scenario (high risk)
hot_seq = np.column_stack([
np.full(30, 38.0),
np.full(30, 80.0),
np.full(30, 7.5),
np.full(30, 0.22),
np.linspace(2.0, 8.0, 30),
np.linspace(0.2, 0.6, 30),
np.linspace(0.8, 2.5, 30),
np.linspace(1.0, 3.5, 30),
])
result = predict_storage_risk(hot_seq)
print(f"Hot transport β†’ {result}")