Spaces:
Sleeping
Sleeping
| import { createServerFn } from "@tanstack/react-start"; | |
| import type { ExplainerAction, StepResult } from "@/lib/types"; | |
| interface EnvSession { | |
| baseUrl: string; | |
| ws: WebSocket; | |
| pending: Promise<void>; | |
| } | |
| const SESSION_KEY = "__explainer_dashboard_env_sessions__"; | |
| const DEFAULT_ENV_BASE_URL = "https://kgdrathan-explainer-env.hf.space"; | |
| const DEFAULT_MODEL_NAME = "bedrock-qwen3-coder-30b-a3b"; | |
| const sessions = ((globalThis as typeof globalThis & { [SESSION_KEY]?: Map<string, EnvSession> })[ | |
| SESSION_KEY | |
| ] ??= new Map<string, EnvSession>()); | |
| function envBaseUrl(): string { | |
| return process.env.ENV_BASE_URL?.replace(/\/+$/, "") || DEFAULT_ENV_BASE_URL; | |
| } | |
| function wsUrl(baseUrl: string): string { | |
| const url = new URL(baseUrl); | |
| url.protocol = url.protocol === "https:" ? "wss:" : "ws:"; | |
| url.pathname = `${url.pathname.replace(/\/+$/, "")}/ws`; | |
| url.search = ""; | |
| url.hash = ""; | |
| return url.toString(); | |
| } | |
| function openWs(url: string): EnvSession { | |
| const ws = new WebSocket(wsUrl(url)); | |
| const pending = new Promise<void>((resolve, reject) => { | |
| ws.addEventListener("open", () => resolve(), { once: true }); | |
| ws.addEventListener( | |
| "error", | |
| () => reject(new Error(`Failed to connect to env websocket at ${wsUrl(url)}`)), | |
| { once: true }, | |
| ); | |
| }); | |
| return { baseUrl: url, ws, pending }; | |
| } | |
| async function getSession(episodeId: string): Promise<EnvSession> { | |
| const baseUrl = envBaseUrl(); | |
| const existing = sessions.get(episodeId); | |
| if (existing?.baseUrl === baseUrl && existing.ws.readyState !== WebSocket.CLOSED) { | |
| await existing.pending; | |
| return existing; | |
| } | |
| if (existing) closeSession(episodeId); | |
| const session = openWs(baseUrl); | |
| sessions.set(episodeId, session); | |
| await session.pending; | |
| return session; | |
| } | |
| function closeSession(episodeId: string): void { | |
| const existing = sessions.get(episodeId); | |
| sessions.delete(episodeId); | |
| if (existing && existing.ws.readyState !== WebSocket.CLOSED) { | |
| existing.ws.close(); | |
| } | |
| } | |
| async function sendWs(episodeId: string, message: unknown): Promise<unknown> { | |
| const session = await getSession(episodeId); | |
| return new Promise((resolve, reject) => { | |
| const cleanup = () => { | |
| session.ws.removeEventListener("message", onMessage); | |
| session.ws.removeEventListener("error", onError); | |
| session.ws.removeEventListener("close", onClose); | |
| }; | |
| const onMessage = (event: MessageEvent) => { | |
| cleanup(); | |
| try { | |
| const parsed = JSON.parse(String(event.data)); | |
| if (parsed.type === "error") { | |
| reject(new Error(parsed.data?.message || "Env websocket error")); | |
| return; | |
| } | |
| resolve(parsed.data); | |
| } catch (e) { | |
| reject(e instanceof Error ? e : new Error(String(e))); | |
| } | |
| }; | |
| const onError = () => { | |
| cleanup(); | |
| reject(new Error("Env websocket error")); | |
| }; | |
| const onClose = () => { | |
| cleanup(); | |
| reject(new Error("Env websocket closed")); | |
| }; | |
| session.ws.addEventListener("message", onMessage); | |
| session.ws.addEventListener("error", onError); | |
| session.ws.addEventListener("close", onClose); | |
| session.ws.send(JSON.stringify(message)); | |
| }); | |
| } | |
| export const envReset = createServerFn({ method: "POST" }) | |
| .inputValidator( | |
| (input: { seed?: number; episode_id?: string; topic?: string } | undefined) => input ?? {}, | |
| ) | |
| .handler(async ({ data }): Promise<StepResult> => { | |
| const body: Record<string, unknown> = {}; | |
| if (typeof data.seed === "number") body.seed = data.seed; | |
| if (data.episode_id) body.episode_id = data.episode_id; | |
| if (data.topic) body.topic = data.topic; | |
| const episodeId = data.episode_id || crypto.randomUUID(); | |
| closeSession(episodeId); | |
| const result = (await sendWs(episodeId, { type: "reset", data: body })) as Partial<StepResult>; | |
| return { | |
| observation: result.observation as StepResult["observation"], | |
| reward: result.reward ?? null, | |
| done: Boolean(result.done), | |
| }; | |
| }); | |
| export const envStep = createServerFn({ method: "POST" }) | |
| .inputValidator((input: { episode_id: string; action: ExplainerAction }) => input) | |
| .handler(async ({ data }): Promise<StepResult> => { | |
| const result = (await sendWs(data.episode_id, { | |
| type: "step", | |
| data: data.action, | |
| })) as Partial<StepResult>; | |
| return { | |
| observation: result.observation as StepResult["observation"], | |
| reward: result.reward ?? null, | |
| done: Boolean(result.done), | |
| }; | |
| }); | |
| export const getRuntimeConfig = createServerFn({ method: "GET" }).handler(async () => { | |
| return { | |
| envUrl: envBaseUrl(), | |
| modelName: process.env.MODEL_NAME || DEFAULT_MODEL_NAME, | |
| apiBaseUrl: process.env.API_BASE_URL || "https://router.huggingface.co/v1", | |
| hasToken: Boolean(process.env.HF_TOKEN || process.env.API_KEY), | |
| }; | |
| }); | |