lewtun HF Staff OpenAI Codex commited on
Commit
71e1892
·
unverified ·
1 Parent(s): 0bd7547

Use HF username for personal trace uploads (#199)

Browse files

* Use HF username for personal trace uploads

Co-authored-by: OpenAI Codex <codex@openai.com>

* Remove redundant HF token branch

Co-authored-by: OpenAI Codex <codex@openai.com>

---------

Co-authored-by: OpenAI Codex <codex@openai.com>

agent/config.py CHANGED
@@ -31,7 +31,7 @@ class Config(BaseModel):
31
  # format so the HF Agent Trace Viewer auto-renders it
32
  # (https://huggingface.co/changelog/agent-trace-viewer). Created private
33
  # on first use; user flips it public via /share-traces. ``{hf_user}`` is
34
- # substituted at upload time from ``Session.user_id``.
35
  share_traces: bool = True
36
  personal_trace_repo_template: str = "{hf_user}/ml-intern-sessions"
37
  auto_save_interval: int = 1 # Save every N user turns (0 = disabled)
 
31
  # format so the HF Agent Trace Viewer auto-renders it
32
  # (https://huggingface.co/changelog/agent-trace-viewer). Created private
33
  # on first use; user flips it public via /share-traces. ``{hf_user}`` is
34
+ # substituted at upload time from the authenticated HF username.
35
  share_traces: bool = True
36
  personal_trace_repo_template: str = "{hf_user}/ml-intern-sessions"
37
  auto_save_interval: int = 1 # Save every N user turns (0 = disabled)
agent/core/session.py CHANGED
@@ -89,10 +89,12 @@ class Session:
89
  defer_turn_complete_notification: bool = False,
90
  session_id: str | None = None,
91
  user_id: str | None = None,
 
92
  persistence_store: Any | None = None,
93
  ):
94
  self.hf_token: Optional[str] = hf_token
95
  self.user_id: Optional[str] = user_id
 
96
  self.persistence_store = persistence_store
97
  self.tool_router = tool_router
98
  self.stream = stream
@@ -363,6 +365,7 @@ class Session:
363
  return {
364
  "session_id": self.session_id,
365
  "user_id": self.user_id,
 
366
  "session_start_time": self.session_start_time,
367
  "session_end_time": datetime.now().isoformat(),
368
  "model_name": self.config.model_name,
@@ -458,7 +461,7 @@ class Session:
458
  return False
459
 
460
  def _personal_trace_repo_id(self) -> Optional[str]:
461
- """Resolve the per-user trace repo id from config + user_id.
462
 
463
  Returns ``None`` when sharing is disabled, the user is anonymous,
464
  or the template is missing — caller skips the personal upload in
@@ -466,13 +469,14 @@ class Session:
466
  """
467
  if not getattr(self.config, "share_traces", False):
468
  return None
469
- if not self.user_id:
 
470
  return None
471
  template = getattr(self.config, "personal_trace_repo_template", None)
472
  if not template:
473
  return None
474
  try:
475
- return template.format(hf_user=self.user_id)
476
  except (KeyError, IndexError):
477
  logger.debug("personal_trace_repo_template format failed: %r", template)
478
  return None
 
89
  defer_turn_complete_notification: bool = False,
90
  session_id: str | None = None,
91
  user_id: str | None = None,
92
+ hf_username: str | None = None,
93
  persistence_store: Any | None = None,
94
  ):
95
  self.hf_token: Optional[str] = hf_token
96
  self.user_id: Optional[str] = user_id
97
+ self.hf_username: Optional[str] = hf_username
98
  self.persistence_store = persistence_store
99
  self.tool_router = tool_router
100
  self.stream = stream
 
