Spaces:
Running
Running
File size: 2,423 Bytes
fc69895 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import { MessageUpdateType, type MessageUpdate } from "$lib/types/MessageUpdate";
import { AbortedGenerations } from "../abortedGenerations";
import type { TextGenerationContext } from "./types";
import type { EndpointMessage } from "../endpoints/endpoints";
import { logger } from "../logger";
type GenerateContext = Omit<TextGenerationContext, "messages"> & { messages: EndpointMessage[] };
export async function* generate(
{
model,
endpoint,
conv,
messages,
assistant,
promptedAt,
forceMultimodal,
locals,
abortController,
}: GenerateContext,
preprompt?: string
): AsyncIterable<MessageUpdate> {
const stream = await endpoint({
messages,
preprompt,
generateSettings: assistant?.generateSettings,
// Allow user-level override to force multimodal
isMultimodal: (forceMultimodal ?? false) || model.multimodal,
conversationId: conv._id,
locals,
abortSignal: abortController.signal,
});
for await (const output of stream) {
// Check if this output contains router metadata
if (
"routerMetadata" in output &&
output.routerMetadata &&
((output.routerMetadata.route && output.routerMetadata.model) ||
output.routerMetadata.provider)
) {
yield {
type: MessageUpdateType.RouterMetadata,
route: output.routerMetadata.route || "",
model: output.routerMetadata.model || "",
provider: output.routerMetadata.provider,
};
continue;
}
// text generation completed
if (output.generated_text) {
let interrupted =
!output.token.special && !model.parameters.stop?.includes(output.token.text);
let text = output.generated_text.trimEnd();
for (const stopToken of model.parameters.stop ?? []) {
if (!text.endsWith(stopToken)) continue;
interrupted = false;
text = text.slice(0, text.length - stopToken.length);
}
yield {
type: MessageUpdateType.FinalAnswer,
text,
interrupted,
};
continue;
}
// ignore special tokens
if (output.token.special) continue;
// yield normal token
yield { type: MessageUpdateType.Stream, token: output.token.text };
// abort check
const date = AbortedGenerations.getInstance().getAbortTime(conv._id.toString());
if (date && date > promptedAt) {
logger.info(`Aborting generation for conversation ${conv._id}`);
if (!abortController.signal.aborted) {
abortController.abort();
}
break;
}
// no output check
if (!output) break;
}
}
|