LSTM-5win-Keystrokes / inference.py
NourFakih's picture
Upload LSTM window=5 artifacts
711c28e verified
import json
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import joblib
from typing import Optional, Dict, Any
from huggingface_hub import hf_hub_download
class LSTMClassifier(nn.Module):
def __init__(self, input_size: int, hidden_size: int = 64, num_layers: int = 1, dropout: float = 0.0):
super().__init__()
self.lstm = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
dropout=dropout if num_layers > 1 else 0.0,
bidirectional=False
)
self.head = nn.Linear(hidden_size, 2)
def forward(self, x):
_, (h_n, _) = self.lstm(x)
last_h = h_n[-1]
return self.head(last_h)
def load_model_and_scaler(repo_id: str, revision: Optional[str] = None, device: Optional[str] = None):
cfg_path = hf_hub_download(repo_id, "config.json", revision=revision)
scaler_path = hf_hub_download(repo_id, "scaler.joblib", revision=revision)
with open(cfg_path, "r", encoding="utf-8") as f:
cfg = json.load(f)
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
model = LSTMClassifier(
input_size=int(cfg["input_size"]),
hidden_size=int(cfg["hidden_size"]),
num_layers=int(cfg["num_layers"]),
dropout=float(cfg["dropout"]),
).to(device)
weights_name = cfg.get("weights_file", "model.safetensors")
weights_path = hf_hub_download(repo_id, weights_name, revision=revision)
if weights_name.endswith(".safetensors"):
from safetensors.torch import load_file
state = load_file(weights_path)
model.load_state_dict({k: v for k, v in state.items()}, strict=True)
else:
state = torch.load(weights_path, map_location="cpu")
model.load_state_dict(state, strict=True)
model.eval()
scaler = joblib.load(scaler_path)
return model, scaler, cfg
def predict_df(df: pd.DataFrame, model: nn.Module, scaler, cfg: Dict[str, Any]) -> np.ndarray:
from numpy.lib.stride_tricks import sliding_window_view
feature_cols = cfg["feature_cols"]
W = int(cfg["window_size"])
stride = int(cfg.get("stride", 1))
X = df[feature_cols].to_numpy(np.float32)
if len(X) < W:
return np.empty((0,), dtype=np.int64)
Xw = sliding_window_view(X, window_shape=(W, X.shape[1])).squeeze(1)
Xw = Xw[::stride]
F = Xw.shape[2]
Xw_scaled = scaler.transform(Xw.reshape(-1, F)).reshape(Xw.shape).astype(np.float32)
device = next(model.parameters()).device
with torch.no_grad():
xb = torch.tensor(Xw_scaled, device=device)
logits = model(xb)
y_pred = torch.argmax(logits, dim=1).detach().cpu().numpy()
return y_pred