import { createServerFn } from "@tanstack/react-start"; import type { ExplainerAction, StepResult } from "@/lib/types"; interface EnvSession { baseUrl: string; ws: WebSocket; pending: Promise; } const SESSION_KEY = "__explainer_dashboard_env_sessions__"; const DEFAULT_ENV_BASE_URL = "https://kgdrathan-explainer-env.hf.space"; const DEFAULT_MODEL_NAME = "bedrock-qwen3-coder-30b-a3b"; const sessions = ((globalThis as typeof globalThis & { [SESSION_KEY]?: Map })[ SESSION_KEY ] ??= new Map()); function envBaseUrl(): string { return process.env.ENV_BASE_URL?.replace(/\/+$/, "") || DEFAULT_ENV_BASE_URL; } function wsUrl(baseUrl: string): string { const url = new URL(baseUrl); url.protocol = url.protocol === "https:" ? "wss:" : "ws:"; url.pathname = `${url.pathname.replace(/\/+$/, "")}/ws`; url.search = ""; url.hash = ""; return url.toString(); } function openWs(url: string): EnvSession { const ws = new WebSocket(wsUrl(url)); const pending = new Promise((resolve, reject) => { ws.addEventListener("open", () => resolve(), { once: true }); ws.addEventListener( "error", () => reject(new Error(`Failed to connect to env websocket at ${wsUrl(url)}`)), { once: true }, ); }); return { baseUrl: url, ws, pending }; } async function getSession(episodeId: string): Promise { const baseUrl = envBaseUrl(); const existing = sessions.get(episodeId); if (existing?.baseUrl === baseUrl && existing.ws.readyState !== WebSocket.CLOSED) { await existing.pending; return existing; } if (existing) closeSession(episodeId); const session = openWs(baseUrl); sessions.set(episodeId, session); await session.pending; return session; } function closeSession(episodeId: string): void { const existing = sessions.get(episodeId); sessions.delete(episodeId); if (existing && existing.ws.readyState !== WebSocket.CLOSED) { existing.ws.close(); } } async function sendWs(episodeId: string, message: unknown): Promise { const session = await getSession(episodeId); return new Promise((resolve, reject) => { const cleanup = () => { session.ws.removeEventListener("message", onMessage); session.ws.removeEventListener("error", onError); session.ws.removeEventListener("close", onClose); }; const onMessage = (event: MessageEvent) => { cleanup(); try { const parsed = JSON.parse(String(event.data)); if (parsed.type === "error") { reject(new Error(parsed.data?.message || "Env websocket error")); return; } resolve(parsed.data); } catch (e) { reject(e instanceof Error ? e : new Error(String(e))); } }; const onError = () => { cleanup(); reject(new Error("Env websocket error")); }; const onClose = () => { cleanup(); reject(new Error("Env websocket closed")); }; session.ws.addEventListener("message", onMessage); session.ws.addEventListener("error", onError); session.ws.addEventListener("close", onClose); session.ws.send(JSON.stringify(message)); }); } export const envReset = createServerFn({ method: "POST" }) .inputValidator( (input: { seed?: number; episode_id?: string; topic?: string } | undefined) => input ?? {}, ) .handler(async ({ data }): Promise => { const body: Record = {}; if (typeof data.seed === "number") body.seed = data.seed; if (data.episode_id) body.episode_id = data.episode_id; if (data.topic) body.topic = data.topic; const episodeId = data.episode_id || crypto.randomUUID(); closeSession(episodeId); const result = (await sendWs(episodeId, { type: "reset", data: body })) as Partial; return { observation: result.observation as StepResult["observation"], reward: result.reward ?? null, done: Boolean(result.done), }; }); export const envStep = createServerFn({ method: "POST" }) .inputValidator((input: { episode_id: string; action: ExplainerAction }) => input) .handler(async ({ data }): Promise => { const result = (await sendWs(data.episode_id, { type: "step", data: data.action, })) as Partial; return { observation: result.observation as StepResult["observation"], reward: result.reward ?? null, done: Boolean(result.done), }; }); export const getRuntimeConfig = createServerFn({ method: "GET" }).handler(async () => { return { envUrl: envBaseUrl(), modelName: process.env.MODEL_NAME || DEFAULT_MODEL_NAME, apiBaseUrl: process.env.API_BASE_URL || "https://router.huggingface.co/v1", hasToken: Boolean(process.env.HF_TOKEN || process.env.API_KEY), }; });