Spaces:
Sleeping
Sleeping
Upload app.py with huggingface_hub
Browse files
app.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
| 1 |
"""
|
| 2 |
-
Hugging Face Space API for IQL Fire Rescue Model
|
| 3 |
-
Deploy this to HF Space to serve your custom IQL model
|
| 4 |
|
| 5 |
-
|
|
|
|
| 6 |
"""
|
| 7 |
|
|
|
|
| 8 |
from fastapi import FastAPI
|
| 9 |
from pydantic import BaseModel
|
| 10 |
from typing import Dict, List, Optional
|
|
@@ -21,12 +22,38 @@ app = FastAPI(title="IQL Fire Rescue API")
|
|
| 21 |
# Config
|
| 22 |
EMBED_MODEL = "all-MiniLM-L6-v2"
|
| 23 |
N_LAST = 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
# ============================================================================
|
| 26 |
-
# IQL Model (
|
| 27 |
# ============================================================================
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
class QNetworkEmbed(nn.Module):
|
|
|
|
| 30 |
def __init__(self, state_dim: int, action_embeds: torch.Tensor, hidden_dim: int, p_drop: float = 0.3):
|
| 31 |
super().__init__()
|
| 32 |
self.action_embeds = nn.Parameter(action_embeds, requires_grad=False)
|
|
@@ -45,34 +72,75 @@ class QNetworkEmbed(nn.Module):
|
|
| 45 |
x = self.f2(x); x = F.relu(x); x = self.ln2(x); x = self.drop(x)
|
| 46 |
return self.head(x).squeeze(-1)
|
| 47 |
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
| 50 |
return np.zeros((model.get_sentence_embedding_dimension(),), dtype=np.float32)
|
| 51 |
-
embs = model.encode(
|
| 52 |
return np.mean(embs, axis=0).astype(np.float32)
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
class IQLSelector:
|
| 55 |
def __init__(self, pt_path, policy_names):
|
| 56 |
self.device = torch.device("cpu")
|
| 57 |
self.embed_model = SentenceTransformer(EMBED_MODEL)
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
ae_key = next((k for k in state_dict.keys() if k.endswith("action_embeds")), None)
|
| 61 |
-
action_embeds_ckpt = state_dict[ae_key]
|
| 62 |
-
if isinstance(action_embeds_ckpt, np.ndarray):
|
| 63 |
-
action_embeds_ckpt = torch.tensor(action_embeds_ckpt)
|
| 64 |
-
|
| 65 |
-
num_actions, action_dim = action_embeds_ckpt.shape
|
| 66 |
-
f1w = state_dict["f1.weight"]
|
| 67 |
-
hidden_dim = f1w.shape[0]
|
| 68 |
-
state_dim = f1w.shape[1] - action_dim
|
| 69 |
-
|
| 70 |
-
dummy = torch.zeros((num_actions, action_dim), dtype=torch.float32)
|
| 71 |
-
self.qnet = QNetworkEmbed(state_dim, dummy, hidden_dim=hidden_dim).to(self.device)
|
| 72 |
-
self.qnet.load_state_dict(state_dict, strict=True)
|
| 73 |
-
self.qnet.eval()
|
| 74 |
self.policy_names = policy_names
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
def select_policy(self, history, n_last=N_LAST):
|
| 78 |
texts = [h["text"] for h in history if h.get("role") == "resident"]
|
|
@@ -80,11 +148,15 @@ class IQLSelector:
|
|
| 80 |
s_vec = embed_state(self.embed_model, last_n)
|
| 81 |
s = torch.tensor(s_vec, dtype=torch.float32, device=self.device).unsqueeze(0)
|
| 82 |
|
| 83 |
-
q_vals = []
|
| 84 |
with torch.no_grad():
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
q_vals.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
best_idx = int(np.argmax(q_vals))
|
| 90 |
return self.policy_names[best_idx], dict(zip(self.policy_names, q_vals))
|
|
@@ -94,17 +166,70 @@ class IQLSelector:
|
|
| 94 |
# ============================================================================
|
| 95 |
|
| 96 |
iql_selector = None
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
@app.on_event("startup")
|
| 99 |
async def load_model():
|
| 100 |
-
global iql_selector
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
try:
|
| 102 |
base = Path(__file__).parent
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
policies = [k for k, _ in sorted(label_map.items(), key=lambda x: x[1])]
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
print("[Space] Model loaded!")
|
| 107 |
except Exception as e:
|
|
|
|
| 108 |
print(f"[Space] Load failed: {e}")
|
| 109 |
import traceback
|
| 110 |
traceback.print_exc()
|
|
@@ -123,73 +248,72 @@ class Response(BaseModel):
|
|
| 123 |
|
| 124 |
@app.post("/", response_model=Response)
|
| 125 |
async def predict(req: Request):
|
| 126 |
-
|
|
|
|
|
|
|
| 127 |
if req.inputs == "START" or not req.inputs:
|
| 128 |
messages = []
|
| 129 |
else:
|
| 130 |
messages = [m.strip() for m in req.inputs.split("|")]
|
| 131 |
-
|
| 132 |
history = [{"role": "resident", "text": m} for m in messages]
|
| 133 |
policy, q_vals = iql_selector.select_policy(history, n_last=N_LAST)
|
| 134 |
-
|
| 135 |
return {"policy": policy, "q_values": q_vals}
|
| 136 |
|
| 137 |
@app.get("/health")
|
| 138 |
async def health():
|
| 139 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
# ============================================================================
|
| 142 |
-
# Embedding API
|
| 143 |
# ============================================================================
|
| 144 |
|
| 145 |
class EmbedRequest(BaseModel):
|
| 146 |
texts: List[str]
|
| 147 |
normalize: Optional[bool] = True
|
| 148 |
|
|
|
|
| 149 |
class EmbedResponse(BaseModel):
|
| 150 |
embeddings: List[List[float]]
|
| 151 |
model: str = EMBED_MODEL
|
| 152 |
dimension: int
|
| 153 |
|
|
|
|
| 154 |
@app.post("/embed", response_model=EmbedResponse)
|
| 155 |
async def embed_texts(req: EmbedRequest):
|
| 156 |
"""
|
| 157 |
-
Embed texts using sentence-transformers (GPU-accelerated)
|
| 158 |
-
|
| 159 |
-
Input:
|
| 160 |
-
texts: List of strings to embed
|
| 161 |
-
normalize: Whether to normalize embeddings (default: True)
|
| 162 |
-
|
| 163 |
-
Output:
|
| 164 |
-
embeddings: List of embedding vectors (384-dim)
|
| 165 |
"""
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
|
|
|
|
|
|
| 169 |
try:
|
| 170 |
-
|
| 171 |
-
embeddings = iql_selector.embed_model.encode(
|
| 172 |
req.texts,
|
| 173 |
convert_to_numpy=True,
|
| 174 |
normalize_embeddings=req.normalize,
|
| 175 |
-
show_progress_bar=False
|
| 176 |
)
|
| 177 |
-
|
| 178 |
-
# Convert to list for JSON serialization
|
| 179 |
embeddings_list = embeddings.tolist()
|
| 180 |
-
|
| 181 |
-
return EmbedResponse(
|
| 182 |
-
embeddings=embeddings_list,
|
| 183 |
-
dimension=embeddings.shape[1]
|
| 184 |
-
)
|
| 185 |
-
|
| 186 |
except Exception as e:
|
| 187 |
print(f"[EMBED] Error: {e}")
|
| 188 |
import traceback
|
| 189 |
traceback.print_exc()
|
| 190 |
-
return {"embeddings": [], "dimension": 384}
|
|
|
|
| 191 |
|
| 192 |
if __name__ == "__main__":
|
| 193 |
import uvicorn
|
| 194 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
| 195 |
-
|
|
|
|
| 1 |
"""
|
| 2 |
+
Hugging Face Space API for IQL Fire Rescue Model (iql_model_state.pt)
|
|
|
|
| 3 |
|
| 4 |
+
Serves the state-mode IQL model uploaded via upload_iql_to_hf.py.
|
| 5 |
+
Deploy to: https://huggingface.co/spaces/YOUR_USERNAME/iql-fire-rescue-api
|
| 6 |
"""
|
| 7 |
|
| 8 |
+
import os
|
| 9 |
from fastapi import FastAPI
|
| 10 |
from pydantic import BaseModel
|
| 11 |
from typing import Dict, List, Optional
|
|
|
|
| 22 |
# Config
|
| 23 |
EMBED_MODEL = "all-MiniLM-L6-v2"
|
| 24 |
N_LAST = 3
|
| 25 |
+
IQL_P_DROP = 0.3
|
| 26 |
+
IQL_HIDDEN_DIM = 1024
|
| 27 |
+
|
| 28 |
+
# Optional: load model from HF model repo (set HF_IQL_REPO=username/iql-fire-rescue)
|
| 29 |
+
HF_IQL_REPO = os.getenv("HF_IQL_REPO", "")
|
| 30 |
|
| 31 |
# ============================================================================
|
| 32 |
+
# IQL Model (state mode - iql_model_state.pt)
|
| 33 |
# ============================================================================
|
| 34 |
|
| 35 |
+
class QNetworkState(nn.Module):
|
| 36 |
+
"""State-only Q-network: Q(s) = [Q(s,a1)..Q(s,aN)]"""
|
| 37 |
+
def __init__(self, state_dim: int, num_actions: int, hidden_dim: int = IQL_HIDDEN_DIM, p_drop: float = IQL_P_DROP):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.net = nn.Sequential(
|
| 40 |
+
nn.Linear(state_dim, hidden_dim),
|
| 41 |
+
nn.ReLU(),
|
| 42 |
+
nn.LayerNorm(hidden_dim),
|
| 43 |
+
nn.Dropout(p_drop),
|
| 44 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 45 |
+
nn.ReLU(),
|
| 46 |
+
nn.LayerNorm(hidden_dim),
|
| 47 |
+
nn.Dropout(p_drop),
|
| 48 |
+
nn.Linear(hidden_dim, num_actions),
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 52 |
+
return self.net(x)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
class QNetworkEmbed(nn.Module):
|
| 56 |
+
"""Embedding-based Q-network (legacy, for iql_model_state.pt)"""
|
| 57 |
def __init__(self, state_dim: int, action_embeds: torch.Tensor, hidden_dim: int, p_drop: float = 0.3):
|
| 58 |
super().__init__()
|
| 59 |
self.action_embeds = nn.Parameter(action_embeds, requires_grad=False)
|
|
|
|
| 72 |
x = self.f2(x); x = F.relu(x); x = self.ln2(x); x = self.drop(x)
|
| 73 |
return self.head(x).squeeze(-1)
|
| 74 |
|
| 75 |
+
|
| 76 |
+
def embed_state(model: SentenceTransformer, last_n_res_texts: List[str]) -> np.ndarray:
|
| 77 |
+
"""Embed conversation state from last N resident messages"""
|
| 78 |
+
if not last_n_res_texts:
|
| 79 |
return np.zeros((model.get_sentence_embedding_dimension(),), dtype=np.float32)
|
| 80 |
+
embs = model.encode(last_n_res_texts, convert_to_numpy=True, normalize_embeddings=True)
|
| 81 |
return np.mean(embs, axis=0).astype(np.float32)
|
| 82 |
|
| 83 |
+
|
| 84 |
+
def _load_state_dict(pt_path):
|
| 85 |
+
raw = torch.load(pt_path, map_location="cpu")
|
| 86 |
+
state_dict = raw
|
| 87 |
+
if isinstance(raw, dict):
|
| 88 |
+
if "model" in raw:
|
| 89 |
+
state_dict = raw["model"]
|
| 90 |
+
elif "state_dict" in raw:
|
| 91 |
+
state_dict = raw["state_dict"]
|
| 92 |
+
if isinstance(state_dict, dict) and len(state_dict) == 1:
|
| 93 |
+
only_val = next(iter(state_dict.values()))
|
| 94 |
+
if isinstance(only_val, dict) and only_val:
|
| 95 |
+
state_dict = only_val
|
| 96 |
+
return state_dict
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _strip_prefix(sd: dict, prefix: str) -> dict:
|
| 100 |
+
return {k[len(prefix):]: v for k, v in sd.items() if k.startswith(prefix)}
|
| 101 |
+
|
| 102 |
+
|
| 103 |
class IQLSelector:
|
| 104 |
def __init__(self, pt_path, policy_names):
|
| 105 |
self.device = torch.device("cpu")
|
| 106 |
self.embed_model = SentenceTransformer(EMBED_MODEL)
|
| 107 |
+
state_dim = self.embed_model.get_sentence_embedding_dimension()
|
| 108 |
+
num_actions = len(policy_names)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
self.policy_names = policy_names
|
| 110 |
+
|
| 111 |
+
state_dict = _load_state_dict(pt_path)
|
| 112 |
+
keys = list(state_dict.keys()) if isinstance(state_dict, dict) else []
|
| 113 |
+
for prefix in ("qnet.", "module.", "model."):
|
| 114 |
+
if any(k.startswith(prefix) for k in keys):
|
| 115 |
+
state_dict = _strip_prefix(state_dict, prefix)
|
| 116 |
+
keys = list(state_dict.keys())
|
| 117 |
+
break
|
| 118 |
+
|
| 119 |
+
# State mode: QNetworkState has "net." keys
|
| 120 |
+
is_state_mode = any(k.startswith("net.") for k in keys)
|
| 121 |
+
if is_state_mode:
|
| 122 |
+
self.mode = "state"
|
| 123 |
+
self.qnet = QNetworkState(state_dim, num_actions).to(self.device)
|
| 124 |
+
model_keys = set(self.qnet.state_dict().keys())
|
| 125 |
+
state_dict_filtered = {k: v for k, v in state_dict.items() if k in model_keys}
|
| 126 |
+
self.qnet.load_state_dict(state_dict_filtered, strict=True)
|
| 127 |
+
self.qnet.eval()
|
| 128 |
+
print(f"[IQL] Loaded state-mode model: {num_actions} policies, state_dim={state_dim}")
|
| 129 |
+
else:
|
| 130 |
+
# Embed mode (legacy)
|
| 131 |
+
self.mode = "embed"
|
| 132 |
+
ae_key = next((k for k in state_dict.keys() if k.endswith("action_embeds")), None)
|
| 133 |
+
action_embeds_ckpt = state_dict[ae_key]
|
| 134 |
+
if isinstance(action_embeds_ckpt, np.ndarray):
|
| 135 |
+
action_embeds_ckpt = torch.tensor(action_embeds_ckpt)
|
| 136 |
+
num_a, action_dim = action_embeds_ckpt.shape
|
| 137 |
+
f1w = state_dict["f1.weight"]
|
| 138 |
+
hidden_dim = f1w.shape[0]
|
| 139 |
+
dummy = torch.zeros((num_a, action_dim), dtype=torch.float32)
|
| 140 |
+
self.qnet = QNetworkEmbed(state_dim, dummy, hidden_dim=hidden_dim).to(self.device)
|
| 141 |
+
self.qnet.load_state_dict(state_dict, strict=True)
|
| 142 |
+
self.qnet.eval()
|
| 143 |
+
print(f"[IQL] Loaded embed-mode model: {num_actions} policies")
|
| 144 |
|
| 145 |
def select_policy(self, history, n_last=N_LAST):
|
| 146 |
texts = [h["text"] for h in history if h.get("role") == "resident"]
|
|
|
|
| 148 |
s_vec = embed_state(self.embed_model, last_n)
|
| 149 |
s = torch.tensor(s_vec, dtype=torch.float32, device=self.device).unsqueeze(0)
|
| 150 |
|
|
|
|
| 151 |
with torch.no_grad():
|
| 152 |
+
if self.mode == "state":
|
| 153 |
+
q_out = self.qnet(s)
|
| 154 |
+
q_vals = q_out.cpu().numpy().flatten().tolist()
|
| 155 |
+
else:
|
| 156 |
+
q_vals = []
|
| 157 |
+
for a_id in range(len(self.policy_names)):
|
| 158 |
+
a = torch.tensor([a_id], dtype=torch.long, device=self.device)
|
| 159 |
+
q_vals.append(float(self.qnet(s, a).item()))
|
| 160 |
|
| 161 |
best_idx = int(np.argmax(q_vals))
|
| 162 |
return self.policy_names[best_idx], dict(zip(self.policy_names, q_vals))
|
|
|
|
| 166 |
# ============================================================================
|
| 167 |
|
| 168 |
iql_selector = None
|
| 169 |
+
embed_model = None # Standalone embed model for /embed endpoint (works even if IQL fails to load)
|
| 170 |
+
load_error = None # Captured error when IQL model fails to load
|
| 171 |
+
|
| 172 |
|
| 173 |
@app.on_event("startup")
|
| 174 |
async def load_model():
|
| 175 |
+
global iql_selector, embed_model, load_error
|
| 176 |
+
# Load embedding model first (lightweight, used by /embed endpoint)
|
| 177 |
+
try:
|
| 178 |
+
embed_model = SentenceTransformer(EMBED_MODEL)
|
| 179 |
+
print("[Space] Embedding model loaded (for /embed endpoint)")
|
| 180 |
+
except Exception as e:
|
| 181 |
+
print(f"[Space] Embedding model failed to load: {e}")
|
| 182 |
+
import traceback
|
| 183 |
+
traceback.print_exc()
|
| 184 |
+
|
| 185 |
try:
|
| 186 |
base = Path(__file__).parent
|
| 187 |
+
print(f"[Space] Base dir: {base}")
|
| 188 |
+
print(f"[Space] Files in base: {list(base.iterdir())}")
|
| 189 |
+
|
| 190 |
+
label_map_path = base / "label_map.json"
|
| 191 |
+
|
| 192 |
+
# Load label_map from HF repo if specified
|
| 193 |
+
if HF_IQL_REPO:
|
| 194 |
+
try:
|
| 195 |
+
from huggingface_hub import hf_hub_download
|
| 196 |
+
label_map_path = Path(hf_hub_download(repo_id=HF_IQL_REPO, filename="label_map.json"))
|
| 197 |
+
except Exception as e:
|
| 198 |
+
print(f"[Space] Could not load label_map from {HF_IQL_REPO}: {e}")
|
| 199 |
+
|
| 200 |
+
label_map = json.loads(label_map_path.read_text())
|
| 201 |
policies = [k for k, _ in sorted(label_map.items(), key=lambda x: x[1])]
|
| 202 |
+
|
| 203 |
+
# Try iql_model_state.pt first (state mode), then iql_model_state.pt
|
| 204 |
+
pt_path = None
|
| 205 |
+
if HF_IQL_REPO:
|
| 206 |
+
try:
|
| 207 |
+
from huggingface_hub import hf_hub_download
|
| 208 |
+
pt_path = Path(hf_hub_download(repo_id=HF_IQL_REPO, filename="iql_model_state.pt"))
|
| 209 |
+
print(f"[Space] Loaded iql_model_state.pt from {HF_IQL_REPO}")
|
| 210 |
+
except Exception as e:
|
| 211 |
+
print(f"[Space] Could not load from HF repo: {e}")
|
| 212 |
+
|
| 213 |
+
if pt_path is None:
|
| 214 |
+
for name in ["iql_model_state.pt", "state.pt"]:
|
| 215 |
+
candidate = base / name
|
| 216 |
+
if candidate.exists():
|
| 217 |
+
pt_path = candidate
|
| 218 |
+
print(f"[Space] Using local {name}")
|
| 219 |
+
break
|
| 220 |
+
|
| 221 |
+
if pt_path is None:
|
| 222 |
+
# List what files exist for debugging
|
| 223 |
+
existing = [f.name for f in base.iterdir() if f.suffix == ".pt"]
|
| 224 |
+
raise FileNotFoundError(
|
| 225 |
+
f"No iql_model_state.pt found. "
|
| 226 |
+
f"Add to Space or set HF_IQL_REPO. Found .pt files: {existing}"
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
iql_selector = IQLSelector(pt_path, policies)
|
| 230 |
print("[Space] Model loaded!")
|
| 231 |
except Exception as e:
|
| 232 |
+
load_error = str(e)
|
| 233 |
print(f"[Space] Load failed: {e}")
|
| 234 |
import traceback
|
| 235 |
traceback.print_exc()
|
|
|
|
| 248 |
|
| 249 |
@app.post("/", response_model=Response)
|
| 250 |
async def predict(req: Request):
|
| 251 |
+
if not iql_selector:
|
| 252 |
+
return {"policy": "niki", "q_values": {}}
|
| 253 |
+
|
| 254 |
if req.inputs == "START" or not req.inputs:
|
| 255 |
messages = []
|
| 256 |
else:
|
| 257 |
messages = [m.strip() for m in req.inputs.split("|")]
|
| 258 |
+
|
| 259 |
history = [{"role": "resident", "text": m} for m in messages]
|
| 260 |
policy, q_vals = iql_selector.select_policy(history, n_last=N_LAST)
|
| 261 |
+
|
| 262 |
return {"policy": policy, "q_values": q_vals}
|
| 263 |
|
| 264 |
@app.get("/health")
|
| 265 |
async def health():
|
| 266 |
+
return {
|
| 267 |
+
"status": "ok",
|
| 268 |
+
"model_loaded": iql_selector is not None,
|
| 269 |
+
"embed_model_loaded": embed_model is not None,
|
| 270 |
+
"embed_ready": (embed_model is not None) or (iql_selector is not None),
|
| 271 |
+
"load_error": load_error, # Why IQL failed (if any)
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
|
| 275 |
# ============================================================================
|
| 276 |
+
# Embedding API (for policy retrieval - works even if IQL model fails to load)
|
| 277 |
# ============================================================================
|
| 278 |
|
| 279 |
class EmbedRequest(BaseModel):
|
| 280 |
texts: List[str]
|
| 281 |
normalize: Optional[bool] = True
|
| 282 |
|
| 283 |
+
|
| 284 |
class EmbedResponse(BaseModel):
|
| 285 |
embeddings: List[List[float]]
|
| 286 |
model: str = EMBED_MODEL
|
| 287 |
dimension: int
|
| 288 |
|
| 289 |
+
|
| 290 |
@app.post("/embed", response_model=EmbedResponse)
|
| 291 |
async def embed_texts(req: EmbedRequest):
|
| 292 |
"""
|
| 293 |
+
Embed texts using sentence-transformers (GPU-accelerated).
|
| 294 |
+
Uses standalone embed_model so it works even if IQL model fails to load.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
"""
|
| 296 |
+
model = embed_model or (iql_selector.embed_model if iql_selector else None)
|
| 297 |
+
if not model:
|
| 298 |
+
print("[EMBED] No embedding model available")
|
| 299 |
+
return {"embeddings": [], "model": EMBED_MODEL, "dimension": 384}
|
| 300 |
+
|
| 301 |
try:
|
| 302 |
+
embeddings = model.encode(
|
|
|
|
| 303 |
req.texts,
|
| 304 |
convert_to_numpy=True,
|
| 305 |
normalize_embeddings=req.normalize,
|
| 306 |
+
show_progress_bar=False,
|
| 307 |
)
|
|
|
|
|
|
|
| 308 |
embeddings_list = embeddings.tolist()
|
| 309 |
+
return EmbedResponse(embeddings=embeddings_list, dimension=embeddings.shape[1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
except Exception as e:
|
| 311 |
print(f"[EMBED] Error: {e}")
|
| 312 |
import traceback
|
| 313 |
traceback.print_exc()
|
| 314 |
+
return {"embeddings": [], "model": EMBED_MODEL, "dimension": 384}
|
| 315 |
+
|
| 316 |
|
| 317 |
if __name__ == "__main__":
|
| 318 |
import uvicorn
|
| 319 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|