365
  return {
366
  "session_id": self.session_id,
367
  "user_id": self.user_id,
368
+ "hf_username": self.hf_username,
369
  "session_start_time": self.session_start_time,
370
  "session_end_time": datetime.now().isoformat(),
371
  "model_name": self.config.model_name,
 
461
  return False
462
 
463
  def _personal_trace_repo_id(self) -> Optional[str]:
464
+ """Resolve the per-user trace repo id from config + HF username.
465
 
466
  Returns ``None`` when sharing is disabled, the user is anonymous,
467
  or the template is missing — caller skips the personal upload in
 
469
  """
470
  if not getattr(self.config, "share_traces", False):
471
  return None
472
+ hf_user = self.hf_username or self.user_id
473
+ if not hf_user:
474
  return None
475
  template = getattr(self.config, "personal_trace_repo_template", None)
476
  if not template:
477
  return None
478
  try:
479
+ return template.format(hf_user=hf_user)
480
  except (KeyError, IndexError):
481
  logger.debug("personal_trace_repo_template format failed: %r", template)
482
  return None
backend/routes/agent.py CHANGED
@@ -185,6 +185,7 @@ async def _check_session_access(
185
  session_id,
186
  user["user_id"],
187
  hf_token=hf_token,
 
188
  )
189
  if not agent_session:
190
  raise HTTPException(status_code=404, detail="Session not found")
@@ -369,6 +370,7 @@ async def create_session(
369
  try:
370
  session_id = await session_manager.create_session(
371
  user_id=user["user_id"],
 
372
  hf_token=hf_token,
373
  model=model,
374
  is_pro=user.get("plan") == "pro",
@@ -408,6 +410,7 @@ async def restore_session_summary(
408
  try:
409
  session_id = await session_manager.create_session(
410
  user_id=user["user_id"],
 
411
  hf_token=hf_token,
412
  model=model,
413
  is_pro=user.get("plan") == "pro",
 
185
  session_id,
186
  user["user_id"],
187
  hf_token=hf_token,
188
+ hf_username=user.get("username"),
189
  )
190
  if not agent_session:
191
  raise HTTPException(status_code=404, detail="Session not found")
 
370
  try:
371
  session_id = await session_manager.create_session(
372
  user_id=user["user_id"],
373
+ hf_username=user.get("username"),
374
  hf_token=hf_token,
375
  model=model,
376
  is_pro=user.get("plan") == "pro",
 
410
  try:
411
  session_id = await session_manager.create_session(
412
  user_id=user["user_id"],
413
+ hf_username=user.get("username"),
414
  hf_token=hf_token,
415
  model=model,
416
  is_pro=user.get("plan") == "pro",
backend/session_manager.py CHANGED
@@ -87,6 +87,7 @@ class AgentSession:
87
  tool_router: ToolRouter
88
  submission_queue: asyncio.Queue
89
  user_id: str = "dev" # Owner of this session
 
90
  hf_token: str | None = None # User's HF OAuth token for tool execution
91
  task: asyncio.Task | None = None
92
  created_at: datetime = field(default_factory=datetime.utcnow)
@@ -157,6 +158,7 @@ class SessionManager:
157
  *,
158
  session_id: str,
159
  user_id: str,
 
160
  hf_token: str | None,
161
  model: str | None,
162
  event_queue: asyncio.Queue,
@@ -178,6 +180,7 @@ class SessionManager:
178
  tool_router=tool_router,
179
  hf_token=hf_token,
180
  user_id=user_id,
 
181
  notification_gateway=self.messaging_gateway,
182
  notification_destinations=notification_destinations or [],
183
  session_id=session_id,
@@ -327,11 +330,18 @@ class SessionManager:
327
  )
328
 
329
  @staticmethod
330
- def _update_hf_token(agent_session: AgentSession, hf_token: str | None) -> None:
331
- if not hf_token:
332
- return
333
- agent_session.hf_token = hf_token
334
- agent_session.session.hf_token = hf_token
 
 
 
 
 
 
 
335
 
336
  async def persist_session_snapshot(
337
  self,
@@ -373,13 +383,18 @@ class SessionManager:
373
  session_id: str,
374
  user_id: str,
375
  hf_token: str | None = None,
 
376
  ) -> AgentSession | None:
377
  """Return a live runtime session, lazily restoring it from Mongo."""
378
  async with self._lock:
379
  existing = self.sessions.get(session_id)
380
  if existing:
381
  if self._can_access_session(existing, user_id):
382
- self._update_hf_token(existing, hf_token)
 
 
 
 
383
  return existing
384
  return None
385
 
@@ -392,7 +407,11 @@ class SessionManager:
392
  existing = self.sessions.get(session_id)
393
  if existing:
394
  if self._can_access_session(existing, user_id):
395
- self._update_hf_token(existing, hf_token)
 
 
 
 
396
  return existing
397
  return None
398
 
@@ -410,6 +429,7 @@ class SessionManager:
410
  self._create_session_sync,
411
  session_id=session_id,
412
  user_id=owner or user_id,
 
413
  hf_token=hf_token,
414
  model=model,
415
  event_queue=event_queue,
@@ -442,6 +462,7 @@ class SessionManager:
442
  tool_router=tool_router,
443
  submission_queue=submission_queue,
444
  user_id=owner or user_id,
 
445
  hf_token=hf_token,
446
  created_at=created_at,
447
  is_active=True,
@@ -455,7 +476,11 @@ class SessionManager:
455
  tool_router=tool_router,
456
  )
457
  if started is not agent_session:
458
- self._update_hf_token(started, hf_token)
 
 
 
 
459
  return started
460
  logger.info("Restored session %s for user %s", session_id, owner or user_id)
461
  return agent_session
@@ -463,6 +488,7 @@ class SessionManager:
463
  async def create_session(
464
  self,
465
  user_id: str = "dev",
 
466
  hf_token: str | None = None,
467
  model: str | None = None,
468
  is_pro: bool | None = None,
@@ -475,6 +501,7 @@ class SessionManager:
475
 
476
  Args:
477
  user_id: The ID of the user who owns this session.
 
478
  hf_token: The user's HF OAuth token, stored for tool execution.
479
  model: Optional model override. When set, replaces ``model_name``
480
  on the per-session config clone. None falls back to the
@@ -513,6 +540,7 @@ class SessionManager:
513
  self._create_session_sync,
514
  session_id=session_id,
515
  user_id=user_id,
 
516
  hf_token=hf_token,
517
  model=model,
518
  event_queue=event_queue,
@@ -525,6 +553,7 @@ class SessionManager:
525
  tool_router=tool_router,
526
  submission_queue=submission_queue,
527
  user_id=user_id,
 
528
  hf_token=hf_token,
529
  )
530
 
 
87
  tool_router: ToolRouter
88
  submission_queue: asyncio.Queue
89
  user_id: str = "dev" # Owner of this session
90
+ hf_username: str | None = None # HF namespace used for personal trace uploads
91
  hf_token: str | None = None # User's HF OAuth token for tool execution
92
  task: asyncio.Task | None = None
93
  created_at: datetime = field(default_factory=datetime.utcnow)
 
158
  *,
159
  session_id: str,
160
  user_id: str,
161
+ hf_username: str | None,
162
  hf_token: str | None,
163
  model: str | None,
164
  event_queue: asyncio.Queue,
 
180
  tool_router=tool_router,
181
  hf_token=hf_token,
182
  user_id=user_id,
183
+ hf_username=hf_username,
184
  notification_gateway=self.messaging_gateway,
185
  notification_destinations=notification_destinations or [],
186
  session_id=session_id,
 
330
  )
331
 
332
  @staticmethod
333
+ def _update_hf_identity(
334
+ agent_session: AgentSession,
335
+ *,
336
+ hf_token: str | None,
337
+ hf_username: str | None,
338
+ ) -> None:
339
+ if hf_token:
340
+ agent_session.hf_token = hf_token
341
+ agent_session.session.hf_token = hf_token
342
+ if hf_username:
343
+ agent_session.hf_username = hf_username
344
+ agent_session.session.hf_username = hf_username
345
 
346
  async def persist_session_snapshot(
347
  self,
 
383
  session_id: str,
384
  user_id: str,
385
  hf_token: str | None = None,
386
+ hf_username: str | None = None,
387
  ) -> AgentSession | None:
388
  """Return a live runtime session, lazily restoring it from Mongo."""
389
  async with self._lock:
390
  existing = self.sessions.get(session_id)
391
  if existing:
392
  if self._can_access_session(existing, user_id):
393
+ self._update_hf_identity(
394
+ existing,
395
+ hf_token=hf_token,
396
+ hf_username=hf_username,
397
+ )
398
  return existing
399
  return None
400
 
 
407
  existing = self.sessions.get(session_id)
408
  if existing:
409
  if self._can_access_session(existing, user_id):
410
+ self._update_hf_identity(
411
+ existing,
412
+ hf_token=hf_token,
413
+ hf_username=hf_username,
414
+ )
415
  return existing
416
  return None
417
 
 
429
  self._create_session_sync,
430
  session_id=session_id,
431
  user_id=owner or user_id,
432
+ hf_username=hf_username,
433
  hf_token=hf_token,
434
  model=model,
435
  event_queue=event_queue,
 
462
  tool_router=tool_router,
463
  submission_queue=submission_queue,
464
  user_id=owner or user_id,
465
+ hf_username=hf_username,
466
  hf_token=hf_token,
467
  created_at=created_at,
468
  is_active=True,
 
476
  tool_router=tool_router,
477
  )
478
  if started is not agent_session:
479
+ self._update_hf_identity(
480
+ started,
481
+ hf_token=hf_token,
482
+ hf_username=hf_username,
483
+ )
484
  return started
485
  logger.info("Restored session %s for user %s", session_id, owner or user_id)
486
  return agent_session
 
488
  async def create_session(
489
  self,
490
  user_id: str = "dev",
491
+ hf_username: str | None = None,
492
  hf_token: str | None = None,
493
  model: str | None = None,
494
  is_pro: bool | None = None,
 
501
 
502
  Args:
503
  user_id: The ID of the user who owns this session.
504
+ hf_username: The HF username/namespace used for personal trace uploads.
505
  hf_token: The user's HF OAuth token, stored for tool execution.
506
  model: Optional model override. When set, replaces ``model_name``
507
  on the per-session config clone. None falls back to the
 
540
  self._create_session_sync,
541
  session_id=session_id,
542
  user_id=user_id,
543
+ hf_username=hf_username,
544
  hf_token=hf_token,
545
  model=model,
546
  event_queue=event_queue,
 
553
  tool_router=tool_router,
554
  submission_queue=submission_queue,
555
  user_id=user_id,
556
+ hf_username=hf_username,
557
  hf_token=hf_token,
558
  )
559
 
tests/unit/test_personal_trace_repo.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from types import SimpleNamespace
3
+
4
+ from agent.core.session import Session
5
+
6
+
7
+ class DummyToolRouter:
8
+ def get_tool_specs_for_llm(self) -> list[dict]:
9
+ return []
10
+
11
+
12
+ def _session(*, user_id: str | None, hf_username: str | None) -> Session:
13
+ config = SimpleNamespace(
14
+ model_name="moonshotai/Kimi-K2.6",
15
+ save_sessions=True,
16
+ share_traces=True,
17
+ personal_trace_repo_template="{hf_user}/ml-intern-sessions",
18
+ session_dataset_repo="smolagents/ml-intern-sessions",
19
+ auto_save_interval=1,
20
+ heartbeat_interval_s=0,
21
+ reasoning_effort=None,
22
+ )
23
+ context_manager = SimpleNamespace(items=[], on_message_added=None)
24
+ return Session(
25
+ event_queue=asyncio.Queue(),
26
+ config=config,
27
+ tool_router=DummyToolRouter(),
28
+ context_manager=context_manager,
29
+ user_id=user_id,
30
+ hf_username=hf_username,
31
+ )
32
+
33
+
34
+ def test_personal_trace_repo_uses_hf_username_before_oauth_subject():
35
+ session = _session(user_id="oauth-subject", hf_username="lewtun")
36
+
37
+ assert session._personal_trace_repo_id() == "lewtun/ml-intern-sessions"
38
+
39
+
40
+ def test_personal_trace_repo_falls_back_to_user_id_for_cli():
41
+ session = _session(user_id="lewtun", hf_username=None)
42
+
43
+ assert session._personal_trace_repo_id() == "lewtun/ml-intern-sessions"