""" HF Space model loader — updated for SAKTWithDecay (v0.2.0 weights). Drop this file into your HF Space as `model_loader.py` and call `load_model_from_hub()` in app.py instead of the old loading logic. The v0.2.0 weights (sakt_decay_best.pt) are saved with our new format: { "state_dict": {...}, "model_type": "SAKTWithDecay", "config": {"num_skills": 20, "embed_dim": 64, ...} } Falls back gracefully to mastery-dict mode if weights can't be loaded. """ from __future__ import annotations import json from pathlib import Path import torch HF_REPO = "Clementio/PLRS" def load_model_from_hub(device: str = "cpu", token: str | None = None): """ Load SAKT model weights from HuggingFace Hub. """ try: from huggingface_hub import hf_hub_download except ImportError: return None, "huggingface_hub not installed" # Try files in priority order for filename, model_type in [ ("models/sakt_decay_best.pt", "SAKTWithDecay"), ("models/sakt_vanilla_best.pt", "SAKTModel"), ("models/sakt_model.pt", "SAKTModel"), ("sakt_model.pt", "SAKTModel"), # Backwards compatibility ]: try: path = hf_hub_download(repo_id=HF_REPO, filename=filename, token=token) model = _load_weights(path, model_type, device, token=token) if model is not None: return model, model_type except Exception: continue return None, "unavailable" def _load_weights(path: str, preferred_type: str, device: str, token: str | None = None): """Load model weights from a .pt file, handling both old and new formats.""" from huggingface_hub import hf_hub_download try: payload = torch.load(path, map_location=device, weights_only=False) except Exception: return None # ── New format (v0.2.0): {"state_dict": ..., "model_type": ..., "config": ...} if isinstance(payload, dict) and "state_dict" in payload: cfg = payload.get("config", {}) model_type = payload.get("model_type", preferred_type) if model_type == "SAKTWithDecay": from plrs.model.sakt_decay import SAKTWithDecay model = SAKTWithDecay( num_skills=cfg.get("num_skills", 5737), embed_dim=cfg.get("embed_dim", 128), num_heads=cfg.get("num_heads", 8), dropout=cfg.get("dropout", 0.2), max_seq_len=cfg.get("max_seq_len", 200), decay_init=cfg.get("decay_init", 1.0), ) else: from plrs.model.sakt import SAKTModel model = SAKTModel( num_skills=cfg.get("num_skills", 5737), embed_dim=cfg.get("embed_dim", 128), num_heads=cfg.get("num_heads", 8), dropout=cfg.get("dropout", 0.2), max_seq_len=cfg.get("max_seq_len", 200), ) try: model.load_state_dict(payload["state_dict"], strict=False) model.eval() model.to(device) return model except Exception: return None # ── Old format (v0.1.0 FYP): raw state_dict + fetch config.json from Hub try: # Try to download config.json from the Hub root try: cfg_path = hf_hub_download(repo_id=HF_REPO, filename="config.json", token=token) with open(cfg_path) as f: config = json.load(f) except Exception: config = {"num_skills": 5737, "embed_dim": 128, "num_heads": 8, "num_layers": 2, "max_seq_len": 200, "dropout": 0.2} from plrs.model.sakt import SAKTModel model = SAKTModel( num_skills=config.get("num_skills", 5737), embed_dim=config.get("embed_dim", 128), num_heads=config.get("num_heads", 8), max_seq_len=config.get("max_seq_len", 200), ) model.load_state_dict(payload, strict=False) model.eval() model.to(device) return model except Exception: return None