Spaces:
Sleeping
Sleeping
| from fastapi import APIRouter, HTTPException, Request, Form | |
| from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse | |
| import os, json, requests, time, re, pathlib | |
| router = APIRouter() | |
| # --------------------------------------------------------------------- | |
| # In-memory job + instance store | |
| # --------------------------------------------------------------------- | |
| _JOBS = {} | |
| _INST = { | |
| "podId": "", "status": "", "ip": "", "port": "", | |
| "blob": None, "model_id": "", "container_image_hint": "", | |
| "predictRoute": None, "healthRoute": None, | |
| "readinessRoute": None, "livenessRoute": None, | |
| } | |
| def _now_ms(): return int(time.time() * 1000) | |
| def _job_log(job_id, msg): | |
| j = _JOBS.setdefault(job_id, {"status":"created","logs":[], | |
| "image_b64":None,"timings":{}}) | |
| j["logs"].append({"t":_now_ms(),"msg":msg}) | |
| print(f"[{job_id}] {msg}", flush=True) | |
| def _log_create(msg): _job_log("compute", f"[CREATE] {msg}") | |
| def _log_status(msg): _job_log("compute", f"[STATUS] {msg}") | |
| def _log_delete(msg): _job_log("compute", f"[DELETE] {msg}") | |
| def _log_id(prefix, pid): _job_log("compute", f"{prefix} ID: {pid}") | |
| # --- local blob ingest (landing page only) --- | |
| _LOCAL_BLOB_PATH = os.getenv("MODEL_BLOB_PATH", "model_blob.json") | |
| def _load_local_blob(): | |
| try: | |
| if os.path.exists(_LOCAL_BLOB_PATH): | |
| with open(_LOCAL_BLOB_PATH, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| except Exception as e: | |
| _job_log("compute", f"ERROR LocalBlobLoad: {e}") | |
| return None | |
| def _ingest_blob(parsed: dict, model_id_hint: str = "", container_image_hint: str = ""): | |
| if not isinstance(parsed, dict): | |
| raise HTTPException(400, "Invalid blob (expected JSON object).") | |
| _INST.update({ | |
| "blob": parsed, | |
| "model_id": model_id_hint or "", | |
| "container_image_hint": container_image_hint or "", | |
| }) | |
| c = (((parsed.get("supportedActions") or {}).get("deploy") or {}).get("containerSpec") | |
| or parsed.get("container") or {}) or {} | |
| for k in ("predictRoute", "healthRoute", "readinessRoute", "livenessRoute"): | |
| v = c.get(k) | |
| if isinstance(v, str) and v.strip(): | |
| _INST[k] = v.strip() | |
| image_uri = (c.get("imageUri") or "").strip().lower() | |
| pr, hr = _infer_routes_from_image(image_uri) | |
| if pr and not _INST.get("predictRoute"): | |
| _INST["predictRoute"] = pr | |
| if hr and not _INST.get("healthRoute"): | |
| _INST["healthRoute"] = hr | |
| return True | |
| # --------------------------------------------------------------------- | |
| # Disk persistence for recovery | |
| # --------------------------------------------------------------------- | |
| _STATE_PATH = "/tmp/pod_state.json" | |
| def _save_state(): | |
| try: | |
| pathlib.Path("/tmp").mkdir(parents=True, exist_ok=True) | |
| with open(_STATE_PATH, "w") as f: | |
| json.dump({k:_INST.get(k,"") for k in | |
| ("podId","status","ip","port")}, f) | |
| except Exception as e: | |
| _job_log("compute", f"ERROR SaveState: {e}") | |
| def _load_state(): | |
| try: | |
| if os.path.exists(_STATE_PATH): | |
| with open(_STATE_PATH) as f: d = json.load(f) | |
| for k in ("podId","status","ip","port"): | |
| if k in d: _INST[k]=d[k] | |
| except Exception as e: | |
| _job_log("compute", f"ERROR LoadState: {e}") | |
| # --------------------------------------------------------------------- | |
| # RunPod helpers | |
| # --------------------------------------------------------------------- | |
| _RP_BASE = "https://rest.runpod.io/v1" | |
| def _rp_headers(): | |
| key=os.getenv("RunPod","").strip() | |
| if not key: | |
| raise HTTPException(500,"Missing RunPod API key (env var 'RunPod').") | |
| return {"Authorization":f"Bearer {key}","Content-Type":"application/json"} | |
| def _as_json(r): | |
| c=(r.headers.get("content-type") or "").lower() | |
| if "json" in c: | |
| try: return r.json() | |
| except Exception: return {"_raw":r.text} | |
| return {"_raw":r.text} | |
| # --------------------------------------------------------------------- | |
| # Probes and route discovery (new) | |
| # --------------------------------------------------------------------- | |
| # Expanded set: will try these against https://pod:port/<route> | |
| _POSSIBLE_ROUTES = [ | |
| "/invocations", # <— added and placed first | |
| "/generate", | |
| "/predict", | |
| "/predictions", | |
| "/v1/chat/completions", | |
| "/v1/models/model:predict", | |
| ] | |
| def _infer_routes_from_image(image_uri: str): | |
| """ | |
| Infer (predict_route, health_route) from known image patterns. | |
| """ | |
| iu = (image_uri or "").lower() | |
| # vLLM images | |
| if "vllm-serve" in iu: | |
| return ("/generate", "/ping") | |
| # HuggingFace / Vertex HF Inference Toolkit | |
| # changed from "/predict" → "/invocations" | |
| if "hf-inference-toolkit" in iu or "huggingface-pytorch-inference" in iu: | |
| return ("/invocations", "/ping") | |
| # Unknown image → allow route scanning fallback | |
| return (None, None) | |
| async def _probe_all_routes(base: str, port: str, session): | |
| """ | |
| Try all known routes until one responds 200/OK-ish. | |
| Returns (predict_route, health_route or None) | |
| """ | |
| from urllib.parse import urljoin | |
| proto_base = f"{base}:{port}" | |
| for route in _POSSIBLE_ROUTES: | |
| url = urljoin(proto_base + "/", route.lstrip("/")) | |
| try: | |
| r = await session.get(url, timeout=3) | |
| if r.status_code < 500: | |
| return route, ("/ping" if "/ping" in route else None) | |
| except Exception: | |
| pass | |
| return None, None | |
| # --------------------------------------------------------------------- | |
| # Blob ingest via Model Blob page JSON (with blob_url override) | |
| # --------------------------------------------------------------------- | |
| _HF_SPACE_PORT = os.getenv("PORT", "7860") | |
| _LOCAL_BASE = f"http://127.0.0.1:{_HF_SPACE_PORT}" | |
| def _normalize_blob_url(u: str | None) -> str | None: | |
| if not u: | |
| return None | |
| u = str(u).strip() | |
| if u.startswith(("http://", "https://")): | |
| return u | |
| # Treat '/x' or 'x' as local to this app (same origin as FE) | |
| if u.startswith("/"): | |
| return f"{_LOCAL_BASE}{u}" | |
| return f"{_LOCAL_BASE}/{u}" | |
| def _fetch_url(u: str): | |
| try: | |
| r = requests.get(u, timeout=8) | |
| if r.ok: | |
| return r.json() | |
| _job_log("compute", f"ERROR BlobFetch code={r.status_code} url={u} body={r.text[:200]}") | |
| except Exception as e: | |
| _job_log("compute", f"ERROR BlobFetch url={u}: {e}") | |
| return None | |
| def _fetch_blob_from_page(): | |
| return _fetch_url(f"{_LOCAL_BASE}/modelblob.json") | |
| def _ingest_blob(parsed: dict, model_id_hint: str = "", container_image_hint: str = ""): | |
| if not isinstance(parsed, dict): | |
| raise HTTPException(400, "Invalid blob (expected JSON object).") | |
| _INST.update({ | |
| "blob": parsed, | |
| "model_id": model_id_hint or "", | |
| "container_image_hint": container_image_hint or "", | |
| }) | |
| c = (((parsed.get("supportedActions") or {}).get("deploy") or {}).get("containerSpec") | |
| or parsed.get("container") or {}) or {} | |
| for k in ("predictRoute", "healthRoute", "readinessRoute", "livenessRoute"): | |
| v = c.get(k) | |
| if isinstance(v, str) and v.strip(): | |
| _INST[k] = v.strip() | |
| image_uri = (c.get("imageUri") or "").strip().lower() | |
| pr, hr = _infer_routes_from_image(image_uri) | |
| if pr and not _INST.get("predictRoute"): | |
| _INST["predictRoute"] = pr | |
| if hr and not _INST.get("healthRoute"): | |
| _INST["healthRoute"] = hr # <-- only addition: cache healthRoute hint | |
| return True | |
| def api_ingest_from_landing(blob_url: str | None = None): | |
| """ | |
| Ingest the deployment blob for downstream use. | |
| Mirrors FE behavior: resolve relative paths like '/modelblob.json' | |
| against the app origin. | |
| """ | |
| u = _normalize_blob_url(blob_url) or _normalize_blob_url("/modelblob.json") | |
| parsed = _fetch_url(u) | |
| if not parsed: | |
| return JSONResponse({"error": "Blob not available"}, 404) | |
| _ingest_blob(parsed, model_id_hint="", container_image_hint="") | |
| return JSONResponse({"ok": True, "source": u}) | |
| # (Optional compatibility: UI posting to /Deployment_UI; accepts blob_url via query) | |
| async def deployment_ui_ingest(request: Request, | |
| model_id: str = Form(""), | |
| container_image: str = Form(""), | |
| blob: str = Form("")): | |
| """ | |
| Legacy entry used by the Deployment UI page. | |
| Prefers blob_url from query string; falls back to the modelblob page JSON. | |
| """ | |
| blob_url = request.query_params.get("blob_url") | |
| u = _normalize_blob_url(blob_url) if blob_url else _normalize_blob_url("/modelblob.json") | |
| parsed = _fetch_url(u) | |
| if not parsed: | |
| return HTMLResponse("<pre>Missing blob (no /modelblob.json and no blob_url)</pre>", 400) | |
| _ingest_blob(parsed, model_id_hint=model_id, container_image_hint=container_image) | |
| return RedirectResponse("/Deployment_UI", 303) | |
| # --------------------------------------------------------------------- | |
| # Create instance | |
| # --------------------------------------------------------------------- | |
| async def api_create_instance(req: Request): | |
| # Ensure blob is present (lazy-load from landing file if needed) | |
| if not _INST.get("blob"): | |
| lb = _load_local_blob() | |
| if lb: | |
| _ingest_blob(lb) | |
| blob = _INST.get("blob") | |
| if not blob: | |
| return JSONResponse({"error": "No deployment blob provided."}, 400) | |
| c = ((blob.get("supportedActions") or {}).get("deploy") or {}).get("containerSpec") \ | |
| or blob.get("container") | |
| if not isinstance(c, dict) or not c: | |
| return JSONResponse({"error": "Blob missing containerSpec."}, 400) | |
| image = (c.get("imageUri") or "").strip() | |
| if not image: | |
| return JSONResponse({"error": "containerSpec.imageUri missing."}, 400) | |
| _log_create(f"imageName: {image}") | |
| env_list = c.get("env") or [] | |
| env_obj = {e.get("name"): e.get("value") for e in env_list | |
| if isinstance(e, dict) and e.get("name")} | |
| _log_create(f"env: {json.dumps(env_obj, ensure_ascii=False)}") | |
| ports_list = c.get("ports") or [] | |
| rp_ports = [] | |
| for p in ports_list: | |
| if isinstance(p, dict): | |
| cp = p.get("containerPort") | |
| proto = (p.get("protocol") or "http").lower() | |
| if proto not in ("http", "tcp"): | |
| proto = "http" | |
| if isinstance(cp, int): | |
| rp_ports.append(f"{cp}/{proto}") | |
| if not rp_ports: | |
| return JSONResponse({"error": "ports[].containerPort required."}, 400) | |
| _log_create(f"ports: {rp_ports}") | |
| command = c.get("command") if isinstance(c.get("command"), list) else None | |
| args = c.get("args") if isinstance(c.get("args"), list) else None | |
| if command: _log_create(f"command: {command}") | |
| if args: _log_create(f"args: {args}") | |
| # GPU normalization (enum -> pretty string); include only if non-empty | |
| dr = c.get("dedicatedResources") or {} | |
| gpu_ids = None | |
| gpu_count = 1 | |
| if isinstance(dr, dict): | |
| typ = (dr.get("machineSpec", {}) or {}).get("acceleratorType") | |
| cnt = (dr.get("machineSpec", {}) or {}).get("acceleratorCount") | |
| if typ: gpu_ids = [typ] if isinstance(typ, str) else typ | |
| if isinstance(cnt, int) and cnt > 0: gpu_count = cnt | |
| def _normalize_gpu_enum(s: str) -> str: | |
| if not isinstance(s, str) or not s.strip(): | |
| return "" | |
| t = s.strip().upper().replace("_", " ") | |
| vendor = "NVIDIA" | |
| if t.startswith("NVIDIA "): | |
| t = t[len("NVIDIA "):] | |
| elif t.startswith("AMD "): | |
| vendor = "AMD"; t = t[len("AMD "):] | |
| t = re.sub(r"(\d)(GB\b)", r"\1 \2", t) # 80GB -> 80 GB | |
| return f"{vendor} {t}".strip() | |
| rp_gpu = None | |
| if gpu_ids: | |
| rp_gpu = _normalize_gpu_enum(gpu_ids[0]).strip() or None | |
| _log_create(f"GPU_TRANSLATION original={gpu_ids[0]} -> runpod='{rp_gpu}'") | |
| _log_create("SECURE_PLACEMENT interruptible=false") | |
| payload = { | |
| "name": re.sub(r"[^a-z0-9-]", "-", f"ephemeral-{int(time.time())}".lower()), | |
| "computeType": "GPU", | |
| "interruptible": False, # On-Demand (not community) | |
| "imageName": image, | |
| "gpuCount": gpu_count, | |
| "ports": rp_ports, | |
| "supportPublicIp": True, | |
| **({"gpuTypeIds": [rp_gpu]} if rp_gpu else {}), | |
| **({"env": env_obj} if env_obj else {}), | |
| **({"command": command} if command else {}), | |
| **({"args": args} if args else {}), | |
| } | |
| _log_create(f"PAYLOAD_SENT {json.dumps(payload, ensure_ascii=False)}") | |
| content = {} | |
| pid = None | |
| try: | |
| r = requests.post(f"{_RP_BASE}/pods", headers=_rp_headers(), json=payload, timeout=60) | |
| content = _as_json(r) | |
| _log_create(f"RUNPOD_RESPONSE {json.dumps(content, ensure_ascii=False)}") | |
| pid = content.get("id") | |
| if not pid and isinstance(content, dict): | |
| for v in content.values(): | |
| if isinstance(v, dict) and "id" in v: | |
| pid = v["id"]; break | |
| except Exception as e: | |
| _log_create(f"ERROR Create: {e}") | |
| return JSONResponse({"error": f"RunPod create failed: {e}"}, 500) | |
| _log_create(f"ID: {pid}") | |
| if not isinstance(r, requests.Response): | |
| return JSONResponse({"error": "No HTTP response from RunPod create."}, 502) | |
| if not r.ok: | |
| return JSONResponse(content if isinstance(content, dict) else {"_raw": str(content)}, r.status_code) | |
| if not pid: | |
| return JSONResponse({"error": "Create succeeded but no pod ID in response.", "raw": content}, 502) | |
| # cache pod id | |
| try: | |
| _INST["podId"] = str(pid).strip() | |
| _log_id("CREATE_SET", _INST["podId"]) | |
| _save_state() | |
| except Exception as e: | |
| return JSONResponse({"error": f"Could not cache pod ID: {e}"}, 502) | |
| # start the pod immediately so networking/IP can come up | |
| try: | |
| sr = requests.post(f"{_RP_BASE}/pods/{_INST['podId']}/start", headers=_rp_headers(), timeout=30) | |
| scontent = _as_json(sr) | |
| _log_status(f"START_RESPONSE {json.dumps(scontent, ensure_ascii=False)}") | |
| except Exception as e: | |
| _log_status(f"ERROR Start: {e}") | |
| # initial status snapshot | |
| try: | |
| rs = requests.get(f"{_RP_BASE}/pods/{_INST['podId']}", headers=_rp_headers(), timeout=30) | |
| st = _as_json(rs) | |
| _log_status(f"STATUS_POLL {json.dumps(st, ensure_ascii=False)}") | |
| content["_status"] = st | |
| except Exception as e: | |
| content["_status_error"] = str(e) | |
| _log_status(f"ERROR Status: {e}") | |
| _INST["status"] = content.get("desiredStatus") or content.get("status") or "" | |
| _INST["ip"] = _INST.get("ip") or "" | |
| _INST["port"] = _INST.get("port") or "" | |
| return JSONResponse(content, r.status_code) | |
| # --------------------------------------------------------------------- | |
| # Poll / read instance status + explicit readiness fields | |
| # --------------------------------------------------------------------- | |
| def api_get_instance(pod_id: str = None): | |
| pid = (pod_id or _INST.get("podId") or "").strip() | |
| if not pid: | |
| return JSONResponse({"error": "pod_id missing."}, 400) | |
| _log_id("STATUS_USES", pid) | |
| try: | |
| r = requests.get(f"{_RP_BASE}/pods/{pid}", headers=_rp_headers(), timeout=30) | |
| last = _as_json(r) | |
| _log_status(f"STATUS_POLL {json.dumps(last, ensure_ascii=False)}") | |
| except Exception as e: | |
| return JSONResponse({"error": f"poll failed: {e}"}, 502) | |
| declared = None | |
| try: | |
| c = (((_INST.get("blob") or {}).get("supportedActions") or {}).get("deploy") or {}).get("containerSpec") \ | |
| or (_INST.get("blob") or {}).get("container") or {} | |
| declared = int((c.get("ports") or [])[0].get("containerPort")) | |
| except Exception: | |
| c = {} | |
| pass | |
| if isinstance(last, dict): | |
| ip = last.get("publicIp") or "" | |
| pm = last.get("portMappings") or {} | |
| if ip and isinstance(pm, dict) and pm: | |
| # choose mapped public port for the declared internal port; else first mapping | |
| if isinstance(declared, int) and str(declared) in pm: | |
| chosen = str(pm[str(declared)]) | |
| else: | |
| k = next(iter(pm.keys())) | |
| chosen = str(pm[k]) | |
| _log_status(f"PORT_MAPPING declared={declared} not_found_using_first key={k}") | |
| _INST.update({"podId": pid, | |
| "status": last.get("desiredStatus", ""), | |
| "ip": ip, | |
| "port": chosen}) | |
| _save_state() | |
| base = f"http://{_INST['ip']}:{_INST['port']}" | |
| _log_status(f"PORT_MAPPING declared={declared} chosen={chosen} all={pm}") | |
| _log_status(f"RESOLVED_ENDPOINT base={base}") | |
| # --- NEW: health-first readiness (mirror Vertex), fallback to predict existence --- | |
| hr = (_INST.get("healthRoute") or "/health").strip() | |
| pr = (_INST.get("predictRoute") or "/predict").strip() | |
| code_h, ms_h, snippet_h = _probe("GET", f"{base}{hr}") | |
| _log_status(f"HEALTH_PROBE path={hr} code={code_h} ms={ms_h} body_snippet={snippet_h[:120]}") | |
| #if code_h in (200, 204): | |
| #_INST["status"] = "READY" | |
| #else: | |
| #code_p, ms_p, _ = _probe("HEAD", f"{base}{pr}") | |
| #_log_status(f"PREDICT_PROBE path={pr} code={code_p} ms={ms_p}") | |
| #if code_p in (200, 204, 400, 405): | |
| # _INST["status"] = "READY" | |
| # Final prompt URL (prefer IP; else proxy host) | |
| proute = _INST.get("predictRoute") or "/predict" | |
| if _INST.get("ip") and _INST.get("port"): | |
| prompt_url = f"http://{_INST['ip']}:{_INST['port']}{proute}" | |
| else: | |
| proxy_base = f"https://{pid}-{declared}.proxy.runpod.net" if declared else "" | |
| prompt_url = f"{proxy_base}{proute}" if proxy_base else "" | |
| if prompt_url: | |
| _log_status(f"PROMPT_ENDPOINT {prompt_url}") | |
| # Always include cached readiness data for the UI | |
| merged = {**last, "cachedState": { | |
| "podId": _INST.get("podId"), | |
| "status": _INST.get("status"), | |
| "ip": _INST.get("ip"), | |
| "port": _INST.get("port"), | |
| "predictRoute": _INST.get("predictRoute"), | |
| "healthRoute": _INST.get("healthRoute"), | |
| }} | |
| return JSONResponse(merged) | |
| # --------------------------------------------------------------------- | |
| # Start, Stop, End All — same as before | |
| # --------------------------------------------------------------------- | |
| def api_start_instance(pod_id:str): | |
| _log_id("START_USES",pod_id) | |
| try: | |
| r=requests.post(f"{_RP_BASE}/pods/{pod_id}/start", | |
| headers=_rp_headers(),timeout=30) | |
| payload=_as_json(r) | |
| _log_status(f"START_RESPONSE {json.dumps(payload,ensure_ascii=False)}") | |
| return JSONResponse(payload, r.status_code) | |
| except Exception as e: | |
| _log_status(f"ERROR Start: {e}") | |
| return JSONResponse({"error":f"RunPod start failed: {e}"},500) | |
| async def api_delete_instance(): | |
| pid = (_INST.get("podId") or "").strip() | |
| if not pid: | |
| return JSONResponse({"error": "pod_id missing and no cached pod found."}, status_code=400) | |
| _log_id("STOP_USES", pid) | |
| try: | |
| _log_delete(">>> STOP endpoint triggered") | |
| r = requests.post(f"{_RP_BASE}/pods/{pid}/stop", headers=_rp_headers(), timeout=60) | |
| payload = _as_json(r) | |
| _log_delete(f"STOP_RESPONSE {json.dumps(payload, ensure_ascii=False)}") | |
| return JSONResponse(status_code=r.status_code, content=payload) | |
| except Exception as e: | |
| _log_delete(f"ERROR Stop: {e}") | |
| return JSONResponse(status_code=500, content={"error": f"RunPod stop failed: {e}"}) | |
| async def api_end_all(): | |
| pid = (_INST.get("podId") or "").strip() | |
| if not pid: | |
| return JSONResponse({"error": "pod_id missing and no cached pod found."}, status_code=400) | |
| _log_id("DELETE_USES", pid) | |
| try: | |
| _log_delete(">>> END-ALL endpoint triggered") | |
| r = requests.delete(f"{_RP_BASE}/pods/{pid}", headers=_rp_headers(), timeout=60) | |
| payload = _as_json(r) | |
| _log_delete(f"DELETE_RESPONSE {json.dumps(payload, ensure_ascii=False)}") | |
| if r.status_code in (200, 202, 204): | |
| _INST.update({"podId": "", "status": "", "ip": "", "port": ""}) | |
| _save_state() | |
| return JSONResponse(status_code=r.status_code, content=payload) | |
| except Exception as e: | |
| _log_delete(f"ERROR Delete: {e}") | |
| return JSONResponse(status_code=500, content={"error": f"RunPod delete failed: {e}"}) | |
| # --------------------------------------------------------------------- | |
| # Wait instance | |
| # --------------------------------------------------------------------- | |
| def api_wait_instance(pod_id: str = None): | |
| pid = (pod_id or _INST.get("podId") or "").strip() | |
| if not pid: | |
| return JSONResponse({"error": "pod_id missing."}, status_code=400) | |
| try: | |
| r = requests.get(f"{_RP_BASE}/pods/{pid}", headers=_rp_headers(), timeout=30) | |
| last = _as_json(r) | |
| _log_status(f"WAIT_STATUS {json.dumps(last, ensure_ascii=False)}") | |
| except Exception as e: | |
| return JSONResponse({"error": f"wait poll failed: {e}"}, status_code=502) | |
| ip = last.get("publicIp") or _INST.get("ip") | |
| pm = last.get("portMappings") or {} | |
| port = None | |
| declared = None | |
| try: | |
| c = (((_INST.get("blob") or {}).get("supportedActions") or {}).get("deploy") or {}).get("containerSpec") \ | |
| or (_INST.get("blob") or {}).get("container") or {} | |
| declared = int((c.get("ports") or [])[0].get("containerPort")) | |
| except Exception: | |
| c = {} | |
| pass | |
| if ip and pm: | |
| try: | |
| if isinstance(declared, int) and str(declared) in pm: | |
| port = str(pm[str(declared)]) | |
| except Exception: | |
| pass | |
| if not port and "8080" in pm: | |
| port = str(pm["8080"]) | |
| elif not port and pm: | |
| port = str(pm[next(iter(pm.keys()))]) | |
| if ip and port: | |
| base = f"http://{ip}:{port}" | |
| _log_status(f"RESOLVED_IP {base}") | |
| # --- NEW: health-first readiness (mirror Vertex), fallback to predict existence --- | |
| hr = (_INST.get("healthRoute") or "/health").strip() | |
| pr = (_INST.get("predictRoute") or "/predict").strip() | |
| code_h, ms_h, snippet_h = _probe("GET", f"{base}{hr}") | |
| _log_status(f"HEALTH_PROBE path={hr} code={code_h} ms={ms_h} body_snippet={snippet_h[:120]}") | |
| # if code_h in (200, 204): | |
| # _INST["status"] = "READY" | |
| #else: | |
| #code_p, ms_p, _ = _probe("HEAD", f"{base}{pr}") | |
| #_log_status(f"PREDICT_PROBE path={pr} code={code_p} ms={ms_p}") | |
| #if code_p in (200, 204, 400, 405): | |
| #_INST["status"] = "READY" | |
| if _INST.get("predictRoute"): | |
| _log_status(f"PROMPT_ENDPOINT {base}{_INST['predictRoute']}") | |
| try: | |
| cspec = _get_container_spec() | |
| internal, _ = _get_port_and_proto(cspec) | |
| if internal: | |
| proxy_base = f"https://{pid}-{internal}.proxy.runpod.net" | |
| _log_status(f"RESOLVED_PROXY {proxy_base}") | |
| _INST["base"] = proxy_base | |
| _save_state() | |
| except Exception: | |
| pass | |
| _INST.update({"ip": ip or "", "port": port or "", "status": last.get("desiredStatus", "")}) | |
| _save_state() | |
| merged = { | |
| **last, | |
| "cachedState": { | |
| "podId": _INST.get("podId"), | |
| "status": _INST.get("status"), | |
| "ip": _INST.get("ip"), | |
| "port": _INST.get("port"), | |
| "base": _INST.get("base"), | |
| "predictRoute": _INST.get("predictRoute"), | |
| "healthRoute": _INST.get("healthRoute"), | |
| }, | |
| } | |
| return JSONResponse(merged) | |
| # --------------------------------------------------------------------- | |
| # Debug: live probes against the instance (IP + Proxy) | |
| # --------------------------------------------------------------------- | |
| def api_debug_probes(pod_id: str = None): | |
| pid = (pod_id or _INST.get("podId") or "").strip() | |
| if not pid: | |
| return JSONResponse({"error": "pod_id missing."}, 400) | |
| # latest pod object (for portMappings/publicIp) | |
| try: | |
| r = requests.get(f"{_RP_BASE}/pods/{pid}", headers=_rp_headers(), timeout=20) | |
| pod = _as_json(r) | |
| _log_status(f"DEBUG_POD_OBJ {json.dumps(pod, ensure_ascii=False)}") | |
| except Exception as e: | |
| return JSONResponse({"error": f"pod fetch failed: {e}"}, 502) | |
| ip = pod.get("publicIp") or _INST.get("ip") | |
| pm = pod.get("portMappings") or {} | |
| # choose internal/public ports | |
| internal = None | |
| try: | |
| cs = _get_container_spec() | |
| internal = int((cs.get("ports") or [])[0].get("containerPort")) | |
| except Exception: | |
| pass | |
| if internal and str(internal) in pm: | |
| public = str(pm[str(internal)]) | |
| elif "8080" in pm: | |
| internal, public = 8080, str(pm["8080"]) | |
| elif pm: | |
| k = next(iter(pm.keys())) | |
| internal, public = int(k), str(pm[k]) | |
| else: | |
| public = None | |
| # candidate paths | |
| healths = [(_INST.get("healthRoute") or "").strip(), "/health", "/ping", "/healthz", "/v1/models"] | |
| healths = [p for p in healths if p] | |
| predicts = [(_INST.get("predictRoute") or "").strip(), "/generate", "/predict", "/predictions", | |
| "/v1/chat/completions", "/v1/models/model:predict"] | |
| predicts = [p for p in predicts if p] | |
| results = {"podId": pid, "ip": ip, "internalPort": internal, "publicPort": public, "probes": []} | |
| # base URLs (IP and proxy) | |
| bases = [] | |
| if ip and public: | |
| bases.append(f"http://{ip}:{public}") | |
| if internal: | |
| bases.append(f"https://{pid}-{internal}.proxy.runpod.net") | |
| # probe health | |
| for base in bases: | |
| for hp in healths: | |
| code, ms, snippet = _probe("GET", f"{base}{hp}") | |
| _log_status(f"DEBUG_HEALTH base={base} path={hp} code={code} ms={ms}") | |
| results["probes"].append({"base": base, "path": hp, "kind": "health", "code": code, "ms": ms, "snippet": snippet}) | |
| # probe predict (HEAD) | |
| for base in bases: | |
| for pp in predicts: | |
| code, ms, _ = _probe("HEAD", f"{base}{pp}") | |
| _log_status(f"DEBUG_PREDICT base={base} path={pp} code={code} ms={ms}") | |
| results["probes"].append({"base": base, "path": pp, "kind": "predict", "code": code, "ms": ms}) | |
| return JSONResponse(results, 200) | |
| # --------------------------------------------------------------------- | |
| # Helper functions for containerSpec parsing | |
| # --------------------------------------------------------------------- | |
| def _get_container_spec(): | |
| blob = _INST.get("blob") | |
| if not blob: | |
| lb = _load_local_blob() | |
| if lb: | |
| _ingest_blob(lb) | |
| blob = lb | |
| return (((blob.get("supportedActions") or {}).get("deploy") or {}).get("containerSpec") | |
| or blob.get("container") or {}) | |
| def _get_port_and_proto(cspec: dict): | |
| try: | |
| ports = cspec.get("ports") or [] | |
| if isinstance(ports, list) and ports: | |
| p0 = ports[0] | |
| internal = p0.get("containerPort") | |
| proto = (p0.get("protocol") or "").lower() or None | |
| return (int(internal) if str(internal).isdigit() else None, proto) | |
| except Exception: | |
| pass | |
| return (None, None) | |
| def _build_proxy_url(route: str) -> str: | |
| pid = (_INST.get("podId") or "").strip() | |
| if not pid: | |
| raise HTTPException(status_code=400, detail="No podId in cache. Create/Start the instance first.") | |
| cspec = _get_container_spec() | |
| internal_port, _ = _get_port_and_proto(cspec) | |
| if not internal_port: | |
| raise HTTPException(status_code=400, detail="Cannot resolve internal port from containerSpec.ports[].") | |
| return f"https://{pid}-{internal_port}.proxy.runpod.net{route}" | |
| def _build_ip_url(route: str) -> str: | |
| ip, port = _INST.get("ip"), _INST.get("port") | |
| if not (ip and port): | |
| raise HTTPException(status_code=400, detail="No running instance (ip/port missing).") | |
| return f"http://{ip}:{port}{route}" | |
| def _resolve_infer_url(route: str) -> str: | |
| cspec = _get_container_spec() | |
| _, proto = _get_port_and_proto(cspec) | |
| try: | |
| if _INST.get("ip") and _INST.get("port"): | |
| url = _build_ip_url(route) | |
| _job_log("compute", f"[MW] Using IP path: {url}") | |
| return url | |
| except HTTPException: | |
| pass | |
| if proto == "http" or True: | |
| url = _build_proxy_url(route) | |
| _job_log("compute", f"[MW] Using Proxy path: {url}") | |
| return url | |
| return _build_ip_url(route) | |
| # --------------------------------------------------------------------- | |
| # /api/infer — updated to use resolver | |
| # --------------------------------------------------------------------- | |
| async def api_infer(req: Request): | |
| route = _INST.get("predictRoute") | |
| if not route: | |
| return JSONResponse( | |
| {"error": "predictRoute unresolved; check ROUTE_PROBE logs and HEALTH results."}, | |
| status_code=428 | |
| ) | |
| body = await req.json() | |
| try: | |
| url = _resolve_infer_url(route) | |
| r = requests.post(url, json=body, timeout=120) | |
| ct = (r.headers.get("content-type") or "").lower() | |
| if "application/json" in ct: | |
| return JSONResponse(status_code=r.status_code, content=r.json()) | |
| return HTMLResponse(status_code=r.status_code, content=r.text) | |
| except HTTPException as he: | |
| return JSONResponse({"error": he.detail}, status_code=he.status_code) | |
| except Exception as e: | |
| return JSONResponse({"error": f"inference request failed: {e}"}, status_code=502) | |
| # --------------------------------------------------------------------- | |
| # /api/middleware/infer — middleware prompt routing and normalization | |
| # --------------------------------------------------------------------- | |
| async def api_middleware_infer(req: Request): | |
| # Always ensure predictRoute exists | |
| route = _INST.get("predictRoute") or "/predict" | |
| _INST["predictRoute"] = route | |
| # Build deterministic proxy URL instead of waiting on readiness | |
| pid = (_INST.get("podId") or "").strip() | |
| if not pid: | |
| try: | |
| _load_state() | |
| pid = (_INST.get("podId") or "").strip() | |
| except Exception: | |
| pass | |
| if not pid: | |
| return JSONResponse({"error": "no podId yet (create/start first)"}, status_code=400) | |
| cspec = _get_container_spec() | |
| internal, _ = _get_port_and_proto(cspec) | |
| if not internal: | |
| return JSONResponse({"error": "cannot resolve internal port from blob"}, status_code=400) | |
| # ---------------- NEW: routing override for HF_TASK=text-to-image ---------------- | |
| env = (cspec.get("env") or []) | |
| kv = {e.get("name"): e.get("value") for e in env if isinstance(e, dict) and e.get("name")} | |
| hf_task = (kv.get("HF_TASK") or "").strip().lower() | |
| model_id = (kv.get("MODEL_ID") or kv.get("HF_MODEL_ID") or "").strip() | |
| if hf_task == "text-to-image" and model_id: | |
| route = f"/predictions/{model_id}" | |
| _INST["predictRoute"] = route | |
| # ------------------------------------------------------------------------------- | |
| base = f"https://{pid}-{internal}.proxy.runpod.net" | |
| url = f"{base}{route}" | |
| _log_status(f"PROMPT_ENDPOINT {url}") | |
| _job_log("compute", f"[MW] Forwarding infer to {url}") | |
| payload = await req.json() | |
| prompt = payload.get("prompt") | |
| if not isinstance(prompt, str) or not prompt.strip(): | |
| return JSONResponse({"error": "Missing 'prompt' in request body."}, 400) | |
| # ---------------- NEW: canonical text-to-image payload ---------------- | |
| if hf_task == "text-to-image": | |
| body = { | |
| "inputs": prompt, | |
| "parameters": { | |
| "num_inference_steps": 30, | |
| "guidance_scale": 7.5, | |
| "width": 1024, | |
| "height": 1024 | |
| } | |
| } | |
| try: | |
| rp = requests.post(url, json=body, timeout=120) | |
| _log_status(f"PREDICT_RESP code={rp.status_code} len={len(rp.text)}") | |
| if rp.ok: | |
| ct = (rp.headers.get("content-type") or "").lower() | |
| data = _as_json(rp) if "application/json" in ct else {"_raw": rp.text} | |
| if isinstance(data, dict) and "image_b64" in data: | |
| return JSONResponse({"image_b64": data["image_b64"], "timings": data.get("timings")}, rp.status_code) | |
| return JSONResponse(data, rp.status_code) | |
| return JSONResponse({"error": rp.text[:400]}, status_code=rp.status_code) | |
| except Exception as e: | |
| _log_status(f"PREDICT_ERR {e}") | |
| return JSONResponse({"error": f"inference request failed: {e}"}, status_code=502) | |
| # --------------------------------------------------------------------- | |
| # HF text-classification shim | |
| img = (_get_container_spec().get("imageUri","")).lower() | |
| if "huggingface-pytorch-inference" in img and isinstance(payload.get("prompt"), str): | |
| payload = {"instances": [payload["prompt"]]} | |
| # Non-image fallback (unchanged) | |
| bodies = [payload, {"prompt": prompt}, {"text": prompt}, {"inputs": prompt}, {"input": prompt}] | |
| for body in bodies: | |
| try: | |
| rp = requests.post(url, json=body, timeout=120) | |
| _log_status(f"PREDICT_RESP code={rp.status_code} len={len(rp.text)}") | |
| if rp.ok: | |
| ct = (rp.headers.get("content-type") or "").lower() | |
| data = _as_json(rp) if "application/json" in ct else {"_raw": rp.text} | |
| if isinstance(data, dict): | |
| if "image_b64" in data: | |
| return JSONResponse({"image_b64": data["image_b64"], "timings": data.get("timings")}, rp.status_code) | |
| if isinstance(data.get("output"), str): | |
| return JSONResponse({"output": data["output"]}, rp.status_code) | |
| if "_raw" in data: | |
| return JSONResponse({"output": data["_raw"]}, rp.status_code) | |
| return JSONResponse({"output": json.dumps(data, ensure_ascii=False)}, rp.status_code) | |
| return JSONResponse({"output": str(data)}, rp.status_code) | |
| except Exception as e: | |
| _log_status(f"PREDICT_ERR {e}") | |
| # Fallthrough: show last response or generic error | |
| try: | |
| return JSONResponse({"error": rp.text[:400]}, status_code=rp.status_code) | |
| except Exception: | |
| return JSONResponse({"error": "no response from model"}, status_code=504) | |
| # --------------------------------------------------------------------- | |
| # Job progress + callback routes | |
| # --------------------------------------------------------------------- | |
| async def api_job_ready(req: Request): | |
| return JSONResponse({"ok": True}) | |
| async def api_job_progress(req: Request): | |
| data = await req.json() | |
| job_id = str(data.get("job_id", "unknown")) | |
| msg = data.get("message", "") | |
| _job_log(job_id, msg or "progress") | |
| return JSONResponse({"ok": True}) | |
| async def api_job_done(req: Request): | |
| data = await req.json() | |
| job_id = str(data.get("job_id", "unknown")) | |
| j = _JOBS.setdefault(job_id, {"status": "created", "logs": [], "image_b64": None, "timings": {}}) | |
| j["status"] = "done" | |
| j["image_b64"] = data.get("image_b64") | |
| j["timings"] = data.get("timings", {}) | |
| _job_log(job_id, "completed") | |
| return JSONResponse({"ok": True}) | |
| def api_job_status(job_id: str): | |
| return JSONResponse(_JOBS.get(job_id, {"status": "unknown"})) | |