Spaces:
Sleeping
Sleeping
File size: 1,479 Bytes
7bf1507 bf75aa7 5b1a9aa b5ae065 bf75aa7 b5ae065 66fa870 7bf1507 d0901fb b5ae065 5b1a9aa b5ae065 66fa870 7bf1507 d0901fb bf75aa7 2ee0fbc 7bf1507 d0901fb 63f4e60 2ee0fbc 7bf1507 2ee0fbc b5ae065 2ee0fbc b5ae065 2ee0fbc 04f5527 2ee0fbc 04f5527 2ee0fbc 0c6af63 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import { taskModel, models } from "$lib/server/models";
import { MessageUpdateType, type MessageUpdate } from "$lib/types/MessageUpdate";
import type { EndpointMessage } from "./endpoints/endpoints";
export async function* generateFromDefaultEndpoint({
messages,
preprompt,
generateSettings,
modelId,
apiKey,
}: {
messages: EndpointMessage[];
preprompt?: string;
generateSettings?: Record<string, unknown>;
/** Optional: use this model instead of the default task model */
modelId?: string;
apiKey?: string;
}): AsyncGenerator<MessageUpdate, string, undefined> {
try {
// Choose endpoint based on provided modelId, else fall back to taskModel
const model = modelId ? (models.find((m) => m.id === modelId) ?? taskModel) : taskModel;
const endpoint = await model.getEndpoint(apiKey ? { apiKey } : undefined);
const tokenStream = await endpoint({ messages, preprompt, generateSettings });
for await (const output of tokenStream) {
// if not generated_text is here it means the generation is not done
if (output.generated_text) {
let generated_text = output.generated_text;
for (const stop of [...(model.parameters?.stop ?? []), "<|endoftext|>"]) {
if (generated_text.endsWith(stop)) {
generated_text = generated_text.slice(0, -stop.length).trimEnd();
}
}
return generated_text;
}
yield {
type: MessageUpdateType.Stream,
token: output.token.text,
};
}
} catch (error) {
return "";
}
return "";
}
|