inference-playground / src /lib /chat /triggerAiCall.ts
enzostvs's picture
enzostvs HF Staff
run format
b3e1908
import type { ChatMessage, TokenUsage } from '$lib/helpers/types';
import { modelsState } from '$lib/state/models.svelte';
import { token } from '$lib/state/token.svelte';
import type { Edge, Node } from '@xyflow/svelte';
export interface TriggerAiCallContext {
userId: string;
newNodes: Node[];
messages: ChatMessage[];
selectedModels: string[];
prompt: string;
nodeData: Record<string, unknown> | undefined;
authToken: string;
billingOption: string;
updateNodeData: (id: string, data: Record<string, unknown>, opts?: { replace?: boolean }) => void;
updateNodes: (fn: (nodes: Node[]) => Node[]) => void;
updateEdges: (fn: (edges: Edge[]) => Edge[]) => void;
onLoadingChange: (loading: boolean) => void;
onError: (message: string) => void;
}
function formatMessagesForModel(messages: ChatMessage[], modelId: string) {
return messages.map((message) => {
if (message.role === 'user') {
return { role: 'user' as const, content: message.content };
}
const modelData = message[modelId];
const content =
typeof modelData === 'object' && modelData && 'content' in modelData
? modelData.content
: (Object.values(message).find(
(m): m is { content: string } => typeof m === 'object' && m !== null && 'content' in m
)?.content ?? '');
return { role: 'assistant' as const, content };
});
}
export async function triggerAiCall(ctx: TriggerAiCallContext): Promise<void> {
const {
userId,
newNodes,
messages,
selectedModels,
prompt,
nodeData,
billingOption,
updateNodeData,
updateNodes,
updateEdges,
onLoadingChange,
onError
} = ctx;
onLoadingChange(true);
const failedNodeIds = new Set<string>();
const failedModelIds = new Set<string>();
const errorMessages: string[] = [];
const results = await Promise.all(
newNodes.map(async (node) => {
const model = node?.data?.selectedModel as string;
if (!model) return null;
try {
const modelSettings = modelsState.models.find((m) => m.id === model);
const start = Date.now();
const formattedMessages = formatMessagesForModel(messages, model);
const response = await fetch('/api', {
method: 'POST',
body: JSON.stringify({
model: model,
provider: modelSettings?.provider ?? 'auto',
messages: formattedMessages,
billingTo: billingOption,
...(modelSettings
? {
options: {
temperature: modelSettings.temperature,
max_tokens: modelSettings.max_tokens,
top_p: modelSettings.top_p
}
}
: {})
}),
headers: { Authorization: `Bearer ${token.value}` }
});
if (!response.ok) {
const errorBody = await response.text().catch(() => response.statusText);
throw new Error(errorBody || response.statusText);
}
if (!response.body) throw new Error('No response body');
let content = '';
let reasoning = '';
let usage: TokenUsage | null = null;
let inThink = false;
let buffer = '';
const reader = response.body.getReader();
const decoder = new TextDecoder();
while (true) {
const { done, value } = await reader.read();
if (done) {
if (content.includes('__ERROR__')) {
const errMsg = content.split('__ERROR__').pop() ?? 'Unknown error';
throw new Error(errMsg);
}
if (content.includes('__USAGE__')) {
const usageParts = content.split('__USAGE__');
const usageJson = usageParts.pop() ?? '';
content = usageParts.join('').trimEnd();
try {
usage = JSON.parse(usageJson) as TokenUsage;
} catch {
// ignore malformed usage JSON
}
}
const end = Date.now();
updateNodeData(
node.id,
{
...node.data,
content,
reasoning,
timestamp: end - start,
loading: false,
messages,
usage
} as Record<string, unknown>,
{ replace: true }
);
return { [model]: { content, timestamp: String(end - start) } };
}
buffer += decoder.decode(value, { stream: true });
// Process buffer chunk by chunk, splitting on think tags
while (true) {
if (inThink) {
const closeIdx = buffer.indexOf('</think>');
if (closeIdx === -1) {
reasoning += buffer;
buffer = '';
break;
}
reasoning += buffer.slice(0, closeIdx);
buffer = buffer.slice(closeIdx + '</think>'.length);
inThink = false;
} else {
const openIdx = buffer.indexOf('<think>');
if (openIdx === -1) {
content += buffer;
buffer = '';
break;
}
content += buffer.slice(0, openIdx);
buffer = buffer.slice(openIdx + '<think>'.length);
inThink = true;
}
}
updateNodeData(
node.id,
{ ...node.data, content, reasoning, loading: false } as Record<string, unknown>,
{ replace: true }
);
}
} catch (error) {
const msg = error instanceof Error ? error.message : 'An unknown error occurred';
failedNodeIds.add(node.id);
// failedModelIds.add(model);
errorMessages.push(msg);
return null;
} finally {
onLoadingChange(false);
}
})
);
if (failedNodeIds.size > 0) {
updateNodes((currentNodes) => currentNodes.filter((n) => !failedNodeIds.has(n.id)));
updateEdges((currentEdges) =>
currentEdges.filter((e) => !failedNodeIds.has(e.target as string))
);
updateNodeData(
userId,
{
...nodeData,
messages: newNodes.length === failedNodeIds.size ? messages.slice(0, -1) : messages,
prompt: newNodes.length === failedNodeIds.size ? prompt : ''
// selectedModels: selectedModels.filter((m) => !failedModelIds.has(m))
} as Record<string, unknown>,
{ replace: true }
);
errorMessages.forEach((msg) => onError(msg));
}
const validResults = results.filter(
(r): r is Record<string, { content: string; timestamp: string }> => r != null
);
if (validResults.length === 0) return;
const assistantMessage = validResults.reduce<ChatMessage>(
(acc, result) => (result ? { ...acc, ...result } : acc),
{ role: 'assistant' }
);
const newNodeId = `user-${crypto.randomUUID()}`;
const newNode: Node = {
id: newNodeId,
type: 'user',
position: { x: 0, y: 0 },
data: {
role: 'user',
selectedModels: selectedModels.filter((m) => !failedModelIds.has(m)),
messages: [...messages, assistantMessage]
}
};
const newEdges: Edge[] = newNodes.map((node) => ({
id: `edge-${crypto.randomUUID()}`,
source: node.id,
target: newNodeId
}));
updateNodes((currentNodes) => [...currentNodes, newNode]);
updateEdges((currentEdges) => [...currentEdges, ...newEdges]);
}