from fastapi import APIRouter, HTTPException, Request, Form from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse import os, json, requests, time, re, pathlib router = APIRouter() # --------------------------------------------------------------------- # In-memory job + instance store # --------------------------------------------------------------------- _JOBS = {} _INST = { "podId": "", "status": "", "ip": "", "port": "", "blob": None, "model_id": "", "container_image_hint": "", "predictRoute": None, "healthRoute": None, "readinessRoute": None, "livenessRoute": None, } def _now_ms(): return int(time.time() * 1000) def _job_log(job_id, msg): j = _JOBS.setdefault(job_id, {"status":"created","logs":[], "image_b64":None,"timings":{}}) j["logs"].append({"t":_now_ms(),"msg":msg}) print(f"[{job_id}] {msg}", flush=True) def _log_create(msg): _job_log("compute", f"[CREATE] {msg}") def _log_status(msg): _job_log("compute", f"[STATUS] {msg}") def _log_delete(msg): _job_log("compute", f"[DELETE] {msg}") def _log_id(prefix, pid): _job_log("compute", f"{prefix} ID: {pid}") # --- local blob ingest (landing page only) --- _LOCAL_BLOB_PATH = os.getenv("MODEL_BLOB_PATH", "model_blob.json") def _load_local_blob(): try: if os.path.exists(_LOCAL_BLOB_PATH): with open(_LOCAL_BLOB_PATH, "r", encoding="utf-8") as f: return json.load(f) except Exception as e: _job_log("compute", f"ERROR LocalBlobLoad: {e}") return None def _ingest_blob(parsed: dict, model_id_hint: str = "", container_image_hint: str = ""): if not isinstance(parsed, dict): raise HTTPException(400, "Invalid blob (expected JSON object).") _INST.update({ "blob": parsed, "model_id": model_id_hint or "", "container_image_hint": container_image_hint or "", }) c = (((parsed.get("supportedActions") or {}).get("deploy") or {}).get("containerSpec") or parsed.get("container") or {}) or {} for k in ("predictRoute", "healthRoute", "readinessRoute", "livenessRoute"): v = c.get(k) if isinstance(v, str) and v.strip(): _INST[k] = v.strip() image_uri = (c.get("imageUri") or "").strip().lower() pr, hr = _infer_routes_from_image(image_uri) if pr and not _INST.get("predictRoute"): _INST["predictRoute"] = pr if hr and not _INST.get("healthRoute"): _INST["healthRoute"] = hr return True # --------------------------------------------------------------------- # Disk persistence for recovery # --------------------------------------------------------------------- _STATE_PATH = "/tmp/pod_state.json" def _save_state(): try: pathlib.Path("/tmp").mkdir(parents=True, exist_ok=True) with open(_STATE_PATH, "w") as f: json.dump({k:_INST.get(k,"") for k in ("podId","status","ip","port")}, f) except Exception as e: _job_log("compute", f"ERROR SaveState: {e}") def _load_state(): try: if os.path.exists(_STATE_PATH): with open(_STATE_PATH) as f: d = json.load(f) for k in ("podId","status","ip","port"): if k in d: _INST[k]=d[k] except Exception as e: _job_log("compute", f"ERROR LoadState: {e}") # --------------------------------------------------------------------- # RunPod helpers # --------------------------------------------------------------------- _RP_BASE = "https://rest.runpod.io/v1" def _rp_headers(): key=os.getenv("RunPod","").strip() if not key: raise HTTPException(500,"Missing RunPod API key (env var 'RunPod').") return {"Authorization":f"Bearer {key}","Content-Type":"application/json"} def _as_json(r): c=(r.headers.get("content-type") or "").lower() if "json" in c: try: return r.json() except Exception: return {"_raw":r.text} return {"_raw":r.text} # --------------------------------------------------------------------- # Probes and route discovery (new) # --------------------------------------------------------------------- # Expanded set: will try these against https://pod:port/ _POSSIBLE_ROUTES = [ "/invocations", # <— added and placed first "/generate", "/predict", "/predictions", "/v1/chat/completions", "/v1/models/model:predict", ] def _infer_routes_from_image(image_uri: str): """ Infer (predict_route, health_route) from known image patterns. """ iu = (image_uri or "").lower() # vLLM images if "vllm-serve" in iu: return ("/generate", "/ping") # HuggingFace / Vertex HF Inference Toolkit # changed from "/predict" → "/invocations" if "hf-inference-toolkit" in iu or "huggingface-pytorch-inference" in iu: return ("/invocations", "/ping") # Unknown image → allow route scanning fallback return (None, None) async def _probe_all_routes(base: str, port: str, session): """ Try all known routes until one responds 200/OK-ish. Returns (predict_route, health_route or None) """ from urllib.parse import urljoin proto_base = f"{base}:{port}" for route in _POSSIBLE_ROUTES: url = urljoin(proto_base + "/", route.lstrip("/")) try: r = await session.get(url, timeout=3) if r.status_code < 500: return route, ("/ping" if "/ping" in route else None) except Exception: pass return None, None # --------------------------------------------------------------------- # Blob ingest via Model Blob page JSON (with blob_url override) # --------------------------------------------------------------------- _HF_SPACE_PORT = os.getenv("PORT", "7860") _LOCAL_BASE = f"http://127.0.0.1:{_HF_SPACE_PORT}" def _normalize_blob_url(u: str | None) -> str | None: if not u: return None u = str(u).strip() if u.startswith(("http://", "https://")): return u # Treat '/x' or 'x' as local to this app (same origin as FE) if u.startswith("/"): return f"{_LOCAL_BASE}{u}" return f"{_LOCAL_BASE}/{u}" def _fetch_url(u: str): try: r = requests.get(u, timeout=8) if r.ok: return r.json() _job_log("compute", f"ERROR BlobFetch code={r.status_code} url={u} body={r.text[:200]}") except Exception as e: _job_log("compute", f"ERROR BlobFetch url={u}: {e}") return None def _fetch_blob_from_page(): return _fetch_url(f"{_LOCAL_BASE}/modelblob.json") def _ingest_blob(parsed: dict, model_id_hint: str = "", container_image_hint: str = ""): if not isinstance(parsed, dict): raise HTTPException(400, "Invalid blob (expected JSON object).") _INST.update({ "blob": parsed, "model_id": model_id_hint or "", "container_image_hint": container_image_hint or "", }) c = (((parsed.get("supportedActions") or {}).get("deploy") or {}).get("containerSpec") or parsed.get("container") or {}) or {} for k in ("predictRoute", "healthRoute", "readinessRoute", "livenessRoute"): v = c.get(k) if isinstance(v, str) and v.strip(): _INST[k] = v.strip() image_uri = (c.get("imageUri") or "").strip().lower() pr, hr = _infer_routes_from_image(image_uri) if pr and not _INST.get("predictRoute"): _INST["predictRoute"] = pr if hr and not _INST.get("healthRoute"): _INST["healthRoute"] = hr # <-- only addition: cache healthRoute hint return True @router.post("/api/ingest/from_landing") def api_ingest_from_landing(blob_url: str | None = None): """ Ingest the deployment blob for downstream use. Mirrors FE behavior: resolve relative paths like '/modelblob.json' against the app origin. """ u = _normalize_blob_url(blob_url) or _normalize_blob_url("/modelblob.json") parsed = _fetch_url(u) if not parsed: return JSONResponse({"error": "Blob not available"}, 404) _ingest_blob(parsed, model_id_hint="", container_image_hint="") return JSONResponse({"ok": True, "source": u}) # (Optional compatibility: UI posting to /Deployment_UI; accepts blob_url via query) @router.post("/Deployment_UI") async def deployment_ui_ingest(request: Request, model_id: str = Form(""), container_image: str = Form(""), blob: str = Form("")): """ Legacy entry used by the Deployment UI page. Prefers blob_url from query string; falls back to the modelblob page JSON. """ blob_url = request.query_params.get("blob_url") u = _normalize_blob_url(blob_url) if blob_url else _normalize_blob_url("/modelblob.json") parsed = _fetch_url(u) if not parsed: return HTMLResponse("
Missing blob (no /modelblob.json and no blob_url)
", 400) _ingest_blob(parsed, model_id_hint=model_id, container_image_hint=container_image) return RedirectResponse("/Deployment_UI", 303) # --------------------------------------------------------------------- # Create instance # --------------------------------------------------------------------- @router.post("/api/compute/create_instance") async def api_create_instance(req: Request): # Ensure blob is present (lazy-load from landing file if needed) if not _INST.get("blob"): lb = _load_local_blob() if lb: _ingest_blob(lb) blob = _INST.get("blob") if not blob: return JSONResponse({"error": "No deployment blob provided."}, 400) c = ((blob.get("supportedActions") or {}).get("deploy") or {}).get("containerSpec") \ or blob.get("container") if not isinstance(c, dict) or not c: return JSONResponse({"error": "Blob missing containerSpec."}, 400) image = (c.get("imageUri") or "").strip() if not image: return JSONResponse({"error": "containerSpec.imageUri missing."}, 400) _log_create(f"imageName: {image}") env_list = c.get("env") or [] env_obj = {e.get("name"): e.get("value") for e in env_list if isinstance(e, dict) and e.get("name")} _log_create(f"env: {json.dumps(env_obj, ensure_ascii=False)}") ports_list = c.get("ports") or [] rp_ports = [] for p in ports_list: if isinstance(p, dict): cp = p.get("containerPort") proto = (p.get("protocol") or "http").lower() if proto not in ("http", "tcp"): proto = "http" if isinstance(cp, int): rp_ports.append(f"{cp}/{proto}") if not rp_ports: return JSONResponse({"error": "ports[].containerPort required."}, 400) _log_create(f"ports: {rp_ports}") command = c.get("command") if isinstance(c.get("command"), list) else None args = c.get("args") if isinstance(c.get("args"), list) else None if command: _log_create(f"command: {command}") if args: _log_create(f"args: {args}") # GPU normalization (enum -> pretty string); include only if non-empty dr = c.get("dedicatedResources") or {} gpu_ids = None gpu_count = 1 if isinstance(dr, dict): typ = (dr.get("machineSpec", {}) or {}).get("acceleratorType") cnt = (dr.get("machineSpec", {}) or {}).get("acceleratorCount") if typ: gpu_ids = [typ] if isinstance(typ, str) else typ if isinstance(cnt, int) and cnt > 0: gpu_count = cnt def _normalize_gpu_enum(s: str) -> str: if not isinstance(s, str) or not s.strip(): return "" t = s.strip().upper().replace("_", " ") vendor = "NVIDIA" if t.startswith("NVIDIA "): t = t[len("NVIDIA "):] elif t.startswith("AMD "): vendor = "AMD"; t = t[len("AMD "):] t = re.sub(r"(\d)(GB\b)", r"\1 \2", t) # 80GB -> 80 GB return f"{vendor} {t}".strip() rp_gpu = None if gpu_ids: rp_gpu = _normalize_gpu_enum(gpu_ids[0]).strip() or None _log_create(f"GPU_TRANSLATION original={gpu_ids[0]} -> runpod='{rp_gpu}'") _log_create("SECURE_PLACEMENT interruptible=false") payload = { "name": re.sub(r"[^a-z0-9-]", "-", f"ephemeral-{int(time.time())}".lower()), "computeType": "GPU", "interruptible": False, # On-Demand (not community) "imageName": image, "gpuCount": gpu_count, "ports": rp_ports, "supportPublicIp": True, **({"gpuTypeIds": [rp_gpu]} if rp_gpu else {}), **({"env": env_obj} if env_obj else {}), **({"command": command} if command else {}), **({"args": args} if args else {}), } _log_create(f"PAYLOAD_SENT {json.dumps(payload, ensure_ascii=False)}") content = {} pid = None try: r = requests.post(f"{_RP_BASE}/pods", headers=_rp_headers(), json=payload, timeout=60) content = _as_json(r) _log_create(f"RUNPOD_RESPONSE {json.dumps(content, ensure_ascii=False)}") pid = content.get("id") if not pid and isinstance(content, dict): for v in content.values(): if isinstance(v, dict) and "id" in v: pid = v["id"]; break except Exception as e: _log_create(f"ERROR Create: {e}") return JSONResponse({"error": f"RunPod create failed: {e}"}, 500) _log_create(f"ID: {pid}") if not isinstance(r, requests.Response): return JSONResponse({"error": "No HTTP response from RunPod create."}, 502) if not r.ok: return JSONResponse(content if isinstance(content, dict) else {"_raw": str(content)}, r.status_code) if not pid: return JSONResponse({"error": "Create succeeded but no pod ID in response.", "raw": content}, 502) # cache pod id try: _INST["podId"] = str(pid).strip() _log_id("CREATE_SET", _INST["podId"]) _save_state() except Exception as e: return JSONResponse({"error": f"Could not cache pod ID: {e}"}, 502) # start the pod immediately so networking/IP can come up try: sr = requests.post(f"{_RP_BASE}/pods/{_INST['podId']}/start", headers=_rp_headers(), timeout=30) scontent = _as_json(sr) _log_status(f"START_RESPONSE {json.dumps(scontent, ensure_ascii=False)}") except Exception as e: _log_status(f"ERROR Start: {e}") # initial status snapshot try: rs = requests.get(f"{_RP_BASE}/pods/{_INST['podId']}", headers=_rp_headers(), timeout=30) st = _as_json(rs) _log_status(f"STATUS_POLL {json.dumps(st, ensure_ascii=False)}") content["_status"] = st except Exception as e: content["_status_error"] = str(e) _log_status(f"ERROR Status: {e}") _INST["status"] = content.get("desiredStatus") or content.get("status") or "" _INST["ip"] = _INST.get("ip") or "" _INST["port"] = _INST.get("port") or "" return JSONResponse(content, r.status_code) # --------------------------------------------------------------------- # Poll / read instance status + explicit readiness fields # --------------------------------------------------------------------- @router.get("/api/compute/pods/{pod_id}") def api_get_instance(pod_id: str = None): pid = (pod_id or _INST.get("podId") or "").strip() if not pid: return JSONResponse({"error": "pod_id missing."}, 400) _log_id("STATUS_USES", pid) try: r = requests.get(f"{_RP_BASE}/pods/{pid}", headers=_rp_headers(), timeout=30) last = _as_json(r) _log_status(f"STATUS_POLL {json.dumps(last, ensure_ascii=False)}") except Exception as e: return JSONResponse({"error": f"poll failed: {e}"}, 502) declared = None try: c = (((_INST.get("blob") or {}).get("supportedActions") or {}).get("deploy") or {}).get("containerSpec") \ or (_INST.get("blob") or {}).get("container") or {} declared = int((c.get("ports") or [])[0].get("containerPort")) except Exception: c = {} pass if isinstance(last, dict): ip = last.get("publicIp") or "" pm = last.get("portMappings") or {} if ip and isinstance(pm, dict) and pm: # choose mapped public port for the declared internal port; else first mapping if isinstance(declared, int) and str(declared) in pm: chosen = str(pm[str(declared)]) else: k = next(iter(pm.keys())) chosen = str(pm[k]) _log_status(f"PORT_MAPPING declared={declared} not_found_using_first key={k}") _INST.update({"podId": pid, "status": last.get("desiredStatus", ""), "ip": ip, "port": chosen}) _save_state() base = f"http://{_INST['ip']}:{_INST['port']}" _log_status(f"PORT_MAPPING declared={declared} chosen={chosen} all={pm}") _log_status(f"RESOLVED_ENDPOINT base={base}") # --- NEW: health-first readiness (mirror Vertex), fallback to predict existence --- hr = (_INST.get("healthRoute") or "/health").strip() pr = (_INST.get("predictRoute") or "/predict").strip() code_h, ms_h, snippet_h = _probe("GET", f"{base}{hr}") _log_status(f"HEALTH_PROBE path={hr} code={code_h} ms={ms_h} body_snippet={snippet_h[:120]}") #if code_h in (200, 204): #_INST["status"] = "READY" #else: #code_p, ms_p, _ = _probe("HEAD", f"{base}{pr}") #_log_status(f"PREDICT_PROBE path={pr} code={code_p} ms={ms_p}") #if code_p in (200, 204, 400, 405): # _INST["status"] = "READY" # Final prompt URL (prefer IP; else proxy host) proute = _INST.get("predictRoute") or "/predict" if _INST.get("ip") and _INST.get("port"): prompt_url = f"http://{_INST['ip']}:{_INST['port']}{proute}" else: proxy_base = f"https://{pid}-{declared}.proxy.runpod.net" if declared else "" prompt_url = f"{proxy_base}{proute}" if proxy_base else "" if prompt_url: _log_status(f"PROMPT_ENDPOINT {prompt_url}") # Always include cached readiness data for the UI merged = {**last, "cachedState": { "podId": _INST.get("podId"), "status": _INST.get("status"), "ip": _INST.get("ip"), "port": _INST.get("port"), "predictRoute": _INST.get("predictRoute"), "healthRoute": _INST.get("healthRoute"), }} return JSONResponse(merged) # --------------------------------------------------------------------- # Start, Stop, End All — same as before # --------------------------------------------------------------------- @router.post("/api/compute/pods/{pod_id}/start") def api_start_instance(pod_id:str): _log_id("START_USES",pod_id) try: r=requests.post(f"{_RP_BASE}/pods/{pod_id}/start", headers=_rp_headers(),timeout=30) payload=_as_json(r) _log_status(f"START_RESPONSE {json.dumps(payload,ensure_ascii=False)}") return JSONResponse(payload, r.status_code) except Exception as e: _log_status(f"ERROR Start: {e}") return JSONResponse({"error":f"RunPod start failed: {e}"},500) @router.delete("/api/compute/delete_instance") async def api_delete_instance(): pid = (_INST.get("podId") or "").strip() if not pid: return JSONResponse({"error": "pod_id missing and no cached pod found."}, status_code=400) _log_id("STOP_USES", pid) try: _log_delete(">>> STOP endpoint triggered") r = requests.post(f"{_RP_BASE}/pods/{pid}/stop", headers=_rp_headers(), timeout=60) payload = _as_json(r) _log_delete(f"STOP_RESPONSE {json.dumps(payload, ensure_ascii=False)}") return JSONResponse(status_code=r.status_code, content=payload) except Exception as e: _log_delete(f"ERROR Stop: {e}") return JSONResponse(status_code=500, content={"error": f"RunPod stop failed: {e}"}) @router.delete("/api/compute/end_all") async def api_end_all(): pid = (_INST.get("podId") or "").strip() if not pid: return JSONResponse({"error": "pod_id missing and no cached pod found."}, status_code=400) _log_id("DELETE_USES", pid) try: _log_delete(">>> END-ALL endpoint triggered") r = requests.delete(f"{_RP_BASE}/pods/{pid}", headers=_rp_headers(), timeout=60) payload = _as_json(r) _log_delete(f"DELETE_RESPONSE {json.dumps(payload, ensure_ascii=False)}") if r.status_code in (200, 202, 204): _INST.update({"podId": "", "status": "", "ip": "", "port": ""}) _save_state() return JSONResponse(status_code=r.status_code, content=payload) except Exception as e: _log_delete(f"ERROR Delete: {e}") return JSONResponse(status_code=500, content={"error": f"RunPod delete failed: {e}"}) # --------------------------------------------------------------------- # Wait instance # --------------------------------------------------------------------- @router.get("/api/compute/wait_instance") def api_wait_instance(pod_id: str = None): pid = (pod_id or _INST.get("podId") or "").strip() if not pid: return JSONResponse({"error": "pod_id missing."}, status_code=400) try: r = requests.get(f"{_RP_BASE}/pods/{pid}", headers=_rp_headers(), timeout=30) last = _as_json(r) _log_status(f"WAIT_STATUS {json.dumps(last, ensure_ascii=False)}") except Exception as e: return JSONResponse({"error": f"wait poll failed: {e}"}, status_code=502) ip = last.get("publicIp") or _INST.get("ip") pm = last.get("portMappings") or {} port = None declared = None try: c = (((_INST.get("blob") or {}).get("supportedActions") or {}).get("deploy") or {}).get("containerSpec") \ or (_INST.get("blob") or {}).get("container") or {} declared = int((c.get("ports") or [])[0].get("containerPort")) except Exception: c = {} pass if ip and pm: try: if isinstance(declared, int) and str(declared) in pm: port = str(pm[str(declared)]) except Exception: pass if not port and "8080" in pm: port = str(pm["8080"]) elif not port and pm: port = str(pm[next(iter(pm.keys()))]) if ip and port: base = f"http://{ip}:{port}" _log_status(f"RESOLVED_IP {base}") # --- NEW: health-first readiness (mirror Vertex), fallback to predict existence --- hr = (_INST.get("healthRoute") or "/health").strip() pr = (_INST.get("predictRoute") or "/predict").strip() code_h, ms_h, snippet_h = _probe("GET", f"{base}{hr}") _log_status(f"HEALTH_PROBE path={hr} code={code_h} ms={ms_h} body_snippet={snippet_h[:120]}") # if code_h in (200, 204): # _INST["status"] = "READY" #else: #code_p, ms_p, _ = _probe("HEAD", f"{base}{pr}") #_log_status(f"PREDICT_PROBE path={pr} code={code_p} ms={ms_p}") #if code_p in (200, 204, 400, 405): #_INST["status"] = "READY" if _INST.get("predictRoute"): _log_status(f"PROMPT_ENDPOINT {base}{_INST['predictRoute']}") try: cspec = _get_container_spec() internal, _ = _get_port_and_proto(cspec) if internal: proxy_base = f"https://{pid}-{internal}.proxy.runpod.net" _log_status(f"RESOLVED_PROXY {proxy_base}") _INST["base"] = proxy_base _save_state() except Exception: pass _INST.update({"ip": ip or "", "port": port or "", "status": last.get("desiredStatus", "")}) _save_state() merged = { **last, "cachedState": { "podId": _INST.get("podId"), "status": _INST.get("status"), "ip": _INST.get("ip"), "port": _INST.get("port"), "base": _INST.get("base"), "predictRoute": _INST.get("predictRoute"), "healthRoute": _INST.get("healthRoute"), }, } return JSONResponse(merged) # --------------------------------------------------------------------- # Debug: live probes against the instance (IP + Proxy) # --------------------------------------------------------------------- @router.get("/api/compute/debug/probes") def api_debug_probes(pod_id: str = None): pid = (pod_id or _INST.get("podId") or "").strip() if not pid: return JSONResponse({"error": "pod_id missing."}, 400) # latest pod object (for portMappings/publicIp) try: r = requests.get(f"{_RP_BASE}/pods/{pid}", headers=_rp_headers(), timeout=20) pod = _as_json(r) _log_status(f"DEBUG_POD_OBJ {json.dumps(pod, ensure_ascii=False)}") except Exception as e: return JSONResponse({"error": f"pod fetch failed: {e}"}, 502) ip = pod.get("publicIp") or _INST.get("ip") pm = pod.get("portMappings") or {} # choose internal/public ports internal = None try: cs = _get_container_spec() internal = int((cs.get("ports") or [])[0].get("containerPort")) except Exception: pass if internal and str(internal) in pm: public = str(pm[str(internal)]) elif "8080" in pm: internal, public = 8080, str(pm["8080"]) elif pm: k = next(iter(pm.keys())) internal, public = int(k), str(pm[k]) else: public = None # candidate paths healths = [(_INST.get("healthRoute") or "").strip(), "/health", "/ping", "/healthz", "/v1/models"] healths = [p for p in healths if p] predicts = [(_INST.get("predictRoute") or "").strip(), "/generate", "/predict", "/predictions", "/v1/chat/completions", "/v1/models/model:predict"] predicts = [p for p in predicts if p] results = {"podId": pid, "ip": ip, "internalPort": internal, "publicPort": public, "probes": []} # base URLs (IP and proxy) bases = [] if ip and public: bases.append(f"http://{ip}:{public}") if internal: bases.append(f"https://{pid}-{internal}.proxy.runpod.net") # probe health for base in bases: for hp in healths: code, ms, snippet = _probe("GET", f"{base}{hp}") _log_status(f"DEBUG_HEALTH base={base} path={hp} code={code} ms={ms}") results["probes"].append({"base": base, "path": hp, "kind": "health", "code": code, "ms": ms, "snippet": snippet}) # probe predict (HEAD) for base in bases: for pp in predicts: code, ms, _ = _probe("HEAD", f"{base}{pp}") _log_status(f"DEBUG_PREDICT base={base} path={pp} code={code} ms={ms}") results["probes"].append({"base": base, "path": pp, "kind": "predict", "code": code, "ms": ms}) return JSONResponse(results, 200) # --------------------------------------------------------------------- # Helper functions for containerSpec parsing # --------------------------------------------------------------------- def _get_container_spec(): blob = _INST.get("blob") if not blob: lb = _load_local_blob() if lb: _ingest_blob(lb) blob = lb return (((blob.get("supportedActions") or {}).get("deploy") or {}).get("containerSpec") or blob.get("container") or {}) def _get_port_and_proto(cspec: dict): try: ports = cspec.get("ports") or [] if isinstance(ports, list) and ports: p0 = ports[0] internal = p0.get("containerPort") proto = (p0.get("protocol") or "").lower() or None return (int(internal) if str(internal).isdigit() else None, proto) except Exception: pass return (None, None) def _build_proxy_url(route: str) -> str: pid = (_INST.get("podId") or "").strip() if not pid: raise HTTPException(status_code=400, detail="No podId in cache. Create/Start the instance first.") cspec = _get_container_spec() internal_port, _ = _get_port_and_proto(cspec) if not internal_port: raise HTTPException(status_code=400, detail="Cannot resolve internal port from containerSpec.ports[].") return f"https://{pid}-{internal_port}.proxy.runpod.net{route}" def _build_ip_url(route: str) -> str: ip, port = _INST.get("ip"), _INST.get("port") if not (ip and port): raise HTTPException(status_code=400, detail="No running instance (ip/port missing).") return f"http://{ip}:{port}{route}" def _resolve_infer_url(route: str) -> str: cspec = _get_container_spec() _, proto = _get_port_and_proto(cspec) try: if _INST.get("ip") and _INST.get("port"): url = _build_ip_url(route) _job_log("compute", f"[MW] Using IP path: {url}") return url except HTTPException: pass if proto == "http" or True: url = _build_proxy_url(route) _job_log("compute", f"[MW] Using Proxy path: {url}") return url return _build_ip_url(route) # --------------------------------------------------------------------- # /api/infer — updated to use resolver # --------------------------------------------------------------------- @router.post("/api/infer") async def api_infer(req: Request): route = _INST.get("predictRoute") if not route: return JSONResponse( {"error": "predictRoute unresolved; check ROUTE_PROBE logs and HEALTH results."}, status_code=428 ) body = await req.json() try: url = _resolve_infer_url(route) r = requests.post(url, json=body, timeout=120) ct = (r.headers.get("content-type") or "").lower() if "application/json" in ct: return JSONResponse(status_code=r.status_code, content=r.json()) return HTMLResponse(status_code=r.status_code, content=r.text) except HTTPException as he: return JSONResponse({"error": he.detail}, status_code=he.status_code) except Exception as e: return JSONResponse({"error": f"inference request failed: {e}"}, status_code=502) # --------------------------------------------------------------------- # /api/middleware/infer — middleware prompt routing and normalization # --------------------------------------------------------------------- @router.post("/api/middleware/infer") async def api_middleware_infer(req: Request): # Always ensure predictRoute exists route = _INST.get("predictRoute") or "/predict" _INST["predictRoute"] = route # Build deterministic proxy URL instead of waiting on readiness pid = (_INST.get("podId") or "").strip() if not pid: try: _load_state() pid = (_INST.get("podId") or "").strip() except Exception: pass if not pid: return JSONResponse({"error": "no podId yet (create/start first)"}, status_code=400) cspec = _get_container_spec() internal, _ = _get_port_and_proto(cspec) if not internal: return JSONResponse({"error": "cannot resolve internal port from blob"}, status_code=400) # ---------------- NEW: routing override for HF_TASK=text-to-image ---------------- env = (cspec.get("env") or []) kv = {e.get("name"): e.get("value") for e in env if isinstance(e, dict) and e.get("name")} hf_task = (kv.get("HF_TASK") or "").strip().lower() model_id = (kv.get("MODEL_ID") or kv.get("HF_MODEL_ID") or "").strip() if hf_task == "text-to-image" and model_id: route = f"/predictions/{model_id}" _INST["predictRoute"] = route # ------------------------------------------------------------------------------- base = f"https://{pid}-{internal}.proxy.runpod.net" url = f"{base}{route}" _log_status(f"PROMPT_ENDPOINT {url}") _job_log("compute", f"[MW] Forwarding infer to {url}") payload = await req.json() prompt = payload.get("prompt") if not isinstance(prompt, str) or not prompt.strip(): return JSONResponse({"error": "Missing 'prompt' in request body."}, 400) # ---------------- NEW: canonical text-to-image payload ---------------- if hf_task == "text-to-image": body = { "inputs": prompt, "parameters": { "num_inference_steps": 30, "guidance_scale": 7.5, "width": 1024, "height": 1024 } } try: rp = requests.post(url, json=body, timeout=120) _log_status(f"PREDICT_RESP code={rp.status_code} len={len(rp.text)}") if rp.ok: ct = (rp.headers.get("content-type") or "").lower() data = _as_json(rp) if "application/json" in ct else {"_raw": rp.text} if isinstance(data, dict) and "image_b64" in data: return JSONResponse({"image_b64": data["image_b64"], "timings": data.get("timings")}, rp.status_code) return JSONResponse(data, rp.status_code) return JSONResponse({"error": rp.text[:400]}, status_code=rp.status_code) except Exception as e: _log_status(f"PREDICT_ERR {e}") return JSONResponse({"error": f"inference request failed: {e}"}, status_code=502) # --------------------------------------------------------------------- # HF text-classification shim img = (_get_container_spec().get("imageUri","")).lower() if "huggingface-pytorch-inference" in img and isinstance(payload.get("prompt"), str): payload = {"instances": [payload["prompt"]]} # Non-image fallback (unchanged) bodies = [payload, {"prompt": prompt}, {"text": prompt}, {"inputs": prompt}, {"input": prompt}] for body in bodies: try: rp = requests.post(url, json=body, timeout=120) _log_status(f"PREDICT_RESP code={rp.status_code} len={len(rp.text)}") if rp.ok: ct = (rp.headers.get("content-type") or "").lower() data = _as_json(rp) if "application/json" in ct else {"_raw": rp.text} if isinstance(data, dict): if "image_b64" in data: return JSONResponse({"image_b64": data["image_b64"], "timings": data.get("timings")}, rp.status_code) if isinstance(data.get("output"), str): return JSONResponse({"output": data["output"]}, rp.status_code) if "_raw" in data: return JSONResponse({"output": data["_raw"]}, rp.status_code) return JSONResponse({"output": json.dumps(data, ensure_ascii=False)}, rp.status_code) return JSONResponse({"output": str(data)}, rp.status_code) except Exception as e: _log_status(f"PREDICT_ERR {e}") # Fallthrough: show last response or generic error try: return JSONResponse({"error": rp.text[:400]}, status_code=rp.status_code) except Exception: return JSONResponse({"error": "no response from model"}, status_code=504) # --------------------------------------------------------------------- # Job progress + callback routes # --------------------------------------------------------------------- @router.post("/api/job/ready") async def api_job_ready(req: Request): return JSONResponse({"ok": True}) @router.post("/api/job/progress") async def api_job_progress(req: Request): data = await req.json() job_id = str(data.get("job_id", "unknown")) msg = data.get("message", "") _job_log(job_id, msg or "progress") return JSONResponse({"ok": True}) @router.post("/api/job/done") async def api_job_done(req: Request): data = await req.json() job_id = str(data.get("job_id", "unknown")) j = _JOBS.setdefault(job_id, {"status": "created", "logs": [], "image_b64": None, "timings": {}}) j["status"] = "done" j["image_b64"] = data.get("image_b64") j["timings"] = data.get("timings", {}) _job_log(job_id, "completed") return JSONResponse({"ok": True}) @router.get("/api/job/status") def api_job_status(job_id: str): return JSONResponse(_JOBS.get(job_id, {"status": "unknown"}))