File size: 4,830 Bytes
1b83e76
 
 
 
 
 
 
 
 
 
a5a755a
 
1b83e76
 
 
 
 
 
a5a755a
1b83e76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5a755a
1b83e76
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import { createServerFn } from "@tanstack/react-start";
import type { ExplainerAction, StepResult } from "@/lib/types";

interface EnvSession {
  baseUrl: string;
  ws: WebSocket;
  pending: Promise<void>;
}

const SESSION_KEY = "__explainer_dashboard_env_sessions__";
const DEFAULT_ENV_BASE_URL = "https://kgdrathan-explainer-env.hf.space";
const DEFAULT_MODEL_NAME = "bedrock-qwen3-coder-30b-a3b";

const sessions = ((globalThis as typeof globalThis & { [SESSION_KEY]?: Map<string, EnvSession> })[
  SESSION_KEY
] ??= new Map<string, EnvSession>());

function envBaseUrl(): string {
  return process.env.ENV_BASE_URL?.replace(/\/+$/, "") || DEFAULT_ENV_BASE_URL;
}

function wsUrl(baseUrl: string): string {
  const url = new URL(baseUrl);
  url.protocol = url.protocol === "https:" ? "wss:" : "ws:";
  url.pathname = `${url.pathname.replace(/\/+$/, "")}/ws`;
  url.search = "";
  url.hash = "";
  return url.toString();
}

function openWs(url: string): EnvSession {
  const ws = new WebSocket(wsUrl(url));
  const pending = new Promise<void>((resolve, reject) => {
    ws.addEventListener("open", () => resolve(), { once: true });
    ws.addEventListener(
      "error",
      () => reject(new Error(`Failed to connect to env websocket at ${wsUrl(url)}`)),
      { once: true },
    );
  });
  return { baseUrl: url, ws, pending };
}

async function getSession(episodeId: string): Promise<EnvSession> {
  const baseUrl = envBaseUrl();
  const existing = sessions.get(episodeId);
  if (existing?.baseUrl === baseUrl && existing.ws.readyState !== WebSocket.CLOSED) {
    await existing.pending;
    return existing;
  }
  if (existing) closeSession(episodeId);
  const session = openWs(baseUrl);
  sessions.set(episodeId, session);
  await session.pending;
  return session;
}

function closeSession(episodeId: string): void {
  const existing = sessions.get(episodeId);
  sessions.delete(episodeId);
  if (existing && existing.ws.readyState !== WebSocket.CLOSED) {
    existing.ws.close();
  }
}

async function sendWs(episodeId: string, message: unknown): Promise<unknown> {
  const session = await getSession(episodeId);
  return new Promise((resolve, reject) => {
    const cleanup = () => {
      session.ws.removeEventListener("message", onMessage);
      session.ws.removeEventListener("error", onError);
      session.ws.removeEventListener("close", onClose);
    };
    const onMessage = (event: MessageEvent) => {
      cleanup();
      try {
        const parsed = JSON.parse(String(event.data));
        if (parsed.type === "error") {
          reject(new Error(parsed.data?.message || "Env websocket error"));
          return;
        }
        resolve(parsed.data);
      } catch (e) {
        reject(e instanceof Error ? e : new Error(String(e)));
      }
    };
    const onError = () => {
      cleanup();
      reject(new Error("Env websocket error"));
    };
    const onClose = () => {
      cleanup();
      reject(new Error("Env websocket closed"));
    };
    session.ws.addEventListener("message", onMessage);
    session.ws.addEventListener("error", onError);
    session.ws.addEventListener("close", onClose);
    session.ws.send(JSON.stringify(message));
  });
}

export const envReset = createServerFn({ method: "POST" })
  .inputValidator(
    (input: { seed?: number; episode_id?: string; topic?: string } | undefined) => input ?? {},
  )
  .handler(async ({ data }): Promise<StepResult> => {
    const body: Record<string, unknown> = {};
    if (typeof data.seed === "number") body.seed = data.seed;
    if (data.episode_id) body.episode_id = data.episode_id;
    if (data.topic) body.topic = data.topic;
    const episodeId = data.episode_id || crypto.randomUUID();
    closeSession(episodeId);
    const result = (await sendWs(episodeId, { type: "reset", data: body })) as Partial<StepResult>;
    return {
      observation: result.observation as StepResult["observation"],
      reward: result.reward ?? null,
      done: Boolean(result.done),
    };
  });

export const envStep = createServerFn({ method: "POST" })
  .inputValidator((input: { episode_id: string; action: ExplainerAction }) => input)
  .handler(async ({ data }): Promise<StepResult> => {
    const result = (await sendWs(data.episode_id, {
      type: "step",
      data: data.action,
    })) as Partial<StepResult>;
    return {
      observation: result.observation as StepResult["observation"],
      reward: result.reward ?? null,
      done: Boolean(result.done),
    };
  });

export const getRuntimeConfig = createServerFn({ method: "GET" }).handler(async () => {
  return {
    envUrl: envBaseUrl(),
    modelName: process.env.MODEL_NAME || DEFAULT_MODEL_NAME,
    apiBaseUrl: process.env.API_BASE_URL || "https://router.huggingface.co/v1",
    hasToken: Boolean(process.env.HF_TOKEN || process.env.API_KEY),
  };
});