Store oauth token in DB + use it when doing API calls (#1885)
Browse files* store oauth token in DB
* store user token/oauth token in session
* use token when making API call if USE_USER_TOKEN is true in env
* lint
- .env +3 -3
- src/app.d.ts +1 -0
- src/hooks.server.ts +1 -0
- src/lib/server/apiToken.ts +11 -0
- src/lib/server/auth.ts +6 -0
- src/lib/server/endpoints/endpoints.ts +1 -0
- src/lib/server/generateFromDefaultEndpoint.ts +3 -1
- src/lib/server/models.ts +1 -1
- src/lib/server/router/arch.ts +4 -2
- src/lib/server/router/endpoint.ts +3 -2
- src/lib/server/textGeneration/generate.ts +4 -1
- src/lib/server/textGeneration/index.ts +1 -1
- src/lib/server/textGeneration/reasoning.ts +3 -1
- src/lib/server/textGeneration/title.ts +9 -3
- src/lib/server/textGeneration/types.ts +1 -0
- src/lib/types/Session.ts +8 -0
- src/routes/conversation/[id]/+server.ts +1 -0
- src/routes/login/callback/+server.ts +2 -1
- src/routes/login/callback/updateUser.spec.ts +12 -4
- src/routes/login/callback/updateUser.ts +21 -3
.env
CHANGED
|
@@ -7,9 +7,9 @@
|
|
| 7 |
OPENAI_BASE_URL=https://router.huggingface.co/v1
|
| 8 |
|
| 9 |
# Canonical auth token for any OpenAI-compatible provider
|
| 10 |
-
OPENAI_API_KEY=#your provider API key (works for HF router, OpenAI, LM Studio, etc.)
|
| 11 |
-
#
|
| 12 |
-
|
| 13 |
|
| 14 |
### MongoDB ###
|
| 15 |
MONGODB_URL=#your mongodb URL here, use chat-ui-db image if you don't want to set this
|
|
|
|
| 7 |
OPENAI_BASE_URL=https://router.huggingface.co/v1
|
| 8 |
|
| 9 |
# Canonical auth token for any OpenAI-compatible provider
|
| 10 |
+
OPENAI_API_KEY=#your provider API key (works for HF router, OpenAI, LM Studio, etc.).
|
| 11 |
+
# When set to true, user token will be used for inference calls
|
| 12 |
+
USE_USER_TOKEN=false
|
| 13 |
|
| 14 |
### MongoDB ###
|
| 15 |
MONGODB_URL=#your mongodb URL here, use chat-ui-db image if you don't want to set this
|
src/app.d.ts
CHANGED
|
@@ -12,6 +12,7 @@ declare global {
|
|
| 12 |
sessionId: string;
|
| 13 |
user?: User & { logoutDisabled?: boolean };
|
| 14 |
isAdmin: boolean;
|
|
|
|
| 15 |
}
|
| 16 |
|
| 17 |
interface Error {
|
|
|
|
| 12 |
sessionId: string;
|
| 13 |
user?: User & { logoutDisabled?: boolean };
|
| 14 |
isAdmin: boolean;
|
| 15 |
+
token?: string;
|
| 16 |
}
|
| 17 |
|
| 18 |
interface Error {
|
src/hooks.server.ts
CHANGED
|
@@ -128,6 +128,7 @@ export const handle: Handle = async ({ event, resolve }) => {
|
|
| 128 |
|
| 129 |
event.locals.user = auth.user || undefined;
|
| 130 |
event.locals.sessionId = auth.sessionId;
|
|
|
|
| 131 |
|
| 132 |
event.locals.isAdmin =
|
| 133 |
event.locals.user?.isAdmin || adminTokenManager.isAdmin(event.locals.sessionId);
|
|
|
|
| 128 |
|
| 129 |
event.locals.user = auth.user || undefined;
|
| 130 |
event.locals.sessionId = auth.sessionId;
|
| 131 |
+
event.locals.token = auth.token;
|
| 132 |
|
| 133 |
event.locals.isAdmin =
|
| 134 |
event.locals.user?.isAdmin || adminTokenManager.isAdmin(event.locals.sessionId);
|
src/lib/server/apiToken.ts
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { config } from "$lib/server/config";
|
| 2 |
+
|
| 3 |
+
export function getApiToken(locals: App.Locals | undefined) {
|
| 4 |
+
if (config.USE_USER_TOKEN === "true") {
|
| 5 |
+
if (!locals?.token) {
|
| 6 |
+
throw new Error("User token not found");
|
| 7 |
+
}
|
| 8 |
+
return locals.token;
|
| 9 |
+
}
|
| 10 |
+
return config.OPENAI_API_KEY || config.HF_TOKEN;
|
| 11 |
+
}
|
src/lib/server/auth.ts
CHANGED
|
@@ -18,6 +18,7 @@ import { ObjectId } from "mongodb";
|
|
| 18 |
import type { Cookie } from "elysia";
|
| 19 |
import { adminTokenManager } from "./adminToken";
|
| 20 |
import type { User } from "$lib/types/User";
|
|
|
|
| 21 |
|
| 22 |
export interface OIDCSettings {
|
| 23 |
redirectURI: string;
|
|
@@ -79,6 +80,7 @@ export async function findUser(
|
|
| 79 |
): Promise<{
|
| 80 |
user: User | null;
|
| 81 |
invalidateSession: boolean;
|
|
|
|
| 82 |
}> {
|
| 83 |
const session = await collections.sessions.findOne({ sessionId });
|
| 84 |
|
|
@@ -93,6 +95,7 @@ export async function findUser(
|
|
| 93 |
return {
|
| 94 |
user: await collections.users.findOne({ _id: session.userId }),
|
| 95 |
invalidateSession: false,
|
|
|
|
| 96 |
};
|
| 97 |
}
|
| 98 |
export const authCondition = (locals: App.Locals) => {
|
|
@@ -283,6 +286,7 @@ export async function authenticateRequest(
|
|
| 283 |
|
| 284 |
return {
|
| 285 |
user: result.user ?? undefined,
|
|
|
|
| 286 |
sessionId,
|
| 287 |
secretSessionId,
|
| 288 |
isAdmin: result.user?.isAdmin || adminTokenManager.isAdmin(sessionId),
|
|
@@ -308,6 +312,7 @@ export async function authenticateRequest(
|
|
| 308 |
return {
|
| 309 |
user,
|
| 310 |
sessionId,
|
|
|
|
| 311 |
secretSessionId,
|
| 312 |
isAdmin: user.isAdmin || adminTokenManager.isAdmin(sessionId),
|
| 313 |
};
|
|
@@ -338,6 +343,7 @@ export async function authenticateRequest(
|
|
| 338 |
user,
|
| 339 |
sessionId,
|
| 340 |
secretSessionId,
|
|
|
|
| 341 |
isAdmin: user.isAdmin || adminTokenManager.isAdmin(sessionId),
|
| 342 |
};
|
| 343 |
}
|
|
|
|
| 18 |
import type { Cookie } from "elysia";
|
| 19 |
import { adminTokenManager } from "./adminToken";
|
| 20 |
import type { User } from "$lib/types/User";
|
| 21 |
+
import type { Session } from "$lib/types/Session";
|
| 22 |
|
| 23 |
export interface OIDCSettings {
|
| 24 |
redirectURI: string;
|
|
|
|
| 80 |
): Promise<{
|
| 81 |
user: User | null;
|
| 82 |
invalidateSession: boolean;
|
| 83 |
+
oauth?: Session["oauth"];
|
| 84 |
}> {
|
| 85 |
const session = await collections.sessions.findOne({ sessionId });
|
| 86 |
|
|
|
|
| 95 |
return {
|
| 96 |
user: await collections.users.findOne({ _id: session.userId }),
|
| 97 |
invalidateSession: false,
|
| 98 |
+
oauth: session.oauth,
|
| 99 |
};
|
| 100 |
}
|
| 101 |
export const authCondition = (locals: App.Locals) => {
|
|
|
|
| 286 |
|
| 287 |
return {
|
| 288 |
user: result.user ?? undefined,
|
| 289 |
+
token: result.oauth?.token?.value,
|
| 290 |
sessionId,
|
| 291 |
secretSessionId,
|
| 292 |
isAdmin: result.user?.isAdmin || adminTokenManager.isAdmin(sessionId),
|
|
|
|
| 312 |
return {
|
| 313 |
user,
|
| 314 |
sessionId,
|
| 315 |
+
token,
|
| 316 |
secretSessionId,
|
| 317 |
isAdmin: user.isAdmin || adminTokenManager.isAdmin(sessionId),
|
| 318 |
};
|
|
|
|
| 343 |
user,
|
| 344 |
sessionId,
|
| 345 |
secretSessionId,
|
| 346 |
+
token,
|
| 347 |
isAdmin: user.isAdmin || adminTokenManager.isAdmin(sessionId),
|
| 348 |
};
|
| 349 |
}
|
src/lib/server/endpoints/endpoints.ts
CHANGED
|
@@ -16,6 +16,7 @@ export interface EndpointParameters {
|
|
| 16 |
generateSettings?: Partial<Model["parameters"]>;
|
| 17 |
isMultimodal?: boolean;
|
| 18 |
conversationId?: ObjectId;
|
|
|
|
| 19 |
}
|
| 20 |
|
| 21 |
export type TextGenerationStreamOutputSimplified = TextGenerationStreamOutput & {
|
|
|
|
| 16 |
generateSettings?: Partial<Model["parameters"]>;
|
| 17 |
isMultimodal?: boolean;
|
| 18 |
conversationId?: ObjectId;
|
| 19 |
+
locals: App.Locals | undefined;
|
| 20 |
}
|
| 21 |
|
| 22 |
export type TextGenerationStreamOutputSimplified = TextGenerationStreamOutput & {
|
src/lib/server/generateFromDefaultEndpoint.ts
CHANGED
|
@@ -7,18 +7,20 @@ export async function* generateFromDefaultEndpoint({
|
|
| 7 |
preprompt,
|
| 8 |
generateSettings,
|
| 9 |
modelId,
|
|
|
|
| 10 |
}: {
|
| 11 |
messages: EndpointMessage[];
|
| 12 |
preprompt?: string;
|
| 13 |
generateSettings?: Record<string, unknown>;
|
| 14 |
/** Optional: use this model instead of the default task model */
|
| 15 |
modelId?: string;
|
|
|
|
| 16 |
}): AsyncGenerator<MessageUpdate, string, undefined> {
|
| 17 |
try {
|
| 18 |
// Choose endpoint based on provided modelId, else fall back to taskModel
|
| 19 |
const model = modelId ? (models.find((m) => m.id === modelId) ?? taskModel) : taskModel;
|
| 20 |
const endpoint = await model.getEndpoint();
|
| 21 |
-
const tokenStream = await endpoint({ messages, preprompt, generateSettings });
|
| 22 |
|
| 23 |
for await (const output of tokenStream) {
|
| 24 |
// if not generated_text is here it means the generation is not done
|
|
|
|
| 7 |
preprompt,
|
| 8 |
generateSettings,
|
| 9 |
modelId,
|
| 10 |
+
locals,
|
| 11 |
}: {
|
| 12 |
messages: EndpointMessage[];
|
| 13 |
preprompt?: string;
|
| 14 |
generateSettings?: Record<string, unknown>;
|
| 15 |
/** Optional: use this model instead of the default task model */
|
| 16 |
modelId?: string;
|
| 17 |
+
locals: App.Locals | undefined;
|
| 18 |
}): AsyncGenerator<MessageUpdate, string, undefined> {
|
| 19 |
try {
|
| 20 |
// Choose endpoint based on provided modelId, else fall back to taskModel
|
| 21 |
const model = modelId ? (models.find((m) => m.id === modelId) ?? taskModel) : taskModel;
|
| 22 |
const endpoint = await model.getEndpoint();
|
| 23 |
+
const tokenStream = await endpoint({ messages, preprompt, generateSettings, locals });
|
| 24 |
|
| 25 |
for await (const output of tokenStream) {
|
| 26 |
// if not generated_text is here it means the generation is not done
|
src/lib/server/models.ts
CHANGED
|
@@ -109,7 +109,7 @@ if (openaiBaseUrl) {
|
|
| 109 |
logger.info({ baseURL }, "[models] Using OpenAI-compatible base URL");
|
| 110 |
|
| 111 |
// Canonical auth token is OPENAI_API_KEY; keep HF_TOKEN as legacy alias
|
| 112 |
-
const authToken = config.OPENAI_API_KEY || config.HF_TOKEN
|
| 113 |
|
| 114 |
// Try unauthenticated request first (many model lists are public, e.g. HF router)
|
| 115 |
let response = await fetch(`${baseURL}/models`);
|
|
|
|
| 109 |
logger.info({ baseURL }, "[models] Using OpenAI-compatible base URL");
|
| 110 |
|
| 111 |
// Canonical auth token is OPENAI_API_KEY; keep HF_TOKEN as legacy alias
|
| 112 |
+
const authToken = config.OPENAI_API_KEY || config.HF_TOKEN;
|
| 113 |
|
| 114 |
// Try unauthenticated request first (many model lists are public, e.g. HF router)
|
| 115 |
let response = await fetch(`${baseURL}/models`);
|
src/lib/server/router/arch.ts
CHANGED
|
@@ -3,6 +3,7 @@ import { logger } from "$lib/server/logger";
|
|
| 3 |
import type { EndpointMessage } from "../endpoints/endpoints";
|
| 4 |
import type { Route, RouteConfig } from "./types";
|
| 5 |
import { getRoutes } from "./policy";
|
|
|
|
| 6 |
|
| 7 |
const DEFAULT_LAST_TURNS = 16;
|
| 8 |
const PROMPT_TEMPLATE = `
|
|
@@ -68,7 +69,8 @@ function parseRouteName(text: string): string | undefined {
|
|
| 68 |
|
| 69 |
export async function archSelectRoute(
|
| 70 |
messages: EndpointMessage[],
|
| 71 |
-
traceId
|
|
|
|
| 72 |
): Promise<{ routeName: string }> {
|
| 73 |
const routes = await getRoutes();
|
| 74 |
const prompt = toRouterPrompt(messages, routes);
|
|
@@ -82,7 +84,7 @@ export async function archSelectRoute(
|
|
| 82 |
}
|
| 83 |
|
| 84 |
const headers: HeadersInit = {
|
| 85 |
-
Authorization: `Bearer ${
|
| 86 |
"Content-Type": "application/json",
|
| 87 |
};
|
| 88 |
const body = {
|
|
|
|
| 3 |
import type { EndpointMessage } from "../endpoints/endpoints";
|
| 4 |
import type { Route, RouteConfig } from "./types";
|
| 5 |
import { getRoutes } from "./policy";
|
| 6 |
+
import { getApiToken } from "$lib/server/apiToken";
|
| 7 |
|
| 8 |
const DEFAULT_LAST_TURNS = 16;
|
| 9 |
const PROMPT_TEMPLATE = `
|
|
|
|
| 69 |
|
| 70 |
export async function archSelectRoute(
|
| 71 |
messages: EndpointMessage[],
|
| 72 |
+
traceId: string | undefined,
|
| 73 |
+
locals: App.Locals | undefined
|
| 74 |
): Promise<{ routeName: string }> {
|
| 75 |
const routes = await getRoutes();
|
| 76 |
const prompt = toRouterPrompt(messages, routes);
|
|
|
|
| 84 |
}
|
| 85 |
|
| 86 |
const headers: HeadersInit = {
|
| 87 |
+
Authorization: `Bearer ${getApiToken(locals)}`,
|
| 88 |
"Content-Type": "application/json",
|
| 89 |
};
|
| 90 |
const body = {
|
src/lib/server/router/endpoint.ts
CHANGED
|
@@ -10,6 +10,7 @@ import { config } from "$lib/server/config";
|
|
| 10 |
import { logger } from "$lib/server/logger";
|
| 11 |
import { archSelectRoute } from "./arch";
|
| 12 |
import { getRoutes, resolveRouteModels } from "./policy";
|
|
|
|
| 13 |
|
| 14 |
const REASONING_BLOCK_REGEX = /<think>[\s\S]*?(?:<\/think>|$)/g;
|
| 15 |
|
|
@@ -72,7 +73,7 @@ export async function makeRouterEndpoint(routerModel: ProcessedModel): Promise<E
|
|
| 72 |
return endpoints.openai({
|
| 73 |
type: "openai",
|
| 74 |
baseURL: (config.OPENAI_BASE_URL || "https://router.huggingface.co/v1").replace(/\/$/, ""),
|
| 75 |
-
apiKey:
|
| 76 |
model: modelForCall,
|
| 77 |
// Ensure streaming path is used
|
| 78 |
streamingSupported: true,
|
|
@@ -133,7 +134,7 @@ export async function makeRouterEndpoint(routerModel: ProcessedModel): Promise<E
|
|
| 133 |
}
|
| 134 |
}
|
| 135 |
|
| 136 |
-
const { routeName } = await archSelectRoute(sanitizedMessages);
|
| 137 |
|
| 138 |
const fallbackModel = config.LLM_ROUTER_FALLBACK_MODEL || routerModel.id;
|
| 139 |
const { candidates } = resolveRouteModels(routeName, routes, fallbackModel);
|
|
|
|
| 10 |
import { logger } from "$lib/server/logger";
|
| 11 |
import { archSelectRoute } from "./arch";
|
| 12 |
import { getRoutes, resolveRouteModels } from "./policy";
|
| 13 |
+
import { getApiToken } from "$lib/server/apiToken";
|
| 14 |
|
| 15 |
const REASONING_BLOCK_REGEX = /<think>[\s\S]*?(?:<\/think>|$)/g;
|
| 16 |
|
|
|
|
| 73 |
return endpoints.openai({
|
| 74 |
type: "openai",
|
| 75 |
baseURL: (config.OPENAI_BASE_URL || "https://router.huggingface.co/v1").replace(/\/$/, ""),
|
| 76 |
+
apiKey: getApiToken(params.locals),
|
| 77 |
model: modelForCall,
|
| 78 |
// Ensure streaming path is used
|
| 79 |
streamingSupported: true,
|
|
|
|
| 134 |
}
|
| 135 |
}
|
| 136 |
|
| 137 |
+
const { routeName } = await archSelectRoute(sanitizedMessages, undefined, params.locals);
|
| 138 |
|
| 139 |
const fallbackModel = config.LLM_ROUTER_FALLBACK_MODEL || routerModel.id;
|
| 140 |
const { candidates } = resolveRouteModels(routeName, routes, fallbackModel);
|
src/lib/server/textGeneration/generate.ts
CHANGED
|
@@ -23,6 +23,7 @@ export async function* generate(
|
|
| 23 |
isContinue,
|
| 24 |
promptedAt,
|
| 25 |
forceMultimodal,
|
|
|
|
| 26 |
}: GenerateContext,
|
| 27 |
preprompt?: string
|
| 28 |
): AsyncIterable<MessageUpdate> {
|
|
@@ -57,6 +58,7 @@ export async function* generate(
|
|
| 57 |
// Allow user-level override to force multimodal
|
| 58 |
isMultimodal: (forceMultimodal ?? false) || model.multimodal,
|
| 59 |
conversationId: conv._id,
|
|
|
|
| 60 |
})) {
|
| 61 |
// Check if this output contains router metadata
|
| 62 |
if (
|
|
@@ -114,6 +116,7 @@ Do not use prefixes such as Response: or Answer: when answering to the user.`,
|
|
| 114 |
max_tokens: 1024,
|
| 115 |
},
|
| 116 |
modelId: model.id,
|
|
|
|
| 117 |
});
|
| 118 |
finalAnswer = summary;
|
| 119 |
yield {
|
|
@@ -224,7 +227,7 @@ Do not use prefixes such as Response: or Answer: when answering to the user.`,
|
|
| 224 |
) {
|
| 225 |
lastReasoningUpdate = new Date();
|
| 226 |
try {
|
| 227 |
-
generateSummaryOfReasoning(reasoningBuffer, model.id).then((summary) => {
|
| 228 |
status = summary;
|
| 229 |
});
|
| 230 |
} catch (e) {
|
|
|
|
| 23 |
isContinue,
|
| 24 |
promptedAt,
|
| 25 |
forceMultimodal,
|
| 26 |
+
locals,
|
| 27 |
}: GenerateContext,
|
| 28 |
preprompt?: string
|
| 29 |
): AsyncIterable<MessageUpdate> {
|
|
|
|
| 58 |
// Allow user-level override to force multimodal
|
| 59 |
isMultimodal: (forceMultimodal ?? false) || model.multimodal,
|
| 60 |
conversationId: conv._id,
|
| 61 |
+
locals,
|
| 62 |
})) {
|
| 63 |
// Check if this output contains router metadata
|
| 64 |
if (
|
|
|
|
| 116 |
max_tokens: 1024,
|
| 117 |
},
|
| 118 |
modelId: model.id,
|
| 119 |
+
locals,
|
| 120 |
});
|
| 121 |
finalAnswer = summary;
|
| 122 |
yield {
|
|
|
|
| 227 |
) {
|
| 228 |
lastReasoningUpdate = new Date();
|
| 229 |
try {
|
| 230 |
+
generateSummaryOfReasoning(reasoningBuffer, model.id, locals).then((summary) => {
|
| 231 |
status = summary;
|
| 232 |
});
|
| 233 |
} catch (e) {
|
src/lib/server/textGeneration/index.ts
CHANGED
|
@@ -23,7 +23,7 @@ async function* keepAlive(done: AbortSignal): AsyncGenerator<MessageUpdate, unde
|
|
| 23 |
export async function* textGeneration(ctx: TextGenerationContext) {
|
| 24 |
const done = new AbortController();
|
| 25 |
|
| 26 |
-
const titleGen = generateTitleForConversation(ctx.conv);
|
| 27 |
const textGen = textGenerationWithoutTitle(ctx, done);
|
| 28 |
const keepAliveGen = keepAlive(done.signal);
|
| 29 |
|
|
|
|
| 23 |
export async function* textGeneration(ctx: TextGenerationContext) {
|
| 24 |
const done = new AbortController();
|
| 25 |
|
| 26 |
+
const titleGen = generateTitleForConversation(ctx.conv, ctx.locals);
|
| 27 |
const textGen = textGenerationWithoutTitle(ctx, done);
|
| 28 |
const keepAliveGen = keepAlive(done.signal);
|
| 29 |
|
src/lib/server/textGeneration/reasoning.ts
CHANGED
|
@@ -3,7 +3,8 @@ import { getReturnFromGenerator } from "$lib/utils/getReturnFromGenerator";
|
|
| 3 |
|
| 4 |
export async function generateSummaryOfReasoning(
|
| 5 |
buffer: string,
|
| 6 |
-
modelId
|
|
|
|
| 7 |
): Promise<string> {
|
| 8 |
let summary: string | undefined;
|
| 9 |
|
|
@@ -25,6 +26,7 @@ export async function generateSummaryOfReasoning(
|
|
| 25 |
max_tokens: 50,
|
| 26 |
},
|
| 27 |
modelId,
|
|
|
|
| 28 |
})
|
| 29 |
);
|
| 30 |
}
|
|
|
|
| 3 |
|
| 4 |
export async function generateSummaryOfReasoning(
|
| 5 |
buffer: string,
|
| 6 |
+
modelId: string | undefined,
|
| 7 |
+
locals: App.Locals | undefined
|
| 8 |
): Promise<string> {
|
| 9 |
let summary: string | undefined;
|
| 10 |
|
|
|
|
| 26 |
max_tokens: 50,
|
| 27 |
},
|
| 28 |
modelId,
|
| 29 |
+
locals,
|
| 30 |
})
|
| 31 |
);
|
| 32 |
}
|
src/lib/server/textGeneration/title.ts
CHANGED
|
@@ -6,7 +6,8 @@ import type { Conversation } from "$lib/types/Conversation";
|
|
| 6 |
import { getReturnFromGenerator } from "$lib/utils/getReturnFromGenerator";
|
| 7 |
|
| 8 |
export async function* generateTitleForConversation(
|
| 9 |
-
conv: Conversation
|
|
|
|
| 10 |
): AsyncGenerator<MessageUpdate, undefined, undefined> {
|
| 11 |
try {
|
| 12 |
const userMessage = conv.messages.find((m) => m.from === "user");
|
|
@@ -15,7 +16,7 @@ export async function* generateTitleForConversation(
|
|
| 15 |
|
| 16 |
const prompt = userMessage.content;
|
| 17 |
const modelForTitle = config.TASK_MODEL?.trim() ? config.TASK_MODEL : conv.model;
|
| 18 |
-
const title = (await generateTitle(prompt, modelForTitle)) ?? "New Chat";
|
| 19 |
|
| 20 |
yield {
|
| 21 |
type: MessageUpdateType.Title,
|
|
@@ -26,7 +27,11 @@ export async function* generateTitleForConversation(
|
|
| 26 |
}
|
| 27 |
}
|
| 28 |
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
if (config.LLM_SUMMARIZATION !== "true") {
|
| 31 |
// When summarization is disabled, use the first five words without adding emojis
|
| 32 |
return prompt.split(/\s+/g).slice(0, 5).join(" ");
|
|
@@ -48,6 +53,7 @@ Return ONLY the title text.`,
|
|
| 48 |
max_tokens: 30,
|
| 49 |
},
|
| 50 |
modelId,
|
|
|
|
| 51 |
})
|
| 52 |
)
|
| 53 |
.then((summary) => {
|
|
|
|
| 6 |
import { getReturnFromGenerator } from "$lib/utils/getReturnFromGenerator";
|
| 7 |
|
| 8 |
export async function* generateTitleForConversation(
|
| 9 |
+
conv: Conversation,
|
| 10 |
+
locals: App.Locals | undefined
|
| 11 |
): AsyncGenerator<MessageUpdate, undefined, undefined> {
|
| 12 |
try {
|
| 13 |
const userMessage = conv.messages.find((m) => m.from === "user");
|
|
|
|
| 16 |
|
| 17 |
const prompt = userMessage.content;
|
| 18 |
const modelForTitle = config.TASK_MODEL?.trim() ? config.TASK_MODEL : conv.model;
|
| 19 |
+
const title = (await generateTitle(prompt, modelForTitle, locals)) ?? "New Chat";
|
| 20 |
|
| 21 |
yield {
|
| 22 |
type: MessageUpdateType.Title,
|
|
|
|
| 27 |
}
|
| 28 |
}
|
| 29 |
|
| 30 |
+
async function generateTitle(
|
| 31 |
+
prompt: string,
|
| 32 |
+
modelId: string | undefined,
|
| 33 |
+
locals: App.Locals | undefined
|
| 34 |
+
) {
|
| 35 |
if (config.LLM_SUMMARIZATION !== "true") {
|
| 36 |
// When summarization is disabled, use the first five words without adding emojis
|
| 37 |
return prompt.split(/\s+/g).slice(0, 5).join(" ");
|
|
|
|
| 53 |
max_tokens: 30,
|
| 54 |
},
|
| 55 |
modelId,
|
| 56 |
+
locals,
|
| 57 |
})
|
| 58 |
)
|
| 59 |
.then((summary) => {
|
src/lib/server/textGeneration/types.ts
CHANGED
|
@@ -16,4 +16,5 @@ export interface TextGenerationContext {
|
|
| 16 |
username?: string;
|
| 17 |
/** Force-enable multimodal handling for endpoints that support it */
|
| 18 |
forceMultimodal?: boolean;
|
|
|
|
| 19 |
}
|
|
|
|
| 16 |
username?: string;
|
| 17 |
/** Force-enable multimodal handling for endpoints that support it */
|
| 18 |
forceMultimodal?: boolean;
|
| 19 |
+
locals: App.Locals | undefined;
|
| 20 |
}
|
src/lib/types/Session.ts
CHANGED
|
@@ -11,4 +11,12 @@ export interface Session extends Timestamps {
|
|
| 11 |
expiresAt: Date;
|
| 12 |
admin?: boolean;
|
| 13 |
coupledCookieHash?: string;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
}
|
|
|
|
| 11 |
expiresAt: Date;
|
| 12 |
admin?: boolean;
|
| 13 |
coupledCookieHash?: string;
|
| 14 |
+
|
| 15 |
+
oauth?: {
|
| 16 |
+
token: {
|
| 17 |
+
value: string;
|
| 18 |
+
expiresAt: Date;
|
| 19 |
+
};
|
| 20 |
+
refreshToken?: string;
|
| 21 |
+
};
|
| 22 |
}
|
src/routes/conversation/[id]/+server.ts
CHANGED
|
@@ -459,6 +459,7 @@ export async function POST({ request, locals, params, getClientAddress }) {
|
|
| 459 |
model.id
|
| 460 |
]
|
| 461 |
),
|
|
|
|
| 462 |
};
|
| 463 |
// run the text generation and send updates to the client
|
| 464 |
for await (const event of textGeneration(ctx)) await update(event);
|
|
|
|
| 459 |
model.id
|
| 460 |
]
|
| 461 |
),
|
| 462 |
+
locals,
|
| 463 |
};
|
| 464 |
// run the text generation and send updates to the client
|
| 465 |
for await (const event of textGeneration(ctx)) await update(event);
|
src/routes/login/callback/+server.ts
CHANGED
|
@@ -52,7 +52,7 @@ export async function GET({ url, locals, cookies, request, getClientAddress }) {
|
|
| 52 |
throw error(403, "Invalid or expired CSRF token");
|
| 53 |
}
|
| 54 |
|
| 55 |
-
const { userData } = await getOIDCUserData(
|
| 56 |
{ redirectURI: validatedToken.redirectUrl },
|
| 57 |
code,
|
| 58 |
iss
|
|
@@ -79,6 +79,7 @@ export async function GET({ url, locals, cookies, request, getClientAddress }) {
|
|
| 79 |
|
| 80 |
await updateUser({
|
| 81 |
userData,
|
|
|
|
| 82 |
locals,
|
| 83 |
cookies,
|
| 84 |
userAgent: request.headers.get("user-agent") ?? undefined,
|
|
|
|
| 52 |
throw error(403, "Invalid or expired CSRF token");
|
| 53 |
}
|
| 54 |
|
| 55 |
+
const { userData, token } = await getOIDCUserData(
|
| 56 |
{ redirectURI: validatedToken.redirectUrl },
|
| 57 |
code,
|
| 58 |
iss
|
|
|
|
| 79 |
|
| 80 |
await updateUser({
|
| 81 |
userData,
|
| 82 |
+
token,
|
| 83 |
locals,
|
| 84 |
cookies,
|
| 85 |
userAgent: request.headers.get("user-agent") ?? undefined,
|
src/routes/login/callback/updateUser.spec.ts
CHANGED
|
@@ -6,6 +6,7 @@ import { ObjectId } from "mongodb";
|
|
| 6 |
import { DEFAULT_SETTINGS } from "$lib/types/Settings";
|
| 7 |
import { defaultModel } from "$lib/server/models";
|
| 8 |
import { findUser } from "$lib/server/auth";
|
|
|
|
| 9 |
|
| 10 |
const userData = {
|
| 11 |
preferred_username: "new-username",
|
|
@@ -21,6 +22,13 @@ const locals = {
|
|
| 21 |
isAdmin: false,
|
| 22 |
};
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
// @ts-expect-error SvelteKit cookies dumb mock
|
| 25 |
const cookiesMock: Cookies = {
|
| 26 |
set: vi.fn(),
|
|
@@ -61,7 +69,7 @@ describe("login", () => {
|
|
| 61 |
it("should update user if existing", async () => {
|
| 62 |
await insertRandomUser();
|
| 63 |
|
| 64 |
-
await updateUser({ userData, locals, cookies: cookiesMock });
|
| 65 |
|
| 66 |
const existingUser = await collections.users.findOne({ hfUserId: userData.sub });
|
| 67 |
|
|
@@ -75,7 +83,7 @@ describe("login", () => {
|
|
| 75 |
|
| 76 |
await insertRandomConversations(2);
|
| 77 |
|
| 78 |
-
await updateUser({ userData, locals, cookies: cookiesMock });
|
| 79 |
|
| 80 |
const conversationCount = await collections.conversations.countDocuments({
|
| 81 |
userId: insertedId,
|
|
@@ -88,7 +96,7 @@ describe("login", () => {
|
|
| 88 |
});
|
| 89 |
|
| 90 |
it("should create default settings for new user", async () => {
|
| 91 |
-
await updateUser({ userData, locals, cookies: cookiesMock });
|
| 92 |
|
| 93 |
const user = (await findUser(locals.sessionId)).user;
|
| 94 |
|
|
@@ -115,7 +123,7 @@ describe("login", () => {
|
|
| 115 |
shareConversationsWithModelAuthors: false,
|
| 116 |
});
|
| 117 |
|
| 118 |
-
await updateUser({ userData, locals, cookies: cookiesMock });
|
| 119 |
|
| 120 |
const settings = await collections.settings.findOne({
|
| 121 |
_id: insertedId,
|
|
|
|
| 6 |
import { DEFAULT_SETTINGS } from "$lib/types/Settings";
|
| 7 |
import { defaultModel } from "$lib/server/models";
|
| 8 |
import { findUser } from "$lib/server/auth";
|
| 9 |
+
import type { TokenSet } from "openid-client";
|
| 10 |
|
| 11 |
const userData = {
|
| 12 |
preferred_username: "new-username",
|
|
|
|
| 22 |
isAdmin: false,
|
| 23 |
};
|
| 24 |
|
| 25 |
+
const token = {
|
| 26 |
+
access_token: "access_token",
|
| 27 |
+
refresh_token: "refresh_token",
|
| 28 |
+
expires_at: 1717334400,
|
| 29 |
+
expires_in: 3600,
|
| 30 |
+
} as TokenSet;
|
| 31 |
+
|
| 32 |
// @ts-expect-error SvelteKit cookies dumb mock
|
| 33 |
const cookiesMock: Cookies = {
|
| 34 |
set: vi.fn(),
|
|
|
|
| 69 |
it("should update user if existing", async () => {
|
| 70 |
await insertRandomUser();
|
| 71 |
|
| 72 |
+
await updateUser({ userData, locals, cookies: cookiesMock, token });
|
| 73 |
|
| 74 |
const existingUser = await collections.users.findOne({ hfUserId: userData.sub });
|
| 75 |
|
|
|
|
| 83 |
|
| 84 |
await insertRandomConversations(2);
|
| 85 |
|
| 86 |
+
await updateUser({ userData, locals, cookies: cookiesMock, token });
|
| 87 |
|
| 88 |
const conversationCount = await collections.conversations.countDocuments({
|
| 89 |
userId: insertedId,
|
|
|
|
| 96 |
});
|
| 97 |
|
| 98 |
it("should create default settings for new user", async () => {
|
| 99 |
+
await updateUser({ userData, locals, cookies: cookiesMock, token });
|
| 100 |
|
| 101 |
const user = (await findUser(locals.sessionId)).user;
|
| 102 |
|
|
|
|
| 123 |
shareConversationsWithModelAuthors: false,
|
| 124 |
});
|
| 125 |
|
| 126 |
+
await updateUser({ userData, locals, cookies: cookiesMock, token });
|
| 127 |
|
| 128 |
const settings = await collections.settings.findOne({
|
| 129 |
_id: insertedId,
|
src/routes/login/callback/updateUser.ts
CHANGED
|
@@ -3,23 +3,24 @@ import { collections } from "$lib/server/database";
|
|
| 3 |
import { ObjectId } from "mongodb";
|
| 4 |
import { DEFAULT_SETTINGS } from "$lib/types/Settings";
|
| 5 |
import { z } from "zod";
|
| 6 |
-
import type { UserinfoResponse } from "openid-client";
|
| 7 |
import { error, type Cookies } from "@sveltejs/kit";
|
| 8 |
import crypto from "crypto";
|
| 9 |
import { sha256 } from "$lib/utils/sha256";
|
| 10 |
-
import { addWeeks } from "date-fns";
|
| 11 |
import { OIDConfig } from "$lib/server/auth";
|
| 12 |
import { config } from "$lib/server/config";
|
| 13 |
import { logger } from "$lib/server/logger";
|
| 14 |
|
| 15 |
export async function updateUser(params: {
|
| 16 |
userData: UserinfoResponse;
|
|
|
|
| 17 |
locals: App.Locals;
|
| 18 |
cookies: Cookies;
|
| 19 |
userAgent?: string;
|
| 20 |
ip?: string;
|
| 21 |
}) {
|
| 22 |
-
const { userData, locals, cookies, userAgent, ip } = params;
|
| 23 |
|
| 24 |
// Microsoft Entra v1 tokens do not provide preferred_username, instead the username is provided in the upn
|
| 25 |
// claim. See https://learn.microsoft.com/en-us/entra/identity-platform/access-token-claims-reference
|
|
@@ -122,6 +123,21 @@ export async function updateUser(params: {
|
|
| 122 |
// Get cookie hash if coupling is enabled
|
| 123 |
const coupledCookieHash = await getCoupledCookieHash({ type: "svelte", value: cookies });
|
| 124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
if (existingUser) {
|
| 126 |
// update existing user if any
|
| 127 |
await collections.users.updateOne(
|
|
@@ -141,6 +157,7 @@ export async function updateUser(params: {
|
|
| 141 |
ip,
|
| 142 |
expiresAt: addWeeks(new Date(), 2),
|
| 143 |
...(coupledCookieHash ? { coupledCookieHash } : {}),
|
|
|
|
| 144 |
});
|
| 145 |
} else {
|
| 146 |
// user doesn't exist yet, create a new one
|
|
@@ -169,6 +186,7 @@ export async function updateUser(params: {
|
|
| 169 |
ip,
|
| 170 |
expiresAt: addWeeks(new Date(), 2),
|
| 171 |
...(coupledCookieHash ? { coupledCookieHash } : {}),
|
|
|
|
| 172 |
});
|
| 173 |
|
| 174 |
// move pre-existing settings to new user
|
|
|
|
| 3 |
import { ObjectId } from "mongodb";
|
| 4 |
import { DEFAULT_SETTINGS } from "$lib/types/Settings";
|
| 5 |
import { z } from "zod";
|
| 6 |
+
import type { UserinfoResponse, TokenSet } from "openid-client";
|
| 7 |
import { error, type Cookies } from "@sveltejs/kit";
|
| 8 |
import crypto from "crypto";
|
| 9 |
import { sha256 } from "$lib/utils/sha256";
|
| 10 |
+
import { addWeeks, subMinutes } from "date-fns";
|
| 11 |
import { OIDConfig } from "$lib/server/auth";
|
| 12 |
import { config } from "$lib/server/config";
|
| 13 |
import { logger } from "$lib/server/logger";
|
| 14 |
|
| 15 |
export async function updateUser(params: {
|
| 16 |
userData: UserinfoResponse;
|
| 17 |
+
token: TokenSet;
|
| 18 |
locals: App.Locals;
|
| 19 |
cookies: Cookies;
|
| 20 |
userAgent?: string;
|
| 21 |
ip?: string;
|
| 22 |
}) {
|
| 23 |
+
const { userData, token, locals, cookies, userAgent, ip } = params;
|
| 24 |
|
| 25 |
// Microsoft Entra v1 tokens do not provide preferred_username, instead the username is provided in the upn
|
| 26 |
// claim. See https://learn.microsoft.com/en-us/entra/identity-platform/access-token-claims-reference
|
|
|
|
| 123 |
// Get cookie hash if coupling is enabled
|
| 124 |
const coupledCookieHash = await getCoupledCookieHash({ type: "svelte", value: cookies });
|
| 125 |
|
| 126 |
+
// Prepare OAuth token data for session storage
|
| 127 |
+
const oauthData = token.access_token
|
| 128 |
+
? {
|
| 129 |
+
token: {
|
| 130 |
+
value: token.access_token,
|
| 131 |
+
expiresAt: token.expires_at
|
| 132 |
+
? subMinutes(new Date(token.expires_at * 1000), 1)
|
| 133 |
+
: token.expires_in
|
| 134 |
+
? subMinutes(new Date(Date.now() + token.expires_in * 1000), 1)
|
| 135 |
+
: addWeeks(new Date(), 2),
|
| 136 |
+
},
|
| 137 |
+
...(token.refresh_token ? { refreshToken: token.refresh_token } : {}),
|
| 138 |
+
}
|
| 139 |
+
: undefined;
|
| 140 |
+
|
| 141 |
if (existingUser) {
|
| 142 |
// update existing user if any
|
| 143 |
await collections.users.updateOne(
|
|
|
|
| 157 |
ip,
|
| 158 |
expiresAt: addWeeks(new Date(), 2),
|
| 159 |
...(coupledCookieHash ? { coupledCookieHash } : {}),
|
| 160 |
+
...(oauthData ? { oauth: oauthData } : {}),
|
| 161 |
});
|
| 162 |
} else {
|
| 163 |
// user doesn't exist yet, create a new one
|
|
|
|
| 186 |
ip,
|
| 187 |
expiresAt: addWeeks(new Date(), 2),
|
| 188 |
...(coupledCookieHash ? { coupledCookieHash } : {}),
|
| 189 |
+
...(oauthData ? { oauth: oauthData } : {}),
|
| 190 |
});
|
| 191 |
|
| 192 |
// move pre-existing settings to new user
|