File size: 1,471 Bytes
4331e77
764ecdf
b4690c8
0ab1dfb
764ecdf
0ab1dfb
 
42dd7d2
4331e77
c4f6eb3
0ab1dfb
b4690c8
0ab1dfb
42dd7d2
4331e77
 
c4f6eb3
764ecdf
e7e98be
4331e77
 
 
c4f6eb3
c5af266
e7e98be
 
 
 
4331e77
e7e98be
 
 
0ab1dfb
e7e98be
0ab1dfb
e7e98be
 
 
 
86f4420
e7e98be
 
86f4420
e7e98be
 
6c72ede
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import { taskModel, models } from "$lib/server/models";
import { MessageUpdateType, type MessageUpdate } from "$lib/types/MessageUpdate";
import type { EndpointMessage } from "./endpoints/endpoints";

export async function* generateFromDefaultEndpoint({
	messages,
	preprompt,
	generateSettings,
	modelId,
	locals,
}: {
	messages: EndpointMessage[];
	preprompt?: string;
	generateSettings?: Record<string, unknown>;
	/** Optional: use this model instead of the default task model */
	modelId?: string;
	locals: App.Locals | undefined;
}): AsyncGenerator<MessageUpdate, string, undefined> {
	try {
		// Choose endpoint based on provided modelId, else fall back to taskModel
		const model = modelId ? (models.find((m) => m.id === modelId) ?? taskModel) : taskModel;
		const endpoint = await model.getEndpoint();
		const tokenStream = await endpoint({ messages, preprompt, generateSettings, locals });

		for await (const output of tokenStream) {
			// if not generated_text is here it means the generation is not done
			if (output.generated_text) {
				let generated_text = output.generated_text;
				for (const stop of [...(model.parameters?.stop ?? []), "<|endoftext|>"]) {
					if (generated_text.endsWith(stop)) {
						generated_text = generated_text.slice(0, -stop.length).trimEnd();
					}
				}
				return generated_text;
			}
			yield {
				type: MessageUpdateType.Stream,
				token: output.token.text,
			};
		}
	} catch (error) {
		return "";
	}

	return "";
}