File size: 4,235 Bytes
db242f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import * as Joi from 'joi';
import {
Body,
Controller,
Delete,
Get,
Param,
Post,
Sse,
} from '@nestjs/common';
import { Role } from '@prisma/client';
import { BizException } from '@/common/exceptions/biz.exception';
import { Payload } from '@/common/guards/auth.guard';
import { JoiValidationPipe } from '@/common/pipes/joi';
import { JWTPayload } from '@/libs/jwt/jwt.service';
import { NewMessageDto } from 'shared';
import { ErrorCodeEnum } from 'shared/dist/error-code';
import { ChatService } from './chat.service';
const newMessageSchema = Joi.object({
modelId: Joi.number().required(),
content: Joi.string().required(),
memoryPrompt: Joi.string().optional(),
});
@Controller('chat')
export class ChatController {
constructor(private readonly chatService: ChatService) {}
/* 获取最近的 session 列表 */
@Get('sessions')
getMyChatSession(@Payload('id') userId: number) {
return this.chatService.getRecentChatSession(userId);
}
/* 删除对话 */
@Delete('sessions/:sessionId')
async deleteChatSession(
@Payload('id') userId: number,
@Param('sessionId') sessionId: string,
) {
await this.chatService.deleteChatSession(userId, sessionId);
return {
success: true,
};
}
/* 删除消息 */
@Delete('messages/:messageId')
async deleteChatMessage(
@Payload('id') userId: number,
@Param('messageId') messageId: string,
) {
await this.chatService.deleteChatMessage(userId, messageId);
return {
success: true,
};
}
/* 获取具体 session 的消息历史 */
@Get('messages/:sessionId')
async getChatMessages(
@Payload('id') userId: number,
@Param('sessionId') sessionId: string,
) {
const data = await this.chatService.getChatMessages(userId, sessionId);
if (!data) {
return {
sid: sessionId,
topic: undefined,
messages: [
{
role: 'system',
content: '你好,请问有什么可以帮助您?',
},
],
updateAt: new Date(),
_count: {
messages: 1,
},
};
}
return {
...data,
messages: data.messages.map((m) => ({
...m,
role: m.role.toLowerCase(),
})),
};
}
/* 获取对话的总结,每一个对话的总结仅会发生一次 */
@Post('summary/:sessionId')
async getSummary(
@Payload('id') userId: number,
@Param('sessionId') sessionId: string,
) {
const data = await this.chatService.getChatMessages(userId, sessionId);
if (!data) {
throw new BizException(ErrorCodeEnum.SessionNotFound);
}
if (data.topic) {
return {
success: true,
data: {
topic: data.topic,
},
};
}
if (data.messages.length > 1) {
const topic = await this.chatService.summarizeTopic(
data.messages.map((m) => `${m.role}: ${m.content}`).join('\n'),
sessionId,
);
return {
success: true,
data: {
topic,
},
};
}
return {
success: true,
data: {
topic: undefined,
},
};
}
/* 新建用户流式传输的对话 */
@Post('messages/:sessionId?')
@Sse()
async newMessageStream(
@Payload() payload: JWTPayload,
@Body(new JoiValidationPipe(newMessageSchema)) data: NewMessageDto,
@Param('sessionId') sessionId: string,
) {
const { id: userId, role: userRole } = payload;
/* 用量限制 */
if (userRole !== Role.Admin) {
const isValid = await this.chatService.limitCheck(userId, data.modelId);
if (isValid <= 0) {
throw new BizException(ErrorCodeEnum.OutOfQuota);
}
}
// 检查数据库中是否存在该 session,不存在则新建
const chatSession = await this.chatService.getOrNewChatSession(
sessionId,
userId,
data.memoryPrompt,
);
/* 从 Key Pool 中挑选合适的 Key */
// const key = await this.keyPool.select();
return await this.chatService.newMessageStream({
userId: userId,
sessionId: chatSession.id,
modelId: data.modelId,
input: data.content,
messages: chatSession.messages,
// key,
});
}
}
|