NeoPy's picture
Upload folder using huggingface_hub
867b17d verified
raw
history blame
6.91 kB
import {
convertToModelMessages,
createUIMessageStream,
JsonToSseTransformStream,
smoothStream,
stepCountIs,
streamText,
} from 'ai';
import { auth, type UserType } from '@/app/(auth)/auth';
import { type RequestHints, systemPrompt } from '@/lib/ai/prompts';
import {
createStreamId,
deleteChatById,
getChatById,
getMessageCountByUserId,
getMessagesByChatId,
saveChat,
saveMessages,
} from '@/lib/db/queries';
import { convertToUIMessages, generateUUID } from '@/lib/utils';
import { generateTitleFromUserMessage } from '../../actions';
import { createDocument } from '@/lib/ai/tools/create-document';
import { updateDocument } from '@/lib/ai/tools/update-document';
import { requestSuggestions } from '@/lib/ai/tools/request-suggestions';
import { getWeather } from '@/lib/ai/tools/get-weather';
import { isProductionEnvironment } from '@/lib/constants';
import { myProvider } from '@/lib/ai/providers';
import { entitlementsByUserType } from '@/lib/ai/entitlements';
import { postRequestBodySchema, type PostRequestBody } from './schema';
import { geolocation } from '@vercel/functions';
import {
createResumableStreamContext,
type ResumableStreamContext,
} from 'resumable-stream';
import { after } from 'next/server';
import { ChatSDKError } from '@/lib/errors';
import type { ChatMessage } from '@/lib/types';
import type { ChatModel } from '@/lib/ai/models';
import type { VisibilityType } from '@/components/visibility-selector';
export const maxDuration = 60;
let globalStreamContext: ResumableStreamContext | null = null;
export function getStreamContext() {
if (!globalStreamContext) {
try {
globalStreamContext = createResumableStreamContext({
waitUntil: after,
});
} catch (error: any) {
if (error.message.includes('REDIS_URL')) {
console.log(
' > Resumable streams are disabled due to missing REDIS_URL',
);
} else {
console.error(error);
}
}
}
return globalStreamContext;
}
export async function POST(request: Request) {
let requestBody: PostRequestBody;
try {
const json = await request.json();
requestBody = postRequestBodySchema.parse(json);
} catch (_) {
return new ChatSDKError('bad_request:api').toResponse();
}
try {
const {
id,
message,
selectedChatModel,
selectedVisibilityType,
}: {
id: string;
message: ChatMessage;
selectedChatModel: ChatModel['id'];
selectedVisibilityType: VisibilityType;
} = requestBody;
const session = await auth();
if (!session?.user) {
return new ChatSDKError('unauthorized:chat').toResponse();
}
const userType: UserType = session.user.type;
const messageCount = await getMessageCountByUserId({
id: session.user.id,
differenceInHours: 24,
});
if (messageCount > entitlementsByUserType[userType].maxMessagesPerDay) {
return new ChatSDKError('rate_limit:chat').toResponse();
}
const chat = await getChatById({ id });
if (!chat) {
const title = await generateTitleFromUserMessage({
message,
});
await saveChat({
id,
userId: session.user.id,
title,
visibility: selectedVisibilityType,
});
} else {
if (chat.userId !== session.user.id) {
return new ChatSDKError('forbidden:chat').toResponse();
}
}
const messagesFromDb = await getMessagesByChatId({ id });
const uiMessages = [...convertToUIMessages(messagesFromDb), message];
const { longitude, latitude, city, country } = geolocation(request);
const requestHints: RequestHints = {
longitude,
latitude,
city,
country,
};
await saveMessages({
messages: [
{
chatId: id,
id: message.id,
role: 'user',
parts: message.parts,
attachments: [],
createdAt: new Date(),
},
],
});
const streamId = generateUUID();
await createStreamId({ streamId, chatId: id });
const stream = createUIMessageStream({
execute: ({ writer: dataStream }) => {
const result = streamText({
model: myProvider.languageModel(selectedChatModel),
system: systemPrompt({ selectedChatModel, requestHints }),
messages: convertToModelMessages(uiMessages),
stopWhen: stepCountIs(5),
experimental_activeTools:
selectedChatModel === 'chat-model-reasoning'
? []
: [
'getWeather',
'createDocument',
'updateDocument',
'requestSuggestions',
],
experimental_transform: smoothStream({ chunking: 'word' }),
tools: {
getWeather,
createDocument: createDocument({ session, dataStream }),
updateDocument: updateDocument({ session, dataStream }),
requestSuggestions: requestSuggestions({
session,
dataStream,
}),
},
experimental_telemetry: {
isEnabled: isProductionEnvironment,
functionId: 'stream-text',
},
});
result.consumeStream();
dataStream.merge(
result.toUIMessageStream({
sendReasoning: true,
}),
);
},
generateId: generateUUID,
onFinish: async ({ messages }) => {
await saveMessages({
messages: messages.map((message) => ({
id: message.id,
role: message.role,
parts: message.parts,
createdAt: new Date(),
attachments: [],
chatId: id,
})),
});
},
onError: () => {
return 'Oops, an error occurred!';
},
});
const streamContext = getStreamContext();
if (streamContext) {
return new Response(
await streamContext.resumableStream(streamId, () =>
stream.pipeThrough(new JsonToSseTransformStream()),
),
);
} else {
return new Response(stream.pipeThrough(new JsonToSseTransformStream()));
}
} catch (error) {
if (error instanceof ChatSDKError) {
return error.toResponse();
}
}
}
export async function DELETE(request: Request) {
const { searchParams } = new URL(request.url);
const id = searchParams.get('id');
if (!id) {
return new ChatSDKError('bad_request:api').toResponse();
}
const session = await auth();
if (!session?.user) {
return new ChatSDKError('unauthorized:chat').toResponse();
}
const chat = await getChatById({ id });
if (chat.userId !== session.user.id) {
return new ChatSDKError('forbidden:chat').toResponse();
}
const deletedChat = await deleteChatById({ id });
return Response.json(deletedChat, { status: 200 });
}