Spaces:
Sleeping
Sleeping
File size: 4,830 Bytes
1b83e76 a5a755a 1b83e76 a5a755a 1b83e76 a5a755a 1b83e76 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | import { createServerFn } from "@tanstack/react-start";
import type { ExplainerAction, StepResult } from "@/lib/types";
interface EnvSession {
baseUrl: string;
ws: WebSocket;
pending: Promise<void>;
}
const SESSION_KEY = "__explainer_dashboard_env_sessions__";
const DEFAULT_ENV_BASE_URL = "https://kgdrathan-explainer-env.hf.space";
const DEFAULT_MODEL_NAME = "bedrock-qwen3-coder-30b-a3b";
const sessions = ((globalThis as typeof globalThis & { [SESSION_KEY]?: Map<string, EnvSession> })[
SESSION_KEY
] ??= new Map<string, EnvSession>());
function envBaseUrl(): string {
return process.env.ENV_BASE_URL?.replace(/\/+$/, "") || DEFAULT_ENV_BASE_URL;
}
function wsUrl(baseUrl: string): string {
const url = new URL(baseUrl);
url.protocol = url.protocol === "https:" ? "wss:" : "ws:";
url.pathname = `${url.pathname.replace(/\/+$/, "")}/ws`;
url.search = "";
url.hash = "";
return url.toString();
}
function openWs(url: string): EnvSession {
const ws = new WebSocket(wsUrl(url));
const pending = new Promise<void>((resolve, reject) => {
ws.addEventListener("open", () => resolve(), { once: true });
ws.addEventListener(
"error",
() => reject(new Error(`Failed to connect to env websocket at ${wsUrl(url)}`)),
{ once: true },
);
});
return { baseUrl: url, ws, pending };
}
async function getSession(episodeId: string): Promise<EnvSession> {
const baseUrl = envBaseUrl();
const existing = sessions.get(episodeId);
if (existing?.baseUrl === baseUrl && existing.ws.readyState !== WebSocket.CLOSED) {
await existing.pending;
return existing;
}
if (existing) closeSession(episodeId);
const session = openWs(baseUrl);
sessions.set(episodeId, session);
await session.pending;
return session;
}
function closeSession(episodeId: string): void {
const existing = sessions.get(episodeId);
sessions.delete(episodeId);
if (existing && existing.ws.readyState !== WebSocket.CLOSED) {
existing.ws.close();
}
}
async function sendWs(episodeId: string, message: unknown): Promise<unknown> {
const session = await getSession(episodeId);
return new Promise((resolve, reject) => {
const cleanup = () => {
session.ws.removeEventListener("message", onMessage);
session.ws.removeEventListener("error", onError);
session.ws.removeEventListener("close", onClose);
};
const onMessage = (event: MessageEvent) => {
cleanup();
try {
const parsed = JSON.parse(String(event.data));
if (parsed.type === "error") {
reject(new Error(parsed.data?.message || "Env websocket error"));
return;
}
resolve(parsed.data);
} catch (e) {
reject(e instanceof Error ? e : new Error(String(e)));
}
};
const onError = () => {
cleanup();
reject(new Error("Env websocket error"));
};
const onClose = () => {
cleanup();
reject(new Error("Env websocket closed"));
};
session.ws.addEventListener("message", onMessage);
session.ws.addEventListener("error", onError);
session.ws.addEventListener("close", onClose);
session.ws.send(JSON.stringify(message));
});
}
export const envReset = createServerFn({ method: "POST" })
.inputValidator(
(input: { seed?: number; episode_id?: string; topic?: string } | undefined) => input ?? {},
)
.handler(async ({ data }): Promise<StepResult> => {
const body: Record<string, unknown> = {};
if (typeof data.seed === "number") body.seed = data.seed;
if (data.episode_id) body.episode_id = data.episode_id;
if (data.topic) body.topic = data.topic;
const episodeId = data.episode_id || crypto.randomUUID();
closeSession(episodeId);
const result = (await sendWs(episodeId, { type: "reset", data: body })) as Partial<StepResult>;
return {
observation: result.observation as StepResult["observation"],
reward: result.reward ?? null,
done: Boolean(result.done),
};
});
export const envStep = createServerFn({ method: "POST" })
.inputValidator((input: { episode_id: string; action: ExplainerAction }) => input)
.handler(async ({ data }): Promise<StepResult> => {
const result = (await sendWs(data.episode_id, {
type: "step",
data: data.action,
})) as Partial<StepResult>;
return {
observation: result.observation as StepResult["observation"],
reward: result.reward ?? null,
done: Boolean(result.done),
};
});
export const getRuntimeConfig = createServerFn({ method: "GET" }).handler(async () => {
return {
envUrl: envBaseUrl(),
modelName: process.env.MODEL_NAME || DEFAULT_MODEL_NAME,
apiBaseUrl: process.env.API_BASE_URL || "https://router.huggingface.co/v1",
hasToken: Boolean(process.env.HF_TOKEN || process.env.API_KEY),
};
});
|