File size: 3,881 Bytes
a30026f e5cd6dd a30026f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 | """
HF Space model loader — updated for SAKTWithDecay (v0.2.0 weights).
Drop this file into your HF Space as `model_loader.py` and call
`load_model_from_hub()` in app.py instead of the old loading logic.
The v0.2.0 weights (sakt_decay_best.pt) are saved with our new format:
{
"state_dict": {...},
"model_type": "SAKTWithDecay",
"config": {"num_skills": 20, "embed_dim": 64, ...}
}
Falls back gracefully to mastery-dict mode if weights can't be loaded.
"""
from __future__ import annotations
import json
from pathlib import Path
import torch
HF_REPO = "Clementio/PLRS"
def load_model_from_hub(device: str = "cpu"):
"""
Load SAKT model weights from HuggingFace Hub.
Tries files in priority order:
1. sakt_decay_best.pt (v0.2.0 — decay attention)
2. sakt_vanilla_best.pt (v0.2.0 — vanilla transformer)
3. sakt_model.pt (v0.1.0 — synthetic baseline)
Returns (model, model_type_str) or (None, "unavailable").
"""
try:
from huggingface_hub import hf_hub_download
except ImportError:
return None, "huggingface_hub not installed"
for filename, model_type in [
("models/sakt_decay_best.pt", "SAKTWithDecay"),
("models/sakt_vanilla_best.pt", "SAKTModel"),
("models/sakt_model.pt", "SAKTModel"),
]:
try:
path = hf_hub_download(repo_id=HF_REPO, filename=filename)
model = _load_weights(path, model_type, device)
if model is not None:
return model, model_type
except Exception:
continue
return None, "unavailable"
def _load_weights(path: str, preferred_type: str, device: str):
"""Load model weights from a .pt file, handling both old and new formats."""
try:
payload = torch.load(path, map_location=device, weights_only=False)
except Exception:
return None
# ── New format (v0.2.0): {"state_dict": ..., "model_type": ..., "config": ...}
if isinstance(payload, dict) and "state_dict" in payload:
cfg = payload.get("config", {})
model_type = payload.get("model_type", preferred_type)
if model_type == "SAKTWithDecay":
from plrs.model.sakt_decay import SAKTWithDecay
model = SAKTWithDecay(
num_skills=cfg.get("num_skills", 5737),
embed_dim=cfg.get("embed_dim", 64),
num_heads=cfg.get("num_heads", 8),
dropout=cfg.get("dropout", 0.2),
max_seq_len=cfg.get("max_seq_len", 100),
decay_init=cfg.get("decay_init", 1.0),
)
else:
from plrs.model.sakt import SAKTModel
model = SAKTModel(
num_skills=cfg.get("num_skills", 5737),
embed_dim=cfg.get("embed_dim", 64),
num_heads=cfg.get("num_heads", 8),
dropout=cfg.get("dropout", 0.2),
max_seq_len=cfg.get("max_seq_len", 100),
)
try:
model.load_state_dict(payload["state_dict"], strict=False)
model.eval()
model.to(device)
return model
except Exception:
return None
# ── Old format (v0.1.0 FYP): raw state_dict + separate config.json
try:
config_path = Path(path).parent / "config.json"
if config_path.exists():
config = json.loads(config_path.read_text())
else:
config = {"num_skills": 5736, "embed_dim": 64}
from plrs.model.sakt import SAKTModel
model = SAKTModel(
num_skills=config.get("num_skills", 5736),
embed_dim=config.get("embed_dim", 64),
)
model.load_state_dict(payload, strict=False)
model.eval()
return model
except Exception:
return None
|