File size: 3,214 Bytes
1b83e76
 
 
 
 
 
 
 
 
 
 
 
 
a5a755a
1b83e76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5a755a
1b83e76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import { createServerFn } from "@tanstack/react-start";
import type { ExplainerObservation, ParsedAction } from "@/lib/types";
import {
  SYSTEM_PROMPT,
  buildExplorePrompt,
  buildGeneratePrompt,
  buildRepairPrompt,
  parseExploreResponse,
  parseGenerateResponse,
} from "./llm/prompts.server";

const TEMPERATURE = 0.7;
const MAX_TOKENS = 4096;
const DEFAULT_MODEL_NAME = "bedrock-qwen3-coder-30b-a3b";

const ALLOWED_TOOLS = new Set([
  "search_wikipedia",
  "search_hf_papers",
  "search_arxiv",
  "search_scholar",
  "fetch_docs",
  "search_hf_hub",
]);

async function callChat(userPrompt: string): Promise<string> {
  const apiBase = (process.env.API_BASE_URL || "https://router.huggingface.co/v1").replace(
    /\/+$/,
    "",
  );
  const apiKey = process.env.HF_TOKEN || process.env.API_KEY;
  const model = process.env.MODEL_NAME || DEFAULT_MODEL_NAME;
  if (!apiKey) throw new Error("HF_TOKEN is not configured on the server.");
  const res = await fetch(`${apiBase}/chat/completions`, {
    method: "POST",
    headers: {
      "Content-Type": "application/json",
      Authorization: `Bearer ${apiKey}`,
    },
    body: JSON.stringify({
      model,
      temperature: TEMPERATURE,
      max_tokens: MAX_TOKENS,
      stream: false,
      messages: [
        { role: "system", content: SYSTEM_PROMPT },
        { role: "user", content: userPrompt },
      ],
    }),
  });
  const text = await res.text();
  if (!res.ok) {
    throw new Error(`LLM ${res.status}: ${text.slice(0, 500)}`);
  }
  try {
    const data = JSON.parse(text);
    return (data.choices?.[0]?.message?.content || "").trim();
  } catch {
    throw new Error(`LLM invalid JSON: ${text.slice(0, 200)}`);
  }
}

export const runLlmStep = createServerFn({ method: "POST" })
  .inputValidator(
    (input: {
      phase: "explore" | "generate" | "repair";
      obs: ExplainerObservation;
      exploreStepIndex?: number;
      previousCode?: string;
      previousFormat?: "marimo" | "manim";
    }) => input,
  )
  .handler(async ({ data }): Promise<{ raw: string; parsed: ParsedAction }> => {
    if (data.phase === "explore") {
      const prompt = buildExplorePrompt(data.obs, data.exploreStepIndex ?? 1);
      const raw = await callChat(prompt);
      const parsed = parseExploreResponse(raw, data.obs.topic);
      const tool = ALLOWED_TOOLS.has(parsed.tool) ? parsed.tool : "search_wikipedia";
      return {
        raw,
        parsed: {
          kind: "explore",
          tool: tool as ParsedAction extends { tool: infer T } ? T : never,
          query: parsed.query,
          intent: parsed.intent,
          skip: parsed.skip,
        } as ParsedAction,
      };
    }
    if (data.phase === "generate") {
      const prompt = buildGeneratePrompt(data.obs);
      const raw = await callChat(prompt);
      const parsed = parseGenerateResponse(raw);
      return { raw, parsed: { kind: "generate", ...parsed } };
    }
    // repair
    const fmt = data.previousFormat || "marimo";
    const prompt = buildRepairPrompt(data.obs, fmt, data.previousCode || "");
    const raw = await callChat(prompt);
    const parsed = parseGenerateResponse(raw);
    return { raw, parsed: { kind: "repair", ...parsed } };
  });