| """ |
| pod_api.py β RunPod-side FastAPI server with structured-output normalizer. |
| |
| Architecture: |
| |
| Client βPOST /v1/jobsβββΆ pod_api.py (this file, port 5000) |
| β |
| β enqueues job |
| βΌ |
| ThreadPoolExecutor |
| β |
| β 1. structured-output normalize via Gemini |
| β 2. POST to trtllm-serve |
| βΌ |
| trtllm-serve (port 8000) βββΆ model on GPU |
| |
| Run: |
| pip install fastapi "uvicorn[standard]" pydantic requests google-genai |
| export GEMINI_API_KEY=... |
| uvicorn pod_api:app --host 0.0.0.0 --port 5000 --workers 1 |
| """ |
| from __future__ import annotations |
|
|
| import json |
| import logging |
| import os |
| import sys |
| import threading |
| import time |
| import uuid |
| from concurrent.futures import ThreadPoolExecutor |
| from contextlib import asynccontextmanager |
| from dataclasses import dataclass, field |
| from pathlib import Path |
| from typing import Any, List, Literal, Optional |
|
|
| import requests |
| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import JSONResponse |
| from pydantic import BaseModel, Field, field_validator |
|
|
| sys.path.insert(0, "/workspace") |
| import inference_edited_chat_opt as inf |
|
|
|
|
| |
| |
| |
| TRTLLM_BASE_URL = os.environ.get("TRTLLM_BASE_URL", "http://localhost:8000") |
| TRTLLM_MODEL = os.environ.get("TRTLLM_MODEL", "final_model") |
| MAX_PROMPT_CHARS = 8_000 |
| MAX_CONCURRENT_JOBS = 16 |
| JOB_TIMEOUT_S = 60 * 25 |
| SYNC_TIMEOUT_S = 60 * 20 |
| JOB_RETENTION_S = 60 * 60 |
| OUTPUT_DIR: Optional[Path] = Path(os.environ.get("API_OUTPUT_DIR", "/workspace/api_output")) |
| GENERATION_MAX_TOKENS = int(os.environ.get("GENERATION_MAX_TOKENS", "8192")) |
| GENERATION_TEMPERATURE = float(os.environ.get("GENERATION_TEMPERATURE", "0.0")) |
|
|
| logger = logging.getLogger("pod_api") |
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s %(levelname)s %(name)s: %(message)s", |
| ) |
|
|
|
|
| |
| |
| |
| class _Colors(BaseModel): |
| base_hex: str = Field(..., description="Page background hex like #F4EFE2") |
| text_hex: str = Field(..., description="Primary text hex like #1A1814") |
| muted_hex: str = Field(..., description="Muted secondary text hex") |
| surface_hex: str = Field(..., description="Card/surface background hex") |
| border_hex: str = Field(..., description="Hairline border hex") |
| accent_hex: str = Field(..., description="Single primary accent hex") |
| accent_role: str = Field(..., description="Where accent is used") |
| success_hex: str = Field(..., description="Success state hex") |
| warning_hex: str = Field(..., description="Warning state hex") |
| danger_hex: str = Field(..., description="Danger state hex") |
|
|
|
|
| class _Typography(BaseModel): |
| display_family: str = Field(..., description="Real Google Font name like Fraunces, Tiempos, Geist. NEVER serif or sans-serif.") |
| display_weight: str = Field(..., description="Weight range like semibold-to-extrabold") |
| body_family: str = Field(..., description="Real Google Font name. NEVER serif or sans-serif.") |
| body_weight: str = Field(..., description="Weight range like regular-to-medium") |
| mono_family: str = Field(default="", description="Optional mono family for tabular only, or empty string") |
|
|
|
|
| class _ClosingRules(BaseModel): |
| gradients: str = Field(..., description="Gradient rule") |
| shadows: str = Field(..., description="Shadow rule") |
| corners: str = Field(..., description="Corner radius rule") |
|
|
|
|
| class _Section(BaseModel): |
| description: str = Field(..., description="One paragraph describing this section's layout, content, specific copy. Use frame-language and named hex colors.") |
|
|
|
|
| class _NormalizedSpec(BaseModel): |
| opening: str = Field(..., description="Opening clause: Design me a [type] for [context] - audience X, goal Y") |
| register_commitment: str = Field(..., description="One sentence committing to the visual register with hex codes, named fonts, and motifs") |
| distinctive_flourish: str = Field(..., description="One sentence about a single standout interactive or visual behavior") |
| sections: List[_Section] = Field(..., min_length=8, max_length=14, description="8-14 sections in DOM order") |
| colors: _Colors |
| typography: _Typography |
| closing: _ClosingRules |
|
|
|
|
| def _assemble(spec: _NormalizedSpec) -> str: |
| parts = [spec.opening.strip(), spec.register_commitment.strip(), spec.distinctive_flourish.strip()] |
| connectives = ["Start with", "Then", "Flow into", "Follow with", "Then", "Then", "Follow with", "Then", "Follow with", "Then", "Follow with", "Then", "Then", "Close with"] |
| starters = {c.split()[0].lower() for c in connectives + ["close"]} |
| for i, s in enumerate(spec.sections): |
| prefix = connectives[i] if i < len(connectives) else "Then" |
| desc = s.description.strip() |
| first = desc.split(" ", 1)[0].lower() if desc else "" |
| if first in starters or not desc: |
| parts.append(desc) |
| else: |
| parts.append(prefix + " " + (desc[0].lower() + desc[1:] if desc[0].isupper() else desc)) |
|
|
| c = spec.colors |
| parts.append( |
| "Use " + c.base_hex + " as the base with " + c.text_hex + " primary text, " + |
| c.muted_hex + " muted copy, " + c.surface_hex + " for card surfaces, " + |
| c.border_hex + " for hairlines, and " + c.accent_hex + " as the primary accent for " + c.accent_role + ", " + |
| "with a state palette of " + c.success_hex + " success, " + c.warning_hex + " warning, and " + c.danger_hex + " danger." |
| ) |
|
|
| t = spec.typography |
| typo = t.display_family + " " + t.display_weight + " for display and headings, paired with " + t.body_family + " " + t.body_weight + " for body" |
| if t.mono_family.strip(): |
| typo += ", plus " + t.mono_family + " used only for tabular figures, IDs, and timestamps - two type families plus a single mono used only for tabular contexts." |
| else: |
| typo += " - exactly two type families across the entire page, no third family anywhere." |
| parts.append(typo) |
|
|
| cr = spec.closing |
| parts.append( |
| cr.gradients + ", " + cr.shadows + ", " + cr.corners + ". " + |
| "Icons via Font Awesome only - never inline SVG - never hidden body overflow." |
| ) |
|
|
| return " ".join(parts) |
|
|
|
|
| SCHEMA_DIRECTIVE = ( |
| "\n\nIMPORTANT OUTPUT FORMAT: Output as JSON matching the provided schema. " |
| "Every field is mandatory and non-empty. All hex codes must be valid 6-digit hex like #1A1814 - never named colors. " |
| "Font families must be real Google Fonts (Fraunces, Inter, Geist, Space Grotesk, Tiempos, Recoleta, Outfit, Plus Jakarta Sans, IBM Plex Mono, JetBrains Mono, etc.) - NEVER use the placeholder serif or sans-serif alone. " |
| "Sections array must have between 8 and 14 entries, each describing one DOM-order region with concrete layout, content, and specific copy." |
| ) |
|
|
|
|
| def _normalize_via_gemini(raw_prompt: str) -> str: |
| if not getattr(inf, "NORMALIZE_PROMPTS", True): |
| return raw_prompt |
|
|
| is_dashboard = inf.is_dashboard_prompt(raw_prompt) |
| system_prompt = inf.DASHBOARD_NORMALIZER_SYSTEM_PROMPT if is_dashboard else inf.NORMALIZER_SYSTEM_PROMPT |
|
|
| try: |
| from google import genai |
| from google.genai import types |
|
|
| client = genai.Client() |
|
|
| r = client.models.generate_content( |
| model="gemini-3-flash-preview", |
| contents=raw_prompt, |
| config=types.GenerateContentConfig( |
| system_instruction=system_prompt + SCHEMA_DIRECTIVE, |
| temperature=0.6, |
| max_output_tokens=8192, |
| thinking_config=types.ThinkingConfig(thinking_level="high"), |
| response_mime_type="application/json", |
| response_schema=_NormalizedSpec, |
| ), |
| ) |
|
|
| spec = getattr(r, "parsed", None) |
| if spec is None: |
| data = json.loads(r.text) |
| spec = _NormalizedSpec.model_validate(data) |
|
|
| assembled = _assemble(spec) |
| if not assembled or not assembled.strip(): |
| raise RuntimeError("assembled normalized prompt is empty") |
| return assembled |
|
|
| except Exception as e: |
| logger.warning("structured normalize failed: %s - falling back to raw prompt", e) |
| return raw_prompt |
|
|
|
|
| inf.normalize_prompt = _normalize_via_gemini |
|
|
|
|
| |
| |
| |
| JobStatus = Literal["queued", "running", "done", "error"] |
|
|
|
|
| @dataclass |
| class Job: |
| id: str |
| raw_prompt: str |
| normalized_prompt: Optional[str] = None |
| status: JobStatus = "queued" |
| html: Optional[str] = None |
| error: Optional[str] = None |
| created_at: float = field(default_factory=time.time) |
| started_at: Optional[float] = None |
| finished_at: Optional[float] = None |
| done_event: threading.Event = field(default_factory=threading.Event) |
|
|
| def to_response(self) -> dict[str, Any]: |
| body: dict[str, Any] = { |
| "job_id": self.id, |
| "status": self.status, |
| "created_at": self.created_at, |
| } |
| if self.started_at is not None: |
| body["started_at"] = self.started_at |
| if self.finished_at is not None: |
| body["finished_at"] = self.finished_at |
| body["duration_seconds"] = round( |
| self.finished_at - (self.started_at or self.created_at), 2 |
| ) |
| if self.normalized_prompt is not None: |
| body["normalized_prompt"] = self.normalized_prompt |
| if self.status == "done": |
| body["html"] = self.html |
| elif self.status == "error": |
| body["error"] = self.error |
| return body |
|
|
|
|
| _jobs: dict[str, Job] = {} |
| _jobs_lock = threading.Lock() |
| _executor: Optional[ThreadPoolExecutor] = None |
| _inflight = 0 |
| _inflight_lock = threading.Lock() |
|
|
|
|
| def _store_job(job: Job) -> None: |
| with _jobs_lock: |
| _jobs[job.id] = job |
|
|
|
|
| def _get_job(job_id: str) -> Optional[Job]: |
| with _jobs_lock: |
| return _jobs.get(job_id) |
|
|
|
|
| def _gc_jobs() -> None: |
| now = time.time() |
| with _jobs_lock: |
| stale = [ |
| jid for jid, j in _jobs.items() |
| if j.finished_at is not None and (now - j.finished_at) > JOB_RETENTION_S |
| ] |
| for jid in stale: |
| _jobs.pop(jid, None) |
|
|
|
|
| def _try_reserve_slot() -> bool: |
| global _inflight |
| with _inflight_lock: |
| if _inflight >= MAX_CONCURRENT_JOBS: |
| return False |
| _inflight += 1 |
| return True |
|
|
|
|
| def _release_slot() -> None: |
| global _inflight |
| with _inflight_lock: |
| _inflight = max(0, _inflight - 1) |
|
|
|
|
| def _inflight_count() -> int: |
| with _inflight_lock: |
| return _inflight |
|
|
|
|
| |
| |
| |
| def _trtllm_generate(prompt_text: str) -> str: |
| body = { |
| "model": TRTLLM_MODEL, |
| "messages": [ |
| {"role": "system", "content": inf.SYSTEM_PROMPT}, |
| {"role": "user", "content": prompt_text}, |
| ], |
| "max_tokens": GENERATION_MAX_TOKENS, |
| "temperature": GENERATION_TEMPERATURE, |
| } |
| resp = requests.post( |
| f"{TRTLLM_BASE_URL}/v1/chat/completions", |
| headers={"Content-Type": "application/json"}, |
| json=body, |
| timeout=JOB_TIMEOUT_S, |
| ) |
| resp.raise_for_status() |
| data = resp.json() |
| text = data["choices"][0]["message"]["content"] |
| if not isinstance(text, str) or not text.strip(): |
| raise RuntimeError("trtllm-serve returned empty content") |
| return text |
|
|
|
|
| |
| |
| |
| def _run_job(job: Job) -> None: |
| job.started_at = time.time() |
| job.status = "running" |
| logger.info("job %s started", job.id) |
|
|
| try: |
| try: |
| normalized = inf.normalize_prompt(job.raw_prompt) |
| except Exception as e: |
| logger.warning( |
| "normalize failed for job %s: %s β falling back to raw prompt", |
| job.id, e, |
| ) |
| normalized = job.raw_prompt |
| if not isinstance(normalized, str) or not normalized.strip(): |
| normalized = job.raw_prompt |
| job.normalized_prompt = normalized |
|
|
| raw_html = _trtllm_generate(job.normalized_prompt) |
|
|
| html = inf.post_process(raw_html) |
| if not html.strip(): |
| raise RuntimeError("post_process returned empty output") |
|
|
| job.html = html |
| job.status = "done" |
| logger.info( |
| "job %s done in %.1fs (%d chars)", |
| job.id, time.time() - job.started_at, len(html), |
| ) |
|
|
| if OUTPUT_DIR is not None: |
| try: |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) |
| (OUTPUT_DIR / f"{job.id}.html").write_text(html, encoding="utf-8") |
| (OUTPUT_DIR / f"{job.id}.json").write_text( |
| json.dumps({ |
| "job_id": job.id, |
| "raw_prompt": job.raw_prompt, |
| "normalized_prompt": job.normalized_prompt, |
| "created_at": job.created_at, |
| "started_at": job.started_at, |
| "finished_at": time.time(), |
| "duration_seconds": round(time.time() - job.started_at, 2), |
| }, indent=2), |
| encoding="utf-8", |
| ) |
| logger.info("job %s saved to %s", job.id, OUTPUT_DIR) |
| except Exception as e: |
| logger.warning("failed to persist job %s: %s", job.id, e) |
|
|
| except requests.HTTPError as e: |
| job.error = f"trtllm-serve returned {e.response.status_code}: {e.response.text[:500]}" |
| job.status = "error" |
| logger.exception("job %s β trtllm-serve HTTP error", job.id) |
|
|
| except requests.RequestException as e: |
| job.error = f"trtllm-serve unreachable: {e}" |
| job.status = "error" |
| logger.exception("job %s β trtllm-serve unreachable", job.id) |
|
|
| except Exception as e: |
| job.error = f"{type(e).__name__}: {e}" |
| job.status = "error" |
| logger.exception("job %s failed", job.id) |
|
|
| finally: |
| job.finished_at = time.time() |
| job.done_event.set() |
| _release_slot() |
| _gc_jobs() |
|
|
|
|
| |
| |
| |
| @asynccontextmanager |
| async def lifespan(_: FastAPI): |
| global _executor |
|
|
| try: |
| r = requests.get(f"{TRTLLM_BASE_URL}/v1/models", timeout=10) |
| r.raise_for_status() |
| logger.info( |
| "trtllm-serve OK at %s (%d models loaded)", |
| TRTLLM_BASE_URL, len(r.json().get("data", [])), |
| ) |
| except Exception as e: |
| logger.error( |
| "trtllm-serve not reachable at %s β %s. " |
| "Start it before this API: trtllm-serve /workspace/final_model --host 0.0.0.0 --port 8000", |
| TRTLLM_BASE_URL, e, |
| ) |
|
|
| _executor = ThreadPoolExecutor( |
| max_workers=MAX_CONCURRENT_JOBS, |
| thread_name_prefix="job-runner", |
| ) |
| logger.info( |
| "executor started (max_workers=%d), output_dir=%s", |
| MAX_CONCURRENT_JOBS, OUTPUT_DIR, |
| ) |
|
|
| try: |
| yield |
| finally: |
| if _executor is not None: |
| _executor.shutdown(wait=False, cancel_futures=True) |
|
|
|
|
| app = FastAPI(title="HTML Generation API (TRT-LLM backed)", version="2.1.0", lifespan=lifespan) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| class GenerateRequest(BaseModel): |
| prompt: str = Field(..., min_length=1, max_length=MAX_PROMPT_CHARS) |
|
|
| @field_validator("prompt") |
| @classmethod |
| def _strip(cls, v: str) -> str: |
| v = v.strip() |
| if not v: |
| raise ValueError("prompt is empty after stripping whitespace") |
| return v |
|
|
|
|
| @app.exception_handler(Exception) |
| async def _unhandled(request, exc): |
| logger.exception("unhandled exception in request: %s", exc) |
| return JSONResponse( |
| status_code=500, |
| content={"error": "internal_server_error", "detail": str(exc)}, |
| ) |
|
|
|
|
| |
| |
| |
| @app.get("/v1/healthz") |
| def healthz(): |
| return {"status": "ok"} |
|
|
|
|
| @app.get("/v1/readyz") |
| def readyz(): |
| if _executor is None: |
| return JSONResponse(status_code=503, content={"status": "executor_not_ready"}) |
| try: |
| r = requests.get(f"{TRTLLM_BASE_URL}/v1/models", timeout=5) |
| if r.status_code != 200: |
| return JSONResponse( |
| status_code=503, |
| content={"status": "trtllm_unhealthy", "trtllm_status": r.status_code}, |
| ) |
| except Exception as e: |
| return JSONResponse( |
| status_code=503, |
| content={"status": "trtllm_unreachable", "detail": str(e)}, |
| ) |
| return { |
| "status": "ready", |
| "in_flight": _inflight_count(), |
| "max_concurrent_jobs": MAX_CONCURRENT_JOBS, |
| "trtllm_url": TRTLLM_BASE_URL, |
| } |
|
|
|
|
| @app.post("/v1/jobs", status_code=202) |
| def create_job(req: GenerateRequest): |
| if _executor is None: |
| raise HTTPException(status_code=503, detail="server still warming up") |
| if not _try_reserve_slot(): |
| raise HTTPException( |
| status_code=503, |
| detail=f"server at capacity ({MAX_CONCURRENT_JOBS} in-flight) β try again shortly", |
| ) |
| job = Job(id=uuid.uuid4().hex, raw_prompt=req.prompt) |
| _store_job(job) |
| _executor.submit(_run_job, job) |
| logger.info( |
| "job %s queued (in_flight=%d, prompt_chars=%d)", |
| job.id, _inflight_count(), len(req.prompt), |
| ) |
| return { |
| "job_id": job.id, |
| "status": "queued", |
| "in_flight": _inflight_count(), |
| } |
|
|
|
|
| @app.get("/v1/jobs/{job_id}") |
| def get_job(job_id: str): |
| job = _get_job(job_id) |
| if job is not None: |
| return job.to_response() |
| if OUTPUT_DIR is not None: |
| html_path = OUTPUT_DIR / f"{job_id}.html" |
| meta_path = OUTPUT_DIR / f"{job_id}.json" |
| if html_path.exists(): |
| try: |
| meta = json.loads(meta_path.read_text(encoding="utf-8")) if meta_path.exists() else {} |
| return { |
| "job_id": job_id, |
| "status": "done", |
| "html": html_path.read_text(encoding="utf-8"), |
| "source": "disk", |
| **meta, |
| } |
| except Exception as e: |
| logger.warning("failed to read persisted job %s: %s", job_id, e) |
| raise HTTPException( |
| status_code=404, |
| detail="job not found (not in memory and not persisted to disk)", |
| ) |
|
|
|
|
| @app.get("/v1/jobs") |
| def list_jobs(limit: int = 50): |
| if limit < 1 or limit > 500: |
| raise HTTPException(status_code=400, detail="limit must be between 1 and 500") |
| with _jobs_lock: |
| items = sorted(_jobs.values(), key=lambda j: j.created_at, reverse=True)[:limit] |
| return { |
| "count": len(items), |
| "jobs": [ |
| {"job_id": j.id, "status": j.status, "created_at": j.created_at} |
| for j in items |
| ], |
| } |
|
|
|
|
| @app.post("/v1/generate") |
| def generate_sync(req: GenerateRequest): |
| if _executor is None: |
| raise HTTPException(status_code=503, detail="server still warming up") |
| if not _try_reserve_slot(): |
| raise HTTPException( |
| status_code=503, |
| detail=f"server at capacity ({MAX_CONCURRENT_JOBS} in-flight) β try again shortly", |
| ) |
| job = Job(id=uuid.uuid4().hex, raw_prompt=req.prompt) |
| _store_job(job) |
| _executor.submit(_run_job, job) |
| finished = job.done_event.wait(timeout=SYNC_TIMEOUT_S) |
| if not finished: |
| raise HTTPException( |
| status_code=504, |
| detail={ |
| "job_id": job.id, |
| "error": "generation timed out β use GET /v1/jobs/{id} to retrieve", |
| }, |
| ) |
| if job.status == "done": |
| return { |
| "job_id": job.id, |
| "html": job.html, |
| "normalized_prompt": job.normalized_prompt, |
| "duration_seconds": round( |
| (job.finished_at or 0) - (job.started_at or 0), 2 |
| ), |
| } |
| raise HTTPException( |
| status_code=500, |
| detail={"job_id": job.id, "error": job.error or "unknown error"}, |
| ) |
|
|