tzhang62 commited on
Commit
d632886
·
verified ·
1 Parent(s): e16c4b5

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +184 -60
app.py CHANGED
@@ -1,10 +1,11 @@
1
  """
2
- Hugging Face Space API for IQL Fire Rescue Model
3
- Deploy this to HF Space to serve your custom IQL model
4
 
5
- Upload to: https://huggingface.co/spaces/tzhang62/iql-fire-rescue-api
 
6
  """
7
 
 
8
  from fastapi import FastAPI
9
  from pydantic import BaseModel
10
  from typing import Dict, List, Optional
@@ -21,12 +22,38 @@ app = FastAPI(title="IQL Fire Rescue API")
21
  # Config
22
  EMBED_MODEL = "all-MiniLM-L6-v2"
23
  N_LAST = 3
 
 
 
 
 
24
 
25
  # ============================================================================
26
- # IQL Model (same as your server.py)
27
  # ============================================================================
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  class QNetworkEmbed(nn.Module):
 
30
  def __init__(self, state_dim: int, action_embeds: torch.Tensor, hidden_dim: int, p_drop: float = 0.3):
31
  super().__init__()
32
  self.action_embeds = nn.Parameter(action_embeds, requires_grad=False)
@@ -45,34 +72,75 @@ class QNetworkEmbed(nn.Module):
45
  x = self.f2(x); x = F.relu(x); x = self.ln2(x); x = self.drop(x)
46
  return self.head(x).squeeze(-1)
47
 
48
- def embed_state(model, texts):
49
- if not texts:
 
 
50
  return np.zeros((model.get_sentence_embedding_dimension(),), dtype=np.float32)
51
- embs = model.encode(texts, convert_to_numpy=True, normalize_embeddings=True)
52
  return np.mean(embs, axis=0).astype(np.float32)
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  class IQLSelector:
55
  def __init__(self, pt_path, policy_names):
56
  self.device = torch.device("cpu")
57
  self.embed_model = SentenceTransformer(EMBED_MODEL)
58
- state_dict = torch.load(pt_path, map_location="cpu")
59
-
60
- ae_key = next((k for k in state_dict.keys() if k.endswith("action_embeds")), None)
61
- action_embeds_ckpt = state_dict[ae_key]
62
- if isinstance(action_embeds_ckpt, np.ndarray):
63
- action_embeds_ckpt = torch.tensor(action_embeds_ckpt)
64
-
65
- num_actions, action_dim = action_embeds_ckpt.shape
66
- f1w = state_dict["f1.weight"]
67
- hidden_dim = f1w.shape[0]
68
- state_dim = f1w.shape[1] - action_dim
69
-
70
- dummy = torch.zeros((num_actions, action_dim), dtype=torch.float32)
71
- self.qnet = QNetworkEmbed(state_dim, dummy, hidden_dim=hidden_dim).to(self.device)
72
- self.qnet.load_state_dict(state_dict, strict=True)
73
- self.qnet.eval()
74
  self.policy_names = policy_names
