coyotte508 HF Staff commited on
Commit
c4f6eb3
·
unverified ·
1 Parent(s): 02c599a

Store oauth token in DB + use it when doing API calls (#1885)

Browse files

* store oauth token in DB

* store user token/oauth token in session

* use token when making API call if USE_USER_TOKEN is true in env

* lint

.env CHANGED
@@ -7,9 +7,9 @@
7
  OPENAI_BASE_URL=https://router.huggingface.co/v1
8
 
9
  # Canonical auth token for any OpenAI-compatible provider
10
- OPENAI_API_KEY=#your provider API key (works for HF router, OpenAI, LM Studio, etc.)
11
- # Legacy alias (still supported): if set and OPENAI_API_KEY is empty, it will be used
12
- # HF_TOKEN=
13
 
14
  ### MongoDB ###
15
  MONGODB_URL=#your mongodb URL here, use chat-ui-db image if you don't want to set this
 
7
  OPENAI_BASE_URL=https://router.huggingface.co/v1
8
 
9
  # Canonical auth token for any OpenAI-compatible provider
10
+ OPENAI_API_KEY=#your provider API key (works for HF router, OpenAI, LM Studio, etc.).
11
+ # When set to true, user token will be used for inference calls
12
+ USE_USER_TOKEN=false
13
 
14
  ### MongoDB ###
15
  MONGODB_URL=#your mongodb URL here, use chat-ui-db image if you don't want to set this
src/app.d.ts CHANGED
@@ -12,6 +12,7 @@ declare global {
12
  sessionId: string;
13
  user?: User & { logoutDisabled?: boolean };
14
  isAdmin: boolean;
 
15
  }
16
 
17
  interface Error {
 
12
  sessionId: string;
13
  user?: User & { logoutDisabled?: boolean };
14
  isAdmin: boolean;
15
+ token?: string;
16
  }
17
 
18
  interface Error {
src/hooks.server.ts CHANGED
@@ -128,6 +128,7 @@ export const handle: Handle = async ({ event, resolve }) => {
128
 
129
  event.locals.user = auth.user || undefined;
130
  event.locals.sessionId = auth.sessionId;
 
131
 
132
  event.locals.isAdmin =
133
  event.locals.user?.isAdmin || adminTokenManager.isAdmin(event.locals.sessionId);
 
128
 
129
  event.locals.user = auth.user || undefined;
130
  event.locals.sessionId = auth.sessionId;
131
+ event.locals.token = auth.token;
132
 
133
  event.locals.isAdmin =
134
  event.locals.user?.isAdmin || adminTokenManager.isAdmin(event.locals.sessionId);
src/lib/server/apiToken.ts ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { config } from "$lib/server/config";
2
+
3
+ export function getApiToken(locals: App.Locals | undefined) {
4
+ if (config.USE_USER_TOKEN === "true") {
5
+ if (!locals?.token) {
6
+ throw new Error("User token not found");
7
+ }
8
+ return locals.token;
9
+ }
10
+ return config.OPENAI_API_KEY || config.HF_TOKEN;
11
+ }
src/lib/server/auth.ts CHANGED
@@ -18,6 +18,7 @@ import { ObjectId } from "mongodb";
18
  import type { Cookie } from "elysia";
19
  import { adminTokenManager } from "./adminToken";
20
  import type { User } from "$lib/types/User";
 
21
 
22
  export interface OIDCSettings {
23
  redirectURI: string;
@@ -79,6 +80,7 @@ export async function findUser(
79
  ): Promise<{
80
  user: User | null;
81
  invalidateSession: boolean;
 
82
  }> {
83
  const session = await collections.sessions.findOne({ sessionId });
84
 
@@ -93,6 +95,7 @@ export async function findUser(
93
  return {
94
  user: await collections.users.findOne({ _id: session.userId }),
95
  invalidateSession: false,
 
96
  };
97
  }
98
  export const authCondition = (locals: App.Locals) => {
@@ -283,6 +286,7 @@ export async function authenticateRequest(
283
 
284
  return {
285
  user: result.user ?? undefined,
 
286
  sessionId,
287
  secretSessionId,
288
  isAdmin: result.user?.isAdmin || adminTokenManager.isAdmin(sessionId),
@@ -308,6 +312,7 @@ export async function authenticateRequest(
308
  return {
309
  user,
310
  sessionId,
 
311
  secretSessionId,
312
  isAdmin: user.isAdmin || adminTokenManager.isAdmin(sessionId),
313
  };
@@ -338,6 +343,7 @@ export async function authenticateRequest(
338
  user,
339
  sessionId,
340
  secretSessionId,
 
341
  isAdmin: user.isAdmin || adminTokenManager.isAdmin(sessionId),
342
  };
343
  }
 
18
  import type { Cookie } from "elysia";
19
  import { adminTokenManager } from "./adminToken";
20
  import type { User } from "$lib/types/User";
21
+ import type { Session } from "$lib/types/Session";
22
 
23
  export interface OIDCSettings {
24
  redirectURI: string;
 
80
  ): Promise<{
81
  user: User | null;
82
  invalidateSession: boolean;
83
+ oauth?: Session["oauth"];
84
  }> {
85
  const session = await collections.sessions.findOne({ sessionId });
86
 
 
95
  return {
96
  user: await collections.users.findOne({ _id: session.userId }),
97
  invalidateSession: false,
98
+ oauth: session.oauth,
99
  };
100
  }
101
  export const authCondition = (locals: App.Locals) => {
 
286
 
287
  return {
288
  user: result.user ?? undefined,
289
+ token: result.oauth?.token?.value,
290
  sessionId,
291
  secretSessionId,
292
  isAdmin: result.user?.isAdmin || adminTokenManager.isAdmin(sessionId),
 
312
  return {
313
  user,
314
  sessionId,
315
+ token,
316
  secretSessionId,
317
  isAdmin: user.isAdmin || adminTokenManager.isAdmin(sessionId),
318
  };
 
343
  user,
344
  sessionId,
345
  secretSessionId,
346
+ token,
347
  isAdmin: user.isAdmin || adminTokenManager.isAdmin(sessionId),
348
  };
349
  }
src/lib/server/endpoints/endpoints.ts CHANGED
@@ -16,6 +16,7 @@ export interface EndpointParameters {
16
  generateSettings?: Partial<Model["parameters"]>;
17
  isMultimodal?: boolean;
18
  conversationId?: ObjectId;
 
19
  }
20
 
21
  export type TextGenerationStreamOutputSimplified = TextGenerationStreamOutput & {
 
16
  generateSettings?: Partial<Model["parameters"]>;
17
  isMultimodal?: boolean;
18
  conversationId?: ObjectId;
19
+ locals: App.Locals | undefined;
20
  }
21
 
22
  export type TextGenerationStreamOutputSimplified = TextGenerationStreamOutput & {
src/lib/server/generateFromDefaultEndpoint.ts CHANGED
@@ -7,18 +7,20 @@ export async function* generateFromDefaultEndpoint({
7
  preprompt,
8
  generateSettings,
9
  modelId,
 
10
  }: {
11
  messages: EndpointMessage[];
12
  preprompt?: string;
13
  generateSettings?: Record<string, unknown>;
14
  /** Optional: use this model instead of the default task model */
15
  modelId?: string;
 
16
  }): AsyncGenerator<MessageUpdate, string, undefined> {
17
  try {
18
  // Choose endpoint based on provided modelId, else fall back to taskModel
19
  const model = modelId ? (models.find((m) => m.id === modelId) ?? taskModel) : taskModel;
20
  const endpoint = await model.getEndpoint();
21
- const tokenStream = await endpoint({ messages, preprompt, generateSettings });
22
 
23
  for await (const output of tokenStream) {
24
  // if not generated_text is here it means the generation is not done
 
7
  preprompt,
8
  generateSettings,
9
  modelId,
10
+ locals,
11
  }: {
12
  messages: EndpointMessage[];
13
  preprompt?: string;
14
  generateSettings?: Record<string, unknown>;
15
  /** Optional: use this model instead of the default task model */
16
  modelId?: string;
17
+ locals: App.Locals | undefined;
18
  }): AsyncGenerator<MessageUpdate, string, undefined> {
19
  try {
20
  // Choose endpoint based on provided modelId, else fall back to taskModel
21
  const model = modelId ? (models.find((m) => m.id === modelId) ?? taskModel) : taskModel;
22
  const endpoint = await model.getEndpoint();
23
+ const tokenStream = await endpoint({ messages, preprompt, generateSettings, locals });
24
 
25
  for await (const output of tokenStream) {
26
  // if not generated_text is here it means the generation is not done
src/lib/server/models.ts CHANGED
@@ -109,7 +109,7 @@ if (openaiBaseUrl) {
109
  logger.info({ baseURL }, "[models] Using OpenAI-compatible base URL");
110
 
111
  // Canonical auth token is OPENAI_API_KEY; keep HF_TOKEN as legacy alias
112
- const authToken = config.OPENAI_API_KEY || config.HF_TOKEN || "";
113
 
114
  // Try unauthenticated request first (many model lists are public, e.g. HF router)
115
  let response = await fetch(`${baseURL}/models`);
 
109
  logger.info({ baseURL }, "[models] Using OpenAI-compatible base URL");
110
 
111
  // Canonical auth token is OPENAI_API_KEY; keep HF_TOKEN as legacy alias
112
+ const authToken = config.OPENAI_API_KEY || config.HF_TOKEN;
113
 
114
  // Try unauthenticated request first (many model lists are public, e.g. HF router)
115
  let response = await fetch(`${baseURL}/models`);
src/lib/server/router/arch.ts CHANGED
@@ -3,6 +3,7 @@ import { logger } from "$lib/server/logger";
3
  import type { EndpointMessage } from "../endpoints/endpoints";
4
  import type { Route, RouteConfig } from "./types";
5
  import { getRoutes } from "./policy";
 
6
 
7
  const DEFAULT_LAST_TURNS = 16;
8
  const PROMPT_TEMPLATE = `
@@ -68,7 +69,8 @@ function parseRouteName(text: string): string | undefined {
68
 
69
  export async function archSelectRoute(
70
  messages: EndpointMessage[],
71
- traceId?: string
 
72
  ): Promise<{ routeName: string }> {
73
  const routes = await getRoutes();
74
  const prompt = toRouterPrompt(messages, routes);
@@ -82,7 +84,7 @@ export async function archSelectRoute(
82
  }
83
 
84
  const headers: HeadersInit = {
85
- Authorization: `Bearer ${config.OPENAI_API_KEY || config.HF_TOKEN}`,
86
  "Content-Type": "application/json",
87
  };
88
  const body = {
 
3
  import type { EndpointMessage } from "../endpoints/endpoints";
4
  import type { Route, RouteConfig } from "./types";
5
  import { getRoutes } from "./policy";
6
+ import { getApiToken } from "$lib/server/apiToken";
7
 
8
  const DEFAULT_LAST_TURNS = 16;
9
  const PROMPT_TEMPLATE = `
 
69
 
70
  export async function archSelectRoute(
71
  messages: EndpointMessage[],
72
+ traceId: string | undefined,
73
+ locals: App.Locals | undefined
74
  ): Promise<{ routeName: string }> {
75
  const routes = await getRoutes();
76
  const prompt = toRouterPrompt(messages, routes);
 
84
  }
85
 
86
  const headers: HeadersInit = {
87
+ Authorization: `Bearer ${getApiToken(locals)}`,
88
  "Content-Type": "application/json",
89
  };
90
  const body = {
src/lib/server/router/endpoint.ts CHANGED
@@ -10,6 +10,7 @@ import { config } from "$lib/server/config";
10
  import { logger } from "$lib/server/logger";
11
  import { archSelectRoute } from "./arch";
12
  import { getRoutes, resolveRouteModels } from "./policy";
 
13
 
14
  const REASONING_BLOCK_REGEX = /<think>[\s\S]*?(?:<\/think>|$)/g;
15
 
@@ -72,7 +73,7 @@ export async function makeRouterEndpoint(routerModel: ProcessedModel): Promise<E
72
  return endpoints.openai({
73
  type: "openai",
74
  baseURL: (config.OPENAI_BASE_URL || "https://router.huggingface.co/v1").replace(/\/$/, ""),
75
- apiKey: config.OPENAI_API_KEY || config.HF_TOKEN || "sk-",
76
  model: modelForCall,
77
  // Ensure streaming path is used
78
  streamingSupported: true,
@@ -133,7 +134,7 @@ export async function makeRouterEndpoint(routerModel: ProcessedModel): Promise<E
133
  }
134
  }
135
 
136
- const { routeName } = await archSelectRoute(sanitizedMessages);
137
 
138
  const fallbackModel = config.LLM_ROUTER_FALLBACK_MODEL || routerModel.id;
139
  const { candidates } = resolveRouteModels(routeName, routes, fallbackModel);
 
10
  import { logger } from "$lib/server/logger";
11
  import { archSelectRoute } from "./arch";
12
  import { getRoutes, resolveRouteModels } from "./policy";
13
+ import { getApiToken } from "$lib/server/apiToken";
14
 
15
  const REASONING_BLOCK_REGEX = /<think>[\s\S]*?(?:<\/think>|$)/g;
16
 
 
73
  return endpoints.openai({
74
  type: "openai",
75
  baseURL: (config.OPENAI_BASE_URL || "https://router.huggingface.co/v1").replace(/\/$/, ""),
76
+ apiKey: getApiToken(params.locals),
77
  model: modelForCall,
78
  // Ensure streaming path is used
79
  streamingSupported: true,
 
134
  }
135
  }
136
 
137
+ const { routeName } = await archSelectRoute(sanitizedMessages, undefined, params.locals);
138
 
139
  const fallbackModel = config.LLM_ROUTER_FALLBACK_MODEL || routerModel.id;
140
  const { candidates } = resolveRouteModels(routeName, routes, fallbackModel);
src/lib/server/textGeneration/generate.ts CHANGED
@@ -23,6 +23,7 @@ export async function* generate(
23
  isContinue,
24
  promptedAt,
25
  forceMultimodal,
 
26
  }: GenerateContext,
27
  preprompt?: string
28
  ): AsyncIterable<MessageUpdate> {
@@ -57,6 +58,7 @@ export async function* generate(
57
  // Allow user-level override to force multimodal
58
  isMultimodal: (forceMultimodal ?? false) || model.multimodal,
59
  conversationId: conv._id,
 
60
  })) {
61
  // Check if this output contains router metadata
62
  if (
@@ -114,6 +116,7 @@ Do not use prefixes such as Response: or Answer: when answering to the user.`,
114
  max_tokens: 1024,
115
  },
116
  modelId: model.id,
 
117
  });
118
  finalAnswer = summary;
119
  yield {
@@ -224,7 +227,7 @@ Do not use prefixes such as Response: or Answer: when answering to the user.`,
224
  ) {
225
  lastReasoningUpdate = new Date();
226
  try {
227
- generateSummaryOfReasoning(reasoningBuffer, model.id).then((summary) => {
228
  status = summary;
229
  });
230
  } catch (e) {
 
23
  isContinue,
24
  promptedAt,
25
  forceMultimodal,
26
+ locals,
27
  }: GenerateContext,
28
  preprompt?: string
29
  ): AsyncIterable<MessageUpdate> {
 
58
  // Allow user-level override to force multimodal
59
  isMultimodal: (forceMultimodal ?? false) || model.multimodal,
60
  conversationId: conv._id,
61
+ locals,
62
  })) {
63
  // Check if this output contains router metadata
64
  if (
 
116
  max_tokens: 1024,
117
  },
118
  modelId: model.id,
119
+ locals,
120
  });
121
  finalAnswer = summary;
122
  yield {
 
227
  ) {
228
  lastReasoningUpdate = new Date();
229
  try {
230
+ generateSummaryOfReasoning(reasoningBuffer, model.id, locals).then((summary) => {
231
  status = summary;
232
  });
233
  } catch (e) {
src/lib/server/textGeneration/index.ts CHANGED
@@ -23,7 +23,7 @@ async function* keepAlive(done: AbortSignal): AsyncGenerator<MessageUpdate, unde
23
  export async function* textGeneration(ctx: TextGenerationContext) {
24
  const done = new AbortController();
25
 
26
- const titleGen = generateTitleForConversation(ctx.conv);
27
  const textGen = textGenerationWithoutTitle(ctx, done);
28
  const keepAliveGen = keepAlive(done.signal);
29
 
 
23
  export async function* textGeneration(ctx: TextGenerationContext) {
24
  const done = new AbortController();
25
 
26
+ const titleGen = generateTitleForConversation(ctx.conv, ctx.locals);
27
  const textGen = textGenerationWithoutTitle(ctx, done);
28
  const keepAliveGen = keepAlive(done.signal);
29
 
src/lib/server/textGeneration/reasoning.ts CHANGED
@@ -3,7 +3,8 @@ import { getReturnFromGenerator } from "$lib/utils/getReturnFromGenerator";
3
 
4
  export async function generateSummaryOfReasoning(
5
  buffer: string,
6
- modelId?: string
 
7
  ): Promise<string> {
8
  let summary: string | undefined;
9
 
@@ -25,6 +26,7 @@ export async function generateSummaryOfReasoning(
25
  max_tokens: 50,
26
  },
27
  modelId,
 
28
  })
29
  );
30
  }
 
3
 
4
  export async function generateSummaryOfReasoning(
5
  buffer: string,
6
+ modelId: string | undefined,
7
+ locals: App.Locals | undefined
8
  ): Promise<string> {
9
  let summary: string | undefined;
10
 
 
26
  max_tokens: 50,
27
  },
28
  modelId,
29
+ locals,
30
  })
31
  );
32
  }
src/lib/server/textGeneration/title.ts CHANGED
@@ -6,7 +6,8 @@ import type { Conversation } from "$lib/types/Conversation";
6
  import { getReturnFromGenerator } from "$lib/utils/getReturnFromGenerator";
7
 
8
  export async function* generateTitleForConversation(
9
- conv: Conversation
 
10
  ): AsyncGenerator<MessageUpdate, undefined, undefined> {
11
  try {
12
  const userMessage = conv.messages.find((m) => m.from === "user");
@@ -15,7 +16,7 @@ export async function* generateTitleForConversation(
15
 
16
  const prompt = userMessage.content;
17
  const modelForTitle = config.TASK_MODEL?.trim() ? config.TASK_MODEL : conv.model;
18
- const title = (await generateTitle(prompt, modelForTitle)) ?? "New Chat";
19
 
20
  yield {
21
  type: MessageUpdateType.Title,
@@ -26,7 +27,11 @@ export async function* generateTitleForConversation(
26
  }
27
  }
28
 
29
- export async function generateTitle(prompt: string, modelId?: string) {
 
 
 
 
30
  if (config.LLM_SUMMARIZATION !== "true") {
31
  // When summarization is disabled, use the first five words without adding emojis
32
  return prompt.split(/\s+/g).slice(0, 5).join(" ");
@@ -48,6 +53,7 @@ Return ONLY the title text.`,
48
  max_tokens: 30,
49
  },
50
  modelId,
 
51
  })
52
  )
53
  .then((summary) => {
 
6
  import { getReturnFromGenerator } from "$lib/utils/getReturnFromGenerator";
7
 
8
  export async function* generateTitleForConversation(
9
+ conv: Conversation,
10
+ locals: App.Locals | undefined
11
  ): AsyncGenerator<MessageUpdate, undefined, undefined> {
12
  try {
13
  const userMessage = conv.messages.find((m) => m.from === "user");
 
16
 
17
  const prompt = userMessage.content;
18
  const modelForTitle = config.TASK_MODEL?.trim() ? config.TASK_MODEL : conv.model;
19
+ const title = (await generateTitle(prompt, modelForTitle, locals)) ?? "New Chat";
20
 
21
  yield {
22
  type: MessageUpdateType.Title,
 
27
  }
28
  }
29
 
30
+ async function generateTitle(
31
+ prompt: string,
32
+ modelId: string | undefined,
33
+ locals: App.Locals | undefined
34
+ ) {
35
  if (config.LLM_SUMMARIZATION !== "true") {
36
  // When summarization is disabled, use the first five words without adding emojis
37
  return prompt.split(/\s+/g).slice(0, 5).join(" ");
 
53
  max_tokens: 30,
54
  },
55
  modelId,
56
+ locals,
57
  })
58
  )
59
  .then((summary) => {
src/lib/server/textGeneration/types.ts CHANGED
@@ -16,4 +16,5 @@ export interface TextGenerationContext {
16
  username?: string;
17
  /** Force-enable multimodal handling for endpoints that support it */
18
  forceMultimodal?: boolean;
 
19
  }
 
16
  username?: string;
17
  /** Force-enable multimodal handling for endpoints that support it */
18
  forceMultimodal?: boolean;
19
+ locals: App.Locals | undefined;
20
  }
src/lib/types/Session.ts CHANGED
@@ -11,4 +11,12 @@ export interface Session extends Timestamps {
11
  expiresAt: Date;
12
  admin?: boolean;
13
  coupledCookieHash?: string;
 
 
 
 
 
 
 
 
14
  }
 
11
  expiresAt: Date;
12
  admin?: boolean;
13
  coupledCookieHash?: string;
14
+
15
+ oauth?: {
16
+ token: {
17
+ value: string;
18
+ expiresAt: Date;
19
+ };
20
+ refreshToken?: string;
21
+ };
22
  }
src/routes/conversation/[id]/+server.ts CHANGED
@@ -459,6 +459,7 @@ export async function POST({ request, locals, params, getClientAddress }) {
459
  model.id
460
  ]
461
  ),
 
462
  };
463
  // run the text generation and send updates to the client
464
  for await (const event of textGeneration(ctx)) await update(event);
 
459
  model.id
460
  ]
461
  ),
462
+ locals,
463
  };
464
  // run the text generation and send updates to the client
465
  for await (const event of textGeneration(ctx)) await update(event);
src/routes/login/callback/+server.ts CHANGED
@@ -52,7 +52,7 @@ export async function GET({ url, locals, cookies, request, getClientAddress }) {
52
  throw error(403, "Invalid or expired CSRF token");
53
  }
54
 
55
- const { userData } = await getOIDCUserData(
56
  { redirectURI: validatedToken.redirectUrl },
57
  code,
58
  iss
@@ -79,6 +79,7 @@ export async function GET({ url, locals, cookies, request, getClientAddress }) {
79
 
80
  await updateUser({
81
  userData,
 
82
  locals,
83
  cookies,
84
  userAgent: request.headers.get("user-agent") ?? undefined,
 
52
  throw error(403, "Invalid or expired CSRF token");
53
  }
54
 
55
+ const { userData, token } = await getOIDCUserData(
56
  { redirectURI: validatedToken.redirectUrl },
57
  code,
58
  iss
 
79
 
80
  await updateUser({
81
  userData,
82
+ token,
83
  locals,
84
  cookies,
85
  userAgent: request.headers.get("user-agent") ?? undefined,
src/routes/login/callback/updateUser.spec.ts CHANGED
@@ -6,6 +6,7 @@ import { ObjectId } from "mongodb";
6
  import { DEFAULT_SETTINGS } from "$lib/types/Settings";
7
  import { defaultModel } from "$lib/server/models";
8
  import { findUser } from "$lib/server/auth";
 
9
 
10
  const userData = {
11
  preferred_username: "new-username",
@@ -21,6 +22,13 @@ const locals = {
21
  isAdmin: false,
22
  };
23
 
 
 
 
 
 
 
 
24
  // @ts-expect-error SvelteKit cookies dumb mock
25
  const cookiesMock: Cookies = {
26
  set: vi.fn(),
@@ -61,7 +69,7 @@ describe("login", () => {
61
  it("should update user if existing", async () => {
62
  await insertRandomUser();
63
 
64
- await updateUser({ userData, locals, cookies: cookiesMock });
65
 
66
  const existingUser = await collections.users.findOne({ hfUserId: userData.sub });
67
 
@@ -75,7 +83,7 @@ describe("login", () => {
75
 
76
  await insertRandomConversations(2);
77
 
78
- await updateUser({ userData, locals, cookies: cookiesMock });
79
 
80
  const conversationCount = await collections.conversations.countDocuments({
81
  userId: insertedId,
@@ -88,7 +96,7 @@ describe("login", () => {
88
  });
89
 
90
  it("should create default settings for new user", async () => {
91
- await updateUser({ userData, locals, cookies: cookiesMock });
92
 
93
  const user = (await findUser(locals.sessionId)).user;
94
 
@@ -115,7 +123,7 @@ describe("login", () => {
115
  shareConversationsWithModelAuthors: false,
116
  });
117
 
118
- await updateUser({ userData, locals, cookies: cookiesMock });
119
 
120
  const settings = await collections.settings.findOne({
121
  _id: insertedId,
 
6
  import { DEFAULT_SETTINGS } from "$lib/types/Settings";
7
  import { defaultModel } from "$lib/server/models";
8
  import { findUser } from "$lib/server/auth";
9
+ import type { TokenSet } from "openid-client";
10
 
11
  const userData = {
12
  preferred_username: "new-username",
 
22
  isAdmin: false,
23
  };
24
 
25
+ const token = {
26
+ access_token: "access_token",
27
+ refresh_token: "refresh_token",
28
+ expires_at: 1717334400,
29
+ expires_in: 3600,
30
+ } as TokenSet;
31
+
32
  // @ts-expect-error SvelteKit cookies dumb mock
33
  const cookiesMock: Cookies = {
34
  set: vi.fn(),
 
69
  it("should update user if existing", async () => {
70
  await insertRandomUser();
71
 
72
+ await updateUser({ userData, locals, cookies: cookiesMock, token });
73
 
74
  const existingUser = await collections.users.findOne({ hfUserId: userData.sub });
75
 
 
83
 
84
  await insertRandomConversations(2);
85
 
86
+ await updateUser({ userData, locals, cookies: cookiesMock, token });
87
 
88
  const conversationCount = await collections.conversations.countDocuments({
89
  userId: insertedId,
 
96
  });
97
 
98
  it("should create default settings for new user", async () => {
99
+ await updateUser({ userData, locals, cookies: cookiesMock, token });
100
 
101
  const user = (await findUser(locals.sessionId)).user;
102
 
 
123
  shareConversationsWithModelAuthors: false,
124
  });
125
 
126
+ await updateUser({ userData, locals, cookies: cookiesMock, token });
127
 
128
  const settings = await collections.settings.findOne({
129
  _id: insertedId,
src/routes/login/callback/updateUser.ts CHANGED
@@ -3,23 +3,24 @@ import { collections } from "$lib/server/database";
3
  import { ObjectId } from "mongodb";
4
  import { DEFAULT_SETTINGS } from "$lib/types/Settings";
5
  import { z } from "zod";
6
- import type { UserinfoResponse } from "openid-client";
7
  import { error, type Cookies } from "@sveltejs/kit";
8
  import crypto from "crypto";
9
  import { sha256 } from "$lib/utils/sha256";
10
- import { addWeeks } from "date-fns";
11
  import { OIDConfig } from "$lib/server/auth";
12
  import { config } from "$lib/server/config";
13
  import { logger } from "$lib/server/logger";
14
 
15
  export async function updateUser(params: {
16
  userData: UserinfoResponse;
 
17
  locals: App.Locals;
18
  cookies: Cookies;
19
  userAgent?: string;
20
  ip?: string;
21
  }) {
22
- const { userData, locals, cookies, userAgent, ip } = params;
23
 
24
  // Microsoft Entra v1 tokens do not provide preferred_username, instead the username is provided in the upn
25
  // claim. See https://learn.microsoft.com/en-us/entra/identity-platform/access-token-claims-reference
@@ -122,6 +123,21 @@ export async function updateUser(params: {
122
  // Get cookie hash if coupling is enabled
123
  const coupledCookieHash = await getCoupledCookieHash({ type: "svelte", value: cookies });
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  if (existingUser) {
126
  // update existing user if any
127
  await collections.users.updateOne(
@@ -141,6 +157,7 @@ export async function updateUser(params: {
141
  ip,
142
  expiresAt: addWeeks(new Date(), 2),
143
  ...(coupledCookieHash ? { coupledCookieHash } : {}),
 
144
  });
145
  } else {
146
  // user doesn't exist yet, create a new one
@@ -169,6 +186,7 @@ export async function updateUser(params: {
169
  ip,
170
  expiresAt: addWeeks(new Date(), 2),
171
  ...(coupledCookieHash ? { coupledCookieHash } : {}),
 
172
  });
173
 
174
  // move pre-existing settings to new user
 
3
  import { ObjectId } from "mongodb";
4
  import { DEFAULT_SETTINGS } from "$lib/types/Settings";
5
  import { z } from "zod";
6
+ import type { UserinfoResponse, TokenSet } from "openid-client";
7
  import { error, type Cookies } from "@sveltejs/kit";
8
  import crypto from "crypto";
9
  import { sha256 } from "$lib/utils/sha256";
10
+ import { addWeeks, subMinutes } from "date-fns";
11
  import { OIDConfig } from "$lib/server/auth";
12
  import { config } from "$lib/server/config";
13
  import { logger } from "$lib/server/logger";
14
 
15
  export async function updateUser(params: {
16
  userData: UserinfoResponse;
17
+ token: TokenSet;
18
  locals: App.Locals;
19
  cookies: Cookies;
20
  userAgent?: string;
21
  ip?: string;
22
  }) {
23
+ const { userData, token, locals, cookies, userAgent, ip } = params;
24
 
25
  // Microsoft Entra v1 tokens do not provide preferred_username, instead the username is provided in the upn
26
  // claim. See https://learn.microsoft.com/en-us/entra/identity-platform/access-token-claims-reference
 
123
  // Get cookie hash if coupling is enabled
124
  const coupledCookieHash = await getCoupledCookieHash({ type: "svelte", value: cookies });
125
 
126
+ // Prepare OAuth token data for session storage
127
+ const oauthData = token.access_token
128
+ ? {
129
+ token: {
130
+ value: token.access_token,
131
+ expiresAt: token.expires_at
132
+ ? subMinutes(new Date(token.expires_at * 1000), 1)
133
+ : token.expires_in
134
+ ? subMinutes(new Date(Date.now() + token.expires_in * 1000), 1)
135
+ : addWeeks(new Date(), 2),
136
+ },
137
+ ...(token.refresh_token ? { refreshToken: token.refresh_token } : {}),
138
+ }
139
+ : undefined;
140
+
141
  if (existingUser) {
142
  // update existing user if any
143
  await collections.users.updateOne(
 
157
  ip,
158
  expiresAt: addWeeks(new Date(), 2),
159
  ...(coupledCookieHash ? { coupledCookieHash } : {}),
160
+ ...(oauthData ? { oauth: oauthData } : {}),
161
  });
162
  } else {
163
  // user doesn't exist yet, create a new one
 
186
  ip,
187
  expiresAt: addWeeks(new Date(), 2),
188
  ...(coupledCookieHash ? { coupledCookieHash } : {}),
189
+ ...(oauthData ? { oauth: oauthData } : {}),
190
  });
191
 
192
  // move pre-existing settings to new user