| """ |
| pod_api.py β RunPod-side FastAPI server that delegates generation to local |
| trtllm-serve while keeping your existing api.py contract (job pattern, |
| Pydantic validation, normalizer routing, auto-save, error handling). |
| |
| Architecture in this pod: |
| |
| Client βPOST /v1/jobsβββΆ pod_api.py (this file, port 5000) |
| β |
| β enqueues job |
| βΌ |
| ThreadPoolExecutor |
| β |
| β 1. normalize via Anthropic API |
| β 2. POST to trtllm-serve |
| βΌ |
| trtllm-serve (port 8000, local) βββΆ model on GPU |
| |
| Why this layout: |
| - Your reliability layer (job pattern, validation, GC, auto-save) stays. |
| - TRT-LLM does the actual generation β 2.85Γ faster than transformers, and |
| ready to add NGram speculative on top via the existing spec_config.yaml. |
| - Anthropic-based normalizer + dashboard routing keep working unchanged |
| because we import your existing inference_edited_chat_opt module. |
| |
| Setup: |
| pip install fastapi "uvicorn[standard]" pydantic requests anthropic |
| export ANTHROPIC_API_KEY=... |
| |
| # Make sure trtllm-serve is already running on :8000. |
| # Then start this: |
| uvicorn pod_api:app --host 0.0.0.0 --port 5000 --workers 1 |
| |
| Endpoints (same shape as your old api.py): |
| GET /v1/healthz |
| GET /v1/readyz |
| POST /v1/jobs -> 202 {"job_id": ...} |
| GET /v1/jobs/{job_id} -> status + html when done |
| GET /v1/jobs -> list recent jobs |
| POST /v1/generate -> synchronous variant |
| """ |
| from __future__ import annotations |
|
|
| import json |
| import logging |
| import os |
| import sys |
| import threading |
| import time |
| import uuid |
| from concurrent.futures import ThreadPoolExecutor |
| from contextlib import asynccontextmanager |
| from dataclasses import dataclass, field |
| from pathlib import Path |
| from typing import Any, Literal, Optional |
|
|
| import requests |
| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import JSONResponse |
| from pydantic import BaseModel, Field, field_validator |
|
|
| |
| |
| sys.path.insert(0, "/workspace") |
| import inference_edited_chat_opt as inf |
|
|
| |
| |
| |
| TRTLLM_BASE_URL = os.environ.get("TRTLLM_BASE_URL", "http://localhost:8000") |
| TRTLLM_MODEL = os.environ.get("TRTLLM_MODEL", "final_model") |
| MAX_PROMPT_CHARS = 8_000 |
| MAX_CONCURRENT_JOBS = 16 |
| JOB_TIMEOUT_S = 60 * 25 |
| SYNC_TIMEOUT_S = 60 * 20 |
| JOB_RETENTION_S = 60 * 60 |
| OUTPUT_DIR: Optional[Path] = Path(os.environ.get("API_OUTPUT_DIR", "/workspace/api_output")) |
| GENERATION_MAX_TOKENS = int(os.environ.get("GENERATION_MAX_TOKENS", "8192")) |
| GENERATION_TEMPERATURE = float(os.environ.get("GENERATION_TEMPERATURE", "0.0")) |
|
|
| logger = logging.getLogger("pod_api") |
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s %(levelname)s %(name)s: %(message)s", |
| ) |
|
|
| |
| |
| |
| JobStatus = Literal["queued", "running", "done", "error"] |
|
|
|
|
| @dataclass |
| class Job: |
| id: str |
| raw_prompt: str |
| normalized_prompt: Optional[str] = None |
| status: JobStatus = "queued" |
| html: Optional[str] = None |
| error: Optional[str] = None |
| created_at: float = field(default_factory=time.time) |
| started_at: Optional[float] = None |
| finished_at: Optional[float] = None |
| done_event: threading.Event = field(default_factory=threading.Event) |
|
|
| def to_response(self) -> dict[str, Any]: |
| body: dict[str, Any] = { |
| "job_id": self.id, |
| "status": self.status, |
| "created_at": self.created_at, |
| } |
| if self.started_at is not None: |
| body["started_at"] = self.started_at |
| if self.finished_at is not None: |
| body["finished_at"] = self.finished_at |
| body["duration_seconds"] = round( |
| self.finished_at - (self.started_at or self.created_at), 2 |
| ) |
| if self.normalized_prompt is not None: |
| body["normalized_prompt"] = self.normalized_prompt |
| if self.status == "done": |
| body["html"] = self.html |
| elif self.status == "error": |
| body["error"] = self.error |
| return body |
|
|
|
|
| _jobs: dict[str, Job] = {} |
| _jobs_lock = threading.Lock() |
| _executor: Optional[ThreadPoolExecutor] = None |
| _inflight = 0 |
| _inflight_lock = threading.Lock() |
|
|
|
|
| def _store_job(job: Job) -> None: |
| with _jobs_lock: |
| _jobs[job.id] = job |
|
|
|
|
| def _get_job(job_id: str) -> Optional[Job]: |
| with _jobs_lock: |
| return _jobs.get(job_id) |
|
|
|
|
| def _gc_jobs() -> None: |
| now = time.time() |
| with _jobs_lock: |
| stale = [ |
| jid for jid, j in _jobs.items() |
| if j.finished_at is not None and (now - j.finished_at) > JOB_RETENTION_S |
| ] |
| for jid in stale: |
| _jobs.pop(jid, None) |
|
|
|
|
| def _try_reserve_slot() -> bool: |
| global _inflight |
| with _inflight_lock: |
| if _inflight >= MAX_CONCURRENT_JOBS: |
| return False |
| _inflight += 1 |
| return True |
|
|
|
|
| def _release_slot() -> None: |
| global _inflight |
| with _inflight_lock: |
| _inflight = max(0, _inflight - 1) |
|
|
|
|
| def _inflight_count() -> int: |
| with _inflight_lock: |
| return _inflight |
|
|
|
|
| |
| |
| |
| def _trtllm_generate(prompt_text: str) -> str: |
| """Send a chat-completion request to trtllm-serve and return the HTML.""" |
| body = { |
| "model": TRTLLM_MODEL, |
| "messages": [ |
| {"role": "system", "content": inf.SYSTEM_PROMPT}, |
| {"role": "user", "content": prompt_text}, |
| ], |
| "max_tokens": GENERATION_MAX_TOKENS, |
| "temperature": GENERATION_TEMPERATURE, |
| } |
| resp = requests.post( |
| f"{TRTLLM_BASE_URL}/v1/chat/completions", |
| headers={"Content-Type": "application/json"}, |
| json=body, |
| timeout=JOB_TIMEOUT_S, |
| ) |
| resp.raise_for_status() |
| data = resp.json() |
| text = data["choices"][0]["message"]["content"] |
| if not isinstance(text, str) or not text.strip(): |
| raise RuntimeError("trtllm-serve returned empty content") |
| return text |
|
|
|
|
| |
| |
| |
| def _run_job(job: Job) -> None: |
| job.started_at = time.time() |
| job.status = "running" |
| logger.info("job %s started", job.id) |
|
|
| try: |
| |
| |
| try: |
| normalized = inf.normalize_prompt(job.raw_prompt) |
| except Exception as e: |
| logger.warning( |
| "normalize failed for job %s: %s β falling back to raw prompt", |
| job.id, e, |
| ) |
| normalized = job.raw_prompt |
| if not isinstance(normalized, str) or not normalized.strip(): |
| normalized = job.raw_prompt |
| job.normalized_prompt = normalized |
|
|
| |
| raw_html = _trtllm_generate(job.normalized_prompt) |
|
|
| |
| html = inf.post_process(raw_html) |
| if not html.strip(): |
| raise RuntimeError("post_process returned empty output") |
|
|
| job.html = html |
| job.status = "done" |
| logger.info( |
| "job %s done in %.1fs (%d chars)", |
| job.id, time.time() - job.started_at, len(html), |
| ) |
|
|
| |
| if OUTPUT_DIR is not None: |
| try: |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) |
| (OUTPUT_DIR / f"{job.id}.html").write_text(html, encoding="utf-8") |
| (OUTPUT_DIR / f"{job.id}.json").write_text( |
| json.dumps({ |
| "job_id": job.id, |
| "raw_prompt": job.raw_prompt, |
| "normalized_prompt": job.normalized_prompt, |
| "created_at": job.created_at, |
| "started_at": job.started_at, |
| "finished_at": time.time(), |
| "duration_seconds": round(time.time() - job.started_at, 2), |
| }, indent=2), |
| encoding="utf-8", |
| ) |
| logger.info("job %s saved to %s", job.id, OUTPUT_DIR) |
| except Exception as e: |
| logger.warning("failed to persist job %s: %s", job.id, e) |
|
|
| except requests.HTTPError as e: |
| job.error = f"trtllm-serve returned {e.response.status_code}: {e.response.text[:500]}" |
| job.status = "error" |
| logger.exception("job %s β trtllm-serve HTTP error", job.id) |
|
|
| except requests.RequestException as e: |
| job.error = f"trtllm-serve unreachable: {e}" |
| job.status = "error" |
| logger.exception("job %s β trtllm-serve unreachable", job.id) |
|
|
| except Exception as e: |
| job.error = f"{type(e).__name__}: {e}" |
| job.status = "error" |
| logger.exception("job %s failed", job.id) |
|
|
| finally: |
| job.finished_at = time.time() |
| job.done_event.set() |
| _release_slot() |
| _gc_jobs() |
|
|
|
|
| |
| |
| |
| @asynccontextmanager |
| async def lifespan(_: FastAPI): |
| global _executor |
|
|
| |
| try: |
| r = requests.get(f"{TRTLLM_BASE_URL}/v1/models", timeout=10) |
| r.raise_for_status() |
| logger.info( |
| "trtllm-serve OK at %s (%d models loaded)", |
| TRTLLM_BASE_URL, len(r.json().get("data", [])), |
| ) |
| except Exception as e: |
| logger.error( |
| "trtllm-serve not reachable at %s β %s. " |
| "Start it before this API: trtllm-serve /workspace/final_model --host 0.0.0.0 --port 8000", |
| TRTLLM_BASE_URL, e, |
| ) |
|
|
| _executor = ThreadPoolExecutor( |
| max_workers=MAX_CONCURRENT_JOBS, |
| thread_name_prefix="job-runner", |
| ) |
| logger.info( |
| "executor started (max_workers=%d), output_dir=%s", |
| MAX_CONCURRENT_JOBS, OUTPUT_DIR, |
| ) |
|
|
| try: |
| yield |
| finally: |
| if _executor is not None: |
| _executor.shutdown(wait=False, cancel_futures=True) |
|
|
|
|
| app = FastAPI(title="HTML Generation API (TRT-LLM backed)", version="2.0.0", lifespan=lifespan) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| class GenerateRequest(BaseModel): |
| prompt: str = Field(..., min_length=1, max_length=MAX_PROMPT_CHARS) |
|
|
| @field_validator("prompt") |
| @classmethod |
| def _strip(cls, v: str) -> str: |
| v = v.strip() |
| if not v: |
| raise ValueError("prompt is empty after stripping whitespace") |
| return v |
|
|
|
|
| @app.exception_handler(Exception) |
| async def _unhandled(request, exc): |
| logger.exception("unhandled exception in request: %s", exc) |
| return JSONResponse( |
| status_code=500, |
| content={"error": "internal_server_error", "detail": str(exc)}, |
| ) |
|
|
|
|
| |
| |
| |
| @app.get("/v1/healthz") |
| def healthz(): |
| return {"status": "ok"} |
|
|
|
|
| @app.get("/v1/readyz") |
| def readyz(): |
| if _executor is None: |
| return JSONResponse(status_code=503, content={"status": "executor_not_ready"}) |
| try: |
| r = requests.get(f"{TRTLLM_BASE_URL}/v1/models", timeout=5) |
| if r.status_code != 200: |
| return JSONResponse( |
| status_code=503, |
| content={"status": "trtllm_unhealthy", "trtllm_status": r.status_code}, |
| ) |
| except Exception as e: |
| return JSONResponse( |
| status_code=503, |
| content={"status": "trtllm_unreachable", "detail": str(e)}, |
| ) |
| return { |
| "status": "ready", |
| "in_flight": _inflight_count(), |
| "max_concurrent_jobs": MAX_CONCURRENT_JOBS, |
| "trtllm_url": TRTLLM_BASE_URL, |
| } |
|
|
|
|
| @app.post("/v1/jobs", status_code=202) |
| def create_job(req: GenerateRequest): |
| if _executor is None: |
| raise HTTPException(status_code=503, detail="server still warming up") |
| if not _try_reserve_slot(): |
| raise HTTPException( |
| status_code=503, |
| detail=f"server at capacity ({MAX_CONCURRENT_JOBS} in-flight) β try again shortly", |
| ) |
| job = Job(id=uuid.uuid4().hex, raw_prompt=req.prompt) |
| _store_job(job) |
| _executor.submit(_run_job, job) |
| logger.info( |
| "job %s queued (in_flight=%d, prompt_chars=%d)", |
| job.id, _inflight_count(), len(req.prompt), |
| ) |
| return { |
| "job_id": job.id, |
| "status": "queued", |
| "in_flight": _inflight_count(), |
| } |
|
|
|
|
| @app.get("/v1/jobs/{job_id}") |
| def get_job(job_id: str): |
| job = _get_job(job_id) |
| if job is not None: |
| return job.to_response() |
| |
| if OUTPUT_DIR is not None: |
| html_path = OUTPUT_DIR / f"{job_id}.html" |
| meta_path = OUTPUT_DIR / f"{job_id}.json" |
| if html_path.exists(): |
| try: |
| meta = json.loads(meta_path.read_text(encoding="utf-8")) if meta_path.exists() else {} |
| return { |
| "job_id": job_id, |
| "status": "done", |
| "html": html_path.read_text(encoding="utf-8"), |
| "source": "disk", |
| **meta, |
| } |
| except Exception as e: |
| logger.warning("failed to read persisted job %s: %s", job_id, e) |
| raise HTTPException( |
| status_code=404, |
| detail="job not found (not in memory and not persisted to disk)", |
| ) |
|
|
|
|
| @app.get("/v1/jobs") |
| def list_jobs(limit: int = 50): |
| if limit < 1 or limit > 500: |
| raise HTTPException(status_code=400, detail="limit must be between 1 and 500") |
| with _jobs_lock: |
| items = sorted(_jobs.values(), key=lambda j: j.created_at, reverse=True)[:limit] |
| return { |
| "count": len(items), |
| "jobs": [ |
| {"job_id": j.id, "status": j.status, "created_at": j.created_at} |
| for j in items |
| ], |
| } |
|
|
|
|
| @app.post("/v1/generate") |
| def generate_sync(req: GenerateRequest): |
| if _executor is None: |
| raise HTTPException(status_code=503, detail="server still warming up") |
| if not _try_reserve_slot(): |
| raise HTTPException( |
| status_code=503, |
| detail=f"server at capacity ({MAX_CONCURRENT_JOBS} in-flight) β try again shortly", |
| ) |
| job = Job(id=uuid.uuid4().hex, raw_prompt=req.prompt) |
| _store_job(job) |
| _executor.submit(_run_job, job) |
| finished = job.done_event.wait(timeout=SYNC_TIMEOUT_S) |
| if not finished: |
| raise HTTPException( |
| status_code=504, |
| detail={ |
| "job_id": job.id, |
| "error": "generation timed out β use GET /v1/jobs/{id} to retrieve", |
| }, |
| ) |
| if job.status == "done": |
| return { |
| "job_id": job.id, |
| "html": job.html, |
| "normalized_prompt": job.normalized_prompt, |
| "duration_seconds": round( |
| (job.finished_at or 0) - (job.started_at or 0), 2 |
| ), |
| } |
| raise HTTPException( |
| status_code=500, |
| detail={"job_id": job.id, "error": job.error or "unknown error"}, |
| ) |
|
|