Clementina Tom (via Gemini) commited on
Commit
0fe2aca
Β·
1 Parent(s): e5cd6dd

Stability Patch: Improved model loading and error handling

Browse files
Files changed (2) hide show
  1. app.py +21 -11
  2. plrs/model/model_loader.py +25 -24
app.py CHANGED
@@ -118,25 +118,35 @@ html, body, [class*="css"] {
118
 
119
  # ── Model + pipeline loading ──────────────────────────────────────────────────
120
 
121
- @st.cache_resource(show_spinner="Loading curriculum & model from HuggingFace...")
122
  def load_pipelines():
123
  from plrs.model.model_loader import load_model_from_hub
 
 
 
 
124
 
125
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
126
  maps = ROOT / "data" / "knowledge_maps"
127
 
128
- # Load model (tries decay, vanilla, then base)
129
- model, model_type = load_model_from_hub(device=str(device))
 
 
 
130
 
131
  pipelines = {}
132
- for domain, fname in [("math", "math_dag.json"), ("cs", "cs_dag.json")]:
133
- path = maps / fname
134
- if path.exists():
135
- curriculum = load_dag(path)
136
- pipeline = PLRSPipeline(curriculum)
137
- if model:
138
- pipeline._model = model
139
- pipelines[domain] = pipeline
 
 
 
140
 
141
  return pipelines, model is not None, model_type
142
 
 
118
 
119
  # ── Model + pipeline loading ──────────────────────────────────────────────────
120
 
121
+ @st.cache_resource(show_spinner="Connecting to Logic Engine...")
122
  def load_pipelines():
123
  from plrs.model.model_loader import load_model_from_hub
124
+ import os
125
+
126
+ # Check for token in environment (HF Spaces allow setting secrets)
127
+ token = os.environ.get("HF_TOKEN")
128
 
129
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
130
  maps = ROOT / "data" / "knowledge_maps"
131
 
132
+ # Load model with potential token for private/restricted access
133
+ try:
134
+ model, model_type = load_model_from_hub(device=str(device), token=token)
135
+ except Exception as e:
136
+ model, model_type = None, f"Error: {str(e)}"
137
 
138
  pipelines = {}
139
+ try:
140
+ for domain, fname in [("math", "math_dag.json"), ("cs", "cs_dag.json")]:
141
+ path = maps / fname
142
+ if path.exists():
143
+ curriculum = load_dag(path)
144
+ pipeline = PLRSPipeline(curriculum)
145
+ if model:
146
+ pipeline._model = model
147
+ pipelines[domain] = pipeline
148
+ except Exception as e:
149
+ st.error(f"Curriculum load error: {e}")
150
 
151
  return pipelines, model is not None, model_type
152
 
plrs/model/model_loader.py CHANGED
@@ -24,30 +24,25 @@ import torch
24
  HF_REPO = "Clementio/PLRS"
25
 
26
 
27
- def load_model_from_hub(device: str = "cpu"):
28
  """
29
  Load SAKT model weights from HuggingFace Hub.
30
-
31
- Tries files in priority order:
32
- 1. sakt_decay_best.pt (v0.2.0 β€” decay attention)
33
- 2. sakt_vanilla_best.pt (v0.2.0 β€” vanilla transformer)
34
- 3. sakt_model.pt (v0.1.0 β€” synthetic baseline)
35
-
36
- Returns (model, model_type_str) or (None, "unavailable").
37
  """
38
  try:
39
  from huggingface_hub import hf_hub_download
40
  except ImportError:
41
  return None, "huggingface_hub not installed"
42
 
 
43
  for filename, model_type in [
44
  ("models/sakt_decay_best.pt", "SAKTWithDecay"),
45
  ("models/sakt_vanilla_best.pt", "SAKTModel"),
46
  ("models/sakt_model.pt", "SAKTModel"),
 
47
  ]:
48
  try:
49
- path = hf_hub_download(repo_id=HF_REPO, filename=filename)
50
- model = _load_weights(path, model_type, device)
51
  if model is not None:
52
  return model, model_type
53
  except Exception:
@@ -56,8 +51,9 @@ def load_model_from_hub(device: str = "cpu"):
56
  return None, "unavailable"
57
 
58
 
59
- def _load_weights(path: str, preferred_type: str, device: str):
60
  """Load model weights from a .pt file, handling both old and new formats."""
 
61
  try:
62
  payload = torch.load(path, map_location=device, weights_only=False)
63
  except Exception:
@@ -65,27 +61,27 @@ def _load_weights(path: str, preferred_type: str, device: str):
65
 
66
  # ── New format (v0.2.0): {"state_dict": ..., "model_type": ..., "config": ...}
67
  if isinstance(payload, dict) and "state_dict" in payload:
68
- cfg = payload.get("config", {})
69
  model_type = payload.get("model_type", preferred_type)
70
 
71
  if model_type == "SAKTWithDecay":
72
  from plrs.model.sakt_decay import SAKTWithDecay
73
  model = SAKTWithDecay(
74
  num_skills=cfg.get("num_skills", 5737),
75
- embed_dim=cfg.get("embed_dim", 64),
76
  num_heads=cfg.get("num_heads", 8),
77
  dropout=cfg.get("dropout", 0.2),
78
- max_seq_len=cfg.get("max_seq_len", 100),
79
  decay_init=cfg.get("decay_init", 1.0),
80
  )
81
  else:
82
  from plrs.model.sakt import SAKTModel
83
  model = SAKTModel(
84
  num_skills=cfg.get("num_skills", 5737),
85
- embed_dim=cfg.get("embed_dim", 64),
86
  num_heads=cfg.get("num_heads", 8),
87
  dropout=cfg.get("dropout", 0.2),
88
- max_seq_len=cfg.get("max_seq_len", 100),
89
  )
90
 
91
  try:
@@ -96,21 +92,26 @@ def _load_weights(path: str, preferred_type: str, device: str):
96
  except Exception:
97
  return None
98
 
99
- # ── Old format (v0.1.0 FYP): raw state_dict + separate config.json
100
  try:
101
- config_path = Path(path).parent / "config.json"
102
- if config_path.exists():
103
- config = json.loads(config_path.read_text())
104
- else:
105
- config = {"num_skills": 5736, "embed_dim": 64}
 
 
106
 
107
  from plrs.model.sakt import SAKTModel
108
  model = SAKTModel(
109
- num_skills=config.get("num_skills", 5736),
110
- embed_dim=config.get("embed_dim", 64),
 
 
111
  )
112
  model.load_state_dict(payload, strict=False)
113
  model.eval()
 
114
  return model
115
  except Exception:
116
  return None
 
24
  HF_REPO = "Clementio/PLRS"
25
 
26
 
27
+ def load_model_from_hub(device: str = "cpu", token: str | None = None):
28
  """
29
  Load SAKT model weights from HuggingFace Hub.
 
 
 
 
 
 
 
30
  """
31
  try:
32
  from huggingface_hub import hf_hub_download
33
  except ImportError:
34
  return None, "huggingface_hub not installed"
35
 
36
+ # Try files in priority order
37
  for filename, model_type in [
38
  ("models/sakt_decay_best.pt", "SAKTWithDecay"),
39
  ("models/sakt_vanilla_best.pt", "SAKTModel"),
40
  ("models/sakt_model.pt", "SAKTModel"),
41
+ ("sakt_model.pt", "SAKTModel"), # Backwards compatibility
42
  ]:
43
  try:
44
+ path = hf_hub_download(repo_id=HF_REPO, filename=filename, token=token)
45
+ model = _load_weights(path, model_type, device, token=token)
46
  if model is not None:
47
  return model, model_type
48
  except Exception:
 
51
  return None, "unavailable"
52
 
53
 
54
+ def _load_weights(path: str, preferred_type: str, device: str, token: str | None = None):
55
  """Load model weights from a .pt file, handling both old and new formats."""
56
+ from huggingface_hub import hf_hub_download
57
  try:
58
  payload = torch.load(path, map_location=device, weights_only=False)
59
  except Exception:
 
61
 
62
  # ── New format (v0.2.0): {"state_dict": ..., "model_type": ..., "config": ...}
63
  if isinstance(payload, dict) and "state_dict" in payload:
64
+ cfg = payload.get("config", {})
65
  model_type = payload.get("model_type", preferred_type)
66
 
67
  if model_type == "SAKTWithDecay":
68
  from plrs.model.sakt_decay import SAKTWithDecay
69
  model = SAKTWithDecay(
70
  num_skills=cfg.get("num_skills", 5737),
71
+ embed_dim=cfg.get("embed_dim", 128),
72
  num_heads=cfg.get("num_heads", 8),
73
  dropout=cfg.get("dropout", 0.2),
74
+ max_seq_len=cfg.get("max_seq_len", 200),
75
  decay_init=cfg.get("decay_init", 1.0),
76
  )
77
  else:
78
  from plrs.model.sakt import SAKTModel
79
  model = SAKTModel(
80
  num_skills=cfg.get("num_skills", 5737),
81
+ embed_dim=cfg.get("embed_dim", 128),
82
  num_heads=cfg.get("num_heads", 8),
83
  dropout=cfg.get("dropout", 0.2),
84
+ max_seq_len=cfg.get("max_seq_len", 200),
85
  )
86
 
87
  try:
 
92
  except Exception:
93
  return None
94
 
95
+ # ── Old format (v0.1.0 FYP): raw state_dict + fetch config.json from Hub
96
  try:
97
+ # Try to download config.json from the Hub root
98
+ try:
99
+ cfg_path = hf_hub_download(repo_id=HF_REPO, filename="config.json", token=token)
100
+ with open(cfg_path) as f:
101
+ config = json.load(f)
102
+ except Exception:
103
+ config = {"num_skills": 5737, "embed_dim": 128, "num_heads": 8, "num_layers": 2, "max_seq_len": 200, "dropout": 0.2}
104
 
105
  from plrs.model.sakt import SAKTModel
106
  model = SAKTModel(
107
+ num_skills=config.get("num_skills", 5737),
108
+ embed_dim=config.get("embed_dim", 128),
109
+ num_heads=config.get("num_heads", 8),
110
+ max_seq_len=config.get("max_seq_len", 200),
111
  )
112
  model.load_state_dict(payload, strict=False)
113
  model.eval()
114
+ model.to(device)
115
  return model
116
  except Exception:
117
  return None