File size: 1,479 Bytes
7bf1507
bf75aa7
5b1a9aa
b5ae065
bf75aa7
b5ae065
 
66fa870
7bf1507
d0901fb
b5ae065
5b1a9aa
b5ae065
66fa870
7bf1507
 
d0901fb
bf75aa7
2ee0fbc
7bf1507
 
d0901fb
63f4e60
 
2ee0fbc
 
 
 
7bf1507
2ee0fbc
 
 
b5ae065
2ee0fbc
b5ae065
2ee0fbc
 
 
 
04f5527
2ee0fbc
 
04f5527
2ee0fbc
 
0c6af63
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import { taskModel, models } from "$lib/server/models";
import { MessageUpdateType, type MessageUpdate } from "$lib/types/MessageUpdate";
import type { EndpointMessage } from "./endpoints/endpoints";

export async function* generateFromDefaultEndpoint({
	messages,
	preprompt,
	generateSettings,
	modelId,
	apiKey,
}: {
	messages: EndpointMessage[];
	preprompt?: string;
	generateSettings?: Record<string, unknown>;
	/** Optional: use this model instead of the default task model */
	modelId?: string;
	apiKey?: string;
}): AsyncGenerator<MessageUpdate, string, undefined> {
	try {
		// Choose endpoint based on provided modelId, else fall back to taskModel
		const model = modelId ? (models.find((m) => m.id === modelId) ?? taskModel) : taskModel;
		const endpoint = await model.getEndpoint(apiKey ? { apiKey } : undefined);
		const tokenStream = await endpoint({ messages, preprompt, generateSettings });

		for await (const output of tokenStream) {
			// if not generated_text is here it means the generation is not done
			if (output.generated_text) {
				let generated_text = output.generated_text;
				for (const stop of [...(model.parameters?.stop ?? []), "<|endoftext|>"]) {
					if (generated_text.endsWith(stop)) {
						generated_text = generated_text.slice(0, -stop.length).trimEnd();
					}
				}
				return generated_text;
			}
			yield {
				type: MessageUpdateType.Stream,
				token: output.token.text,
			};
		}
	} catch (error) {
		return "";
	}

	return "";
}