75
- print(f"[IQL] Loaded: {num_actions} policies, state_dim={state_dim}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  def select_policy(self, history, n_last=N_LAST):
78
  texts = [h["text"] for h in history if h.get("role") == "resident"]
@@ -80,11 +148,15 @@ class IQLSelector:
80
  s_vec = embed_state(self.embed_model, last_n)
81
  s = torch.tensor(s_vec, dtype=torch.float32, device=self.device).unsqueeze(0)
82
 
83
- q_vals = []
84
  with torch.no_grad():
85
- for a_id in range(len(self.policy_names)):
86
- a = torch.tensor([a_id], dtype=torch.long, device=self.device)
87
- q_vals.append(float(self.qnet(s, a).item()))
 
 
 
 
 
88
 
89
  best_idx = int(np.argmax(q_vals))
90
  return self.policy_names[best_idx], dict(zip(self.policy_names, q_vals))
@@ -94,17 +166,70 @@ class IQLSelector:
94
  # ============================================================================
95
 
96
  iql_selector = None
 
 
 
97
 
98
  @app.on_event("startup")
99
  async def load_model():
100
- global iql_selector
 
 
 
 
 
 
 
 
 
101
  try:
102
  base = Path(__file__).parent
103
- label_map = json.loads((base / "label_map.json").read_text())
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  policies = [k for k, _ in sorted(label_map.items(), key=lambda x: x[1])]
105
- iql_selector = IQLSelector(base / "iql_model_embed.pt", policies)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  print("[Space] Model loaded!")
107
  except Exception as e:
 
108
  print(f"[Space] Load failed: {e}")
109
  import traceback
110
  traceback.print_exc()
@@ -123,73 +248,72 @@ class Response(BaseModel):
123
 
124
  @app.post("/", response_model=Response)
125
  async def predict(req: Request):
126
- # Parse state: "msg1 | msg2 | msg3"
 
 
127
  if req.inputs == "START" or not req.inputs:
128
  messages = []
129
  else:
130
  messages = [m.strip() for m in req.inputs.split("|")]
131
-
132
  history = [{"role": "resident", "text": m} for m in messages]
133
  policy, q_vals = iql_selector.select_policy(history, n_last=N_LAST)
134
-
135
  return {"policy": policy, "q_values": q_vals}
136
 
137
  @app.get("/health")
138
  async def health():
139
- return {"status": "ok", "model_loaded": iql_selector is not None}
 
 
 
 
 
 
 
140
 
141
  # ============================================================================
142
- # Embedding API
143
  # ============================================================================
144
 
145
  class EmbedRequest(BaseModel):
146
  texts: List[str]
147
  normalize: Optional[bool] = True
148
 
 
149
  class EmbedResponse(BaseModel):
150
  embeddings: List[List[float]]
151
  model: str = EMBED_MODEL
152
  dimension: int
153
 
 
154
  @app.post("/embed", response_model=EmbedResponse)
155
  async def embed_texts(req: EmbedRequest):
156
  """
157
- Embed texts using sentence-transformers (GPU-accelerated)
158
-
159
- Input:
160
- texts: List of strings to embed
161
- normalize: Whether to normalize embeddings (default: True)
162
-
163
- Output:
164
- embeddings: List of embedding vectors (384-dim)
165
  """
166
- if not iql_selector:
167
- return {"embeddings": [], "dimension": 384}
168
-
 
 
169
  try:
170
- # Use the already-loaded sentence-transformers model
171
- embeddings = iql_selector.embed_model.encode(
172
  req.texts,
173
  convert_to_numpy=True,
174
  normalize_embeddings=req.normalize,
175
- show_progress_bar=False
176
  )
177
-
178
- # Convert to list for JSON serialization
179
  embeddings_list = embeddings.tolist()
180
-
181
- return EmbedResponse(
182
- embeddings=embeddings_list,
183
- dimension=embeddings.shape[1]
184
- )
185
-
186
  except Exception as e:
187
  print(f"[EMBED] Error: {e}")
188
  import traceback
189
  traceback.print_exc()
190
- return {"embeddings": [], "dimension": 384}
 
191
 
192
  if __name__ == "__main__":
193
  import uvicorn
194
  uvicorn.run(app, host="0.0.0.0", port=7860)
195
-
 
1
  """
2
+ Hugging Face Space API for IQL Fire Rescue Model (iql_model_state.pt)
 
3
 
4
+ Serves the state-mode IQL model uploaded via upload_iql_to_hf.py.
5
+ Deploy to: https://huggingface.co/spaces/YOUR_USERNAME/iql-fire-rescue-api
6
  """
7
 
8
+ import os
9
  from fastapi import FastAPI
10
  from pydantic import BaseModel
11
  from typing import Dict, List, Optional
 
22
  # Config
23
  EMBED_MODEL = "all-MiniLM-L6-v2"
24
  N_LAST = 3
25
+ IQL_P_DROP = 0.3
26
+ IQL_HIDDEN_DIM = 1024
27
+
28
+ # Optional: load model from HF model repo (set HF_IQL_REPO=username/iql-fire-rescue)
29
+ HF_IQL_REPO = os.getenv("HF_IQL_REPO", "")
30
 
31
  # ============================================================================
32
+ # IQL Model (state mode - iql_model_state.pt)
33
  # ============================================================================
34
 
35
+ class QNetworkState(nn.Module):
36
+ """State-only Q-network: Q(s) = [Q(s,a1)..Q(s,aN)]"""
37
+ def __init__(self, state_dim: int, num_actions: int, hidden_dim: int = IQL_HIDDEN_DIM, p_drop: float = IQL_P_DROP):
38
+ super().__init__()
39
+ self.net = nn.Sequential(
40
+ nn.Linear(state_dim, hidden_dim),
41
+ nn.ReLU(),
42
+ nn.LayerNorm(hidden_dim),
43
+ nn.Dropout(p_drop),
44
+ nn.Linear(hidden_dim, hidden_dim),
45
+ nn.ReLU(),
46
+ nn.LayerNorm(hidden_dim),
47
+ nn.Dropout(p_drop),
48
+ nn.Linear(hidden_dim, num_actions),
49
+ )
50
+
51
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
52
+ return self.net(x)
53
+
54
+
55
  class QNetworkEmbed(nn.Module):
56
+ """Embedding-based Q-network (legacy, for iql_model_state.pt)"""
57
  def __init__(self, state_dim: int, action_embeds: torch.Tensor, hidden_dim: int, p_drop: float = 0.3):
58
  super().__init__()
59
  self.action_embeds = nn.Parameter(action_embeds, requires_grad=False)
 
72
  x = self.f2(x); x = F.relu(x); x = self.ln2(x); x = self.drop(x)
73
  return self.head(x).squeeze(-1)
74
 
75
+
76
+ def embed_state(model: SentenceTransformer, last_n_res_texts: List[str]) -> np.ndarray:
77
+ """Embed conversation state from last N resident messages"""
78
+ if not last_n_res_texts:
79
  return np.zeros((model.get_sentence_embedding_dimension(),), dtype=np.float32)
80
+ embs = model.encode(last_n_res_texts, convert_to_numpy=True, normalize_embeddings=True)
81
  return np.mean(embs, axis=0).astype(np.float32)
82
 
83
+
84
+ def _load_state_dict(pt_path):
85
+ raw = torch.load(pt_path, map_location="cpu")
86
+ state_dict = raw
87
+ if isinstance(raw, dict):
88
+ if "model" in raw:
89
+ state_dict = raw["model"]
90
+ elif "state_dict" in raw:
91
+ state_dict = raw["state_dict"]
92
+ if isinstance(state_dict, dict) and len(state_dict) == 1:
93
+ only_val = next(iter(state_dict.values()))
94
+ if isinstance(only_val, dict) and only_val:
95
+ state_dict = only_val
96
+ return state_dict
97
+
98
+
99
+ def _strip_prefix(sd: dict, prefix: str) -> dict:
100
+ return {k[len(prefix):]: v for k, v in sd.items() if k.startswith(prefix)}
101
+
102
+
103
  class IQLSelector:
104
  def __init__(self, pt_path, policy_names):
105
  self.device = torch.device("cpu")
106
  self.embed_model = SentenceTransformer(EMBED_MODEL)
107
+ state_dim = self.embed_model.get_sentence_embedding_dimension()
108
+ num_actions = len(policy_names)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  self.policy_names = policy_names
110
+
111
+ state_dict = _load_state_dict(pt_path)
112
+ keys = list(state_dict.keys()) if isinstance(state_dict, dict) else []
113
+ for prefix in ("qnet.", "module.", "model."):
114
+ if any(k.startswith(prefix) for k in keys):
115
+ state_dict = _strip_prefix(state_dict, prefix)
116
+ keys = list(state_dict.keys())
117
+ break
118
+
119
+ # State mode: QNetworkState has "net." keys
120
+ is_state_mode = any(k.startswith("net.") for k in keys)
121
+ if is_state_mode:
122
+ self.mode = "state"
123
+ self.qnet = QNetworkState(state_dim, num_actions).to(self.device)
124
+ model_keys = set(self.qnet.state_dict().keys())
125
+ state_dict_filtered = {k: v for k, v in state_dict.items() if k in model_keys}
126
+ self.qnet.load_state_dict(state_dict_filtered, strict=True)
127
+ self.qnet.eval()
128
+ print(f"[IQL] Loaded state-mode model: {num_actions} policies, state_dim={state_dim}")
129
+ else:
130
+ # Embed mode (legacy)
131
+ self.mode = "embed"
132
+ ae_key = next((k for k in state_dict.keys() if k.endswith("action_embeds")), None)
133
+ action_embeds_ckpt = state_dict[ae_key]
134
+ if isinstance(action_embeds_ckpt, np.ndarray):
135
+ action_embeds_ckpt = torch.tensor(action_embeds_ckpt)
136
+ num_a, action_dim = action_embeds_ckpt.shape
137
+ f1w = state_dict["f1.weight"]
138
+ hidden_dim = f1w.shape[0]
139
+ dummy = torch.zeros((num_a, action_dim), dtype=torch.float32)
140
+ self.qnet = QNetworkEmbed(state_dim, dummy, hidden_dim=hidden_dim).to(self.device)
141
+ self.qnet.load_state_dict(state_dict, strict=True)
142
+ self.qnet.eval()
143
+ print(f"[IQL] Loaded embed-mode model: {num_actions} policies")
144
 
145
  def select_policy(self, history, n_last=N_LAST):
146
  texts = [h["text"] for h in history if h.get("role") == "resident"]
 
148
  s_vec = embed_state(self.embed_model, last_n)
149
  s = torch.tensor(s_vec, dtype=torch.float32, device=self.device).unsqueeze(0)
150
 
 
151
  with torch.no_grad():
152
+ if self.mode == "state":
153
+ q_out = self.qnet(s)
154
+ q_vals = q_out.cpu().numpy().flatten().tolist()
155
+ else:
156
+ q_vals = []
157
+ for a_id in range(len(self.policy_names)):
158
+ a = torch.tensor([a_id], dtype=torch.long, device=self.device)
159
+ q_vals.append(float(self.qnet(s, a).item()))
160
 
161
  best_idx = int(np.argmax(q_vals))
162
  return self.policy_names[best_idx], dict(zip(self.policy_names, q_vals))
 
166
  # ============================================================================
167
 
168
  iql_selector = None
169
+ embed_model = None # Standalone embed model for /embed endpoint (works even if IQL fails to load)
170
+ load_error = None # Captured error when IQL model fails to load
171
+
172
 
173
  @app.on_event("startup")
174
  async def load_model():
175
+ global iql_selector, embed_model, load_error
176
+ # Load embedding model first (lightweight, used by /embed endpoint)
177
+ try:
178
+ embed_model = SentenceTransformer(EMBED_MODEL)
179
+ print("[Space] Embedding model loaded (for /embed endpoint)")
180
+ except Exception as e:
181
+ print(f"[Space] Embedding model failed to load: {e}")
182
+ import traceback
183
+ traceback.print_exc()
184
+
185
  try:
186
  base = Path(__file__).parent
187
+ print(f"[Space] Base dir: {base}")
188
+ print(f"[Space] Files in base: {list(base.iterdir())}")
189
+
190
+ label_map_path = base / "label_map.json"
191
+
192
+ # Load label_map from HF repo if specified
193
+ if HF_IQL_REPO:
194
+ try:
195
+ from huggingface_hub import hf_hub_download
196
+ label_map_path = Path(hf_hub_download(repo_id=HF_IQL_REPO, filename="label_map.json"))
197
+ except Exception as e:
198
+ print(f"[Space] Could not load label_map from {HF_IQL_REPO}: {e}")
199
+
200
+ label_map = json.loads(label_map_path.read_text())
201
  policies = [k for k, _ in sorted(label_map.items(), key=lambda x: x[1])]
202
+
203
+ # Try iql_model_state.pt first (state mode), then iql_model_state.pt
204
+ pt_path = None
205
+ if HF_IQL_REPO:
206
+ try:
207
+ from huggingface_hub import hf_hub_download
208
+ pt_path = Path(hf_hub_download(repo_id=HF_IQL_REPO, filename="iql_model_state.pt"))
209
+ print(f"[Space] Loaded iql_model_state.pt from {HF_IQL_REPO}")
210
+ except Exception as e:
211
+ print(f"[Space] Could not load from HF repo: {e}")
212
+
213
+ if pt_path is None:
214
+ for name in ["iql_model_state.pt", "state.pt"]:
215
+ candidate = base / name
216
+ if candidate.exists():
217
+ pt_path = candidate
218
+ print(f"[Space] Using local {name}")
219
+ break
220
+
221
+ if pt_path is None:
222
+ # List what files exist for debugging
223
+ existing = [f.name for f in base.iterdir() if f.suffix == ".pt"]
224
+ raise FileNotFoundError(
225
+ f"No iql_model_state.pt found. "
226
+ f"Add to Space or set HF_IQL_REPO. Found .pt files: {existing}"
227
+ )
228
+
229
+ iql_selector = IQLSelector(pt_path, policies)
230
  print("[Space] Model loaded!")
231
  except Exception as e:
232
+ load_error = str(e)
233
  print(f"[Space] Load failed: {e}")
234
  import traceback
235
  traceback.print_exc()
 
248
 
249
  @app.post("/", response_model=Response)
250
  async def predict(req: Request):
251
+ if not iql_selector:
252
+ return {"policy": "niki", "q_values": {}}
253
+
254
  if req.inputs == "START" or not req.inputs:
255
  messages = []
256
  else:
257
  messages = [m.strip() for m in req.inputs.split("|")]
258
+
259
  history = [{"role": "resident", "text": m} for m in messages]
260
  policy, q_vals = iql_selector.select_policy(history, n_last=N_LAST)
261
+
262
  return {"policy": policy, "q_values": q_vals}
263
 
264
  @app.get("/health")
265
  async def health():
266
+ return {
267
+ "status": "ok",
268
+ "model_loaded": iql_selector is not None,
269
+ "embed_model_loaded": embed_model is not None,
270
+ "embed_ready": (embed_model is not None) or (iql_selector is not None),
271
+ "load_error": load_error, # Why IQL failed (if any)
272
+ }
273
+
274
 
275
  # ============================================================================
276
+ # Embedding API (for policy retrieval - works even if IQL model fails to load)
277
  # ============================================================================
278
 
279
  class EmbedRequest(BaseModel):
280
  texts: List[str]
281
  normalize: Optional[bool] = True
282
 
283
+
284
  class EmbedResponse(BaseModel):
285
  embeddings: List[List[float]]
286
  model: str = EMBED_MODEL
287
  dimension: int
288
 
289
+
290
  @app.post("/embed", response_model=EmbedResponse)
291
  async def embed_texts(req: EmbedRequest):
292
  """
293
+ Embed texts using sentence-transformers (GPU-accelerated).
294
+ Uses standalone embed_model so it works even if IQL model fails to load.
 
 
 
 
 
 
295
  """
296
+ model = embed_model or (iql_selector.embed_model if iql_selector else None)
297
+ if not model:
298
+ print("[EMBED] No embedding model available")
299
+ return {"embeddings": [], "model": EMBED_MODEL, "dimension": 384}
300
+
301
  try:
302
+ embeddings = model.encode(
 
303
  req.texts,
304
  convert_to_numpy=True,
305
  normalize_embeddings=req.normalize,
306
+ show_progress_bar=False,
307
  )
 
 
308
  embeddings_list = embeddings.tolist()
309
+ return EmbedResponse(embeddings=embeddings_list, dimension=embeddings.shape[1])
 
 
 
 
 
310
  except Exception as e:
311
  print(f"[EMBED] Error: {e}")
312
  import traceback
313
  traceback.print_exc()
314
+ return {"embeddings": [], "model": EMBED_MODEL, "dimension": 384}
315
+
316
 
317
  if __name__ == "__main__":
318
  import uvicorn
319
  uvicorn.run(app, host="0.0.0.0", port=7860)