| import { Elysia, error, t } from "elysia"; |
| import { authPlugin } from "$api/authPlugin"; |
| import { collections } from "$lib/server/database"; |
| import { ObjectId } from "mongodb"; |
| import { authCondition } from "$lib/server/auth"; |
| import { models, validModelIdSchema } from "$lib/server/models"; |
| import { convertLegacyConversation } from "$lib/utils/tree/convertLegacyConversation"; |
| import type { Conversation } from "$lib/types/Conversation"; |
|
|
| import { CONV_NUM_PER_PAGE } from "$lib/constants/pagination"; |
| import pkg from "natural"; |
| const { PorterStemmer } = pkg; |
|
|
| export const conversationGroup = new Elysia().use(authPlugin).group("/conversations", (app) => { |
| return app |
| .guard({ |
| as: "scoped", |
| beforeHandle: async ({ locals }) => { |
| if (!locals.user?._id && !locals.sessionId) { |
| return error(401, "Must have a valid session or user"); |
| } |
| }, |
| }) |
| .get( |
| "", |
| async ({ locals, query }) => { |
| const convs = await collections.conversations |
| .find(authCondition(locals)) |
| .project<Pick<Conversation, "_id" | "title" | "updatedAt" | "model" | "assistantId">>({ |
| title: 1, |
| updatedAt: 1, |
| model: 1, |
| assistantId: 1, |
| }) |
| .sort({ updatedAt: -1 }) |
| .skip((query.p ?? 0) * CONV_NUM_PER_PAGE) |
| .limit(CONV_NUM_PER_PAGE) |
| .toArray(); |
|
|
| const nConversations = await collections.conversations.countDocuments( |
| authCondition(locals) |
| ); |
|
|
| const res = convs.map((conv) => ({ |
| _id: conv._id, |
| id: conv._id, |
| title: conv.title, |
| updatedAt: conv.updatedAt, |
| model: conv.model, |
| modelId: conv.model, |
| assistantId: conv.assistantId, |
| modelTools: models.find((m) => m.id == conv.model)?.tools ?? false, |
| })); |
|
|
| return { conversations: res, nConversations }; |
| }, |
| { |
| query: t.Object({ |
| p: t.Optional(t.Number()), |
| }), |
| } |
| ) |
| .delete("", async ({ locals }) => { |
| const res = await collections.conversations.deleteMany({ |
| ...authCondition(locals), |
| }); |
| return res.deletedCount; |
| }) |
| .get( |
| "/search", |
| async ({ locals, query }) => { |
| const searchQuery = query.q; |
| const p = query.p ?? 0; |
|
|
| if (!searchQuery || searchQuery.length < 3) { |
| return []; |
| } |
|
|
| if (!locals.user?._id && !locals.sessionId) { |
| throw new Error("Must have a valid session or user"); |
| } |
|
|
| const convs = await collections.conversations |
| .find({ |
| sessionId: undefined, |
| ...authCondition(locals), |
| $text: { $search: searchQuery }, |
| }) |
| .sort({ |
| updatedAt: -1, |
| }) |
| .project< |
| Pick< |
| Conversation, |
| "_id" | "title" | "updatedAt" | "model" | "assistantId" | "messages" | "userId" |
| > |
| >({ |
| title: 1, |
| updatedAt: 1, |
| model: 1, |
| assistantId: 1, |
| messages: 1, |
| userId: 1, |
| }) |
| .skip(p * 5) |
| .limit(5) |
| .toArray() |
| .then((convs) => |
| convs.map((conv) => { |
| let matchedContent = ""; |
| let matchedText = ""; |
|
|
| |
| let bestMatch = null; |
| let bestMatchLength = 0; |
|
|
| |
| const findBestMatch = ( |
| content: string, |
| query: string |
| ): { start: number; end: number; text: string } | null => { |
| const contentLower = content.toLowerCase(); |
| const queryLower = query.toLowerCase(); |
|
|
| |
| const wordRegex = new RegExp( |
| `\\b${queryLower.replace(/[.*+?^${}()|[\]\\]/g, "\\$&")}\\b`, |
| "gi" |
| ); |
| const wordMatch = wordRegex.exec(content); |
| if (wordMatch) { |
| return { |
| start: wordMatch.index, |
| end: wordMatch.index + wordMatch[0].length - 1, |
| text: wordMatch[0], |
| }; |
| } |
|
|
| |
| const index = contentLower.indexOf(queryLower); |
| if (index !== -1) { |
| return { |
| start: index, |
| end: index + queryLower.length - 1, |
| text: content.substring(index, index + queryLower.length), |
| }; |
| } |
|
|
| return null; |
| }; |
|
|
| |
| const searchVariations = [searchQuery.toLowerCase()]; |
|
|
| |
| try { |
| const stemmed = PorterStemmer.stem(searchQuery.toLowerCase()); |
| if (stemmed !== searchQuery.toLowerCase()) { |
| searchVariations.push(stemmed); |
| } |
|
|
| |
| for (const message of conv.messages) { |
| if (message.content) { |
| const words = message.content.toLowerCase().match(/\b\w+\b/g) || []; |
| words.forEach((word: string) => { |
| if ( |
| PorterStemmer.stem(word) === stemmed && |
| !searchVariations.includes(word) |
| ) { |
| searchVariations.push(word); |
| } |
| }); |
| } |
| } |
| } catch (e) { |
| console.warn("Stemming failed for:", searchQuery, e); |
| } |
|
|
| |
| const query = searchQuery.toLowerCase(); |
| if (query.endsWith("s") && query.length > 3) { |
| searchVariations.push(query.slice(0, -1)); |
| } else if (!query.endsWith("s")) { |
| searchVariations.push(query + "s"); |
| } |
|
|
| |
| for (const message of conv.messages) { |
| if (!message.content) continue; |
|
|
| |
| for (const variation of searchVariations) { |
| const match = findBestMatch(message.content, variation); |
| if (match) { |
| const isExactQuery = variation === searchQuery.toLowerCase(); |
| const priority = isExactQuery ? 1000 : match.text.length; |
|
|
| if (priority > bestMatchLength) { |
| bestMatch = { |
| content: message.content, |
| matchStart: match.start, |
| matchEnd: match.end, |
| matchedText: match.text, |
| }; |
| bestMatchLength = priority; |
|
|
| |
| if (isExactQuery) break; |
| } |
| } |
| } |
|
|
| |
| if (bestMatchLength >= 1000) break; |
| } |
|
|
| if (bestMatch) { |
| const { content, matchStart, matchEnd } = bestMatch; |
| matchedText = bestMatch.matchedText; |
|
|
| |
| const maxContextLength = 160; |
| const matchLength = matchEnd - matchStart + 1; |
|
|
| |
| const availableForContext = |
| Math.min(maxContextLength, content.length) - matchLength; |
| const contextPerSide = Math.floor(availableForContext / 2); |
|
|
| |
| let snippetStart = Math.max(0, matchStart - contextPerSide); |
| let snippetEnd = Math.min( |
| content.length, |
| matchStart + matchLength + contextPerSide |
| ); |
|
|
| |
| if (snippetEnd - snippetStart > maxContextLength) { |
| if (matchStart - contextPerSide < 0) { |
| |
| snippetEnd = Math.min(content.length, snippetStart + maxContextLength); |
| } else { |
| |
| snippetEnd = Math.min(content.length, snippetStart + maxContextLength); |
| } |
| } |
|
|
| |
| const originalStart = snippetStart; |
| const originalEnd = snippetEnd; |
|
|
| while ( |
| snippetStart > 0 && |
| content[snippetStart] !== " " && |
| content[snippetStart] !== "\n" && |
| originalStart - snippetStart < 15 |
| ) { |
| snippetStart--; |
| } |
| while ( |
| snippetEnd < content.length && |
| content[snippetEnd] !== " " && |
| content[snippetEnd] !== "\n" && |
| snippetEnd - originalEnd < 15 |
| ) { |
| snippetEnd++; |
| } |
|
|
| |
| let extractedContent = content.substring(snippetStart, snippetEnd).trim(); |
| |
| if (snippetStart > 0) { |
| extractedContent = "..." + extractedContent; |
| } |
| if (snippetEnd < content.length) { |
| extractedContent = extractedContent + "..."; |
| } |
|
|
| matchedContent = extractedContent; |
| } else { |
| |
| const firstMessage = conv.messages[0]; |
| if (firstMessage?.content) { |
| const content = firstMessage.content; |
| matchedContent = |
| content.length > 200 ? content.substring(0, 200) + "..." : content; |
| matchedText = searchQuery; |
| } |
| } |
|
|
| return { |
| _id: conv._id, |
| id: conv._id, |
| title: conv.title, |
| content: matchedContent, |
| matchedText, |
| updatedAt: conv.updatedAt, |
| model: conv.model, |
| assistantId: conv.assistantId, |
| modelTools: models.find((m) => m.id == conv.model)?.tools ?? false, |
| }; |
| }) |
| ); |
|
|
| return convs; |
| }, |
| { |
| query: t.Object({ |
| q: t.String(), |
| p: t.Optional(t.Number()), |
| }), |
| } |
| ) |
| .group( |
| "/:id", |
| { |
| params: t.Object({ |
| id: t.String(), |
| }), |
| }, |
| (app) => { |
| return app |
| .derive(async ({ locals, params }) => { |
| let conversation; |
| let shared = false; |
|
|
| |
| if (params.id.length === 7) { |
| |
| conversation = await collections.sharedConversations.findOne({ |
| _id: params.id, |
| }); |
| shared = true; |
|
|
| if (!conversation) { |
| throw new Error("Conversation not found"); |
| } |
| } else { |
| |
| try { |
| new ObjectId(params.id); |
| } catch { |
| throw new Error("Invalid conversation ID format"); |
| } |
| conversation = await collections.conversations.findOne({ |
| _id: new ObjectId(params.id), |
| ...authCondition(locals), |
| }); |
|
|
| if (!conversation) { |
| const conversationExists = |
| (await collections.conversations.countDocuments({ |
| _id: new ObjectId(params.id), |
| })) !== 0; |
|
|
| if (conversationExists) { |
| throw new Error( |
| "You don't have access to this conversation. If someone gave you this link, ask them to use the 'share' feature instead." |
| ); |
| } |
|
|
| throw new Error("Conversation not found."); |
| } |
| } |
|
|
| const convertedConv = { |
| ...conversation, |
| ...convertLegacyConversation(conversation), |
| shared, |
| }; |
|
|
| return { conversation: convertedConv }; |
| }) |
| .get("", async ({ conversation }) => { |
| return { |
| messages: conversation.messages, |
| title: conversation.title, |
| model: conversation.model, |
| preprompt: conversation.preprompt, |
| rootMessageId: conversation.rootMessageId, |
| assistant: conversation.assistantId |
| ? ((await collections.assistants.findOne({ |
| _id: new ObjectId(conversation.assistantId), |
| })) ?? undefined) |
| : undefined, |
| id: conversation._id.toString(), |
| updatedAt: conversation.updatedAt, |
| modelId: conversation.model, |
| assistantId: conversation.assistantId, |
| modelTools: models.find((m) => m.id == conversation.model)?.tools ?? false, |
| shared: conversation.shared, |
| }; |
| }) |
| .post("", () => { |
| |
| throw new Error("Not implemented"); |
| }) |
| .delete("", async ({ locals, params }) => { |
| const res = await collections.conversations.deleteOne({ |
| _id: new ObjectId(params.id), |
| ...authCondition(locals), |
| }); |
|
|
| if (res.deletedCount === 0) { |
| throw new Error("Conversation not found"); |
| } |
|
|
| return { success: true }; |
| }) |
| .get("/output/:sha256", () => { |
| |
| throw new Error("Not implemented"); |
| }) |
| .post("/share", () => { |
| |
| throw new Error("Not implemented"); |
| }) |
| .post("/stop-generating", () => { |
| |
| throw new Error("Not implemented"); |
| }) |
| .patch( |
| "", |
| async ({ locals, params, body }) => { |
| if (body.model) { |
| if (!validModelIdSchema.safeParse(body.model).success) { |
| throw new Error("Invalid model ID"); |
| } |
| } |
|
|
| |
| const updateValues = { |
| ...(body.title !== undefined && { title: body.title }), |
| ...(body.model !== undefined && { model: body.model }), |
| }; |
|
|
| const res = await collections.conversations.updateOne( |
| { |
| _id: new ObjectId(params.id), |
| ...authCondition(locals), |
| }, |
| { |
| $set: updateValues, |
| } |
| ); |
|
|
| if (res.modifiedCount === 0) { |
| throw new Error("Conversation not found"); |
| } |
|
|
| return { success: true }; |
| }, |
| { |
| body: t.Object({ |
| title: t.Optional( |
| t.String({ |
| minLength: 1, |
| maxLength: 100, |
| }) |
| ), |
| model: t.Optional(t.String()), |
| }), |
| } |
| ) |
| .delete( |
| "/message/:messageId", |
| async ({ locals, params, conversation }) => { |
| if (!conversation.messages.map((m) => m.id).includes(params.messageId)) { |
| throw new Error("Message not found"); |
| } |
|
|
| const filteredMessages = conversation.messages |
| .filter( |
| (message) => |
| |
| !(message.id === params.messageId) && |
| message.ancestors && |
| !message.ancestors.includes(params.messageId) |
| ) |
| .map((message) => { |
| |
| if (message.children && message.children.includes(params.messageId)) { |
| message.children = message.children.filter( |
| (child) => child !== params.messageId |
| ); |
| } |
| return message; |
| }); |
|
|
| const res = await collections.conversations.updateOne( |
| { _id: new ObjectId(conversation._id), ...authCondition(locals) }, |
| { $set: { messages: filteredMessages } } |
| ); |
|
|
| if (res.modifiedCount === 0) { |
| throw new Error("Deleting message failed"); |
| } |
|
|
| return { success: true }; |
| }, |
| { |
| params: t.Object({ |
| id: t.String(), |
| messageId: t.String(), |
| }), |
| } |
| ); |
| } |
| ); |
| }); |
|
|