akseljoonas HF Staff commited on
Commit
bdbcdab
Β·
1 Parent(s): 7b48ae0

feat: merge HF Space improvements

Browse files
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. README.md +4 -0
  2. agent/context_manager/manager.py +76 -9
  3. agent/core/agent_loop.py +382 -123
  4. agent/core/session.py +48 -9
  5. agent/core/session_uploader.py +2 -4
  6. agent/core/tools.py +12 -9
  7. agent/prompts/system_prompt.yaml +2 -2
  8. agent/prompts/system_prompt_v2.yaml +46 -59
  9. agent/prompts/system_prompt_v3.yaml +0 -118
  10. agent/tools/dataset_tools.py +16 -9
  11. agent/tools/docs_tools.py +21 -10
  12. agent/tools/github_find_examples.py +49 -10
  13. agent/tools/github_read_file.py +52 -6
  14. agent/tools/jobs_tool.py +138 -122
  15. agent/tools/plan_tool.py +12 -5
  16. agent/tools/sandbox_client.py +0 -714
  17. agent/tools/sandbox_tool.py +0 -201
  18. backend/dependencies.py +144 -0
  19. backend/main.py +8 -0
  20. backend/models.py +12 -0
  21. backend/routes/agent.py +282 -27
  22. backend/routes/auth.py +74 -51
  23. backend/session_manager.py +114 -14
  24. backend/websocket.py +0 -10
  25. configs/main_agent_config.json +2 -2
  26. frontend/package-lock.json +168 -0
  27. frontend/package.json +2 -0
  28. frontend/src/App.tsx +5 -0
  29. frontend/src/components/ApprovalModal/ApprovalModal.tsx +0 -208
  30. frontend/src/components/Chat/ActivityStatusBar.tsx +57 -0
  31. frontend/src/components/Chat/ApprovalFlow.tsx +0 -515
  32. frontend/src/components/Chat/AssistantMessage.tsx +119 -0
  33. frontend/src/components/Chat/ChatInput.tsx +218 -15
  34. frontend/src/components/Chat/MarkdownContent.tsx +160 -0
  35. frontend/src/components/Chat/MessageBubble.tsx +32 -203
  36. frontend/src/components/Chat/MessageList.tsx +125 -74
  37. frontend/src/components/Chat/ThinkingIndicator.tsx +48 -0
  38. frontend/src/components/Chat/ToolCallGroup.tsx +655 -0
  39. frontend/src/components/Chat/UserMessage.tsx +105 -0
  40. frontend/src/components/CodePanel/CodePanel.tsx +479 -256
  41. frontend/src/components/Layout/AppLayout.tsx +351 -167
  42. frontend/src/components/SessionSidebar/SessionSidebar.tsx +279 -181
  43. frontend/src/components/WelcomeScreen/WelcomeScreen.tsx +247 -0
  44. frontend/src/hooks/useAgentChat.ts +278 -0
  45. frontend/src/hooks/useAgentWebSocket.ts +0 -503
  46. frontend/src/hooks/useAuth.ts +77 -0
  47. frontend/src/lib/chat-message-store.ts +62 -0
  48. frontend/src/lib/ws-chat-transport.ts +593 -0
  49. frontend/src/main.tsx +13 -3
  50. frontend/src/store/agentStore.ts +121 -206
README.md CHANGED
@@ -9,7 +9,11 @@ hf_oauth: true
9
  hf_oauth_scopes:
10
  - read-repos
11
  - write-repos
 
 
12
  - inference-api
 
 
13
  ---
14
 
15
  # HF Agent
 
9
  hf_oauth_scopes:
10
  - read-repos
11
  - write-repos
12
+ - contribute-repos
13
+ - manage-repos
14
  - inference-api
15
+ - jobs
16
+ - write-discussions
17
  ---
18
 
19
  # HF Agent
agent/context_manager/manager.py CHANGED
@@ -2,6 +2,7 @@
2
  Context management for conversation history
3
  """
4
 
 
5
  import os
6
  import zoneinfo
7
  from datetime import datetime
@@ -13,6 +14,72 @@ from huggingface_hub import HfApi
13
  from jinja2 import Template
14
  from litellm import Message, acompletion
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  class ContextManager:
18
  """Manages conversation context and message history for the agent"""
@@ -23,11 +90,11 @@ class ContextManager:
23
  compact_size: float = 0.1,
24
  untouched_messages: int = 5,
25
  tool_specs: list[dict[str, Any]] | None = None,
26
- prompt_file_suffix: str = "system_prompt_v3.yaml",
27
  ):
28
  self.system_prompt = self._load_system_prompt(
29
  tool_specs or [],
30
- prompt_file_suffix="system_prompt_v3.yaml",
31
  )
32
  self.max_context = max_context
33
  self.compact_size = int(max_context * compact_size)
@@ -54,9 +121,8 @@ class ContextManager:
54
  current_time = now.strftime("%H:%M:%S.%f")[:-3]
55
  current_timezone = f"{now.strftime('%Z')} (UTC{now.strftime('%z')[:3]}:{now.strftime('%z')[3:]})"
56
 
57
- # Get HF user info with explicit token from env
58
- hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
59
- hf_user_info = HfApi(token=hf_token).whoami().get("name", "unknown")
60
 
61
  template = Template(template_str)
62
  return template.render(
@@ -78,9 +144,7 @@ class ContextManager:
78
  """Get all messages for sending to LLM"""
79
  return self.items
80
 
81
- async def compact(
82
- self, model_name: str, tool_specs: list[dict] | None = None
83
- ) -> None:
84
  """Remove old messages to keep history under target size"""
85
  if (self.context_length <= self.max_context) or not self.items:
86
  return
@@ -110,11 +174,14 @@ class ContextManager:
110
  )
111
  )
112
 
 
113
  response = await acompletion(
114
  model=model_name,
115
  messages=messages_to_summarize,
116
  max_completion_tokens=self.compact_size,
117
- tools=tool_specs,
 
 
118
  )
119
  summarized_message = Message(
120
  role="assistant", content=response.choices[0].message.content
 
2
  Context management for conversation history
3
  """
4
 
5
+ import logging
6
  import os
7
  import zoneinfo
8
  from datetime import datetime
 
14
  from jinja2 import Template
15
  from litellm import Message, acompletion
16
 
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # Module-level cache for HF username β€” avoids repeating the slow whoami() call
20
+ _hf_username_cache: str | None = None
21
+
22
+ _HF_WHOAMI_URL = "https://huggingface.co/api/whoami-v2"
23
+ _HF_WHOAMI_TIMEOUT = 5 # seconds
24
+
25
+
26
+ def _get_hf_username() -> str:
27
+ """Return the HF username, cached after the first call.
28
+
29
+ Uses subprocess + curl to avoid Python HTTP client IPv6 issues that
30
+ cause 40+ second hangs (httpx/urllib try IPv6 first which times out
31
+ at OS level before falling back to IPv4 β€” the "Happy Eyeballs" problem).
32
+ """
33
+ import json
34
+ import subprocess
35
+ import time as _t
36
+
37
+ global _hf_username_cache
38
+ if _hf_username_cache is not None:
39
+ return _hf_username_cache
40
+
41
+ hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
42
+ if not hf_token:
43
+ logger.warning("No HF_TOKEN set, using 'unknown' as username")
44
+ _hf_username_cache = "unknown"
45
+ return _hf_username_cache
46
+
47
+ t0 = _t.monotonic()
48
+ try:
49
+ result = subprocess.run(
50
+ [
51
+ "curl",
52
+ "-s",
53
+ "-4", # force IPv4
54
+ "-m",
55
+ str(_HF_WHOAMI_TIMEOUT), # max time
56
+ "-H",
57
+ f"Authorization: Bearer {hf_token}",
58
+ _HF_WHOAMI_URL,
59
+ ],
60
+ capture_output=True,
61
+ text=True,
62
+ timeout=_HF_WHOAMI_TIMEOUT + 2,
63
+ )
64
+ t1 = _t.monotonic()
65
+ if result.returncode == 0 and result.stdout:
66
+ data = json.loads(result.stdout)
67
+ _hf_username_cache = data.get("name", "unknown")
68
+ logger.info(
69
+ f"HF username resolved to '{_hf_username_cache}' in {t1 - t0:.2f}s"
70
+ )
71
+ else:
72
+ logger.warning(
73
+ f"curl whoami failed (rc={result.returncode}) in {t1 - t0:.2f}s"
74
+ )
75
+ _hf_username_cache = "unknown"
76
+ except Exception as e:
77
+ t1 = _t.monotonic()
78
+ logger.warning(f"HF whoami failed in {t1 - t0:.2f}s: {e}")
79
+ _hf_username_cache = "unknown"
80
+
81
+ return _hf_username_cache
82
+
83
 
84
  class ContextManager:
85
  """Manages conversation context and message history for the agent"""
 
90
  compact_size: float = 0.1,
91
  untouched_messages: int = 5,
92
  tool_specs: list[dict[str, Any]] | None = None,
93
+ prompt_file_suffix: str = "system_prompt_v2.yaml",
94
  ):
95
  self.system_prompt = self._load_system_prompt(
96
  tool_specs or [],
97
+ prompt_file_suffix="system_prompt_v2.yaml",
98
  )
99
  self.max_context = max_context
100
  self.compact_size = int(max_context * compact_size)
 
121
  current_time = now.strftime("%H:%M:%S.%f")[:-3]
122
  current_timezone = f"{now.strftime('%Z')} (UTC{now.strftime('%z')[:3]}:{now.strftime('%z')[3:]})"
123
 
124
+ # Get HF user info (cached after the first call)
125
+ hf_user_info = _get_hf_username()
 
126
 
127
  template = Template(template_str)
128
  return template.render(
 
144
  """Get all messages for sending to LLM"""
145
  return self.items
146
 
147
+ async def compact(self, model_name: str) -> None:
 
 
148
  """Remove old messages to keep history under target size"""
149
  if (self.context_length <= self.max_context) or not self.items:
150
  return
 
174
  )
175
  )
176
 
177
+ hf_key = os.environ.get("INFERENCE_TOKEN")
178
  response = await acompletion(
179
  model=model_name,
180
  messages=messages_to_summarize,
181
  max_completion_tokens=self.compact_size,
182
+ api_key=hf_key
183
+ if hf_key and model_name.startswith("huggingface/")
184
+ else None,
185
  )
186
  summarized_message = Message(
187
  role="assistant", content=response.choices[0].message.content
agent/core/agent_loop.py CHANGED
@@ -4,9 +4,10 @@ Main agent implementation with integrated tool system and MCP support
4
 
5
  import asyncio
6
  import json
 
 
7
 
8
- from litellm import ChatCompletionMessageToolCall, Message, ModelResponse, acompletion
9
- from litellm.exceptions import ContextWindowExceededError
10
  from lmnr import observe
11
 
12
  from agent.config import Config
@@ -14,7 +15,42 @@ from agent.core.session import Event, OpType, Session
14
  from agent.core.tools import ToolRouter
15
  from agent.tools.jobs_tool import CPU_FLAVORS
16
 
 
 
17
  ToolCall = ChatCompletionMessageToolCall
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  def _validate_tool_args(tool_args: dict) -> tuple[bool, str | None]:
@@ -52,9 +88,6 @@ def _needs_approval(
52
  if not args_valid:
53
  return False
54
 
55
- if tool_name == "sandbox_create":
56
- return True
57
-
58
  if tool_name == "hf_jobs":
59
  operation = tool_args.get("operation", "")
60
  if operation not in ["run", "uv", "scheduled run", "scheduled uv"]:
@@ -109,31 +142,49 @@ def _needs_approval(
109
  return False
110
 
111
 
112
- async def _compact_and_notify(session: Session) -> None:
113
- """Run compaction and send event if context was reduced."""
114
- old_length = session.context_manager.context_length
115
- tool_specs = session.tool_router.get_tool_specs_for_llm()
116
- await session.context_manager.compact(
117
- model_name=session.config.model_name,
118
- tool_specs=tool_specs,
119
- )
120
- new_length = session.context_manager.context_length
121
- if new_length != old_length:
122
- await session.send_event(
123
- Event(
124
- event_type="compacted",
125
- data={"old_tokens": old_length, "new_tokens": new_length},
 
 
 
 
 
 
 
 
126
  )
127
- )
128
 
 
 
 
 
 
 
 
 
 
 
129
 
130
- class Handlers:
131
- """Handler functions for each operation type"""
132
 
133
  @staticmethod
134
  @observe(name="run_agent")
135
  async def run_agent(
136
- session: Session, text: str, max_iterations: int = 300
137
  ) -> str | None:
138
  """
139
  Handle user input (like user_input_or_turn in codex.rs:1291)
@@ -145,6 +196,11 @@ class Handlers:
145
 
146
  Laminar.set_trace_session_id(session_id=session.session_id)
147
 
 
 
 
 
 
148
  # Add user message to history only if there's actual content
149
  if text:
150
  user_msg = Message(role="user", content=text)
@@ -160,42 +216,102 @@ class Handlers:
160
  final_response = None
161
 
162
  while iteration < max_iterations:
163
- # Compact before calling the LLM if context is near the limit
164
- await _compact_and_notify(session)
165
-
166
  messages = session.context_manager.get_messages()
167
  tools = session.tool_router.get_tool_specs_for_llm()
168
-
169
  try:
170
- response: ModelResponse = await acompletion(
171
- model=session.config.model_name,
 
172
  messages=messages,
173
  tools=tools,
174
  tool_choice="auto",
 
 
 
175
  )
176
 
177
- # Extract text response, token usage, and tool calls
178
- message = response.choices[0].message
179
- content = message.content
180
- token_count = response.usage.total_tokens
181
- tool_calls: list[ToolCall] = message.get("tool_calls", [])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
  # If no tool calls, add assistant message and we're done
184
  if not tool_calls:
185
  if content:
186
  assistant_msg = Message(role="assistant", content=content)
187
  session.context_manager.add_message(assistant_msg, token_count)
188
- await session.send_event(
189
- Event(
190
- event_type="assistant_message",
191
- data={"content": content},
192
- )
193
- )
194
  final_response = content
195
  break
196
 
197
  # Add assistant message with tool calls to history
198
- # LiteLLM will format this correctly for the provider
199
  assistant_msg = Message(
200
  role="assistant",
201
  content=content,
@@ -203,66 +319,97 @@ class Handlers:
203
  )
204
  session.context_manager.add_message(assistant_msg, token_count)
205
 
206
- if content:
207
- await session.send_event(
208
- Event(event_type="assistant_message", data={"content": content})
209
- )
210
-
211
  # Separate tools into those requiring approval and those that don't
212
  approval_required_tools = []
213
  non_approval_tools = []
214
 
215
  for tc in tool_calls:
216
  tool_name = tc.function.name
217
- tool_args = json.loads(tc.function.arguments)
 
 
 
 
218
 
219
  if _needs_approval(tool_name, tool_args, session.config):
220
  approval_required_tools.append(tc)
221
  else:
222
  non_approval_tools.append(tc)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
- # Execute non-approval tools first
225
- for tc in non_approval_tools:
226
- tool_name = tc.function.name
227
- tool_args = json.loads(tc.function.arguments)
228
-
229
- # Validate tool arguments before calling
230
- args_valid, error_msg = _validate_tool_args(tool_args)
231
- if not args_valid:
232
- # Return error to agent instead of calling tool
233
- output = error_msg
234
- success = False
235
- else:
236
- await session.send_event(
237
- Event(
238
- event_type="tool_call",
239
- data={"tool": tool_name, "arguments": tool_args},
240
  )
241
- )
242
 
243
- output, success = await session.tool_router.call_tool(
244
- tool_name, tool_args, session=session
 
 
 
 
 
 
 
 
 
 
245
  )
 
246
 
247
- # Add tool result to history
248
- tool_msg = Message(
249
- role="tool",
250
- content=output,
251
- tool_call_id=tc.id,
252
- name=tool_name,
253
  )
254
- session.context_manager.add_message(tool_msg)
255
 
256
- await session.send_event(
257
- Event(
258
- event_type="tool_output",
259
- data={
260
- "tool": tool_name,
261
- "output": output,
262
- "success": success,
263
- },
 
 
 
 
 
 
 
 
 
 
 
 
264
  )
265
- )
266
 
267
  # If there are tools requiring approval, ask for batch approval
268
  if approval_required_tools:
@@ -270,7 +417,10 @@ class Handlers:
270
  tools_data = []
271
  for tc in approval_required_tools:
272
  tool_name = tc.function.name
273
- tool_args = json.loads(tc.function.arguments)
 
 
 
274
  tools_data.append(
275
  {
276
  "tool": tool_name,
@@ -299,14 +449,6 @@ class Handlers:
299
 
300
  iteration += 1
301
 
302
- except ContextWindowExceededError:
303
- # Force compact and retry this iteration
304
- session.context_manager.context_length = (
305
- session.context_manager.max_context + 1
306
- )
307
- await _compact_and_notify(session)
308
- continue
309
-
310
  except Exception as e:
311
  import traceback
312
 
@@ -318,6 +460,18 @@ class Handlers:
318
  )
319
  break
320
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  await session.send_event(
322
  Event(
323
  event_type="turn_complete",
@@ -337,13 +491,43 @@ class Handlers:
337
  session.interrupt()
338
  await session.send_event(Event(event_type="interrupted"))
339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  @staticmethod
341
  async def undo(session: Session) -> None:
342
- """Handle undo (like undo in codex.rs:1314)"""
343
- # Remove last user turn and all following items
344
- # Simplified: just remove last 2 items
345
- for _ in range(min(2, len(session.context_manager.items))):
346
- session.context_manager.items.pop()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
 
348
  await session.send_event(Event(event_type="undo_complete"))
349
 
@@ -371,6 +555,9 @@ class Handlers:
371
 
372
  # Create a map of tool_call_id -> approval decision
373
  approval_map = {a["tool_call_id"]: a for a in approvals}
 
 
 
374
 
375
  # Separate approved and rejected tool calls
376
  approved_tasks = []
@@ -378,36 +565,99 @@ class Handlers:
378
 
379
  for tc in tool_calls:
380
  tool_name = tc.function.name
381
- tool_args = json.loads(tc.function.arguments)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  approval_decision = approval_map.get(tc.id, {"approved": False})
383
 
384
  if approval_decision.get("approved", False):
385
- approved_tasks.append((tc, tool_name, tool_args))
 
 
 
 
 
 
386
  else:
387
  rejected_tasks.append((tc, tool_name, approval_decision))
388
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
  # Execute all approved tools concurrently
390
- async def execute_tool(tc, tool_name, tool_args):
391
- """Execute a single tool and return its result"""
 
 
 
 
 
392
  await session.send_event(
393
  Event(
394
- event_type="tool_call",
395
- data={"tool": tool_name, "arguments": tool_args},
 
 
 
 
396
  )
397
  )
398
 
399
  output, success = await session.tool_router.call_tool(
400
- tool_name, tool_args, session=session
401
  )
402
 
403
- return (tc, tool_name, output, success)
404
 
405
  # Execute all approved tools concurrently and wait for ALL to complete
406
  if approved_tasks:
407
  results = await asyncio.gather(
408
  *[
409
- execute_tool(tc, tool_name, tool_args)
410
- for tc, tool_name, tool_args in approved_tasks
411
  ],
412
  return_exceptions=True,
413
  )
@@ -416,10 +666,13 @@ class Handlers:
416
  for result in results:
417
  if isinstance(result, Exception):
418
  # Handle execution error
419
- print(f"Tool execution error: {result}")
420
  continue
421
 
422
- tc, tool_name, output, success = result
 
 
 
423
 
424
  # Add tool result to context
425
  tool_msg = Message(
@@ -435,6 +688,7 @@ class Handlers:
435
  event_type="tool_output",
436
  data={
437
  "tool": tool_name,
 
438
  "output": output,
439
  "success": success,
440
  },
@@ -446,7 +700,14 @@ class Handlers:
446
  rejection_msg = "Job execution cancelled by user"
447
  user_feedback = approval_decision.get("feedback")
448
  if user_feedback:
449
- rejection_msg += f". User feedback: {user_feedback}"
 
 
 
 
 
 
 
450
 
451
  tool_msg = Message(
452
  role="tool",
@@ -461,6 +722,7 @@ class Handlers:
461
  event_type="tool_output",
462
  data={
463
  "tool": tool_name,
 
464
  "output": rejection_msg,
465
  "success": False,
466
  },
@@ -478,11 +740,9 @@ class Handlers:
478
  """Handle shutdown (like shutdown in codex.rs:1329)"""
479
  # Save session trajectory if enabled (fire-and-forget, returns immediately)
480
  if session.config.save_sessions:
481
- print("πŸ’Ύ Saving session...")
482
  repo_id = session.config.session_dataset_repo
483
  _ = session.save_and_upload_detached(repo_id)
484
- # if local_path:
485
- # print("βœ… Session saved locally, upload in progress")
486
 
487
  session.is_running = False
488
  await session.send_event(Event(event_type="shutdown"))
@@ -497,7 +757,7 @@ async def process_submission(session: Session, submission) -> bool:
497
  bool: True to continue, False to shutdown
498
  """
499
  op = submission.operation
500
- # print(f"πŸ“¨ Received: {op.op_type.value}")
501
 
502
  if op.op_type == OpType.USER_INPUT:
503
  text = op.data.get("text", "") if op.data else ""
@@ -509,8 +769,7 @@ async def process_submission(session: Session, submission) -> bool:
509
  return True
510
 
511
  if op.op_type == OpType.COMPACT:
512
- # compact from the frontend
513
- await _compact_and_notify(session)
514
  return True
515
 
516
  if op.op_type == OpType.UNDO:
@@ -525,7 +784,7 @@ async def process_submission(session: Session, submission) -> bool:
525
  if op.op_type == OpType.SHUTDOWN:
526
  return not await Handlers.shutdown(session)
527
 
528
- print(f"⚠️ Unknown operation: {op.op_type}")
529
  return True
530
 
531
 
@@ -543,7 +802,7 @@ async def submission_loop(
543
 
544
  # Create session with tool router
545
  session = Session(event_queue, config=config, tool_router=tool_router)
546
- print("Agent loop started")
547
 
548
  # Retry any failed uploads from previous sessions (fire-and-forget)
549
  if config and config.save_sessions:
@@ -567,25 +826,25 @@ async def submission_loop(
567
  if not should_continue:
568
  break
569
  except asyncio.CancelledError:
570
- print("\n⚠️ Agent loop cancelled")
571
  break
572
  except Exception as e:
573
- print(f"❌ Error in agent loop: {e}")
574
  await session.send_event(
575
  Event(event_type="error", data={"error": str(e)})
576
  )
577
 
578
- print("πŸ›‘ Agent loop exited")
579
 
580
  finally:
581
  # Emergency save if session saving is enabled and shutdown wasn't called properly
582
  if session.config.save_sessions and session.is_running:
583
- print("\nπŸ’Ύ Emergency save: preserving session before exit...")
584
  try:
585
  local_path = session.save_and_upload_detached(
586
  session.config.session_dataset_repo
587
  )
588
  if local_path:
589
- print("βœ… Emergency save successful, upload in progress")
590
  except Exception as e:
591
- print(f"❌ Emergency save failed: {e}")
 
4
 
5
  import asyncio
6
  import json
7
+ import logging
8
+ import os
9
 
10
+ from litellm import ChatCompletionMessageToolCall, Message, acompletion
 
11
  from lmnr import observe
12
 
13
  from agent.config import Config
 
15
  from agent.core.tools import ToolRouter
16
  from agent.tools.jobs_tool import CPU_FLAVORS
17
 
18
+ logger = logging.getLogger(__name__)
19
+
20
  ToolCall = ChatCompletionMessageToolCall
21
+ # Explicit inference token β€” needed because litellm checks HF_TOKEN before
22
+ # HUGGINGFACE_API_KEY, and HF_TOKEN (used for Hub ops) may lack inference permissions.
23
+ _INFERENCE_API_KEY = os.environ.get("INFERENCE_TOKEN")
24
+
25
+
26
+ def _resolve_hf_router_params(model_name: str) -> dict:
27
+ """
28
+ Build LiteLLM kwargs for HuggingFace Router models.
29
+
30
+ api-inference.huggingface.co is deprecated; the new router lives at
31
+ router.huggingface.co/<provider>/v3/openai. LiteLLM's built-in
32
+ ``huggingface/`` provider still targets the old endpoint, so we
33
+ rewrite model names to ``openai/`` and supply the correct api_base.
34
+
35
+ Input format: huggingface/<router_provider>/<org>/<model>
36
+ Example: huggingface/novita/moonshotai/kimi-k2.5
37
+ """
38
+ if not model_name.startswith("huggingface/"):
39
+ return {"model": model_name}
40
+
41
+ parts = model_name.split("/", 2) # ['huggingface', 'novita', 'moonshotai/kimi-k2.5']
42
+ if len(parts) < 3:
43
+ return {"model": model_name}
44
+
45
+ router_provider = parts[1]
46
+ actual_model = parts[2]
47
+ api_key = _INFERENCE_API_KEY or os.environ.get("HF_TOKEN")
48
+
49
+ return {
50
+ "model": f"openai/{actual_model}",
51
+ "api_base": f"https://router.huggingface.co/{router_provider}/v3/openai",
52
+ "api_key": api_key,
53
+ }
54
 
55
 
56
  def _validate_tool_args(tool_args: dict) -> tuple[bool, str | None]:
 
88
  if not args_valid:
89
  return False
90
 
 
 
 
91
  if tool_name == "hf_jobs":
92
  operation = tool_args.get("operation", "")
93
  if operation not in ["run", "uv", "scheduled run", "scheduled uv"]:
 
142
  return False
143
 
144
 
145
+ class Handlers:
146
+ """Handler functions for each operation type"""
147
+
148
+ @staticmethod
149
+ async def _abandon_pending_approval(session: Session) -> None:
150
+ """Cancel pending approval tools when the user continues the conversation.
151
+
152
+ Injects rejection tool-result messages into the LLM context (so the
153
+ history stays valid) and notifies the frontend that those tools were
154
+ abandoned.
155
+ """
156
+ tool_calls = session.pending_approval.get("tool_calls", [])
157
+ for tc in tool_calls:
158
+ tool_name = tc.function.name
159
+ abandon_msg = "Task abandoned β€” user continued the conversation without approving."
160
+
161
+ # Keep LLM context valid: every tool_call needs a tool result
162
+ tool_msg = Message(
163
+ role="tool",
164
+ content=abandon_msg,
165
+ tool_call_id=tc.id,
166
+ name=tool_name,
167
  )
168
+ session.context_manager.add_message(tool_msg)
169
 
170
+ await session.send_event(
171
+ Event(
172
+ event_type="tool_state_change",
173
+ data={
174
+ "tool_call_id": tc.id,
175
+ "tool": tool_name,
176
+ "state": "abandoned",
177
+ },
178
+ )
179
+ )
180
 
181
+ session.pending_approval = None
182
+ logger.info("Abandoned %d pending approval tool(s)", len(tool_calls))
183
 
184
  @staticmethod
185
  @observe(name="run_agent")
186
  async def run_agent(
187
+ session: Session, text: str, max_iterations: int = 10
188
  ) -> str | None:
189
  """
190
  Handle user input (like user_input_or_turn in codex.rs:1291)
 
196
 
197
  Laminar.set_trace_session_id(session_id=session.session_id)
198
 
199
+ # If there's a pending approval and the user sent a new message,
200
+ # abandon the pending tools so the LLM context stays valid.
201
+ if text and session.pending_approval:
202
+ await Handlers._abandon_pending_approval(session)
203
+
204
  # Add user message to history only if there's actual content
205
  if text:
206
  user_msg = Message(role="user", content=text)
 
216
  final_response = None
217
 
218
  while iteration < max_iterations:
 
 
 
219
  messages = session.context_manager.get_messages()
220
  tools = session.tool_router.get_tool_specs_for_llm()
 
221
  try:
222
+ # ── Stream the LLM response ──────────────────────────
223
+ llm_params = _resolve_hf_router_params(session.config.model_name)
224
+ response = await acompletion(
225
  messages=messages,
226
  tools=tools,
227
  tool_choice="auto",
228
+ stream=True,
229
+ stream_options={"include_usage": True},
230
+ **llm_params,
231
  )
232
 
233
+ full_content = ""
234
+ tool_calls_acc: dict[int, dict] = {}
235
+ token_count = 0
236
+
237
+ async for chunk in response:
238
+ choice = chunk.choices[0] if chunk.choices else None
239
+ if not choice:
240
+ # Last chunk may carry only usage info
241
+ if hasattr(chunk, "usage") and chunk.usage:
242
+ token_count = chunk.usage.total_tokens
243
+ continue
244
+
245
+ delta = choice.delta
246
+
247
+ # Stream text deltas to the frontend
248
+ if delta.content:
249
+ full_content += delta.content
250
+ await session.send_event(
251
+ Event(
252
+ event_type="assistant_chunk",
253
+ data={"content": delta.content},
254
+ )
255
+ )
256
+
257
+ # Accumulate tool-call deltas (name + args arrive in pieces)
258
+ if delta.tool_calls:
259
+ for tc_delta in delta.tool_calls:
260
+ idx = tc_delta.index
261
+ if idx not in tool_calls_acc:
262
+ tool_calls_acc[idx] = {
263
+ "id": "",
264
+ "type": "function",
265
+ "function": {"name": "", "arguments": ""},
266
+ }
267
+ if tc_delta.id:
268
+ tool_calls_acc[idx]["id"] = tc_delta.id
269
+ if tc_delta.function:
270
+ if tc_delta.function.name:
271
+ tool_calls_acc[idx]["function"]["name"] += (
272
+ tc_delta.function.name
273
+ )
274
+ if tc_delta.function.arguments:
275
+ tool_calls_acc[idx]["function"]["arguments"] += (
276
+ tc_delta.function.arguments
277
+ )
278
+
279
+ # Capture usage from the final chunk
280
+ if hasattr(chunk, "usage") and chunk.usage:
281
+ token_count = chunk.usage.total_tokens
282
+
283
+ # ── Stream finished β€” reconstruct full message ───────
284
+ content = full_content or None
285
+
286
+ # Build tool_calls list from accumulated deltas
287
+ tool_calls: list[ToolCall] = []
288
+ for idx in sorted(tool_calls_acc.keys()):
289
+ tc_data = tool_calls_acc[idx]
290
+ tool_calls.append(
291
+ ToolCall(
292
+ id=tc_data["id"],
293
+ type="function",
294
+ function={
295
+ "name": tc_data["function"]["name"],
296
+ "arguments": tc_data["function"]["arguments"],
297
+ },
298
+ )
299
+ )
300
+
301
+ # Signal end of streaming to the frontend
302
+ await session.send_event(
303
+ Event(event_type="assistant_stream_end", data={})
304
+ )
305
 
306
  # If no tool calls, add assistant message and we're done
307
  if not tool_calls:
308
  if content:
309
  assistant_msg = Message(role="assistant", content=content)
310
  session.context_manager.add_message(assistant_msg, token_count)
 
 
 
 
 
 
311
  final_response = content
312
  break
313
 
314
  # Add assistant message with tool calls to history
 
315
  assistant_msg = Message(
316
  role="assistant",
317
  content=content,
 
319
  )
320
  session.context_manager.add_message(assistant_msg, token_count)
321
 
 
 
 
 
 
322
  # Separate tools into those requiring approval and those that don't
323
  approval_required_tools = []
324
  non_approval_tools = []
325
 
326
  for tc in tool_calls:
327
  tool_name = tc.function.name
328
+ try:
329
+ tool_args = json.loads(tc.function.arguments)
330
+ except (json.JSONDecodeError, TypeError) as e:
331
+ logger.warning(f"Malformed tool arguments for {tool_name}: {e}")
332
+ tool_args = {}
333
 
334
  if _needs_approval(tool_name, tool_args, session.config):
335
  approval_required_tools.append(tc)
336
  else:
337
  non_approval_tools.append(tc)
338
+ # Execute non-approval tools (in parallel when possible)
339
+ if non_approval_tools:
340
+ # 1. Parse args and validate upfront
341
+ parsed_tools: list[
342
+ tuple[ChatCompletionMessageToolCall, str, dict, bool, str]
343
+ ] = []
344
+ for tc in non_approval_tools:
345
+ tool_name = tc.function.name
346
+ try:
347
+ tool_args = json.loads(tc.function.arguments)
348
+ except (json.JSONDecodeError, TypeError):
349
+ tool_args = {}
350
+
351
+ args_valid, error_msg = _validate_tool_args(tool_args)
352
+ parsed_tools.append(
353
+ (tc, tool_name, tool_args, args_valid, error_msg)
354
+ )
355
 
356
+ # 2. Send all tool_call events upfront (so frontend shows them all)
357
+ for tc, tool_name, tool_args, args_valid, _ in parsed_tools:
358
+ if args_valid:
359
+ await session.send_event(
360
+ Event(
361
+ event_type="tool_call",
362
+ data={
363
+ "tool": tool_name,
364
+ "arguments": tool_args,
365
+ "tool_call_id": tc.id,
366
+ },
367
+ )
 
 
 
 
368
  )
 
369
 
370
+ # 3. Execute all valid tools in parallel
371
+ async def _exec_tool(
372
+ tc: ChatCompletionMessageToolCall,
373
+ name: str,
374
+ args: dict,
375
+ valid: bool,
376
+ err: str,
377
+ ) -> tuple[ChatCompletionMessageToolCall, str, dict, str, bool]:
378
+ if not valid:
379
+ return (tc, name, args, err, False)
380
+ out, ok = await session.tool_router.call_tool(
381
+ name, args, session=session
382
  )
383
+ return (tc, name, args, out, ok)
384
 
385
+ results = await asyncio.gather(
386
+ *[
387
+ _exec_tool(tc, name, args, valid, err)
388
+ for tc, name, args, valid, err in parsed_tools
389
+ ]
 
390
  )
 
391
 
392
+ # 4. Record results and send outputs (order preserved)
393
+ for tc, tool_name, tool_args, output, success in results:
394
+ tool_msg = Message(
395
+ role="tool",
396
+ content=output,
397
+ tool_call_id=tc.id,
398
+ name=tool_name,
399
+ )
400
+ session.context_manager.add_message(tool_msg)
401
+
402
+ await session.send_event(
403
+ Event(
404
+ event_type="tool_output",
405
+ data={
406
+ "tool": tool_name,
407
+ "tool_call_id": tc.id,
408
+ "output": output,
409
+ "success": success,
410
+ },
411
+ )
412
  )
 
413
 
414
  # If there are tools requiring approval, ask for batch approval
415
  if approval_required_tools:
 
417
  tools_data = []
418
  for tc in approval_required_tools:
419
  tool_name = tc.function.name
420
+ try:
421
+ tool_args = json.loads(tc.function.arguments)
422
+ except (json.JSONDecodeError, TypeError):
423
+ tool_args = {}
424
  tools_data.append(
425
  {
426
  "tool": tool_name,
 
449
 
450
  iteration += 1
451
 
 
 
 
 
 
 
 
 
452
  except Exception as e:
453
  import traceback
454
 
 
460
  )
461
  break
462
 
463
+ old_length = session.context_manager.context_length
464
+ await session.context_manager.compact(model_name=session.config.model_name)
465
+ new_length = session.context_manager.context_length
466
+
467
+ if new_length != old_length:
468
+ await session.send_event(
469
+ Event(
470
+ event_type="compacted",
471
+ data={"old_tokens": old_length, "new_tokens": new_length},
472
+ )
473
+ )
474
+
475
  await session.send_event(
476
  Event(
477
  event_type="turn_complete",
 
491
  session.interrupt()
492
  await session.send_event(Event(event_type="interrupted"))
493
 
494
+ @staticmethod
495
+ async def compact(session: Session) -> None:
496
+ """Handle compact (like compact in codex.rs:1317)"""
497
+ old_length = session.context_manager.context_length
498
+ await session.context_manager.compact(model_name=session.config.model_name)
499
+ new_length = session.context_manager.context_length
500
+
501
+ await session.send_event(
502
+ Event(
503
+ event_type="compacted",
504
+ data={"removed": old_length, "remaining": new_length},
505
+ )
506
+ )
507
+
508
  @staticmethod
509
  async def undo(session: Session) -> None:
510
+ """Remove the last complete turn (user msg + all assistant/tool msgs that follow).
511
+
512
+ Anthropic requires every tool_use to have a matching tool_result,
513
+ so we can't just pop 2 items β€” we must pop everything back to
514
+ (and including) the last user message to keep the history valid.
515
+ """
516
+ items = session.context_manager.items
517
+ if not items:
518
+ await session.send_event(Event(event_type="undo_complete"))
519
+ return
520
+
521
+ # Pop from the end until we've removed the last user message
522
+ removed_user = False
523
+ while items:
524
+ msg = items.pop()
525
+ if getattr(msg, "role", None) == "user":
526
+ removed_user = True
527
+ break
528
+
529
+ if not removed_user:
530
+ logger.warning("Undo: no user message found to remove")
531
 
532
  await session.send_event(Event(event_type="undo_complete"))
533
 
 
555
 
556
  # Create a map of tool_call_id -> approval decision
557
  approval_map = {a["tool_call_id"]: a for a in approvals}
558
+ for a in approvals:
559
+ if a.get("edited_script"):
560
+ logger.info(f"Received edited script for tool_call {a['tool_call_id']} ({len(a['edited_script'])} chars)")
561
 
562
  # Separate approved and rejected tool calls
563
  approved_tasks = []
 
565
 
566
  for tc in tool_calls:
567
  tool_name = tc.function.name
568
+ try:
569
+ tool_args = json.loads(tc.function.arguments)
570
+ except (json.JSONDecodeError, TypeError) as e:
571
+ # Malformed arguments β€” treat as failed, notify agent
572
+ logger.warning(f"Malformed tool arguments for {tool_name}: {e}")
573
+ tool_msg = Message(
574
+ role="tool",
575
+ content=f"Malformed arguments: {e}",
576
+ tool_call_id=tc.id,
577
+ name=tool_name,
578
+ )
579
+ session.context_manager.add_message(tool_msg)
580
+ await session.send_event(
581
+ Event(
582
+ event_type="tool_output",
583
+ data={
584
+ "tool": tool_name,
585
+ "tool_call_id": tc.id,
586
+ "output": f"Malformed arguments: {e}",
587
+ "success": False,
588
+ },
589
+ )
590
+ )
591
+ continue
592
+
593
  approval_decision = approval_map.get(tc.id, {"approved": False})
594
 
595
  if approval_decision.get("approved", False):
596
+ edited_script = approval_decision.get("edited_script")
597
+ was_edited = False
598
+ if edited_script and "script" in tool_args:
599
+ tool_args["script"] = edited_script
600
+ was_edited = True
601
+ logger.info(f"Using user-edited script for {tool_name} ({tc.id})")
602
+ approved_tasks.append((tc, tool_name, tool_args, was_edited))
603
  else:
604
  rejected_tasks.append((tc, tool_name, approval_decision))
605
 
606
+ # Notify frontend of approval decisions immediately (before execution)
607
+ for tc, tool_name, tool_args, _was_edited in approved_tasks:
608
+ await session.send_event(
609
+ Event(
610
+ event_type="tool_state_change",
611
+ data={
612
+ "tool_call_id": tc.id,
613
+ "tool": tool_name,
614
+ "state": "approved",
615
+ },
616
+ )
617
+ )
618
+ for tc, tool_name, approval_decision in rejected_tasks:
619
+ await session.send_event(
620
+ Event(
621
+ event_type="tool_state_change",
622
+ data={
623
+ "tool_call_id": tc.id,
624
+ "tool": tool_name,
625
+ "state": "rejected",
626
+ },
627
+ )
628
+ )
629
+
630
  # Execute all approved tools concurrently
631
+ async def execute_tool(tc, tool_name, tool_args, was_edited):
632
+ """Execute a single tool and return its result.
633
+
634
+ The TraceLog already exists on the frontend (created by
635
+ approval_required), so we send tool_state_change instead of
636
+ tool_call to avoid creating a duplicate.
637
+ """
638
  await session.send_event(
639
  Event(
640
+ event_type="tool_state_change",
641
+ data={
642
+ "tool_call_id": tc.id,
643
+ "tool": tool_name,
644
+ "state": "running",
645
+ },
646
  )
647
  )
648
 
649
  output, success = await session.tool_router.call_tool(
650
+ tool_name, tool_args, session=session, tool_call_id=tc.id
651
  )
652
 
653
+ return (tc, tool_name, output, success, was_edited)
654
 
655
  # Execute all approved tools concurrently and wait for ALL to complete
656
  if approved_tasks:
657
  results = await asyncio.gather(
658
  *[
659
+ execute_tool(tc, tool_name, tool_args, was_edited)
660
+ for tc, tool_name, tool_args, was_edited in approved_tasks
661
  ],
662
  return_exceptions=True,
663
  )
 
666
  for result in results:
667
  if isinstance(result, Exception):
668
  # Handle execution error
669
+ logger.error(f"Tool execution error: {result}")
670
  continue
671
 
672
+ tc, tool_name, output, success, was_edited = result
673
+
674
+ if was_edited:
675
+ output = f"[Note: The user edited the script before execution. The output below reflects the user-modified version, not your original script.]\n\n{output}"
676
 
677
  # Add tool result to context
678
  tool_msg = Message(
 
688
  event_type="tool_output",
689
  data={
690
  "tool": tool_name,
691
+ "tool_call_id": tc.id,
692
  "output": output,
693
  "success": success,
694
  },
 
700
  rejection_msg = "Job execution cancelled by user"
701
  user_feedback = approval_decision.get("feedback")
702
  if user_feedback:
703
+ # Ensure feedback is a string and sanitize any problematic characters
704
+ feedback_str = str(user_feedback).strip()
705
+ # Remove any control characters that might break JSON parsing
706
+ feedback_str = "".join(char for char in feedback_str if ord(char) >= 32 or char in "\n\t")
707
+ rejection_msg += f". User feedback: {feedback_str}"
708
+
709
+ # Ensure rejection_msg is a clean string
710
+ rejection_msg = str(rejection_msg).strip()
711
 
712
  tool_msg = Message(
713
  role="tool",
 
722
  event_type="tool_output",
723
  data={
724
  "tool": tool_name,
725
+ "tool_call_id": tc.id,
726
  "output": rejection_msg,
727
  "success": False,
728
  },
 
740
  """Handle shutdown (like shutdown in codex.rs:1329)"""
741
  # Save session trajectory if enabled (fire-and-forget, returns immediately)
742
  if session.config.save_sessions:
743
+ logger.info("Saving session...")
744
  repo_id = session.config.session_dataset_repo
745
  _ = session.save_and_upload_detached(repo_id)
 
 
746
 
747
  session.is_running = False
748
  await session.send_event(Event(event_type="shutdown"))
 
757
  bool: True to continue, False to shutdown
758
  """
759
  op = submission.operation
760
+ logger.debug("Received operation: %s", op.op_type.value)
761
 
762
  if op.op_type == OpType.USER_INPUT:
763
  text = op.data.get("text", "") if op.data else ""
 
769
  return True
770
 
771
  if op.op_type == OpType.COMPACT:
772
+ await Handlers.compact(session)
 
773
  return True
774
 
775
  if op.op_type == OpType.UNDO:
 
784
  if op.op_type == OpType.SHUTDOWN:
785
  return not await Handlers.shutdown(session)
786
 
787
+ logger.warning(f"Unknown operation: {op.op_type}")
788
  return True
789
 
790
 
 
802
 
803
  # Create session with tool router
804
  session = Session(event_queue, config=config, tool_router=tool_router)
805
+ logger.info("Agent loop started")
806
 
807
  # Retry any failed uploads from previous sessions (fire-and-forget)
808
  if config and config.save_sessions:
 
826
  if not should_continue:
827
  break
828
  except asyncio.CancelledError:
829
+ logger.warning("Agent loop cancelled")
830
  break
831
  except Exception as e:
832
+ logger.error(f"Error in agent loop: {e}")
833
  await session.send_event(
834
  Event(event_type="error", data={"error": str(e)})
835
  )
836
 
837
+ logger.info("Agent loop exited")
838
 
839
  finally:
840
  # Emergency save if session saving is enabled and shutdown wasn't called properly
841
  if session.config.save_sessions and session.is_running:
842
+ logger.info("Emergency save: preserving session before exit...")
843
  try:
844
  local_path = session.save_and_upload_detached(
845
  session.config.session_dataset_repo
846
  )
847
  if local_path:
848
+ logger.info("Emergency save successful, upload in progress")
849
  except Exception as e:
850
+ logger.error(f"Emergency save failed: {e}")
agent/core/session.py CHANGED
@@ -1,5 +1,6 @@
1
  import asyncio
2
  import json
 
3
  import subprocess
4
  import sys
5
  import uuid
@@ -9,11 +10,48 @@ from enum import Enum
9
  from pathlib import Path
10
  from typing import Any, Optional
11
 
12
- from litellm import get_max_tokens
13
-
14
  from agent.config import Config
15
  from agent.context_manager.manager import ContextManager
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  class OpType(Enum):
19
  USER_INPUT = "user_input"
@@ -46,7 +84,7 @@ class Session:
46
  self.tool_router = tool_router
47
  tool_specs = tool_router.get_tool_specs_for_llm() if tool_router else []
48
  self.context_manager = context_manager or ContextManager(
49
- max_context=get_max_tokens(config.model_name),
50
  compact_size=0.1,
51
  untouched_messages=5,
52
  tool_specs=tool_specs,
@@ -59,7 +97,8 @@ class Session:
59
  self.is_running = True
60
  self.current_task: asyncio.Task | None = None
61
  self.pending_approval: Optional[dict[str, Any]] = None
62
- self.sandbox = None
 
63
 
64
  # Session trajectory logging
65
  self.logged_events: list[dict] = []
@@ -100,7 +139,7 @@ class Session:
100
 
101
  turns_since_last_save = self.turn_count - self.last_auto_save_turn
102
  if turns_since_last_save >= interval:
103
- print(f"\nπŸ’Ύ Auto-saving session (turn {self.turn_count})...")
104
  # Fire-and-forget save - returns immediately
105
  self.save_and_upload_detached(self.config.session_dataset_repo)
106
  self.last_auto_save_turn = self.turn_count
@@ -152,7 +191,7 @@ class Session:
152
 
153
  return str(filepath)
154
  except Exception as e:
155
- print(f"Failed to save session locally: {e}")
156
  return None
157
 
158
  def update_local_save_status(
@@ -172,7 +211,7 @@ class Session:
172
 
173
  return True
174
  except Exception as e:
175
- print(f"Failed to update local save status: {e}")
176
  return False
177
 
178
  def save_and_upload_detached(self, repo_id: str) -> Optional[str]:
@@ -203,7 +242,7 @@ class Session:
203
  start_new_session=True, # Detach from parent
204
  )
205
  except Exception as e:
206
- print(f"⚠️ Failed to spawn upload subprocess: {e}")
207
 
208
  return local_path
209
 
@@ -233,4 +272,4 @@ class Session:
233
  start_new_session=True, # Detach from parent
234
  )
235
  except Exception as e:
236
- print(f"⚠️ Failed to spawn retry subprocess: {e}")
 
1
  import asyncio
2
  import json
3
+ import logging
4
  import subprocess
5
  import sys
6
  import uuid
 
10
  from pathlib import Path
11
  from typing import Any, Optional
12
 
 
 
13
  from agent.config import Config
14
  from agent.context_manager.manager import ContextManager
15
 
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # Local max-token lookup β€” avoids litellm.get_max_tokens() which can hang
19
+ # on network calls for certain providers (known litellm issue).
20
+ _MAX_TOKENS_MAP: dict[str, int] = {
21
+ # Anthropic
22
+ "anthropic/claude-opus-4-5-20251101": 200_000,
23
+ "anthropic/claude-sonnet-4-5-20250929": 200_000,
24
+ "anthropic/claude-sonnet-4-20250514": 200_000,
25
+ "anthropic/claude-haiku-3-5-20241022": 200_000,
26
+ "anthropic/claude-3-5-sonnet-20241022": 200_000,
27
+ "anthropic/claude-3-opus-20240229": 200_000,
28
+ "huggingface/novita/minimax/minimax-m2.1": 196_608,
29
+ "huggingface/novita/moonshotai/kimi-k2.5": 262_144,
30
+ "huggingface/novita/zai-org/glm-5": 200_000,
31
+ }
32
+ _DEFAULT_MAX_TOKENS = 200_000
33
+
34
+
35
+ def _get_max_tokens_safe(model_name: str) -> int:
36
+ """Return the max context window for a model without network calls."""
37
+ tokens = _MAX_TOKENS_MAP.get(model_name)
38
+ if tokens:
39
+ return tokens
40
+ # Fallback: try litellm but with a short timeout via threading
41
+ try:
42
+ from litellm import get_max_tokens
43
+
44
+ result = get_max_tokens(model_name)
45
+ if result and isinstance(result, int):
46
+ return result
47
+ logger.warning(
48
+ f"get_max_tokens returned {result} for {model_name}, using default"
49
+ )
50
+ return _DEFAULT_MAX_TOKENS
51
+ except Exception as e:
52
+ logger.warning(f"get_max_tokens failed for {model_name}, using default: {e}")
53
+ return _DEFAULT_MAX_TOKENS
54
+
55
 
56
  class OpType(Enum):
57
  USER_INPUT = "user_input"
 
84
  self.tool_router = tool_router
85
  tool_specs = tool_router.get_tool_specs_for_llm() if tool_router else []
86
  self.context_manager = context_manager or ContextManager(
87
+ max_context=_get_max_tokens_safe(config.model_name),
88
  compact_size=0.1,
89
  untouched_messages=5,
90
  tool_specs=tool_specs,
 
97
  self.is_running = True
98
  self.current_task: asyncio.Task | None = None
99
  self.pending_approval: Optional[dict[str, Any]] = None
100
+ # User's HF OAuth token β€” set by session_manager after construction
101
+ self.hf_token: Optional[str] = None
102
 
103
  # Session trajectory logging
104
  self.logged_events: list[dict] = []
 
139
 
140
  turns_since_last_save = self.turn_count - self.last_auto_save_turn
141
  if turns_since_last_save >= interval:
142
+ logger.info(f"Auto-saving session (turn {self.turn_count})...")
143
  # Fire-and-forget save - returns immediately
144
  self.save_and_upload_detached(self.config.session_dataset_repo)
145
  self.last_auto_save_turn = self.turn_count
 
191
 
192
  return str(filepath)
193
  except Exception as e:
194
+ logger.error(f"Failed to save session locally: {e}")
195
  return None
196
 
197
  def update_local_save_status(
 
211
 
212
  return True
213
  except Exception as e:
214
+ logger.error(f"Failed to update local save status: {e}")
215
  return False
216
 
217
  def save_and_upload_detached(self, repo_id: str) -> Optional[str]:
 
242
  start_new_session=True, # Detach from parent
243
  )
244
  except Exception as e:
245
+ logger.warning(f"Failed to spawn upload subprocess: {e}")
246
 
247
  return local_path
248
 
 
272
  start_new_session=True, # Detach from parent
273
  )
274
  except Exception as e:
275
+ logger.warning(f"Failed to spawn retry subprocess: {e}")
agent/core/session_uploader.py CHANGED
@@ -15,10 +15,8 @@ from dotenv import load_dotenv
15
 
16
  load_dotenv()
17
 
18
- # Fallback token for session uploads (write-only access to akseljoonas/hf-agent-sessions)
19
- _SESSION_TOKEN = "".join([
20
- "hf_", "Nzya", "Eeb", "ESz", "DtA", "BoW", "Czj", "SEC", "ZZv", "kVL", "Ac", "Vf", "Sz"
21
- ])
22
 
23
 
24
  def upload_session_as_file(
 
15
 
16
  load_dotenv()
17
 
18
+ # Token for session uploads β€” loaded from env var (never hardcode tokens in source)
19
+ _SESSION_TOKEN = os.environ.get("HF_SESSION_UPLOAD_TOKEN", "")
 
 
20
 
21
 
22
  def upload_session_as_file(
agent/core/tools.py CHANGED
@@ -3,10 +3,13 @@ Tool system for the agent
3
  Provides ToolSpec and ToolRouter for managing both built-in and MCP tools
4
  """
5
 
 
6
  import warnings
7
  from dataclasses import dataclass
8
  from typing import Any, Awaitable, Callable, Optional
9
 
 
 
10
  from fastmcp import Client
11
  from fastmcp.exceptions import ToolError
12
  from lmnr import observe
@@ -45,7 +48,6 @@ from agent.tools.hf_repo_git_tool import (
45
  )
46
  from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, hf_jobs_handler
47
  from agent.tools.plan_tool import PLAN_TOOL_SPEC, plan_tool_handler
48
- from agent.tools.sandbox_tool import get_sandbox_tools
49
 
50
  # NOTE: Private HF repo tool disabled - replaced by hf_repo_files and hf_repo_git
51
  # from agent.tools.private_hf_repo_tools import (
@@ -132,6 +134,7 @@ class ToolRouter:
132
  for tool in create_builtin_tools():
133
  self.register_tool(tool)
134
 
 
135
  if mcp_servers:
136
  mcp_servers_payload = {}
137
  for name, server in mcp_servers.items():
@@ -159,7 +162,7 @@ class ToolRouter:
159
  handler=None,
160
  )
161
  )
162
- print(
163
  f"Loaded {len(registered_names)} MCP tools: {', '.join(registered_names)} ({skipped_count} disabled)"
164
  )
165
 
@@ -180,7 +183,7 @@ class ToolRouter:
180
  handler=search_openapi_handler,
181
  )
182
  )
183
- print(f"Loaded OpenAPI search tool: {openapi_spec['name']}")
184
 
185
  def get_tool_specs_for_llm(self) -> list[dict[str, Any]]:
186
  """Get tool specifications in OpenAI format"""
@@ -209,7 +212,7 @@ class ToolRouter:
209
  await self.register_openapi_tool()
210
 
211
  total_tools = len(self.tools)
212
- print(f"\nAgent ready with {total_tools} tools total\n")
213
 
214
  return self
215
 
@@ -220,7 +223,7 @@ class ToolRouter:
220
 
221
  @observe(name="call_tool")
222
  async def call_tool(
223
- self, tool_name: str, arguments: dict[str, Any], session: Any = None
224
  ) -> tuple[str, bool]:
225
  """
226
  Call a tool and return (output_string, success_bool).
@@ -236,6 +239,9 @@ class ToolRouter:
236
  # Check if handler accepts session argument
237
  sig = inspect.signature(tool.handler)
238
  if "session" in sig.parameters:
 
 
 
239
  return await tool.handler(arguments, session=session)
240
  return await tool.handler(arguments)
241
 
@@ -328,10 +334,7 @@ def create_builtin_tools() -> list[ToolSpec]:
328
  ),
329
  ]
330
 
331
- # Sandbox tools
332
- tools = get_sandbox_tools() + tools
333
-
334
  tool_names = ", ".join([t.name for t in tools])
335
- print(f"Loaded {len(tools)} built-in tools: {tool_names}")
336
 
337
  return tools
 
3
  Provides ToolSpec and ToolRouter for managing both built-in and MCP tools
4
  """
5
 
6
+ import logging
7
  import warnings
8
  from dataclasses import dataclass
9
  from typing import Any, Awaitable, Callable, Optional
10
 
11
+ logger = logging.getLogger(__name__)
12
+
13
  from fastmcp import Client
14
  from fastmcp.exceptions import ToolError
15
  from lmnr import observe
 
48
  )
49
  from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, hf_jobs_handler
50
  from agent.tools.plan_tool import PLAN_TOOL_SPEC, plan_tool_handler
 
51
 
52
  # NOTE: Private HF repo tool disabled - replaced by hf_repo_files and hf_repo_git
53
  # from agent.tools.private_hf_repo_tools import (
 
134
  for tool in create_builtin_tools():
135
  self.register_tool(tool)
136
 
137
+ self.mcp_client: Client | None = None
138
  if mcp_servers:
139
  mcp_servers_payload = {}
140
  for name, server in mcp_servers.items():
 
162
  handler=None,
163
  )
164
  )
165
+ logger.info(
166
  f"Loaded {len(registered_names)} MCP tools: {', '.join(registered_names)} ({skipped_count} disabled)"
167
  )
168
 
 
183
  handler=search_openapi_handler,
184
  )
185
  )
186
+ logger.info(f"Loaded OpenAPI search tool: {openapi_spec['name']}")
187
 
188
  def get_tool_specs_for_llm(self) -> list[dict[str, Any]]:
189
  """Get tool specifications in OpenAI format"""
 
212
  await self.register_openapi_tool()
213
 
214
  total_tools = len(self.tools)
215
+ logger.info(f"Agent ready with {total_tools} tools total")
216
 
217
  return self
218
 
 
223
 
224
  @observe(name="call_tool")
225
  async def call_tool(
226
+ self, tool_name: str, arguments: dict[str, Any], session: Any = None, tool_call_id: str | None = None
227
  ) -> tuple[str, bool]:
228
  """
229
  Call a tool and return (output_string, success_bool).
 
239
  # Check if handler accepts session argument
240
  sig = inspect.signature(tool.handler)
241
  if "session" in sig.parameters:
242
+ # Check if handler also accepts tool_call_id parameter
243
+ if "tool_call_id" in sig.parameters:
244
+ return await tool.handler(arguments, session=session, tool_call_id=tool_call_id)
245
  return await tool.handler(arguments, session=session)
246
  return await tool.handler(arguments)
247
 
 
334
  ),
335
  ]
336
 
 
 
 
337
  tool_names = ", ".join([t.name for t in tools])
338
+ logger.info(f"Loaded {len(tools)} built-in tools: {tool_names}")
339
 
340
  return tools
agent/prompts/system_prompt.yaml CHANGED
@@ -1,5 +1,5 @@
1
  system_prompt: |
2
- You are Hugging Face Agent, a skilled AI assistant for machine learning engineering. Hugging Face is a company that provides two main services : libraries to write deep learning tasks, and ressources (models, datasets, compute) to execute them. You will aid users to do theses tasks, interacting with the Hugging Face stack via {{ num_tools }}.
3
 
4
  # General behavior
5
 
@@ -9,7 +9,7 @@ system_prompt: |
9
 
10
  **CRITICAL : Research first, Then Implement**
11
 
12
- For ANY implementation task (training, fine-tuning, inference, data processing, etc.), you should proceed in thoses three mandatory steps:
13
 
14
  1. **FIRST**: Search HF documentation to find the correct approach.
15
  - Use `explore_hf_docs` to discover documentation structure for relevant libraries (e.g., "trl", "transformers", "diffusers").
 
1
  system_prompt: |
2
+ You are Hugging Face Agent, a skilled AI assistant for machine learning engineering. Hugging Face is a company that provides two main services : libraries to write deep learning tasks, and resources (models, datasets, compute) to execute them. You will aid users to do these tasks, interacting with the Hugging Face stack via {{ num_tools }}.
3
 
4
  # General behavior
5
 
 
9
 
10
  **CRITICAL : Research first, Then Implement**
11
 
12
+ For ANY implementation task (training, fine-tuning, inference, data processing, etc.), you should proceed in these three mandatory steps:
13
 
14
  1. **FIRST**: Search HF documentation to find the correct approach.
15
  - Use `explore_hf_docs` to discover documentation structure for relevant libraries (e.g., "trl", "transformers", "diffusers").
agent/prompts/system_prompt_v2.yaml CHANGED
@@ -186,59 +186,61 @@ system_prompt: |
186
  3. βœ… Determine optimal processing approach based on requirements
187
  4. βœ… Plan output format and destination
188
 
189
- ## PHASE 3: IMPLEMENT (Develop in Sandbox, Launch via Jobs)
190
-
191
- ⚠️ **CRITICAL WORKFLOW: Sandbox First, Jobs Second**
192
-
193
- For ANY implementation task (training, data processing, inference), follow this pattern:
194
-
195
- **Step 1: Create a sandbox** β€” `sandbox_create` with appropriate hardware (cpu-basic for scripting, t4-small for GPU testing)
196
- **Step 2: Develop & iterate** β€” Write scripts, install dependencies, test with small runs, fix errors interactively
197
- **Step 3: Launch via hf_jobs** β€” Once the script works, pass the sandbox file path directly: `hf_jobs(operation="run", script="/app/train.py", ...)`
198
-
199
- This is the CORRECT pattern:
200
- ```
201
- sandbox_create(hardware="t4-small") # interactive dev environment
202
- bash("pip install trl transformers") # install deps
203
- write("/app/train.py", "...") # write training script
204
- bash("cd /app && python train.py --max_steps 10") # test run
205
- edit("/app/train.py", ...) # fix issues
206
- bash("cd /app && python train.py --max_steps 10") # verify fix
207
- hf_jobs(operation="run", script="/app/train.py", hardware_flavor="a10g-large", timeout="4h") # launch at scale
208
- ```
209
-
210
- Do NOT write long inline scripts directly in hf_jobs if necessary β€” develop in sandbox first.
211
-
212
- ### Training Script Requirements
213
-
214
- **Script MUST Include:**
215
- - Imports from researched documentation (current APIs)
216
- - Trackio initialization with project/run_name/config
217
- - Model and tokenizer loading
218
- - Dataset loading with verified columns and conversational format
219
- - Training config with ALL critical settings:
220
  - `push_to_hub=True` ⚠️ MANDATORY
221
  - `hub_model_id="username/model-name"` ⚠️ MANDATORY
222
  - `report_to=["trackio"]` (for monitoring)
223
  - `output_dir="./output"`
224
  - `num_train_epochs`, `per_device_train_batch_size`, `learning_rate`
225
  - `logging_steps`, `save_steps`
226
- - `trainer.train()` call
227
- - `trainer.push_to_hub()` at end ⚠️ MANDATORY
228
-
229
- **hf_jobs Launch Configuration:**
230
- - `script`: Path to sandbox file (e.g. "/app/train.py") or inline code
231
- - `dependencies`: ['transformers', 'trl', 'torch', 'datasets', 'trackio']
232
- - `hardware_flavor`: Based on model size:
233
- - 1-3B models: `t4-small` or `a10g-small`
234
- - 7-13B models: `a10g-large`
235
- - 30B+ models: `a100-large`
236
- - 70B+ models: `h100` or `h100x8`
237
- - `timeout`: ⚠️ CRITICAL β€” Small (2-4h), Medium (4-8h), Large (8-24h). NEVER default 30m for training.
 
 
 
 
 
 
 
 
238
 
239
  ### For Data Processing Tasks
240
 
241
- **Same pattern:** develop script in sandbox, test on subset, launch via hf_jobs.
 
 
 
 
 
242
  - Use `cpu-upgrade` or `cpu-performance` for most data tasks
243
  - Set timeout based on dataset size (1-4 hours typical)
244
 
@@ -339,21 +341,6 @@ system_prompt: |
339
  - ⚠️ Include HF_TOKEN for Hub operations
340
  - ⚠️ Storage is EPHEMERAL - must push_to_hub
341
 
342
- ## Sandbox (Interactive Development Environment)
343
-
344
- **sandbox_create:**
345
- - ⚠️ **Create a sandbox FIRST for any implementation task** β€” develop and test before launching jobs
346
- - Persistent remote Linux environment on HF Spaces
347
- - First call sandbox_create with hardware choice, then use bash/read/write/edit freely
348
- - Hardware: cpu-basic (free tier), cpu-upgrade (8vCPU/32GB), t4-small (16GB GPU), a10g-small (24GB GPU), a10g-large (24GB GPU + 46GB RAM), a100-large (80GB GPU)
349
- - `pip install` works out of the box β€” no special flags needed
350
- - Workflow: sandbox_create β†’ write script β†’ test β†’ fix β†’ hf_jobs(script="/app/script.py") to launch at scale
351
-
352
- **bash / read / write / edit:**
353
- - Available after sandbox_create β€” no additional approvals needed
354
- - Same semantics as local file/shell operations, but run on the remote sandbox
355
- - bash: run shell commands; read/write/edit: file operations
356
-
357
  **hf_private_repos:**
358
  - Store job outputs persistently in datasets with push_to_hub (jobs lose files after completion)
359
  - Upload logs, scripts, results that can't push_to_hub
 
186
  3. βœ… Determine optimal processing approach based on requirements
187
  4. βœ… Plan output format and destination
188
 
189
+ ## PHASE 3: IMPLEMENT (Execute with Researched Approaches)
190
+
191
+ ### For Training Tasks
192
+
193
+ ⚠️ **TRAINING REQUIREMENTS CHECKLIST:**
194
+
195
+ **Before Submission:**
196
+ - [ ] Researched current TRL documentation
197
+ - [ ] Found and verified base model
198
+ - [ ] Found dataset and VALIDATED columns and conversational format matches method
199
+ - [ ] Selected optimal model + dataset + hardware configuration
200
+ - [ ] Created plan with plan_tool
201
+ - [ ] Researched Trackio monitoring setup
202
+
203
+ **Training Script MUST Include:**
204
+ - [ ] Imports from researched documentation (current APIs)
205
+ - [ ] Trackio initialization with project/run_name/config
206
+ - [ ] Model and tokenizer loading
207
+ - [ ] Dataset loading with verified columns and conversational format
208
+ - [ ] Training config with ALL critical settings:
 
 
 
 
 
 
 
 
 
 
 
209
  - `push_to_hub=True` ⚠️ MANDATORY
210
  - `hub_model_id="username/model-name"` ⚠️ MANDATORY
211
  - `report_to=["trackio"]` (for monitoring)
212
  - `output_dir="./output"`
213
  - `num_train_epochs`, `per_device_train_batch_size`, `learning_rate`
214
  - `logging_steps`, `save_steps`
215
+ - `max_length` if needed (default 1024 usually fine)
216
+ - [ ] Trainer initialization with model, args, dataset, tokenizer
217
+ - [ ] `trainer.train()` call
218
+ - [ ] `trainer.push_to_hub()` at end ⚠️ MANDATORY
219
+ - [ ] `tracker.finish()` for Trackio
220
+
221
+ **Job Configuration MUST Include:**
222
+ - [ ] `operation`: "run" (for one-time) or "scheduled run" (for recurring)
223
+ - [ ] `script`: Training script with all above elements
224
+ - [ ] `dependencies`: ['transformers', 'trl', 'torch', 'datasets', 'trackio']
225
+ - [ ] `hardware_flavor`: Based on model size (see hf_jobs tool for detailed vCPU/RAM/GPU specs):
226
+ - 1-3B models: `t4-small` (4vCPU/15GB/GPU 16GB) for demos or `a10g-small` (4vCPU/14GB/GPU 24GB) for production
227
+ - 7-13B models: `a10g-large` (12vCPU/46GB/GPU 24GB)
228
+ - 30B+ models: `a100-large` (12vCPU/142GB/GPU 80GB)
229
+ - 70B+ models: `h100` (23vCPU/240GB/GPU 80GB) or `h100x8` for distributed
230
+ - [ ] `timeout`: ⚠️ CRITICAL - Set based on model/data size:
231
+ - Small models (1-3B): "2h" to "4h"
232
+ - Medium models (7-13B): "4h" to "8h"
233
+ - Large models (30B+): "8h" to "24h"
234
+ - **NEVER use default 30m for training!**
235
 
236
  ### For Data Processing Tasks
237
 
238
+ **Script Requirements:**
239
+ - Load dataset with `load_dataset`
240
+ - Process according to user requirements
241
+ - Push results with `push_to_hub()` or upload to `hf_private_repos`
242
+
243
+ **Job Configuration:**
244
  - Use `cpu-upgrade` or `cpu-performance` for most data tasks
245
  - Set timeout based on dataset size (1-4 hours typical)
246
 
 
341
  - ⚠️ Include HF_TOKEN for Hub operations
342
  - ⚠️ Storage is EPHEMERAL - must push_to_hub
343
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
  **hf_private_repos:**
345
  - Store job outputs persistently in datasets with push_to_hub (jobs lose files after completion)
346
  - Upload logs, scripts, results that can't push_to_hub
agent/prompts/system_prompt_v3.yaml DELETED
@@ -1,118 +0,0 @@
1
- system_prompt: |
2
- You are Hugging Face Agent, an ML engineering assistant with {{ num_tools }} tools for training, fine-tuning, data processing, inference, and evaluation on the Hugging Face ecosystem.
3
-
4
- _Current Time: **{{ current_date }} {{ current_time }} ({{ current_timezone }})**_
5
- {% if hf_user_info %}_Authenticated as: **{{ hf_user_info }}**_{% endif %}
6
-
7
- Your goal is to complete what the user requested with zero errors. You are fully autonomous β€” research, validate, implement, and deliver results without asking for unnecessary confirmation.
8
-
9
- # Your knowledge of HF libraries is outdated
10
-
11
- You do not know current APIs for TRL, Transformers, PEFT, Trackio, or other HF libraries. Your internal knowledge WILL produce wrong imports, wrong argument names, and wrong trainer configurations.
12
-
13
- Before writing any ML implementation code (training, fine-tuning, inference, data processing), ground yourself in current working code:
14
-
15
- github_find_examples β†’ github_read_file β†’ explore_hf_docs + fetch_hf_docs
16
-
17
- Skip research only for trivial non-code operations.
18
-
19
- # Mistakes you WILL make without research
20
-
21
- HALLUCINATED IMPORTS: You will import from modules that were renamed or removed. Example: old TRL trainer class names, deprecated Transformers APIs, wrong trackio parameter names (e.g. `run_name` instead of `name`). Fix: read a current example script first.
22
-
23
- WRONG TRAINER ARGUMENTS: You will pass configuration arguments that don't exist in current trainer versions. Fix: fetch the actual trainer/config docs via explore_hf_docs + fetch_hf_docs.
24
-
25
- WRONG DATASET FORMAT: You will assume column names without checking. Training fails with KeyError. Fix: call hf_inspect_dataset or hub_repo_details and verify columns match the training method.
26
-
27
- DEFAULT TIMEOUT KILLS JOBS: You will leave timeout at the default 30m for training jobs. Training takes hours. The job gets killed and all progress is lost. Fix: set timeout based on model size (minimum 2h for any training).
28
-
29
- LOST MODELS: You will forget push_to_hub=True and hub_model_id in training config. Job storage is ephemeral β€” the filesystem is deleted when the job ends. Without push_to_hub, the trained model is permanently lost.
30
-
31
- BATCH FAILURES: You will submit all ablation/batch jobs at once without testing that one works first. All will fail for the same bug. Fix: submit ONE job first, verify it completes successfully, then submit the rest.
32
-
33
- SILENT DATASET SUBSTITUTION: When a requested dataset fails to load, you will silently switch to a different one without telling the user. Fix: if the requested dataset isn't available, tell the user and ask what to do.
34
-
35
- HARDCODED UNAVAILABLE PACKAGES: You will forget to install necessary packages like 'flash-attn' for flash_attention_2 or other packages that aren't automatically installed in the job environment. Fix: install necessary packages before running the job.
36
-
37
- SCOPE-CHANGING FIXES: Avoid at all costs! When you hit an error (especially OOM), you will try "creative" workarounds that change what the user asked for and/or change the training task itself β€” switching full SFT to LoRA on OOM, reducing max_length (silently truncates training data and changes what the model learns), disabling monitoring instead of fixing it. Do not do this. Fix errors with the minimal change that preserves the user's original request and are grounded in research and examples. If the original approach genuinely cannot work, explain why and ask the user for input before changing methods, sequence length, training approach or any other part of the task.
38
-
39
- # When writing ML code
40
-
41
- Required sequence before any training/fine-tuning/inference script:
42
- 1. Find working examples: github_find_examples (discover) β†’ github_read_file (study)
43
- 2. Check documentation: explore_hf_docs + fetch_hf_docs for trainer configs and parameters
44
- 3. Validate dataset details: hf_inspect_dataset to confirm column names and format.
45
- 4. Validate model details: hub_repo_details to confirm model exists, it's the correct architecture/size/tokenizer etc.
46
-
47
- Dataset format requirements by training method:
48
- SFT: "messages", "text", or "prompt"/"completion"
49
- DPO: "prompt", "chosen", "rejected"
50
- GRPO: "prompt"
51
-
52
- # When submitting a training job
53
-
54
- Before calling hf_jobs, output a pre-flight check:
55
- - Reference implementation: [which example you based this on]
56
- - Dataset format verified: [columns confirmed via hf_inspect_dataset/hub_repo_details]
57
- - push_to_hub=True and hub_model_id set
58
- - timeout: [value] (based on: [model size] on [hardware])
59
- - Trackio monitoring included and working
60
-
61
- If you cannot fill in all items, stop and complete the missing steps first.
62
-
63
- For batch/ablation jobs: submit ONE job first. Check logs to confirm it starts training successfully. Only then submit the remaining jobs. Never submit all at once.
64
-
65
- Hardware sizing:
66
- 1-3B params: a10g-largex2
67
- 7-13B params: a100-large
68
- 30B+ params: l40sx4 or a100x4
69
- 70B+ params: a100x8
70
- Note: a10g-small and a10g-large have the SAME 24GB GPU memory. The difference is CPU/RAM only.
71
-
72
- # Sandbox-first development
73
-
74
- For non-trivial scripts, develop and test in a sandbox before launching via hf_jobs:
75
- sandbox_create β†’ install deps β†’ write script β†’ test with small run β†’ fix errors β†’ launch via hf_jobs at scale
76
-
77
- Use GPU sandbox (t4-small minimum) when testing code that uses CUDA, bf16, or model loading. CPU sandboxes cannot test GPU code paths.
78
-
79
-
80
- # When a task has 3+ steps
81
-
82
- Use plan_tool to track progress. One task in_progress at a time. Mark completed immediately after finishing. Update frequently to show the user what you're doing.
83
-
84
- # Error recovery
85
-
86
- When something fails:
87
- - Diagnose the actual error. Read the full error message and logs.
88
- - Do not retry the exact same thing. Identify what needs to change.
89
- - If an API/import error: check documentation for the correct API.
90
- - If an OOM error: (1) reduce per_device_train_batch_size and increase gradient_accumulation_steps proportionally to keep effective batch size identical, (2) enable gradient_checkpointing=True, (3) upgrade to larger GPU (a10gx4→a100→a100x4→a100x8). Do NOT switch training methods (e.g. SFT→LoRA) or reduce max_length — those change what the user gets. If OOM happens in sandbox, create a new sandbox with larger GPU hardware.
91
- - Never change the user's requested approach (training method, dataset, model, sequence length) without explicit approval.
92
- - If a tool call fails repeatedly for the same reason: stop and try a different approach.
93
- - Never silently substitute resources (datasets, models) β€” tell the user if something isn't available.
94
-
95
- # Task completion
96
-
97
- Before ending your turn, verify:
98
- - Did you actually DO what the user asked, not just explain what you would do?
99
- - If something failed: did you diagnose and fix it, or at minimum explain what went wrong and ask for user input?
100
- - For training jobs: did you include a working Trackio dashboard URL?
101
-
102
- Do not stop after describing what you plan to do. Continue calling tools until the task is verifiably done.
103
- Do not mark plan tasks as completed if they failed or are only partially done.
104
-
105
- # Communication
106
-
107
- - Be concise and direct. No filler, no restating what the user said.
108
- - One-word answers when appropriate for simple questions.
109
- - Always include direct Hub URLs when referencing models, datasets, Spaces, or jobs.
110
- - For errors: state what went wrong, why, and what you're doing to fix it.
111
- - Do not over-explain or present elaborate option menus for simple tasks. When the user's intent is clear, act on it. Present options only when there's genuine ambiguity.
112
-
113
- # Tool usage
114
-
115
- - Execute multiple independent tool calls in parallel when possible.
116
- - HF_TOKEN is automatically available in job secrets β€” no need to include it extra.
117
- - For training monitoring: include Trackio in the script and provide the dashboard URL.
118
- - For private/gated datasets: HF_TOKEN is needed β€” it's auto-loaded into job secrets.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
agent/tools/dataset_tools.py CHANGED
@@ -388,15 +388,22 @@ def _format_parquet_files(data: dict, max_rows: int = 10) -> str | None:
388
  HF_INSPECT_DATASET_TOOL_SPEC = {
389
  "name": "hf_inspect_dataset",
390
  "description": (
391
- "Inspect a HF dataset in one call: status, configs/splits, schema, sample rows, parquet info.\n\n"
392
- "REQUIRED before any training job to verify dataset format matches training method:\n"
393
- " SFT: needs 'messages', 'text', or 'prompt'/'completion'\n"
394
- " DPO: needs 'prompt', 'chosen', 'rejected'\n"
395
- " GRPO: needs 'prompt'\n"
396
- "All datasets used for training have to be in conversational ChatML format to be compatible with HF libraries.'\n"
397
- "Training will fail with KeyError if columns don't match.\n\n"
398
- "Also use to get example datapoints, understand column names, data types, and available splits before writing any data loading code. "
399
- "Supports private/gated datasets when HF_TOKEN is set."
 
 
 
 
 
 
 
400
  ),
401
  "parameters": {
402
  "type": "object",
 
388
  HF_INSPECT_DATASET_TOOL_SPEC = {
389
  "name": "hf_inspect_dataset",
390
  "description": (
391
+ "Inspect a Hugging Face dataset comprehensively in one call.\n\n"
392
+ "## What you get\n"
393
+ "- Status check (validates dataset works without errors)\n"
394
+ "- All configs and splits (row counts/shares may be '?' when metadata is missing)\n"
395
+ "- Column names and types (schema)\n"
396
+ "- Sample rows to understand data format\n"
397
+ "- Parquet file structure and sizes\n\n"
398
+ "## CRITICAL\n"
399
+ "**Always inspect datasets before writing training code** to understand:\n"
400
+ "- Column names for your dataloader\n"
401
+ "- Data types and format\n"
402
+ "- Available splits (train/test/validation)\n\n"
403
+ "Supports private/gated datasets when HF_TOKEN is set.\n\n"
404
+ "## Examples\n"
405
+ '{"dataset": "stanfordnlp/imdb"}\n'
406
+ '{"dataset": "nyu-mll/glue", "config": "mrpc", "sample_rows": 5}\n'
407
  ),
408
  "parameters": {
409
  "type": "object",
agent/tools/docs_tools.py CHANGED
@@ -845,12 +845,17 @@ DOC_ENDPOINTS = [
845
  EXPLORE_HF_DOCS_TOOL_SPEC = {
846
  "name": "explore_hf_docs",
847
  "description": (
848
- "Browse HF documentation structure β€” discover all available documentation with 200-char previews.\n\n"
849
- "Use this to find relevant documentation and/or examples with detailed parameter docs and API reference. "
850
- "To be used together with github_find_examples and github_read_file to find working examples and documentation.\n\n"
851
- "Pattern: explore_hf_docs (find relevant pages) β†’ fetch_hf_docs (get full content).\n\n"
852
- "For training tasks: fetch the trainer config docs (SFTConfig, DPOConfig, GRPOConfig) to verify parameter names. "
853
- "Returns top 20 results by default; set max_results (max 50) to adjust."
 
 
 
 
 
854
  ),
855
  "parameters": {
856
  "type": "object",
@@ -923,10 +928,16 @@ EXPLORE_HF_DOCS_TOOL_SPEC = {
923
  HF_DOCS_FETCH_TOOL_SPEC = {
924
  "name": "fetch_hf_docs",
925
  "description": (
926
- "Fetch full markdown content of an HF documentation page. Use after explore_hf_docs.\n\n"
927
- "Critical for finding documentation e.g. current trainer configuration parameters (SFTConfig, DPOConfig, etc.) "
928
- "Use for researching solutions and before writing training scripts. Your internal knowledge is outdated.\n\n"
929
- "Provide the full URL from explore_hf_docs results. The .md extension is added automatically."
 
 
 
 
 
 
930
  ),
931
  "parameters": {
932
  "type": "object",
 
845
  EXPLORE_HF_DOCS_TOOL_SPEC = {
846
  "name": "explore_hf_docs",
847
  "description": (
848
+ "Explore Hugging Face documentation structure and discover available pages with 200-character previews. "
849
+ "⚠️ MANDATORY: ALWAYS use this BEFORE implementing any ML task (training, fine-tuning, data processing, inference). "
850
+ "Your training data may be outdated - current documentation is the source of truth. "
851
+ "**Use when:** (1) Starting any implementation task, (2) User asks 'how to' questions, "
852
+ "(3) Before writing training/processing code, (4) Researching library capabilities, "
853
+ "(5) Verifying API syntax and parameters. "
854
+ "**Pattern:** explore (discover structure) β†’ fetch_hf_docs (get details) β†’ implement with researched approach. "
855
+ "Returns: Sidebar navigation with titles, URLs, and glimpses of all pages in the selected documentation. "
856
+ "**Then:** Use fetch_hf_docs with specific URLs from results to get full content. "
857
+ "**Critical for reliability:** Never implement based on internal knowledge without checking current docs first - APIs change frequently."
858
+ " By default returns the top 20 results; set max_results (max 50) to adjust."
859
  ),
860
  "parameters": {
861
  "type": "object",
 
928
  HF_DOCS_FETCH_TOOL_SPEC = {
929
  "name": "fetch_hf_docs",
930
  "description": (
931
+ "Fetch full markdown content of a specific HF documentation page. "
932
+ "⚠️ CRITICAL: Use this after explore_hf_docs to get detailed implementation guidance. "
933
+ "**Use when:** (1) Found relevant page in explore_hf_docs results, (2) Need complete API documentation, "
934
+ "(3) Need training method details (SFT/DPO/GRPO), (4) Need configuration examples, "
935
+ "(5) Need parameter descriptions and usage patterns. "
936
+ "**Pattern:** explore_hf_docs (find relevant page) β†’ fetch_hf_docs (get full content) β†’ implement using documented approach. "
937
+ "Provide full URL from explore_hf_docs results (e.g., 'https://huggingface.co/docs/trl/sft_trainer'). "
938
+ "Returns: Complete markdown documentation with examples, parameters, and usage patterns. "
939
+ "**For training tasks:** ALWAYS fetch trainer docs (SFTConfig, DPOConfig, etc.) before creating training scripts. "
940
+ "**Critical for reliability:** This ensures you use current APIs and best practices."
941
  ),
942
  "parameters": {
943
  "type": "object",
agent/tools/github_find_examples.py CHANGED
@@ -405,16 +405,55 @@ def find_examples(
405
  GITHUB_FIND_EXAMPLES_TOOL_SPEC = {
406
  "name": "github_find_examples",
407
  "description": (
408
- "Find working example scripts in GitHub repositories (from a list of predetermined directories e.g. examples/, scripts/, tutorials/, etc.). "
409
- "Uses fuzzy keyword matching.\n\n"
410
- "MANDATORY before writing any ML training, fine-tuning, or inference code. "
411
- "Your internal knowledge of library APIs is outdated β€” working examples show current API patterns.\n\n"
412
- "Sequence: github_find_examples β†’ github_read_file (study the example) β†’ implement based on what you found.\n\n"
413
- "Skip this only for: simple data queries, status checks, non-code tasks.\n\n"
414
- "Examples:\n"
415
- " {keyword: 'sft', repo: 'trl'} β†’ finds examples/scripts/sft.py\n"
416
- " {keyword: 'grpo', repo: 'trl'} β†’ finds GRPO training examples\n"
417
- " {repo: 'trl', max_results: 20} β†’ lists all available training method examples"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
  ),
419
  "parameters": {
420
  "type": "object",
 
405
  GITHUB_FIND_EXAMPLES_TOOL_SPEC = {
406
  "name": "github_find_examples",
407
  "description": (
408
+ "Discover working code examples, tutorials, scripts, and demos in GitHub repositories. "
409
+ "⚠️ CRITICAL: ALWAYS use this BEFORE implementing ML tasks - find working reference code first. "
410
+ "Your training data may be outdated; real repository examples show current best practices. "
411
+ "**Use when:** (1) Starting any ML implementation (training, inference, evaluation), "
412
+ "(2) User asks 'how to' questions about libraries, (3) Need reference implementations, "
413
+ "(4) Exploring library capabilities, (5) Before writing training/processing scripts. "
414
+ "**Pattern:** github_find_examples (discover) β†’ github_read_file (study code) β†’ implement with researched approach. "
415
+ "Returns: List of example files (scripts/notebooks/tutorials) with paths and URLs, sorted by relevance. "
416
+ "**Then:** Use github_read_file to read the actual implementation code. "
417
+ "**Critical for reliability:** Real examples prevent outdated API usage and show proven patterns. "
418
+ "## How it works\n\n"
419
+ "1. Fetches all example files (examples/, scripts/, tutorials/, demos/, notebooks/, etc.) from repository\n"
420
+ "2. If keyword provided, scores files against keyword using fuzzy matching\n"
421
+ "3. Returns best matches sorted by relevance and pattern priority\n"
422
+ "4. Provides copyable parameters for github_read_file tool\n\n"
423
+ "## Examples\n\n"
424
+ "<example>\n"
425
+ "// ML Workflow Step: Find GRPO training examples before implementation\n"
426
+ "// Task: Starting GRPO fine-tuning project, need reference implementation\n"
427
+ "{\n"
428
+ " keyword: 'grpo',\n"
429
+ " repo: 'trl',\n"
430
+ " org: 'huggingface'\n"
431
+ "}\n"
432
+ "// Returns: examples/scripts/grpo_agent.py, examples/scripts/grpo_vlm.py\n"
433
+ "// Next step: github_read_file to study working implementation\n"
434
+ "</example>\n\n"
435
+ "<example>\n"
436
+ "// ML Workflow Step: Discover all available training methods\n"
437
+ "// Task: Exploring TRL training options before choosing approach\n"
438
+ "{\n"
439
+ " repo: 'trl',\n"
440
+ " org: 'huggingface',\n"
441
+ " max_results: 20\n"
442
+ "}\n"
443
+ "// Lists: SFT, DPO, GRPO, PPO, reward modeling examples\n"
444
+ "// Helps user choose appropriate method\n"
445
+ "</example>\n\n"
446
+ "<example>\n"
447
+ "// ML Workflow Step: Find LoRA fine-tuning examples\n"
448
+ "// Task: Learning parameter-efficient fine-tuning patterns\n"
449
+ "{\n"
450
+ " keyword: 'lora',\n"
451
+ " repo: 'peft',\n"
452
+ " org: 'huggingface'\n"
453
+ "}\n"
454
+ "// Discovers LoRA configuration and training examples\n"
455
+ "// Shows current PEFT API usage patterns\n"
456
+ "</example>"
457
  ),
458
  "parameters": {
459
  "type": "object",
agent/tools/github_read_file.py CHANGED
@@ -250,13 +250,59 @@ def read_file(
250
  GITHUB_READ_FILE_TOOL_SPEC = {
251
  "name": "github_read_file",
252
  "description": (
253
- "Read file contents from GitHub repositories. Returns first 300 lines by default. "
254
- "Auto-converts Jupyter notebooks to markdown.\n\n"
255
- "Use AFTER github_find_examples to study the working implementation. "
256
- "The purpose is to learn current API patterns β€” imports, trainer configs, dataset handling β€” "
257
- "so your implementation uses correct, up-to-date code.\n\n"
 
 
 
 
258
  "Use line_start/line_end for large files (>300 lines) to read specific sections.\n\n"
259
- "When NOT to use: when you don't know the file path (use github_find_examples first)."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  ),
261
  "parameters": {
262
  "type": "object",
 
250
  GITHUB_READ_FILE_TOOL_SPEC = {
251
  "name": "github_read_file",
252
  "description": (
253
+ "Read file contents from GitHub repositories with line range support (default 300 lines). "
254
+ "⚠️ CRITICAL: Use AFTER github_find_examples to study working implementation code. "
255
+ "**Use when:** (1) Found example file via github_find_examples and need full code, "
256
+ "(2) Need to read trainer class implementation, (3) Study configuration patterns, "
257
+ "(4) Read specific code sections with line ranges, (5) Review code from specific branches/commits. "
258
+ "**Pattern:** github_find_examples (discover files) β†’ github_read_file (read code) β†’ implement using researched patterns. "
259
+ "Returns: File contents with line numbers, formatted for LLM reading. Auto-converts Jupyter notebooks to markdown. "
260
+ "**Then:** Implement using patterns and APIs from the example code. "
261
+ "**Critical for reliability:** Reading working examples prevents API errors and shows current best practices. "
262
  "Use line_start/line_end for large files (>300 lines) to read specific sections.\n\n"
263
+ "## When to use this tool\n\n"
264
+ "- When reading example code, trainer implementations, or configuration files\n"
265
+ "- After github_find_examples returns file paths you want to study\n"
266
+ "- When investigating specific code sections with line ranges\n"
267
+ "- When reading from specific branches, tags, or commits (use ref parameter)\n\n"
268
+ "## When NOT to use this tool\n\n"
269
+ "- When you don't know exact file path (use github_find_examples or github_search_code first)\n"
270
+ "- When searching for code patterns across repos (use github_search_code instead)\n\n"
271
+ "## Examples\n\n"
272
+ "<example>\n"
273
+ "// ML Workflow Step: Read GRPO trainer class after finding via github_find_examples\n"
274
+ "// Use case: Understand GRPOTrainer API, parameters, and methods\n"
275
+ "{\n"
276
+ " repo: 'huggingface/trl',\n"
277
+ " path: 'trl/trainer/grpo_trainer.py',\n"
278
+ " line_start: 1,\n"
279
+ " line_end: 200\n"
280
+ "}\n"
281
+ "// Read class definition and constructor to understand current API\n"
282
+ "// Shows: __init__ parameters, configuration, required arguments\n"
283
+ "</example>\n\n"
284
+ "<example>\n"
285
+ "// ML Workflow Step: Study complete training script from examples\n"
286
+ "// Use case: Learn end-to-end VLM fine-tuning workflow\n"
287
+ "{\n"
288
+ " repo: 'huggingface/trl',\n"
289
+ " path: 'examples/scripts/grpo_vlm.py'\n"
290
+ "}\n"
291
+ "// Returns first 300 lines - shows full training setup\n"
292
+ "// Use line_start/line_end if need to read more\n"
293
+ "</example>\n\n"
294
+ "<example>\n"
295
+ "// ML Workflow Step: Check TrainingArguments configuration patterns\n"
296
+ "// Use case: Learn how to structure training configs correctly\n"
297
+ "{\n"
298
+ " repo: 'huggingface/transformers',\n"
299
+ " path: 'examples/pytorch/language-modeling/run_clm.py',\n"
300
+ " line_start: 50,\n"
301
+ " line_end: 150\n"
302
+ "}\n"
303
+ "// Read argument parsing and config setup section\n"
304
+ "// Shows: current parameter names, default values, best practices\n"
305
+ "</example>"
306
  ),
307
  "parameters": {
308
  "type": "object",
agent/tools/jobs_tool.py CHANGED
@@ -9,7 +9,9 @@ import base64
9
  import http.client
10
  import os
11
  import re
12
- from typing import Any, Awaitable, Callable, Dict, Literal, Optional
 
 
13
 
14
  import httpx
15
  from huggingface_hub import HfApi
@@ -17,6 +19,8 @@ from huggingface_hub.utils import HfHubHTTPError
17
 
18
  from agent.core.session import Event
19
  from agent.tools.types import ToolResult
 
 
20
  from agent.tools.utilities import (
21
  format_job_details,
22
  format_jobs_table,
@@ -25,33 +29,38 @@ from agent.tools.utilities import (
25
  )
26
 
27
  # Hardware flavors
28
- CPU_FLAVORS = ["cpu-basic", "cpu-upgrade"]
29
  GPU_FLAVORS = [
 
 
30
  "t4-small",
31
  "t4-medium",
32
- "a10g-small",
33
- "a10g-large",
34
- "a10g-largex2",
35
- "a10g-largex4",
36
- "a100-large",
37
- "a100x4",
38
- "a100x8",
39
  "l4x1",
40
  "l4x4",
41
  "l40sx1",
42
  "l40sx4",
43
  "l40sx8",
 
 
 
 
 
 
 
44
  ]
45
 
46
  # Detailed specs for display (vCPU/RAM/GPU VRAM)
47
- CPU_FLAVORS_DESC = "cpu-basic(2vCPU/16GB), cpu-upgrade(8vCPU/32GB)"
 
 
48
  GPU_FLAVORS_DESC = (
49
  "t4-small(4vCPU/15GB/GPU 16GB), t4-medium(8vCPU/30GB/GPU 16GB), "
50
- "a10g-small(4vCPU/15GB/GPU 24GB), a10g-large(12vCPU/46GB/GPU 24GB), "
51
- "a10g-largex2(24vCPU/92GB/GPU 48GB), a10g-largex4(48vCPU/184GB/GPU 96GB), "
52
- "a100-large(12vCPU/142GB/GPU 80GB), a100x4(48vCPU/568GB/GPU 320GB), a100x8(96vCPU/1136GB/GPU 640GB), "
53
  "l4x1(8vCPU/30GB/GPU 24GB), l4x4(48vCPU/186GB/GPU 96GB), "
54
- "l40sx1(8vCPU/62GB/GPU 48GB), l40sx4(48vCPU/382GB/GPU 192GB), l40sx8(192vCPU/1534GB/GPU 384GB)"
 
 
 
 
55
  )
56
  SPECIALIZED_FLAVORS = ["inf2x6"]
57
  ALL_FLAVORS = CPU_FLAVORS + GPU_FLAVORS + SPECIALIZED_FLAVORS
@@ -113,23 +122,11 @@ def _filter_uv_install_output(logs: list[str]) -> list[str]:
113
  return logs
114
 
115
 
116
- _DEFAULT_ENV = {
117
- "HF_HUB_DISABLE_PROGRESS_BARS": "1",
118
- "TQDM_DISABLE": "1",
119
- "TRANSFORMERS_VERBOSITY": "warning",
120
- "HF_HUB_ENABLE_HF_TRANSFER": "1",
121
- }
122
-
123
-
124
- def _add_default_env(params: Dict[str, Any] | None) -> Dict[str, Any]:
125
- """Inject default env vars for clean, agent-friendly output."""
126
- result = dict(_DEFAULT_ENV)
127
- result.update(params or {}) # user-provided values override defaults
128
- return result
129
-
130
-
131
- def _add_environment_variables(params: Dict[str, Any] | None) -> Dict[str, Any]:
132
- token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") or ""
133
 
134
  # Start with user-provided env vars, then force-set token last
135
  result = dict(params or {})
@@ -285,10 +282,15 @@ class HfJobsTool:
285
  hf_token: Optional[str] = None,
286
  namespace: Optional[str] = None,
287
  log_callback: Optional[Callable[[str], Awaitable[None]]] = None,
 
 
288
  ):
 
289
  self.api = HfApi(token=hf_token)
290
  self.namespace = namespace
291
  self.log_callback = log_callback
 
 
292
 
293
  async def execute(self, params: Dict[str, Any]) -> ToolResult:
294
  """Execute the specified operation"""
@@ -384,9 +386,7 @@ class HfJobsTool:
384
  def log_producer():
385
  try:
386
  # fetch_job_logs is a blocking sync generator
387
- logs_gen = self.api.fetch_job_logs(
388
- job_id=job_id, namespace=namespace
389
- )
390
  for line in logs_gen:
391
  # Push line to queue thread-safely
392
  loop.call_soon_threadsafe(queue.put_nowait, line)
@@ -413,7 +413,7 @@ class HfJobsTool:
413
 
414
  # Process log line
415
  log_line = item
416
- print("\t" + log_line)
417
  if self.log_callback:
418
  await self.log_callback(log_line)
419
  all_logs.append(log_line)
@@ -441,19 +441,19 @@ class HfJobsTool:
441
 
442
  if current_status in terminal_states:
443
  # Job finished, no need to retry
444
- print(f"\tJob reached terminal state: {current_status}")
445
  break
446
 
447
  # Job still running, retry connection
448
- print(
449
- f"\tConnection interrupted ({str(e)[:50]}...), reconnecting in {retry_delay}s..."
450
  )
451
  await asyncio.sleep(retry_delay)
452
  continue
453
 
454
  except (ConnectionError, TimeoutError, OSError):
455
  # Can't even check job status, wait and retry
456
- print(f"\tConnection error, retrying in {retry_delay}s...")
457
  await asyncio.sleep(retry_delay)
458
  continue
459
 
@@ -509,16 +509,30 @@ class HfJobsTool:
509
  self.api.run_job,
510
  image=image,
511
  command=command,
512
- env=_add_default_env(args.get("env")),
513
- secrets=_add_environment_variables(args.get("secrets")),
514
  flavor=args.get("hardware_flavor", "cpu-basic"),
515
  timeout=args.get("timeout", "30m"),
516
  namespace=self.namespace,
517
  )
518
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
519
  # Wait for completion and stream logs
520
- print(f"{job_type} job started: {job.url}")
521
- print("Streaming logs...\n---\n")
522
 
523
  final_status, all_logs = await self._wait_for_job_completion(
524
  job_id=job.id,
@@ -727,8 +741,8 @@ To verify, call this tool with `{{"operation": "inspect", "job_id": "{job_id}"}}
727
  image=image,
728
  command=command,
729
  schedule=schedule,
730
- env=_add_default_env(args.get("env")),
731
- secrets=_add_environment_variables(args.get("secrets")),
732
  flavor=args.get("hardware_flavor", "cpu-basic"),
733
  timeout=args.get("timeout", "30m"),
734
  namespace=self.namespace,
@@ -887,31 +901,56 @@ To inspect, call this tool with `{{"operation": "scheduled inspect", "scheduled_
887
  HF_JOBS_TOOL_SPEC = {
888
  "name": "hf_jobs",
889
  "description": (
890
- "Execute Python scripts or Docker containers on HF cloud infrastructure.\n\n"
891
- "Two modes (mutually exclusive): Python mode (script + dependencies) or Docker mode (command + image). "
892
- "Provide exactly ONE of 'script' or 'command'.\n\n"
893
- "BEFORE submitting training/fine-tuning jobs:\n"
894
- "- You MUST have called github_find_examples + github_read_file to find a working reference implementation. "
895
- "Scripts based on your internal knowledge WILL use outdated APIs and fail.\n"
896
- "- You MUST have validated dataset format via hf_inspect_dataset or hub_repo_details.\n"
897
- "- Training config MUST include push_to_hub=True and hub_model_id. "
898
- "Job storage is EPHEMERAL β€” all files are deleted when the job ends. Without push_to_hub, trained models are lost permanently.\n"
899
- "- Include trackio monitoring and provide the dashboard URL to the user.\n\n"
900
- "BATCH/ABLATION JOBS: Submit ONE job first. Check logs to confirm it starts training successfully. "
901
- "Only then submit the remaining jobs. Never submit all at once β€” if there's a bug, all jobs fail.\n\n"
902
- "Operations: run, ps, logs, inspect, cancel, scheduled run/ps/inspect/delete/suspend/resume.\n\n"
903
- f"Hardware: CPU: {CPU_FLAVORS_DESC}. GPU: {GPU_FLAVORS_DESC}.\n"
904
- "Common picks: t4-small ($0.60/hr, 1-3B), a10g-large ($2/hr, 7-13B), a100-large ($4/hr, 30B+), h100 ($6/hr, 70B+). "
905
- "Note: a10g-small and a10g-large have the SAME 24GB GPU β€” the difference is CPU/RAM only.\n\n"
906
- "OOM RECOVERY: When a training job fails with CUDA OOM:\n"
907
- "1. Reduce per_device_train_batch_size and increase gradient_accumulation_steps proportionally (keep effective batch size identical)\n"
908
- "2. Enable gradient_checkpointing=True\n"
909
- "3. Upgrade to larger GPU (a10g→a100→h100)\n"
910
- "Do NOT switch training methods (e.g. full SFT to LoRA) or reduce max_length β€” those change what the user gets and require explicit approval.\n\n"
911
- "Examples:\n"
912
- "Training: {'operation': 'run', 'script': '/app/train.py', 'dependencies': ['transformers', 'trl', 'torch', 'datasets', 'trackio'], 'hardware_flavor': 'a100-large', 'timeout': '8h'}\n"
913
- "Monitor: {'operation': 'ps'}, {'operation': 'logs', 'job_id': 'xxx'}, {'operation': 'cancel', 'job_id': 'xxx'}"
914
- "Docker: {'operation': 'run', 'command': ['duckdb', '-c', 'select 1 + 2'], 'image': 'duckdb/duckdb', 'hardware_flavor': 'cpu-basic', 'timeout': '1h'}\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
915
  ),
916
  "parameters": {
917
  "type": "object",
@@ -931,65 +970,58 @@ HF_JOBS_TOOL_SPEC = {
931
  "scheduled suspend",
932
  "scheduled resume",
933
  ],
934
- "description": "Operation to execute.",
 
 
 
 
935
  },
 
936
  "script": {
937
  "type": "string",
938
- "description": (
939
- "Python code or sandbox file path (e.g. '/app/train.py') or URL. "
940
- "Triggers Python mode. For ML training: base this on a working example found via github_find_examples, not on internal knowledge. "
941
- "Mutually exclusive with 'command'."
942
- ),
943
  },
944
  "dependencies": {
945
  "type": "array",
946
  "items": {"type": "string"},
947
- "description": (
948
- "Pip packages to install. Include ALL required packages. "
949
- "Common training set: ['transformers', 'trl', 'torch', 'datasets', 'trackio', 'accelerate']. "
950
- "Only used with 'script'."
951
- ),
952
  },
 
953
  "image": {
954
  "type": "string",
955
- "description": "Docker image. Optional β€” auto-selected if not provided. Use with 'command'.",
956
  },
957
  "command": {
958
  "type": "array",
959
  "items": {"type": "string"},
960
- "description": "Command to execute as list. Triggers Docker mode. Mutually exclusive with 'script'.",
961
  },
 
962
  "hardware_flavor": {
963
  "type": "string",
964
- "description": (
965
- "Hardware type. Sizing guide: 1-3B params β†’ t4-small/a10g-small, "
966
- "7-13B β†’ a10g-large, 30B+ β†’ a100-large, 70B+ β†’ h100/h100x8. "
967
- f"All options: CPU: {CPU_FLAVORS}. GPU: {GPU_FLAVORS}."
968
- ),
969
  },
970
  "timeout": {
971
  "type": "string",
972
- "description": (
973
- "Maximum job runtime. MUST be >2h for any training job β€” default 30m kills training mid-run. "
974
- "Guidelines: 1-3B models: 3-4h, 7-13B: 6-8h, 30B+: 12-24h. "
975
- "Use 30m-1h only for quick data processing or inference tasks. Default: '30m'."
976
- ),
977
  },
978
  "env": {
979
  "type": "object",
980
- "description": "Environment variables {'KEY': 'VALUE'}. HF_TOKEN is auto-included.",
981
  },
 
982
  "job_id": {
983
  "type": "string",
984
- "description": "Job ID. Required for: logs, inspect, cancel.",
985
  },
 
986
  "scheduled_job_id": {
987
  "type": "string",
988
- "description": "Scheduled job ID. Required for: scheduled inspect/delete/suspend/resume.",
989
  },
990
  "schedule": {
991
  "type": "string",
992
- "description": "Cron schedule or preset (@hourly, @daily, @weekly, @monthly). Required for: scheduled run.",
993
  },
994
  },
995
  "required": ["operation"],
@@ -998,7 +1030,7 @@ HF_JOBS_TOOL_SPEC = {
998
 
999
 
1000
  async def hf_jobs_handler(
1001
- arguments: Dict[str, Any], session: Any = None
1002
  ) -> tuple[str, bool]:
1003
  """Handler for agent tool router"""
1004
  try:
@@ -1009,36 +1041,20 @@ async def hf_jobs_handler(
1009
  Event(event_type="tool_log", data={"tool": "hf_jobs", "log": log})
1010
  )
1011
 
1012
- # If script is a sandbox file path, read it from the sandbox
1013
- script = arguments.get("script", "")
1014
- sandbox = getattr(session, "sandbox", None) if session else None
1015
- is_path = (
1016
- sandbox
1017
- and isinstance(script, str)
1018
- and script.strip() == script
1019
- and not any(c in script for c in "\r\n\0")
1020
- and (
1021
- script.startswith("/")
1022
- or script.startswith("./")
1023
- or script.startswith("../")
1024
- )
1025
  )
1026
- if is_path:
1027
- import shlex
1028
-
1029
- result = await asyncio.to_thread(sandbox.bash, f"cat {shlex.quote(script)}")
1030
- if not result.success:
1031
- return f"Failed to read {script} from sandbox: {result.error}", False
1032
- arguments = {**arguments, "script": result.output}
1033
-
1034
- # Get token and namespace from HF token
1035
- hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
1036
- namespace = HfApi(token=hf_token).whoami().get("name") if hf_token else None
1037
 
1038
  tool = HfJobsTool(
1039
  namespace=namespace,
1040
  hf_token=hf_token,
1041
  log_callback=log_callback if session else None,
 
 
1042
  )
1043
  result = await tool.execute(arguments)
1044
  return result["formatted"], not result.get("isError", False)
 
9
  import http.client
10
  import os
11
  import re
12
+ from typing import Any, Dict, Literal, Optional, Callable, Awaitable
13
+
14
+ import logging
15
 
16
  import httpx
17
  from huggingface_hub import HfApi
 
19
 
20
  from agent.core.session import Event
21
  from agent.tools.types import ToolResult
22
+
23
+ logger = logging.getLogger(__name__)
24
  from agent.tools.utilities import (
25
  format_job_details,
26
  format_jobs_table,
 
29
  )
30
 
31
  # Hardware flavors
32
+ CPU_FLAVORS = ["cpu-basic", "cpu-upgrade", "cpu-performance", "cpu-xl"]
33
  GPU_FLAVORS = [
34
+ "sprx8",
35
+ "zero-a10g",
36
  "t4-small",
37
  "t4-medium",
 
 
 
 
 
 
 
38
  "l4x1",
39
  "l4x4",
40
  "l40sx1",
41
  "l40sx4",
42
  "l40sx8",
43
+ "a10g-small",
44
+ "a10g-large",
45
+ "a10g-largex2",
46
+ "a10g-largex4",
47
+ "a100-large",
48
+ "h100",
49
+ "h100x8",
50
  ]
51
 
52
  # Detailed specs for display (vCPU/RAM/GPU VRAM)
53
+ CPU_FLAVORS_DESC = (
54
+ "cpu-basic(2vCPU/16GB), cpu-upgrade(8vCPU/32GB), cpu-performance, cpu-xl"
55
+ )
56
  GPU_FLAVORS_DESC = (
57
  "t4-small(4vCPU/15GB/GPU 16GB), t4-medium(8vCPU/30GB/GPU 16GB), "
 
 
 
58
  "l4x1(8vCPU/30GB/GPU 24GB), l4x4(48vCPU/186GB/GPU 96GB), "
59
+ "l40sx1(8vCPU/62GB/GPU 48GB), l40sx4(48vCPU/382GB/GPU 192GB), l40sx8(192vCPU/1534GB/GPU 384GB), "
60
+ "a10g-small(4vCPU/14GB/GPU 24GB), a10g-large(12vCPU/46GB/GPU 24GB), "
61
+ "a10g-largex2(24vCPU/92GB/GPU 48GB), a10g-largex4(48vCPU/184GB/GPU 96GB), "
62
+ "a100-large(12vCPU/142GB/GPU 80GB), h100(23vCPU/240GB/GPU 80GB), h100x8(184vCPU/1920GB/GPU 640GB), "
63
+ "zero-a10g(dynamic alloc)"
64
  )
65
  SPECIALIZED_FLAVORS = ["inf2x6"]
66
  ALL_FLAVORS = CPU_FLAVORS + GPU_FLAVORS + SPECIALIZED_FLAVORS
 
122
  return logs
123
 
124
 
125
+ def _add_environment_variables(
126
+ params: Dict[str, Any] | None, user_token: str | None = None
127
+ ) -> Dict[str, Any]:
128
+ # Prefer the authenticated user's OAuth token, fall back to global env var
129
+ token = user_token or os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") or ""
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  # Start with user-provided env vars, then force-set token last
132
  result = dict(params or {})
 
282
  hf_token: Optional[str] = None,
283
  namespace: Optional[str] = None,
284
  log_callback: Optional[Callable[[str], Awaitable[None]]] = None,
285
+ session: Any = None,
286
+ tool_call_id: Optional[str] = None,
287
  ):
288
+ self.hf_token = hf_token
289
  self.api = HfApi(token=hf_token)
290
  self.namespace = namespace
291
  self.log_callback = log_callback
292
+ self.session = session
293
+ self.tool_call_id = tool_call_id
294
 
295
  async def execute(self, params: Dict[str, Any]) -> ToolResult:
296
  """Execute the specified operation"""
 
386
  def log_producer():
387
  try:
388
  # fetch_job_logs is a blocking sync generator
389
+ logs_gen = self.api.fetch_job_logs(job_id=job_id, namespace=namespace)
 
 
390
  for line in logs_gen:
391
  # Push line to queue thread-safely
392
  loop.call_soon_threadsafe(queue.put_nowait, line)
 
413
 
414
  # Process log line
415
  log_line = item
416
+ logger.debug(log_line)
417
  if self.log_callback:
418
  await self.log_callback(log_line)
419
  all_logs.append(log_line)
 
441
 
442
  if current_status in terminal_states:
443
  # Job finished, no need to retry
444
+ logger.info(f"Job reached terminal state: {current_status}")
445
  break
446
 
447
  # Job still running, retry connection
448
+ logger.warning(
449
+ f"Connection interrupted ({str(e)[:50]}...), reconnecting in {retry_delay}s..."
450
  )
451
  await asyncio.sleep(retry_delay)
452
  continue
453
 
454
  except (ConnectionError, TimeoutError, OSError):
455
  # Can't even check job status, wait and retry
456
+ logger.warning(f"Connection error, retrying in {retry_delay}s...")
457
  await asyncio.sleep(retry_delay)
458
  continue
459
 
 
509
  self.api.run_job,
510
  image=image,
511
  command=command,
512
+ env=args.get("env"),
513
+ secrets=_add_environment_variables(args.get("secrets"), self.hf_token),
514
  flavor=args.get("hardware_flavor", "cpu-basic"),
515
  timeout=args.get("timeout", "30m"),
516
  namespace=self.namespace,
517
  )
518
 
519
+ # Send job URL immediately after job creation (before waiting for completion)
520
+ if self.session and self.tool_call_id:
521
+ await self.session.send_event(
522
+ Event(
523
+ event_type="tool_state_change",
524
+ data={
525
+ "tool_call_id": self.tool_call_id,
526
+ "tool": "hf_jobs",
527
+ "state": "running",
528
+ "jobUrl": job.url,
529
+ },
530
+ )
531
+ )
532
+
533
  # Wait for completion and stream logs
534
+ logger.info(f"{job_type} job started: {job.url}")
535
+ logger.info("Streaming logs...")
536
 
537
  final_status, all_logs = await self._wait_for_job_completion(
538
  job_id=job.id,
 
741
  image=image,
742
  command=command,
743
  schedule=schedule,
744
+ env=args.get("env"),
745
+ secrets=_add_environment_variables(args.get("secrets"), self.hf_token),
746
  flavor=args.get("hardware_flavor", "cpu-basic"),
747
  timeout=args.get("timeout", "30m"),
748
  namespace=self.namespace,
 
901
  HF_JOBS_TOOL_SPEC = {
902
  "name": "hf_jobs",
903
  "description": (
904
+ "Execute Python scripts or Docker containers on HF cloud infrastructure (CPUs/GPUs) in one of two modes. "
905
+ "\n\n"
906
+ "**Two Modes (mutually exclusive):**\n"
907
+ "1. Python mode: using 'script' arg (REQUIRED) + 'dependencies'\n"
908
+ "2. Docker mode: using 'command' arg (REQUIRED) + 'image'\n\n"
909
+ "🚨 **REQUIRED:** You MUST provide exactly ONE of: 'script' (Python code as string) OR 'command' (Docker command as array). "
910
+ "They are mutually exclusive - provide one or the other, never both, never neither. "
911
+ "Do NOT call with just {'operation': 'run'} - always include your code. Example: {'operation': 'run', 'script': 'import torch; print(torch.cuda.is_available())', 'dependencies': ['torch']} or {'operation': 'run', 'command': ['duckdb', '-c', 'select 1 + 2']', 'image': 'duckdb/duckdb'}\n\n"
912
+ "⚠️ CRITICAL for reliability: (1) Jobs run ASYNC - provide monitoring URL immediately, don't poll; "
913
+ "(2) Set timeout >30min (default too short - training needs 2-8h); "
914
+ "(3) HF_TOKEN auto-loaded to secrets for Hub ops (push_to_hub, private repos); "
915
+ "(4) Job storage EPHEMERAL - MUST push_to_hub() or ALL work is LOST. "
916
+ "**Use when:** User wants cloud compute, training models, data processing, batch inference, GPU workloads, scheduled tasks. "
917
+ "ALWAYS use this tool (βœ“), never bash 'hf jobs' commands (βœ—). Pass script content inline (βœ“), don't save to files unless requested (βœ—). "
918
+ "\n\n"
919
+ "**Operations:** run, ps, logs, inspect, cancel, scheduled run, scheduled ps, scheduled inspect, scheduled delete, scheduled suspend, scheduled resume. "
920
+ "**Available Hardware (vCPU/RAM/GPU):**\n"
921
+ f"β€’ CPU: {CPU_FLAVORS_DESC}\n"
922
+ f"β€’ GPU: {GPU_FLAVORS_DESC}\n"
923
+ " β—¦ Common: t4-small ($0.60/hr, demos/1-3B models), a10g-small ($1/hr), a10g-large ($2/hr, production 7-13B), a100-large ($4/hr, 30B+), h100 ($6/hr, 70B+)\n\n"
924
+ "**After Submission Ground Rules:**\n"
925
+ "βœ“ Return immediately with job ID and monitoring URL\n"
926
+ "βœ“ Provide expected completion time and cost estimate\n"
927
+ "βœ“ For training: Include Trackio dashboard URL\n"
928
+ "βœ“ Note user can check status later\n"
929
+ "βœ— DON'T poll logs automatically\n"
930
+ "βœ— DON'T wait for completion\n"
931
+ "βœ— DON'T check status unless user asks\n\n"
932
+ "**For Training Tasks:**\n"
933
+ "β€’ ALWAYS research TRL docs first: explore_hf_docs('trl') β†’ fetch_hf_docs(<trainer_url>)\n"
934
+ "β€’ ALWAYS validate dataset format with hub_repo_details (SFT needs messages/text, DPO needs chosen/rejected)\n"
935
+ "β€’ ALWAYS include Trackio monitoring in script (explore_hf_docs('trackio'))\n"
936
+ "β€’ ALWAYS enable push_to_hub=True in training config\n"
937
+ "β€’ Set timeout 2-8h for training (NOT default 30m)\n"
938
+ "β€’ Confirm model/dataset choices with user before submitting\n\n"
939
+ "**Examples:**\n\n"
940
+ "**Training - Fine-tune LLM:**\n"
941
+ "{'operation': 'run', 'script': '# Training script with TRL\\nfrom trl import SFTConfig, SFTTrainer\\nfrom transformers import AutoModelForCausalLM\\nmodel = AutoModelForCausalLM.from_pretrained(\"Qwen/Qwen3-4B\")\\n# ... researched implementation from docs ...\\ntrainer.train()\\ntrainer.push_to_hub(\"user-name/my-model\")', 'dependencies': ['transformers', 'trl', 'torch', 'datasets', 'trackio'], 'hardware_flavor': 'a10g-large', 'timeout': '4h'}\n\n"
942
+ "**Data Processing:**\n"
943
+ "{'operation': 'run', 'script': 'from datasets import load_dataset\\nds = load_dataset(\"data\")\\n# process...\\nds.push_to_hub(\"user/processed\")', 'dependencies': ['datasets', 'pandas'], 'hardware_flavor': 'cpu-upgrade', 'timeout': '2h'}\n\n"
944
+ "**Scheduled Daily Job:**\n"
945
+ "{'operation': 'scheduled run', 'schedule': '@daily', 'script': 'from datasets import Dataset\\nimport pandas as pd\\n# scrape/generate data\\ndf = pd.DataFrame(data)\\nds = Dataset.from_pandas(df)\\nds.push_to_hub(\"user-name/daily-dataset\")', 'dependencies': ['datasets', 'pandas'], 'hardware_flavor': 'cpu-basic'}\n\n"
946
+ "**Docker Mode:**\n"
947
+ "{'operation': 'run', 'image': 'pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime', 'command': ['python', 'train.py', '--epochs', '10'], 'hardware_flavor': 'a100-large'}\n\n"
948
+ "**Monitor Operations:**\n"
949
+ "{'operation': 'ps'} - List all jobs\n"
950
+ "{'operation': 'logs', 'job_id': 'xxx'} - Stream logs (only when user requests)\n"
951
+ "{'operation': 'inspect', 'job_id': 'xxx'} - Get job details\n"
952
+ "{'operation': 'cancel', 'job_id': 'xxx'} - Stop job\n\n"
953
+ "⚠️ CRITICAL: Files created during execution are DELETED when job finishes. MUST push_to_hub() all outputs (models, datasets, artifacts) in script. For logs/scripts, use hf_private_repos after completion."
954
  ),
955
  "parameters": {
956
  "type": "object",
 
970
  "scheduled suspend",
971
  "scheduled resume",
972
  ],
973
+ "description": (
974
+ "Operation to execute. Valid values: [run, ps, logs, inspect, cancel, "
975
+ "scheduled run, scheduled ps, scheduled inspect, scheduled delete, "
976
+ "scheduled suspend, scheduled resume]"
977
+ ),
978
  },
979
+ # Python/UV specific parameters
980
  "script": {
981
  "type": "string",
982
+ "description": "Python code to execute. Triggers Python mode (auto pip install). Use with 'run'/'scheduled run'. Mutually exclusive with 'command'.",
 
 
 
 
983
  },
984
  "dependencies": {
985
  "type": "array",
986
  "items": {"type": "string"},
987
+ "description": "Pip packages to install. Example: ['trl', 'torch', 'datasets', 'transformers']. Only used with 'script'.",
 
 
 
 
988
  },
989
+ # Docker specific parameters
990
  "image": {
991
  "type": "string",
992
+ "description": "Docker image. Example: 'pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime'. Use with 'run'/'scheduled run'. Optional (auto-selected if not provided).",
993
  },
994
  "command": {
995
  "type": "array",
996
  "items": {"type": "string"},
997
+ "description": "Command to execute as list. Example: ['python', 'train.py', '--epochs', '10']. Triggers Docker mode. Use with 'run'/'scheduled run'. Mutually exclusive with 'script'.",
998
  },
999
+ # Hardware and environment
1000
  "hardware_flavor": {
1001
  "type": "string",
1002
+ "description": f"Hardware type. Available CPU flavors: {CPU_FLAVORS}. Available GPU flavors: {GPU_FLAVORS}. Use with 'run'/'scheduled run'.",
 
 
 
 
1003
  },
1004
  "timeout": {
1005
  "type": "string",
1006
+ "description": "Max runtime. Examples: '30m', '2h', '4h'. Default: '30m'. Important for long training jobs. Use with 'run'/'scheduled run'.",
 
 
 
 
1007
  },
1008
  "env": {
1009
  "type": "object",
1010
+ "description": "Environment variables. Format: {'KEY': 'VALUE'}. HF_TOKEN is automatically included from your auth. Use with 'run'/'scheduled run'.",
1011
  },
1012
+ # Job management parameters
1013
  "job_id": {
1014
  "type": "string",
1015
+ "description": "Job ID to operate on. Required for: 'logs', 'inspect', 'cancel'.",
1016
  },
1017
+ # Scheduled job parameters
1018
  "scheduled_job_id": {
1019
  "type": "string",
1020
+ "description": "Scheduled job ID. Required for: 'scheduled inspect', 'scheduled delete', 'scheduled suspend', 'scheduled resume'.",
1021
  },
1022
  "schedule": {
1023
  "type": "string",
1024
+ "description": "Schedule for recurring job. Presets: '@hourly', '@daily', '@weekly', '@monthly'. Cron: '0 9 * * 1' (Mon 9am). Required for: 'scheduled run'.",
1025
  },
1026
  },
1027
  "required": ["operation"],
 
1030
 
1031
 
1032
  async def hf_jobs_handler(
1033
+ arguments: Dict[str, Any], session: Any = None, tool_call_id: str | None = None
1034
  ) -> tuple[str, bool]:
1035
  """Handler for agent tool router"""
1036
  try:
 
1041
  Event(event_type="tool_log", data={"tool": "hf_jobs", "log": log})
1042
  )
1043
 
1044
+ # Prefer the authenticated user's OAuth token, fall back to global env
1045
+ hf_token = (
1046
+ (getattr(session, "hf_token", None) if session else None)
1047
+ or os.environ.get("HF_TOKEN")
1048
+ or os.environ.get("HUGGINGFACE_HUB_TOKEN")
 
 
 
 
 
 
 
 
1049
  )
1050
+ namespace = os.environ.get("HF_NAMESPACE") or (HfApi(token=hf_token).whoami().get("name") if hf_token else None)
 
 
 
 
 
 
 
 
 
 
1051
 
1052
  tool = HfJobsTool(
1053
  namespace=namespace,
1054
  hf_token=hf_token,
1055
  log_callback=log_callback if session else None,
1056
+ session=session,
1057
+ tool_call_id=tool_call_id,
1058
  )
1059
  result = await tool.execute(arguments)
1060
  return result["formatted"], not result.get("isError", False)
agent/tools/plan_tool.py CHANGED
@@ -85,11 +85,18 @@ def get_current_plan() -> List[Dict[str, str]]:
85
  PLAN_TOOL_SPEC = {
86
  "name": "plan_tool",
87
  "description": (
88
- "Track progress on multi-step tasks with a todo list (pending/in_progress/completed).\n\n"
89
- "Use for tasks with 3+ steps. Each call replaces the entire plan (send full list).\n\n"
90
- "Rules: exactly ONE task in_progress at a time. Mark completed immediately after finishing. "
91
- "Only mark completed when the task fully succeeded β€” keep in_progress if there are errors. "
92
- "Update frequently so the user sees progress."
 
 
 
 
 
 
 
93
  ),
94
  "parameters": {
95
  "type": "object",
 
85
  PLAN_TOOL_SPEC = {
86
  "name": "plan_tool",
87
  "description": (
88
+ "Manage task planning and progress tracking with todo list (pending/in_progress/completed statuses). "
89
+ "⚠️ CRITICAL: ALWAYS use for multi-step tasks (3+ steps) and MUST update frequently to show progress. "
90
+ "**Use when:** (1) User provides multiple tasks, (2) Complex workflows (training, evaluation, data processing), "
91
+ "(3) Tasks requiring multiple tool calls, (4) Need to communicate progress clearly to user, "
92
+ "(5) Breaking down ambiguous requests into concrete steps. "
93
+ "**Pattern:** Create plan at start β†’ Mark in_progress when starting task β†’ Mark completed immediately after finishing β†’ User sees clear progress. "
94
+ "Each call replaces entire plan (full list required). "
95
+ "**Critical for reliability:** Exactly ONE task in_progress at a time (not zero, not multiple). "
96
+ "Mark tasks completed IMMEDIATELY after finishing - don't batch completions. "
97
+ "**For long-running tasks:** Update plan after each major step to keep user informed. "
98
+ "**Only mark completed when:** Task fully accomplished, no errors, all requirements met. "
99
+ "Keep tasks pending if blocked/errors occur - create new task to resolve blockers."
100
  ),
101
  "parameters": {
102
  "type": "object",
agent/tools/sandbox_client.py DELETED
@@ -1,714 +0,0 @@
1
- #!/usr/bin/env python3
2
- # /// script
3
- # requires-python = ">=3.10"
4
- # dependencies = ["huggingface_hub>=0.20.0", "httpx>=0.27.0"]
5
- # ///
6
- """
7
- Sandbox Tools β€” Agent-native primitives for HF Space dev-mode sandboxes.
8
-
9
- Architecture:
10
- - Creates a sandbox by duplicating a template Space (runs sandbox_server.py)
11
- - Waits for it to come online
12
- - Communicates via HTTPS to the Space's API
13
- - Optionally deletes the Space when done
14
-
15
- Lifecycle:
16
- sb = Sandbox.create(owner="burtenshaw") # duplicate, wait, connect
17
- sb = Sandbox.create(owner="burtenshaw", # with options
18
- hardware="t4-small",
19
- private=True,
20
- sleep_time=3600)
21
- sb = Sandbox.connect("burtenshaw/my-sandbox-abc") # attach to existing
22
-
23
- sb.bash("uv run train.py")
24
- sb.read("/app/train.py")
25
- sb.edit("/app/train.py", old_str="lr=1e-3", new_str="lr=1e-4")
26
-
27
- sb.delete() # tear down when done
28
-
29
- # Or use as a context manager for automatic cleanup
30
- with Sandbox.create(owner="burtenshaw") as sb:
31
- sb.bash("python train.py")
32
- # Space deleted on exit
33
-
34
- Tools: bash, read, write, edit, upload
35
- """
36
-
37
- from __future__ import annotations
38
-
39
- import io
40
- import os
41
- import sys
42
- import time
43
- import uuid
44
- from dataclasses import dataclass, field
45
- from typing import Any
46
-
47
- import httpx
48
- from huggingface_hub import CommitOperationAdd, HfApi
49
-
50
- TEMPLATE_SPACE = "burtenshaw/sandbox"
51
- HARDWARE_OPTIONS = [
52
- "cpu-basic",
53
- "cpu-upgrade",
54
- "t4-small",
55
- "t4-medium",
56
- "a10g-small",
57
- "a10g-large",
58
- "a100-large",
59
- ]
60
- OUTPUT_LIMIT = 30000
61
- LINE_LIMIT = 2000
62
- DEFAULT_READ_LIMIT = 2000
63
- DEFAULT_TIMEOUT = 120
64
- MAX_TIMEOUT = 600
65
- WAIT_TIMEOUT = 300
66
- WAIT_INTERVAL = 5
67
- API_WAIT_TIMEOUT = 180
68
-
69
- _DOCKERFILE = """\
70
- FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim
71
-
72
- RUN apt-get update && \\
73
- apt-get install -y \\
74
- bash git git-lfs wget curl procps \\
75
- htop vim nano jq tmux \\
76
- build-essential && \\
77
- rm -rf /var/lib/apt/lists/*
78
-
79
- RUN uv pip install --system fastapi uvicorn python-multipart
80
-
81
- RUN useradd -m -u 1000 user
82
- USER user
83
-
84
- ENV HOME=/home/user \\
85
- PATH=/home/user/.local/bin:$PATH \\
86
- PIP_USER=1 \\
87
- HF_HUB_DISABLE_PROGRESS_BARS=1 \\
88
- TQDM_DISABLE=1 \\
89
- TRANSFORMERS_VERBOSITY=warning \\
90
- HF_HUB_ENABLE_HF_TRANSFER=1
91
-
92
- WORKDIR /app
93
- COPY --chown=user . /app
94
-
95
- EXPOSE 7860
96
-
97
- CMD ["python", "sandbox_server.py"]
98
- """
99
-
100
- _SANDBOX_SERVER = '''\
101
- """Minimal FastAPI server for sandbox operations."""
102
- import os, subprocess, pathlib
103
- from fastapi import FastAPI
104
- from pydantic import BaseModel
105
- from typing import Optional
106
- import uvicorn
107
-
108
- app = FastAPI()
109
-
110
- class BashReq(BaseModel):
111
- command: str
112
- work_dir: str = "/app"
113
- timeout: int = 120
114
-
115
- class ReadReq(BaseModel):
116
- path: str
117
- offset: Optional[int] = None
118
- limit: Optional[int] = 2000
119
-
120
- class WriteReq(BaseModel):
121
- path: str
122
- content: str
123
-
124
- class EditReq(BaseModel):
125
- path: str
126
- old_str: str
127
- new_str: str
128
- replace_all: bool = False
129
-
130
- class ExistsReq(BaseModel):
131
- path: str
132
-
133
- @app.get("/api/health")
134
- def health():
135
- return {"status": "ok"}
136
-
137
- @app.post("/api/bash")
138
- def bash(req: BashReq):
139
- try:
140
- r = subprocess.run(
141
- req.command, shell=True, capture_output=True, text=True,
142
- cwd=req.work_dir, timeout=req.timeout,
143
- )
144
- output = r.stdout + r.stderr
145
- if len(output) > 30000:
146
- output = output[:30000] + "\\n... (truncated)"
147
- return {"success": r.returncode == 0, "output": output, "error": "" if r.returncode == 0 else f"Exit code {r.returncode}"}
148
- except subprocess.TimeoutExpired:
149
- return {"success": False, "output": "", "error": f"Timeout after {req.timeout}s"}
150
- except Exception as e:
151
- return {"success": False, "output": "", "error": str(e)}
152
-
153
- @app.post("/api/read")
154
- def read(req: ReadReq):
155
- try:
156
- p = pathlib.Path(req.path)
157
- if not p.exists():
158
- return {"success": False, "output": "", "error": f"File not found: {req.path}"}
159
- if p.is_dir():
160
- return {"success": False, "output": "", "error": f"Is a directory: {req.path}"}
161
- lines = p.read_text().splitlines()
162
- start = (req.offset or 1) - 1
163
- end = start + (req.limit or len(lines))
164
- selected = lines[start:end]
165
- numbered = "\\n".join(f"{start + i + 1}\\t{line}" for i, line in enumerate(selected))
166
- return {"success": True, "output": numbered, "error": ""}
167
- except Exception as e:
168
- return {"success": False, "output": "", "error": str(e)}
169
-
170
- @app.post("/api/write")
171
- def write(req: WriteReq):
172
- try:
173
- p = pathlib.Path(req.path)
174
- p.parent.mkdir(parents=True, exist_ok=True)
175
- p.write_text(req.content)
176
- return {"success": True, "output": f"Wrote {len(req.content)} bytes to {req.path}", "error": ""}
177
- except Exception as e:
178
- return {"success": False, "output": "", "error": str(e)}
179
-
180
- @app.post("/api/edit")
181
- def edit(req: EditReq):
182
- try:
183
- p = pathlib.Path(req.path)
184
- if not p.exists():
185
- return {"success": False, "output": "", "error": f"File not found: {req.path}"}
186
- content = p.read_text()
187
- if req.old_str not in content:
188
- return {"success": False, "output": "", "error": f"old_str not found in {req.path}"}
189
- if not req.replace_all and content.count(req.old_str) > 1:
190
- return {"success": False, "output": "", "error": f"old_str appears {content.count(req.old_str)} times. Use replace_all=true or provide more context."}
191
- if req.replace_all:
192
- new_content = content.replace(req.old_str, req.new_str)
193
- else:
194
- new_content = content.replace(req.old_str, req.new_str, 1)
195
- p.write_text(new_content)
196
- return {"success": True, "output": f"Edited {req.path}", "error": ""}
197
- except Exception as e:
198
- return {"success": False, "output": "", "error": str(e)}
199
-
200
- @app.post("/api/exists")
201
- def exists(req: ExistsReq):
202
- return {"success": True, "output": str(pathlib.Path(req.path).exists()).lower(), "error": ""}
203
-
204
- if __name__ == "__main__":
205
- uvicorn.run(app, host="0.0.0.0", port=7860)
206
- '''
207
-
208
-
209
- @dataclass
210
- class ToolResult:
211
- success: bool
212
- output: str = ""
213
- error: str = ""
214
-
215
- def __str__(self):
216
- if self.success:
217
- return self.output or "(no output)"
218
- return f"ERROR: {self.error}"
219
-
220
- def to_dict(self) -> dict:
221
- return {"success": self.success, "output": self.output, "error": self.error}
222
-
223
-
224
- @dataclass
225
- class Sandbox:
226
- """
227
- A handle to an HF Space sandbox.
228
-
229
- Use Sandbox.create() to spin up a new one, or Sandbox.connect() to
230
- attach to an existing running Space.
231
- """
232
-
233
- space_id: str
234
- token: str | None = None
235
- work_dir: str = "/app"
236
- timeout: int = DEFAULT_TIMEOUT
237
- _owns_space: bool = field(default=False, repr=False)
238
- _base_url: str = field(init=False, repr=False)
239
- _client: httpx.Client = field(init=False, repr=False)
240
- _hf_api: HfApi = field(init=False, repr=False)
241
- _files_read: set = field(init=False, repr=False, default_factory=set)
242
-
243
- def __post_init__(self):
244
- self.token = self.token or os.environ.get("HF_TOKEN")
245
- slug = self.space_id.replace("/", "-")
246
- # Trailing slash is critical: httpx resolves relative paths against base_url.
247
- # Without it, client.get("health") resolves to /health instead of /api/health.
248
- self._base_url = f"https://{slug}.hf.space/api/"
249
- self._client = httpx.Client(
250
- base_url=self._base_url,
251
- headers={"Authorization": f"Bearer {self.token}"} if self.token else {},
252
- timeout=httpx.Timeout(MAX_TIMEOUT, connect=30),
253
- follow_redirects=True,
254
- )
255
- self._hf_api = HfApi(token=self.token)
256
-
257
- # ── Lifecycle ─────────────────────────────────────────────────
258
-
259
- @classmethod
260
- def create(
261
- cls,
262
- owner: str,
263
- *,
264
- name: str | None = None,
265
- template: str = TEMPLATE_SPACE,
266
- hardware: str = "cpu-basic",
267
- private: bool = False,
268
- sleep_time: int | None = None,
269
- token: str | None = None,
270
- wait_timeout: int = WAIT_TIMEOUT,
271
- ) -> Sandbox:
272
- """
273
- Create a new sandbox by duplicating the template Space.
274
-
275
- Generates a unique space name, duplicates the template, waits for it
276
- to come online, then returns a connected Sandbox.
277
-
278
- Args:
279
- owner: HF username or org (e.g. "burtenshaw").
280
- name: Base name for the space. Defaults to "sandbox".
281
- A unique suffix is always appended.
282
- template: Source Space to duplicate (default: burtenshaw/sandbox).
283
- hardware: Hardware tier (cpu-basic, t4-small, etc.).
284
- private: Whether the Space should be private.
285
- sleep_time: Auto-sleep after N seconds of inactivity.
286
- token: HF API token. Falls back to HF_TOKEN env var.
287
- wait_timeout: Max seconds to wait for Space to start (default: 300).
288
-
289
- Returns:
290
- A Sandbox instance connected to the running Space.
291
- """
292
- token = token or os.environ.get("HF_TOKEN")
293
- api = HfApi(token=token)
294
-
295
- base = name or "sandbox"
296
- suffix = uuid.uuid4().hex[:8]
297
- space_id = f"{owner}/{base}-{suffix}"
298
-
299
- print(f"Creating sandbox: {space_id} (from {template})...")
300
-
301
- kwargs = {
302
- "from_id": template,
303
- "to_id": space_id,
304
- "private": private,
305
- "hardware": hardware,
306
- }
307
- if sleep_time is not None:
308
- kwargs["sleep_time"] = sleep_time
309
-
310
- api.duplicate_space(**kwargs)
311
- print(f"Space created: https://huggingface.co/spaces/{space_id}")
312
-
313
- # Upload sandbox server and Dockerfile (triggers rebuild)
314
- cls._setup_server(space_id, api)
315
-
316
- # Wait for it to come online (rebuild + start)
317
- print(f"Waiting for Space to start (timeout: {wait_timeout}s)...")
318
- deadline = time.time() + wait_timeout
319
- while time.time() < deadline:
320
- runtime = api.get_space_runtime(space_id)
321
- if runtime.stage == "RUNNING":
322
- print(f"Space is running (hardware: {runtime.hardware})")
323
- break
324
- if runtime.stage in ("RUNTIME_ERROR", "BUILD_ERROR"):
325
- raise RuntimeError(
326
- f"Space failed to start: {runtime.stage}. "
327
- f"Check https://huggingface.co/spaces/{space_id}"
328
- )
329
- print(f" {runtime.stage}...")
330
- time.sleep(WAIT_INTERVAL)
331
- else:
332
- raise TimeoutError(
333
- f"Space did not start within {wait_timeout}s. "
334
- f"Check https://huggingface.co/spaces/{space_id}"
335
- )
336
-
337
- # Wait for the API server to be responsive (non-fatal)
338
- sb = cls(space_id=space_id, token=token, _owns_space=True)
339
- try:
340
- sb._wait_for_api(timeout=API_WAIT_TIMEOUT)
341
- except TimeoutError as e:
342
- print(
343
- f"Warning: API health check timed out ({e}), but Space is RUNNING. Continuing."
344
- )
345
- return sb
346
-
347
- @staticmethod
348
- def _setup_server(space_id: str, api: HfApi) -> None:
349
- """Upload embedded sandbox server + Dockerfile to the Space (single commit)."""
350
- print(f"Uploading sandbox server to {space_id}...")
351
- api.create_commit(
352
- repo_id=space_id,
353
- repo_type="space",
354
- operations=[
355
- CommitOperationAdd(
356
- path_in_repo="sandbox_server.py",
357
- path_or_fileobj=io.BytesIO(_SANDBOX_SERVER.encode()),
358
- ),
359
- CommitOperationAdd(
360
- path_in_repo="Dockerfile",
361
- path_or_fileobj=io.BytesIO(_DOCKERFILE.encode()),
362
- ),
363
- ],
364
- commit_message="Setup sandbox server",
365
- )
366
- print("Server files uploaded, rebuild triggered.")
367
-
368
- @classmethod
369
- def connect(cls, space_id: str, *, token: str | None = None) -> Sandbox:
370
- """
371
- Connect to an existing running Space.
372
-
373
- Does a health check to verify the Space is reachable.
374
- """
375
- sb = cls(space_id=space_id, token=token, _owns_space=False)
376
- sb._wait_for_api(timeout=60)
377
- return sb
378
-
379
- def _wait_for_api(self, timeout: int = API_WAIT_TIMEOUT):
380
- """Poll the health endpoint until the server responds."""
381
- deadline = time.time() + timeout
382
- last_err = None
383
- last_status = None
384
- while time.time() < deadline:
385
- try:
386
- resp = self._client.get("health", timeout=10)
387
- last_status = resp.status_code
388
- if resp.status_code == 200:
389
- print(f"API is responsive at {self._base_url}")
390
- return
391
- except Exception as e:
392
- last_err = e
393
- time.sleep(3)
394
- raise TimeoutError(
395
- f"Sandbox API at {self._base_url} not responding after {timeout}s. "
396
- f"Last status: {last_status}, last error: {last_err}"
397
- )
398
-
399
- def delete(self):
400
- """Delete the Space. Only works if this Sandbox created it."""
401
- if not self._owns_space:
402
- raise RuntimeError(
403
- f"This Sandbox did not create {self.space_id}. "
404
- f"Use self._hf_api.delete_repo() directly if you're sure."
405
- )
406
- print(f"Deleting sandbox: {self.space_id}...")
407
- self._hf_api.delete_repo(self.space_id, repo_type="space")
408
- self._client.close()
409
- print("Deleted.")
410
-
411
- def pause(self):
412
- """Pause the Space (stops billing, preserves state)."""
413
- self._hf_api.pause_space(self.space_id)
414
-
415
- def restart(self):
416
- """Restart the Space."""
417
- self._hf_api.restart_space(self.space_id)
418
- self._wait_for_api()
419
-
420
- @property
421
- def url(self) -> str:
422
- """Public URL of the Space."""
423
- return f"https://huggingface.co/spaces/{self.space_id}"
424
-
425
- @property
426
- def status(self) -> str:
427
- """Current Space stage (RUNNING, BUILDING, PAUSED, etc.)."""
428
- return self._hf_api.get_space_runtime(self.space_id).stage
429
-
430
- def __enter__(self) -> Sandbox:
431
- return self
432
-
433
- def __exit__(self, *exc):
434
- if self._owns_space:
435
- try:
436
- self.delete()
437
- except Exception as e:
438
- print(f"Warning: failed to delete sandbox: {e}", file=sys.stderr)
439
- self._client.close()
440
-
441
- # ── HTTP plumbing ─────────────────────────────────────────────
442
-
443
- def _call(
444
- self, endpoint: str, payload: dict, timeout: float | None = None
445
- ) -> ToolResult:
446
- # Strip leading slash for correct httpx base_url resolution
447
- endpoint = endpoint.lstrip("/")
448
- try:
449
- resp = self._client.post(
450
- endpoint,
451
- json=payload,
452
- timeout=timeout or self.timeout,
453
- )
454
- data = resp.json()
455
- if resp.status_code == 200:
456
- return ToolResult(
457
- success=data.get("success", True),
458
- output=data.get("output", ""),
459
- error=data.get("error", ""),
460
- )
461
- return ToolResult(
462
- success=False,
463
- error=data.get("error", f"HTTP {resp.status_code}"),
464
- )
465
- except httpx.TimeoutException:
466
- return ToolResult(
467
- success=False, error=f"Timeout after {timeout or self.timeout}s"
468
- )
469
- except httpx.ConnectError:
470
- return ToolResult(
471
- success=False,
472
- error=f"Cannot connect to sandbox. Is {self.space_id} running? Status: {self.status}",
473
- )
474
- except Exception as e:
475
- return ToolResult(success=False, error=str(e))
476
-
477
- # ── Tools ─────────────────────────────────────────────────────
478
-
479
- def bash(
480
- self,
481
- command: str,
482
- *,
483
- work_dir: str | None = None,
484
- timeout: int | None = None,
485
- description: str | None = None,
486
- ) -> ToolResult:
487
- return self._call(
488
- "bash",
489
- {
490
- "command": command,
491
- "work_dir": work_dir or self.work_dir,
492
- "timeout": min(timeout or self.timeout, MAX_TIMEOUT),
493
- },
494
- timeout=timeout,
495
- )
496
-
497
- def read(
498
- self, path: str, *, offset: int | None = None, limit: int | None = None
499
- ) -> ToolResult:
500
- self._files_read.add(path)
501
- return self._call(
502
- "read",
503
- {
504
- "path": path,
505
- "offset": offset,
506
- "limit": limit or (DEFAULT_READ_LIMIT if offset is None else None),
507
- },
508
- )
509
-
510
- def write(self, path: str, content: str) -> ToolResult:
511
- if path not in self._files_read:
512
- check = self._call("exists", {"path": path})
513
- if check.success and check.output == "true":
514
- return ToolResult(
515
- success=False,
516
- error=(
517
- f"File {path} exists but has not been read this session. "
518
- f"Read it first, or use sandbox_edit for targeted changes."
519
- ),
520
- )
521
- result = self._call("write", {"path": path, "content": content})
522
- if result.success:
523
- self._files_read.add(path)
524
- return result
525
-
526
- def edit(
527
- self, path: str, old_str: str, new_str: str, *, replace_all: bool = False
528
- ) -> ToolResult:
529
- if old_str == new_str:
530
- return ToolResult(success=False, error="old_str and new_str are identical.")
531
- if path not in self._files_read:
532
- return ToolResult(
533
- success=False,
534
- error=f"File {path} has not been read this session. Read it first.",
535
- )
536
- return self._call(
537
- "edit",
538
- {
539
- "path": path,
540
- "old_str": old_str,
541
- "new_str": new_str,
542
- "replace_all": replace_all,
543
- },
544
- )
545
-
546
- # ── Tool schemas & dispatch ───────────────────────────────────
547
-
548
- TOOLS = {
549
- "bash": {
550
- "description": (
551
- "Run a shell command in the remote sandbox and return stdout/stderr.\n"
552
- "\n"
553
- "Commands run in a shell at the working directory (default /app). "
554
- "Each invocation is independent β€” use files in /app to persist state.\n"
555
- "\n"
556
- "AVOID using bash for operations covered by specialized tools:\n"
557
- "- File reading: use read (not cat/head/tail)\n"
558
- "- File editing: use edit (not sed/awk)\n"
559
- "- File writing: use write (not echo/cat <<EOF)\n"
560
- "\n"
561
- "For long-running tasks, background them:\n"
562
- " nohup uv run train.py > /app/train.log 2>&1 &\n"
563
- "Then check with read on the log file.\n"
564
- "\n"
565
- "Chain dependent commands with &&. Independent commands should be "
566
- "separate bash calls (they can run in parallel).\n"
567
- "\n"
568
- "Timeout default 120s, max 600s."
569
- ),
570
- "parameters": {
571
- "type": "object",
572
- "required": ["command"],
573
- "additionalProperties": False,
574
- "properties": {
575
- "command": {
576
- "type": "string",
577
- "description": "The shell command to execute.",
578
- },
579
- "description": {
580
- "type": "string",
581
- "description": "Short description (5-10 words, active voice). E.g. 'Install dependencies', 'Run training script'.",
582
- },
583
- "work_dir": {
584
- "type": "string",
585
- "description": "Working directory (default: /app).",
586
- },
587
- "timeout": {
588
- "type": "integer",
589
- "description": "Timeout in seconds (default: 120, max: 600).",
590
- },
591
- },
592
- },
593
- },
594
- "read": {
595
- "description": (
596
- "Read file contents with line numbers (cat -n format).\n"
597
- "\n"
598
- "Returns the first 2000 lines by default. For large files, use offset/limit "
599
- "to read a specific range. Line numbers always match the original file.\n"
600
- "\n"
601
- "Lines longer than 2000 chars are truncated.\n"
602
- "Cannot read directories β€” use bash with 'ls' instead."
603
- ),
604
- "parameters": {
605
- "type": "object",
606
- "required": ["path"],
607
- "additionalProperties": False,
608
- "properties": {
609
- "path": {
610
- "type": "string",
611
- "description": "Absolute path to the file to read.",
612
- },
613
- "offset": {
614
- "type": "integer",
615
- "description": "Start from this line (1-based). Only if file is too large.",
616
- },
617
- "limit": {
618
- "type": "integer",
619
- "description": "Number of lines to read. Only if file is too large.",
620
- },
621
- },
622
- },
623
- },
624
- "write": {
625
- "description": (
626
- "Create or overwrite a file. Creates parent directories as needed.\n"
627
- "\n"
628
- "For existing files, you MUST read the file first (system enforced). "
629
- "Prefer edit for modifications."
630
- ),
631
- "parameters": {
632
- "type": "object",
633
- "required": ["path", "content"],
634
- "additionalProperties": False,
635
- "properties": {
636
- "path": {
637
- "type": "string",
638
- "description": "Absolute path to the file to write.",
639
- },
640
- "content": {
641
- "type": "string",
642
- "description": "Complete file content.",
643
- },
644
- },
645
- },
646
- },
647
- "edit": {
648
- "description": (
649
- "Targeted edit via exact string replacement.\n"
650
- "\n"
651
- "Rules:\n"
652
- "- old_str must appear EXACTLY once (unless replace_all is true).\n"
653
- "- Include enough context in old_str for uniqueness.\n"
654
- "- old_str and new_str must differ.\n"
655
- "- Preserve indentation exactly.\n"
656
- "- To delete code, set new_str to empty string.\n"
657
- "- File MUST have been read this session (system enforced).\n"
658
- "- Do NOT include line number prefixes in old_str/new_str.\n"
659
- "\n"
660
- "Use replace_all=true for batch operations like variable renaming."
661
- ),
662
- "parameters": {
663
- "type": "object",
664
- "required": ["path", "old_str", "new_str"],
665
- "additionalProperties": False,
666
- "properties": {
667
- "path": {
668
- "type": "string",
669
- "description": "Absolute path to the file.",
670
- },
671
- "old_str": {
672
- "type": "string",
673
- "description": "Exact text to find (must differ from new_str).",
674
- },
675
- "new_str": {"type": "string", "description": "Replacement text."},
676
- "replace_all": {
677
- "type": "boolean",
678
- "description": "Replace all occurrences (default: false).",
679
- "default": False,
680
- },
681
- },
682
- },
683
- },
684
- }
685
-
686
- @classmethod
687
- def tool_definitions(cls) -> list[dict]:
688
- return [{"name": name, **spec} for name, spec in cls.TOOLS.items()]
689
-
690
- def call_tool(self, name: str, arguments: dict[str, Any]) -> ToolResult:
691
- dispatch = {
692
- "bash": lambda a: self.bash(
693
- a["command"],
694
- work_dir=a.get("work_dir"),
695
- timeout=a.get("timeout"),
696
- description=a.get("description"),
697
- ),
698
- "read": lambda a: self.read(
699
- a["path"],
700
- offset=a.get("offset"),
701
- limit=a.get("limit"),
702
- ),
703
- "write": lambda a: self.write(a["path"], a["content"]),
704
- "edit": lambda a: self.edit(
705
- a["path"],
706
- a["old_str"],
707
- a["new_str"],
708
- replace_all=a.get("replace_all", False),
709
- ),
710
- }
711
- fn = dispatch.get(name)
712
- if not fn:
713
- return ToolResult(success=False, error=f"Unknown tool: {name}")
714
- return fn(arguments)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
agent/tools/sandbox_tool.py DELETED
@@ -1,201 +0,0 @@
1
- """
2
- Sandbox tools β€” expose the Sandbox client as agent tools.
3
-
4
- 5 tools total:
5
- sandbox_create β€” explicit sandbox creation (requires approval)
6
- bash, read, write, edit β€” operations on the sandbox
7
-
8
- If any operation tool is called without an active sandbox,
9
- a cpu-basic sandbox is auto-created (no approval needed).
10
- """
11
-
12
- from __future__ import annotations
13
-
14
- import asyncio
15
- import os
16
- from typing import Any
17
-
18
- from huggingface_hub import HfApi, SpaceHardware
19
-
20
- from agent.core.session import Event
21
- from agent.tools.sandbox_client import Sandbox
22
-
23
- # ── Tool name mapping (short agent names β†’ Sandbox client names) ──────
24
-
25
-
26
- async def _ensure_sandbox(
27
- session: Any, hardware: str = "cpu-basic", **create_kwargs
28
- ) -> tuple[Sandbox | None, str | None]:
29
- """
30
- Ensure a sandbox exists on the session. Auto-creates with given hardware if needed.
31
-
32
- Returns:
33
- (sandbox, error_message) β€” one will be None.
34
- """
35
- if session and getattr(session, "sandbox", None):
36
- return session.sandbox, None
37
-
38
- if not session:
39
- return None, "No session available."
40
-
41
- token = os.environ.get("HF_TOKEN")
42
- if not token:
43
- return None, "HF_TOKEN environment variable not set. Cannot create sandbox."
44
-
45
- api = HfApi(token=token)
46
- user_info = api.whoami()
47
- owner = user_info.get("name", user_info.get("user", ""))
48
- if not owner:
49
- return None, "Could not determine HF username from token."
50
-
51
- await session.send_event(
52
- Event(
53
- event_type="tool_log",
54
- data={
55
- "tool": "sandbox",
56
- "log": f"Auto-creating sandbox for {owner} ({hardware})...",
57
- },
58
- )
59
- )
60
-
61
- kwargs = {"owner": owner, "hardware": hardware, "token": token, **create_kwargs}
62
- sb = await asyncio.to_thread(Sandbox.create, **kwargs)
63
- session.sandbox = sb
64
-
65
- await session.send_event(
66
- Event(
67
- event_type="tool_log",
68
- data={"tool": "sandbox", "log": f"Sandbox ready: {sb.space_id} ({sb.url})"},
69
- )
70
- )
71
-
72
- return sb, None
73
-
74
-
75
- # ── sandbox_create tool ──────────────────────────────────────────────
76
-
77
- SANDBOX_CREATE_TOOL_SPEC = {
78
- "name": "sandbox_create",
79
- "description": (
80
- "Create a persistent remote Linux environment for developing and testing scripts.\n\n"
81
- "Workflow: sandbox_create β†’ write script β†’ pip install β†’ test with small run β†’ fix errors β†’ hf_jobs at scale.\n"
82
- "The sandbox persists across tool calls within the session. pip install works out of the box.\n\n"
83
- "Use this when: you need to develop, test, and iterate on scripts before launching via hf_jobs. "
84
- "Especially for training scripts where you need to verify imports, test on a small subset, and fix errors interactively.\n\n"
85
- "Skip this when: the task is a simple one-shot operation (status check, resource search, quick data query), "
86
- "or the script is copied from a verified working example with minimal changes.\n\n"
87
- "For ML code that uses CUDA, bf16, or model loading: use GPU hardware (t4-small minimum). "
88
- "CPU sandboxes cannot run GPU code paths β€” your test will not catch GPU-related errors.\n\n"
89
- "Hardware: " + ", ".join([e.value for e in SpaceHardware]) + ".\n"
90
- ),
91
- "parameters": {
92
- "type": "object",
93
- "required": [],
94
- "additionalProperties": False,
95
- "properties": {
96
- "hardware": {
97
- "type": "string",
98
- "enum": [e.value for e in SpaceHardware],
99
- "description": "Hardware tier for the sandbox (default: cpu-basic)",
100
- },
101
- "private": {
102
- "type": "boolean",
103
- "description": "If true, create a private Space",
104
- },
105
- },
106
- },
107
- }
108
-
109
-
110
- async def sandbox_create_handler(
111
- args: dict[str, Any], session: Any = None
112
- ) -> tuple[str, bool]:
113
- """Handle sandbox_create tool calls."""
114
- # If sandbox already exists, return its info
115
- if session and getattr(session, "sandbox", None):
116
- sb = session.sandbox
117
- return (
118
- f"Sandbox already active: {sb.space_id}\n"
119
- f"URL: {sb.url}\n"
120
- f"Use bash/read/write/edit to interact with it."
121
- ), True
122
-
123
- hardware = args.get("hardware", "cpu-basic")
124
- create_kwargs = {}
125
- if "private" in args:
126
- create_kwargs["private"] = args["private"]
127
-
128
- try:
129
- sb, error = await _ensure_sandbox(session, hardware=hardware, **create_kwargs)
130
- except Exception as e:
131
- return f"Failed to create sandbox: {e}", False
132
-
133
- if error:
134
- return error, False
135
-
136
- return (
137
- f"Sandbox created: {sb.space_id}\n"
138
- f"URL: {sb.url}\n"
139
- f"Hardware: {hardware}\n"
140
- f"Use bash/read/write/edit to interact with it."
141
- ), True
142
-
143
-
144
- def _make_tool_handler(sandbox_tool_name: str):
145
- """Factory: create a handler for a sandbox operation tool."""
146
-
147
- async def handler(args: dict[str, Any], session: Any = None) -> tuple[str, bool]:
148
- # Auto-create sandbox if not present
149
- try:
150
- sb, error = await _ensure_sandbox(session)
151
- except Exception as e:
152
- return f"Failed to auto-create sandbox: {e}", False
153
-
154
- if error:
155
- return error, False
156
-
157
- try:
158
- result = await asyncio.to_thread(sb.call_tool, sandbox_tool_name, args)
159
- if result.success:
160
- return result.output or "(no output)", True
161
- else:
162
- error_msg = result.error or "Unknown error"
163
- output = result.output
164
- if output:
165
- return f"{output}\n\nERROR: {error_msg}", False
166
- return f"ERROR: {error_msg}", False
167
- except Exception as e:
168
- return f"Sandbox operation failed: {e}", False
169
-
170
- return handler
171
-
172
-
173
- def get_sandbox_tools():
174
- """Return all 5 sandbox ToolSpecs (sandbox_create + 4 operation tools)."""
175
- from agent.core.tools import ToolSpec
176
-
177
- tools = []
178
-
179
- # sandbox_create (explicit creation, requires approval)
180
- tools.append(
181
- ToolSpec(
182
- name=SANDBOX_CREATE_TOOL_SPEC["name"],
183
- description=SANDBOX_CREATE_TOOL_SPEC["description"],
184
- parameters=SANDBOX_CREATE_TOOL_SPEC["parameters"],
185
- handler=sandbox_create_handler,
186
- )
187
- )
188
-
189
- # Operation tools (auto-execute, no approval needed)
190
- for name in Sandbox.TOOLS.keys():
191
- spec = Sandbox.TOOLS[name]
192
- tools.append(
193
- ToolSpec(
194
- name=name,
195
- description=spec["description"],
196
- parameters=spec["parameters"],
197
- handler=_make_tool_handler(name),
198
- )
199
- )
200
-
201
- return tools
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
backend/dependencies.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Authentication dependencies for FastAPI routes.
2
+
3
+ Provides auth validation for both REST and WebSocket endpoints.
4
+ - In dev mode (OAUTH_CLIENT_ID not set): auth is bypassed, returns a default "dev" user.
5
+ - In production: validates Bearer tokens or cookies against HF OAuth.
6
+ """
7
+
8
+ import logging
9
+ import os
10
+ import time
11
+ from typing import Any
12
+
13
+ import httpx
14
+ from fastapi import HTTPException, Request, WebSocket, status
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co")
19
+ AUTH_ENABLED = bool(os.environ.get("OAUTH_CLIENT_ID", ""))
20
+
21
+ # Simple in-memory token cache: token -> (user_info, expiry_time)
22
+ _token_cache: dict[str, tuple[dict[str, Any], float]] = {}
23
+ TOKEN_CACHE_TTL = 300 # 5 minutes
24
+
25
+ DEV_USER: dict[str, Any] = {
26
+ "user_id": "dev",
27
+ "username": "dev",
28
+ "authenticated": True,
29
+ }
30
+
31
+
32
+ async def _validate_token(token: str) -> dict[str, Any] | None:
33
+ """Validate a token against HF OAuth userinfo endpoint.
34
+
35
+ Results are cached for TOKEN_CACHE_TTL seconds to avoid excessive API calls.
36
+ """
37
+ now = time.time()
38
+
39
+ # Check cache
40
+ if token in _token_cache:
41
+ user_info, expiry = _token_cache[token]
42
+ if now < expiry:
43
+ return user_info
44
+ del _token_cache[token]
45
+
46
+ # Validate against HF
47
+ async with httpx.AsyncClient(timeout=10.0) as client:
48
+ try:
49
+ response = await client.get(
50
+ f"{OPENID_PROVIDER_URL}/oauth/userinfo",
51
+ headers={"Authorization": f"Bearer {token}"},
52
+ )
53
+ if response.status_code != 200:
54
+ logger.debug("Token validation failed: status %d", response.status_code)
55
+ return None
56
+ user_info = response.json()
57
+ _token_cache[token] = (user_info, now + TOKEN_CACHE_TTL)
58
+ return user_info
59
+ except httpx.HTTPError as e:
60
+ logger.warning("Token validation error: %s", e)
61
+ return None
62
+
63
+
64
+ def _user_from_info(user_info: dict[str, Any]) -> dict[str, Any]:
65
+ """Build a normalized user dict from HF userinfo response."""
66
+ return {
67
+ "user_id": user_info.get("sub", user_info.get("preferred_username", "unknown")),
68
+ "username": user_info.get("preferred_username", "unknown"),
69
+ "name": user_info.get("name"),
70
+ "picture": user_info.get("picture"),
71
+ "authenticated": True,
72
+ }
73
+
74
+
75
+ async def _extract_user_from_token(token: str) -> dict[str, Any] | None:
76
+ """Validate a token and return a user dict, or None."""
77
+ user_info = await _validate_token(token)
78
+ if user_info:
79
+ return _user_from_info(user_info)
80
+ return None
81
+
82
+
83
+ async def get_current_user(request: Request) -> dict[str, Any]:
84
+ """FastAPI dependency: extract and validate the current user.
85
+
86
+ Checks (in order):
87
+ 1. Authorization: Bearer <token> header
88
+ 2. hf_access_token cookie
89
+
90
+ In dev mode (AUTH_ENABLED=False), returns a default dev user.
91
+ """
92
+ if not AUTH_ENABLED:
93
+ return DEV_USER
94
+
95
+ # Try Authorization header
96
+ auth_header = request.headers.get("Authorization", "")
97
+ if auth_header.startswith("Bearer "):
98
+ token = auth_header[7:]
99
+ user = await _extract_user_from_token(token)
100
+ if user:
101
+ return user
102
+
103
+ # Try cookie
104
+ token = request.cookies.get("hf_access_token")
105
+ if token:
106
+ user = await _extract_user_from_token(token)
107
+ if user:
108
+ return user
109
+
110
+ raise HTTPException(
111
+ status_code=status.HTTP_401_UNAUTHORIZED,
112
+ detail="Not authenticated. Please log in via /auth/login.",
113
+ headers={"WWW-Authenticate": "Bearer"},
114
+ )
115
+
116
+
117
+ async def get_ws_user(websocket: WebSocket) -> dict[str, Any] | None:
118
+ """Extract and validate user from WebSocket connection.
119
+
120
+ WebSocket doesn't support custom headers from browser, so we check:
121
+ 1. ?token= query parameter
122
+ 2. hf_access_token cookie (sent automatically for same-origin)
123
+
124
+ Returns user dict or None if not authenticated.
125
+ In dev mode, returns the default dev user.
126
+ """
127
+ if not AUTH_ENABLED:
128
+ return DEV_USER
129
+
130
+ # Try query param
131
+ token = websocket.query_params.get("token")
132
+ if token:
133
+ user = await _extract_user_from_token(token)
134
+ if user:
135
+ return user
136
+
137
+ # Try cookie (works for same-origin WebSocket)
138
+ token = websocket.cookies.get("hf_access_token")
139
+ if token:
140
+ user = await _extract_user_from_token(token)
141
+ if user:
142
+ return user
143
+
144
+ return None
backend/main.py CHANGED
@@ -5,6 +5,14 @@ import os
5
  from contextlib import asynccontextmanager
6
  from pathlib import Path
7
 
 
 
 
 
 
 
 
 
8
  from fastapi import FastAPI
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from fastapi.staticfiles import StaticFiles
 
5
  from contextlib import asynccontextmanager
6
  from pathlib import Path
7
 
8
+ from dotenv import load_dotenv
9
+
10
+ load_dotenv()
11
+
12
+ # Ensure HF_TOKEN is set β€” fall back to HF_ADMIN_TOKEN if available (HF Spaces)
13
+ if not os.environ.get("HF_TOKEN") and os.environ.get("HF_ADMIN_TOKEN"):
14
+ os.environ["HF_TOKEN"] = os.environ["HF_ADMIN_TOKEN"]
15
+
16
  from fastapi import FastAPI
17
  from fastapi.middleware.cors import CORSMiddleware
18
  from fastapi.staticfiles import StaticFiles
backend/models.py CHANGED
@@ -37,6 +37,7 @@ class ToolApproval(BaseModel):
37
  tool_call_id: str
38
  approved: bool
39
  feedback: str | None = None
 
40
 
41
 
42
  class ApprovalRequest(BaseModel):
@@ -67,6 +68,7 @@ class SessionInfo(BaseModel):
67
  created_at: str
68
  is_active: bool
69
  message_count: int
 
70
 
71
 
72
  class HealthResponse(BaseModel):
@@ -74,3 +76,13 @@ class HealthResponse(BaseModel):
74
 
75
  status: str = "ok"
76
  active_sessions: int = 0
 
 
 
 
 
 
 
 
 
 
 
37
  tool_call_id: str
38
  approved: bool
39
  feedback: str | None = None
40
+ edited_script: str | None = None
41
 
42
 
43
  class ApprovalRequest(BaseModel):
 
68
  created_at: str
69
  is_active: bool
70
  message_count: int
71
+ user_id: str = "dev"
72
 
73
 
74
  class HealthResponse(BaseModel):
 
76
 
77
  status: str = "ok"
78
  active_sessions: int = 0
79
+ max_sessions: int = 0
80
+
81
+
82
+ class LLMHealthResponse(BaseModel):
83
+ """LLM provider health check response."""
84
+
85
+ status: str # "ok" | "error"
86
+ model: str
87
+ error: str | None = None
88
+ error_type: str | None = None # "auth" | "credits" | "rate_limit" | "network" | "unknown"
backend/routes/agent.py CHANGED
@@ -1,58 +1,252 @@
1
- """Agent API routes - WebSocket and REST endpoints."""
2
 
3
- import logging
 
 
4
 
5
- from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
 
7
  from models import (
8
  ApprovalRequest,
9
  HealthResponse,
 
10
  SessionInfo,
11
  SessionResponse,
12
  SubmitRequest,
13
  )
14
- from session_manager import session_manager
15
  from websocket import manager as ws_manager
16
 
17
  logger = logging.getLogger(__name__)
18
 
19
  router = APIRouter(prefix="/api", tags=["agent"])
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  @router.get("/health", response_model=HealthResponse)
23
  async def health_check() -> HealthResponse:
24
  """Health check endpoint."""
25
  return HealthResponse(
26
- status="ok", active_sessions=session_manager.active_session_count
 
 
27
  )
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  @router.post("/session", response_model=SessionResponse)
31
- async def create_session() -> SessionResponse:
32
- """Create a new agent session."""
33
- session_id = await session_manager.create_session()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  return SessionResponse(session_id=session_id, ready=True)
35
 
36
 
37
  @router.get("/session/{session_id}", response_model=SessionInfo)
38
- async def get_session(session_id: str) -> SessionInfo:
39
- """Get session information."""
 
 
 
40
  info = session_manager.get_session_info(session_id)
41
- if not info:
42
- raise HTTPException(status_code=404, detail="Session not found")
43
  return SessionInfo(**info)
44
 
45
 
46
  @router.get("/sessions", response_model=list[SessionInfo])
47
- async def list_sessions() -> list[SessionInfo]:
48
- """List all sessions."""
49
- sessions = session_manager.list_sessions()
50
  return [SessionInfo(**s) for s in sessions]
51
 
52
 
53
  @router.delete("/session/{session_id}")
54
- async def delete_session(session_id: str) -> dict:
55
- """Delete a session."""
 
 
 
56
  success = await session_manager.delete_session(session_id)
57
  if not success:
58
  raise HTTPException(status_code=404, detail="Session not found")
@@ -60,8 +254,11 @@ async def delete_session(session_id: str) -> dict:
60
 
61
 
62
  @router.post("/submit")
63
- async def submit_input(request: SubmitRequest) -> dict:
64
- """Submit user input to a session."""
 
 
 
65
  success = await session_manager.submit_user_input(request.session_id, request.text)
66
  if not success:
67
  raise HTTPException(status_code=404, detail="Session not found or inactive")
@@ -69,13 +266,17 @@ async def submit_input(request: SubmitRequest) -> dict:
69
 
70
 
71
  @router.post("/approve")
72
- async def submit_approval(request: ApprovalRequest) -> dict:
73
- """Submit tool approvals to a session."""
 
 
 
74
  approvals = [
75
  {
76
  "tool_call_id": a.tool_call_id,
77
  "approved": a.approved,
78
  "feedback": a.feedback,
 
79
  }
80
  for a in request.approvals
81
  ]
@@ -86,8 +287,11 @@ async def submit_approval(request: ApprovalRequest) -> dict:
86
 
87
 
88
  @router.post("/interrupt/{session_id}")
89
- async def interrupt_session(session_id: str) -> dict:
 
 
90
  """Interrupt the current operation in a session."""
 
91
  success = await session_manager.interrupt(session_id)
92
  if not success:
93
  raise HTTPException(status_code=404, detail="Session not found or inactive")
@@ -95,8 +299,9 @@ async def interrupt_session(session_id: str) -> dict:
95
 
96
 
97
  @router.post("/undo/{session_id}")
98
- async def undo_session(session_id: str) -> dict:
99
  """Undo the last turn in a session."""
 
100
  success = await session_manager.undo(session_id)
101
  if not success:
102
  raise HTTPException(status_code=404, detail="Session not found or inactive")
@@ -104,8 +309,11 @@ async def undo_session(session_id: str) -> dict:
104
 
105
 
106
  @router.post("/compact/{session_id}")
107
- async def compact_session(session_id: str) -> dict:
 
 
108
  """Compact the context in a session."""
 
109
  success = await session_manager.compact(session_id)
110
  if not success:
111
  raise HTTPException(status_code=404, detail="Session not found or inactive")
@@ -113,8 +321,11 @@ async def compact_session(session_id: str) -> dict:
113
 
114
 
115
  @router.post("/shutdown/{session_id}")
116
- async def shutdown_session(session_id: str) -> dict:
 
 
117
  """Shutdown a session."""
 
118
  success = await session_manager.shutdown_session(session_id)
119
  if not success:
120
  raise HTTPException(status_code=404, detail="Session not found or inactive")
@@ -123,17 +334,61 @@ async def shutdown_session(session_id: str) -> dict:
123
 
124
  @router.websocket("/ws/{session_id}")
125
  async def websocket_endpoint(websocket: WebSocket, session_id: str) -> None:
126
- """WebSocket endpoint for real-time events."""
 
 
 
 
 
 
 
 
 
 
127
  logger.info(f"WebSocket connection request for session {session_id}")
 
 
 
 
 
 
 
 
 
 
 
128
  # Verify session exists
129
  info = session_manager.get_session_info(session_id)
130
  if not info:
131
- logger.warning(f"WebSocket connection rejected: Session {session_id} not found")
 
132
  await websocket.close(code=4004, reason="Session not found")
133
  return
134
 
 
 
 
 
 
 
 
 
 
135
  await ws_manager.connect(websocket, session_id)
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  try:
138
  while True:
139
  # Keep connection alive, handle ping/pong
 
1
+ """Agent API routes - WebSocket and REST endpoints.
2
 
3
+ All routes (except /health) require authentication via the get_current_user
4
+ dependency. In dev mode (no OAUTH_CLIENT_ID), auth is bypassed automatically.
5
+ """
6
 
7
+ import logging
8
+ import os
9
+ from typing import Any
10
+
11
+ from dependencies import get_current_user, get_ws_user
12
+ from fastapi import (
13
+ APIRouter,
14
+ Depends,
15
+ HTTPException,
16
+ Request,
17
+ WebSocket,
18
+ WebSocketDisconnect,
19
+ )
20
+ from litellm import acompletion
21
 
22
+ from agent.core.agent_loop import _resolve_hf_router_params
23
  from models import (
24
  ApprovalRequest,
25
  HealthResponse,
26
+ LLMHealthResponse,
27
  SessionInfo,
28
  SessionResponse,
29
  SubmitRequest,
30
  )
31
+ from session_manager import MAX_SESSIONS, SessionCapacityError, session_manager
32
  from websocket import manager as ws_manager
33
 
34
  logger = logging.getLogger(__name__)
35
 
36
  router = APIRouter(prefix="/api", tags=["agent"])
37
 
38
+ AVAILABLE_MODELS = [
39
+ {
40
+ "id": "huggingface/novita/minimax/minimax-m2.1",
41
+ "label": "MiniMax M2.1",
42
+ "provider": "huggingface",
43
+ "recommended": True,
44
+ },
45
+ {
46
+ "id": "anthropic/claude-opus-4-5-20251101",
47
+ "label": "Claude Opus 4.5",
48
+ "provider": "anthropic",
49
+ "recommended": True,
50
+ },
51
+ {
52
+ "id": "huggingface/novita/moonshotai/kimi-k2.5",
53
+ "label": "Kimi K2.5",
54
+ "provider": "huggingface",
55
+ },
56
+ {
57
+ "id": "huggingface/novita/zai-org/glm-5",
58
+ "label": "GLM 5",
59
+ "provider": "huggingface",
60
+ },
61
+ ]
62
+
63
+
64
+ def _check_session_access(session_id: str, user: dict[str, Any]) -> None:
65
+ """Verify the user has access to the given session. Raises 403 or 404."""
66
+ info = session_manager.get_session_info(session_id)
67
+ if not info:
68
+ raise HTTPException(status_code=404, detail="Session not found")
69
+ if not session_manager.verify_session_access(session_id, user["user_id"]):
70
+ raise HTTPException(status_code=403, detail="Access denied to this session")
71
+
72
 
73
  @router.get("/health", response_model=HealthResponse)
74
  async def health_check() -> HealthResponse:
75
  """Health check endpoint."""
76
  return HealthResponse(
77
+ status="ok",
78
+ active_sessions=session_manager.active_session_count,
79
+ max_sessions=MAX_SESSIONS,
80
  )
81
 
82
 
83
+ @router.get("/health/llm", response_model=LLMHealthResponse)
84
+ async def llm_health_check() -> LLMHealthResponse:
85
+ """Check if the LLM provider is reachable and the API key is valid.
86
+
87
+ Makes a minimal 1-token completion call. Catches common errors:
88
+ - 401 β†’ invalid API key
89
+ - 402/insufficient_quota β†’ out of credits
90
+ - 429 β†’ rate limited
91
+ - timeout / network β†’ provider unreachable
92
+ """
93
+ model = session_manager.config.model_name
94
+ try:
95
+ llm_params = _resolve_hf_router_params(model)
96
+ await acompletion(
97
+ messages=[{"role": "user", "content": "hi"}],
98
+ max_tokens=1,
99
+ timeout=10,
100
+ **llm_params,
101
+ )
102
+ return LLMHealthResponse(status="ok", model=model)
103
+ except Exception as e:
104
+ err_str = str(e).lower()
105
+ error_type = "unknown"
106
+
107
+ if (
108
+ "401" in err_str
109
+ or "auth" in err_str
110
+ or "invalid" in err_str
111
+ or "api key" in err_str
112
+ ):
113
+ error_type = "auth"
114
+ elif (
115
+ "402" in err_str
116
+ or "credit" in err_str
117
+ or "quota" in err_str
118
+ or "insufficient" in err_str
119
+ or "billing" in err_str
120
+ ):
121
+ error_type = "credits"
122
+ elif "429" in err_str or "rate" in err_str:
123
+ error_type = "rate_limit"
124
+ elif "timeout" in err_str or "connect" in err_str or "network" in err_str:
125
+ error_type = "network"
126
+
127
+ logger.warning(f"LLM health check failed ({error_type}): {e}")
128
+ return LLMHealthResponse(
129
+ status="error",
130
+ model=model,
131
+ error=str(e)[:500],
132
+ error_type=error_type,
133
+ )
134
+
135
+
136
+ @router.get("/config/model")
137
+ async def get_model() -> dict:
138
+ """Get current model and available models. No auth required."""
139
+ return {
140
+ "current": session_manager.config.model_name,
141
+ "available": AVAILABLE_MODELS,
142
+ }
143
+
144
+
145
+ @router.post("/config/model")
146
+ async def set_model(body: dict, user: dict = Depends(get_current_user)) -> dict:
147
+ """Set the LLM model. Applies to new conversations."""
148
+ model_id = body.get("model")
149
+ if not model_id:
150
+ raise HTTPException(status_code=400, detail="Missing 'model' field")
151
+ valid_ids = {m["id"] for m in AVAILABLE_MODELS}
152
+ if model_id not in valid_ids:
153
+ raise HTTPException(status_code=400, detail=f"Unknown model: {model_id}")
154
+ session_manager.config.model_name = model_id
155
+ logger.info(f"Model changed to {model_id} by {user.get('username', 'unknown')}")
156
+ return {"model": model_id}
157
+
158
+
159
+ @router.post("/title")
160
+ async def generate_title(
161
+ request: SubmitRequest, user: dict = Depends(get_current_user)
162
+ ) -> dict:
163
+ """Generate a short title for a chat session based on the first user message."""
164
+ model = session_manager.config.model_name
165
+ llm_params = _resolve_hf_router_params(model)
166
+ try:
167
+ response = await acompletion(
168
+ messages=[
169
+ {
170
+ "role": "system",
171
+ "content": (
172
+ "Generate a very short title (max 6 words) for a chat conversation "
173
+ "that starts with the following user message. "
174
+ "Reply with ONLY the title, no quotes, no punctuation at the end."
175
+ ),
176
+ },
177
+ {"role": "user", "content": request.text[:500]},
178
+ ],
179
+ max_tokens=20,
180
+ temperature=0.3,
181
+ timeout=8,
182
+ **llm_params,
183
+ )
184
+ title = response.choices[0].message.content.strip().strip('"').strip("'")
185
+ # Safety: cap at 50 chars
186
+ if len(title) > 50:
187
+ title = title[:50].rstrip() + "…"
188
+ return {"title": title}
189
+ except Exception as e:
190
+ logger.warning(f"Title generation failed: {e}")
191
+ # Fallback: truncate the message
192
+ fallback = request.text.strip()
193
+ title = fallback[:40].rstrip() + "…" if len(fallback) > 40 else fallback
194
+ return {"title": title}
195
+
196
+
197
  @router.post("/session", response_model=SessionResponse)
198
+ async def create_session(
199
+ request: Request, user: dict = Depends(get_current_user)
200
+ ) -> SessionResponse:
201
+ """Create a new agent session bound to the authenticated user.
202
+
203
+ The user's HF access token is extracted from the Authorization header
204
+ and stored in the session so that tools (e.g. hf_jobs) can act on
205
+ behalf of the user.
206
+
207
+ Returns 503 if the server or user has reached the session limit.
208
+ """
209
+ # Extract the user's HF token (Bearer header or HttpOnly cookie)
210
+ hf_token = None
211
+ auth_header = request.headers.get("Authorization", "")
212
+ if auth_header.startswith("Bearer "):
213
+ hf_token = auth_header[7:]
214
+ if not hf_token:
215
+ hf_token = request.cookies.get("hf_access_token")
216
+
217
+ try:
218
+ session_id = await session_manager.create_session(
219
+ user_id=user["user_id"], hf_token=hf_token
220
+ )
221
+ except SessionCapacityError as e:
222
+ raise HTTPException(status_code=503, detail=str(e))
223
+
224
  return SessionResponse(session_id=session_id, ready=True)
225
 
226
 
227
  @router.get("/session/{session_id}", response_model=SessionInfo)
228
+ async def get_session(
229
+ session_id: str, user: dict = Depends(get_current_user)
230
+ ) -> SessionInfo:
231
+ """Get session information. Only accessible by the session owner."""
232
+ _check_session_access(session_id, user)
233
  info = session_manager.get_session_info(session_id)
 
 
234
  return SessionInfo(**info)
235
 
236
 
237
  @router.get("/sessions", response_model=list[SessionInfo])
238
+ async def list_sessions(user: dict = Depends(get_current_user)) -> list[SessionInfo]:
239
+ """List sessions belonging to the authenticated user."""
240
+ sessions = session_manager.list_sessions(user_id=user["user_id"])
241
  return [SessionInfo(**s) for s in sessions]
242
 
243
 
244
  @router.delete("/session/{session_id}")
245
+ async def delete_session(
246
+ session_id: str, user: dict = Depends(get_current_user)
247
+ ) -> dict:
248
+ """Delete a session. Only accessible by the session owner."""
249
+ _check_session_access(session_id, user)
250
  success = await session_manager.delete_session(session_id)
251
  if not success:
252
  raise HTTPException(status_code=404, detail="Session not found")
 
254
 
255
 
256
  @router.post("/submit")
257
+ async def submit_input(
258
+ request: SubmitRequest, user: dict = Depends(get_current_user)
259
+ ) -> dict:
260
+ """Submit user input to a session. Only accessible by the session owner."""
261
+ _check_session_access(request.session_id, user)
262
  success = await session_manager.submit_user_input(request.session_id, request.text)
263
  if not success:
264
  raise HTTPException(status_code=404, detail="Session not found or inactive")
 
266
 
267
 
268
  @router.post("/approve")
269
+ async def submit_approval(
270
+ request: ApprovalRequest, user: dict = Depends(get_current_user)
271
+ ) -> dict:
272
+ """Submit tool approvals to a session. Only accessible by the session owner."""
273
+ _check_session_access(request.session_id, user)
274
  approvals = [
275
  {
276
  "tool_call_id": a.tool_call_id,
277
  "approved": a.approved,
278
  "feedback": a.feedback,
279
+ "edited_script": a.edited_script,
280
  }
281
  for a in request.approvals
282
  ]
 
287
 
288
 
289
  @router.post("/interrupt/{session_id}")
290
+ async def interrupt_session(
291
+ session_id: str, user: dict = Depends(get_current_user)
292
+ ) -> dict:
293
  """Interrupt the current operation in a session."""
294
+ _check_session_access(session_id, user)
295
  success = await session_manager.interrupt(session_id)
296
  if not success:
297
  raise HTTPException(status_code=404, detail="Session not found or inactive")
 
299
 
300
 
301
  @router.post("/undo/{session_id}")
302
+ async def undo_session(session_id: str, user: dict = Depends(get_current_user)) -> dict:
303
  """Undo the last turn in a session."""
304
+ _check_session_access(session_id, user)
305
  success = await session_manager.undo(session_id)
306
  if not success:
307
  raise HTTPException(status_code=404, detail="Session not found or inactive")
 
309
 
310
 
311
  @router.post("/compact/{session_id}")
312
+ async def compact_session(
313
+ session_id: str, user: dict = Depends(get_current_user)
314
+ ) -> dict:
315
  """Compact the context in a session."""
316
+ _check_session_access(session_id, user)
317
  success = await session_manager.compact(session_id)
318
  if not success:
319
  raise HTTPException(status_code=404, detail="Session not found or inactive")
 
321
 
322
 
323
  @router.post("/shutdown/{session_id}")
324
+ async def shutdown_session(
325
+ session_id: str, user: dict = Depends(get_current_user)
326
+ ) -> dict:
327
  """Shutdown a session."""
328
+ _check_session_access(session_id, user)
329
  success = await session_manager.shutdown_session(session_id)
330
  if not success:
331
  raise HTTPException(status_code=404, detail="Session not found or inactive")
 
334
 
335
  @router.websocket("/ws/{session_id}")
336
  async def websocket_endpoint(websocket: WebSocket, session_id: str) -> None:
337
+ """WebSocket endpoint for real-time events.
338
+
339
+ Authentication is done via:
340
+ - ?token= query parameter (for browsers that can't send WS headers)
341
+ - Cookie (automatic for same-origin connections)
342
+ - Dev mode bypass (when OAUTH_CLIENT_ID is not set)
343
+
344
+ NOTE: We must accept() before close() so the browser receives our custom
345
+ close codes (4001, 4003, 4004). If we close() before accept(), Starlette
346
+ sends HTTP 403 and the browser only sees code 1006 (abnormal closure).
347
+ """
348
  logger.info(f"WebSocket connection request for session {session_id}")
349
+
350
+ # Authenticate the WebSocket connection
351
+ user = await get_ws_user(websocket)
352
+ if not user:
353
+ logger.warning(
354
+ f"WebSocket rejected: authentication failed for session {session_id}"
355
+ )
356
+ await websocket.accept()
357
+ await websocket.close(code=4001, reason="Authentication required")
358
+ return
359
+
360
  # Verify session exists
361
  info = session_manager.get_session_info(session_id)
362
  if not info:
363
+ logger.warning(f"WebSocket rejected: session {session_id} not found")
364
+ await websocket.accept()
365
  await websocket.close(code=4004, reason="Session not found")
366
  return
367
 
368
+ # Verify user owns the session
369
+ if not session_manager.verify_session_access(session_id, user["user_id"]):
370
+ logger.warning(
371
+ f"WebSocket rejected: user {user['user_id']} denied access to session {session_id}"
372
+ )
373
+ await websocket.accept()
374
+ await websocket.close(code=4003, reason="Access denied")
375
+ return
376
+
377
  await ws_manager.connect(websocket, session_id)
378
 
379
+ # Send "ready" immediately on WebSocket connection so the frontend
380
+ # knows the session is alive. The original ready event from _run_session
381
+ # fires before the WS is connected and is always lost.
382
+ try:
383
+ await websocket.send_json(
384
+ {
385
+ "event_type": "ready",
386
+ "data": {"message": "Agent initialized"},
387
+ }
388
+ )
389
+ except Exception as e:
390
+ logger.error(f"Failed to send ready event for session {session_id}: {e}")
391
+
392
  try:
393
  while True:
394
  # Keep connection alive, handle ping/pong
backend/routes/auth.py CHANGED
@@ -1,11 +1,17 @@
1
- """Authentication routes for HF OAuth."""
 
 
 
 
2
 
3
  import os
4
  import secrets
 
5
  from urllib.parse import urlencode
6
 
7
  import httpx
8
- from fastapi import APIRouter, HTTPException, Request
 
9
  from fastapi.responses import RedirectResponse
10
 
11
  router = APIRouter(prefix="/auth", tags=["auth"])
@@ -15,10 +21,19 @@ OAUTH_CLIENT_ID = os.environ.get("OAUTH_CLIENT_ID", "")
15
  OAUTH_CLIENT_SECRET = os.environ.get("OAUTH_CLIENT_SECRET", "")
16
  OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co")
17
 
18
- # In-memory session store (replace with proper session management in production)
 
19
  oauth_states: dict[str, dict] = {}
20
 
21
 
 
 
 
 
 
 
 
 
22
  def get_redirect_uri(request: Request) -> str:
23
  """Get the OAuth callback redirect URI."""
24
  # In HF Spaces, use the SPACE_HOST if available
@@ -38,17 +53,26 @@ async def oauth_login(request: Request) -> RedirectResponse:
38
  detail="OAuth not configured. Set OAUTH_CLIENT_ID environment variable.",
39
  )
40
 
 
 
 
41
  # Generate state for CSRF protection
42
  state = secrets.token_urlsafe(32)
43
- oauth_states[state] = {"redirect_uri": get_redirect_uri(request)}
 
 
 
44
 
45
  # Build authorization URL
46
  params = {
47
  "client_id": OAUTH_CLIENT_ID,
48
  "redirect_uri": get_redirect_uri(request),
49
- "scope": "openid profile",
50
  "response_type": "code",
51
  "state": state,
 
 
 
52
  }
53
  auth_url = f"{OPENID_PROVIDER_URL}/oauth/authorize?{urlencode(params)}"
54
 
@@ -91,58 +115,57 @@ async def oauth_callback(
91
 
92
  # Get user info
93
  access_token = token_data.get("access_token")
94
- if access_token:
95
- async with httpx.AsyncClient() as client:
96
- try:
97
- userinfo_response = await client.get(
98
- f"{OPENID_PROVIDER_URL}/oauth/userinfo",
99
- headers={"Authorization": f"Bearer {access_token}"},
100
- )
101
- userinfo_response.raise_for_status()
102
- user_info = userinfo_response.json()
103
- except httpx.HTTPError:
104
- user_info = {}
105
- else:
106
- user_info = {}
107
-
108
- # For now, redirect to home with token in query params
109
- # In production, use secure cookies or session storage
110
- redirect_params = {
111
- "access_token": access_token,
112
- "username": user_info.get("preferred_username", ""),
113
- }
114
 
115
- return RedirectResponse(url=f"/?{urlencode(redirect_params)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
 
118
  @router.get("/logout")
119
  async def logout() -> RedirectResponse:
120
- """Log out the user."""
121
- return RedirectResponse(url="/")
 
 
122
 
123
 
124
- @router.get("/me")
125
- async def get_current_user(request: Request) -> dict:
126
- """Get current user info from Authorization header."""
127
- auth_header = request.headers.get("Authorization", "")
128
- if not auth_header.startswith("Bearer "):
129
- return {"authenticated": False}
130
 
131
- token = auth_header.split(" ")[1]
132
 
133
- async with httpx.AsyncClient() as client:
134
- try:
135
- response = await client.get(
136
- f"{OPENID_PROVIDER_URL}/oauth/userinfo",
137
- headers={"Authorization": f"Bearer {token}"},
138
- )
139
- response.raise_for_status()
140
- user_info = response.json()
141
- return {
142
- "authenticated": True,
143
- "username": user_info.get("preferred_username"),
144
- "name": user_info.get("name"),
145
- "picture": user_info.get("picture"),
146
- }
147
- except httpx.HTTPError:
148
- return {"authenticated": False}
 
1
+ """Authentication routes for HF OAuth.
2
+
3
+ Handles the OAuth 2.0 authorization code flow with HF as provider.
4
+ After successful auth, sets an HttpOnly cookie with the access token.
5
+ """
6
 
7
  import os
8
  import secrets
9
+ import time
10
  from urllib.parse import urlencode
11
 
12
  import httpx
13
+ from dependencies import AUTH_ENABLED, get_current_user
14
+ from fastapi import APIRouter, Depends, HTTPException, Request
15
  from fastapi.responses import RedirectResponse
16
 
17
  router = APIRouter(prefix="/auth", tags=["auth"])
 
21
  OAUTH_CLIENT_SECRET = os.environ.get("OAUTH_CLIENT_SECRET", "")
22
  OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co")
23
 
24
+ # In-memory OAuth state store with expiry (5 min TTL)
25
+ _OAUTH_STATE_TTL = 300
26
  oauth_states: dict[str, dict] = {}
27
 
28
 
29
+ def _cleanup_expired_states() -> None:
30
+ """Remove expired OAuth states to prevent memory growth."""
31
+ now = time.time()
32
+ expired = [k for k, v in oauth_states.items() if now > v.get("expires_at", 0)]
33
+ for k in expired:
34
+ del oauth_states[k]
35
+
36
+
37
  def get_redirect_uri(request: Request) -> str:
38
  """Get the OAuth callback redirect URI."""
39
  # In HF Spaces, use the SPACE_HOST if available
 
53
  detail="OAuth not configured. Set OAUTH_CLIENT_ID environment variable.",
54
  )
55
 
56
+ # Clean up expired states to prevent memory growth
57
+ _cleanup_expired_states()
58
+
59
  # Generate state for CSRF protection
60
  state = secrets.token_urlsafe(32)
61
+ oauth_states[state] = {
62
+ "redirect_uri": get_redirect_uri(request),
63
+ "expires_at": time.time() + _OAUTH_STATE_TTL,
64
+ }
65
 
66
  # Build authorization URL
67
  params = {
68
  "client_id": OAUTH_CLIENT_ID,
69
  "redirect_uri": get_redirect_uri(request),
70
+ "scope": "openid profile read-repos write-repos contribute-repos manage-repos inference-api jobs write-discussions",
71
  "response_type": "code",
72
  "state": state,
73
+ "orgIds": os.environ.get(
74
+ "HF_OAUTH_ORG_ID", "698dbf55845d85df163175f1"
75
+ ), # ml-agent-explorers
76
  }
77
  auth_url = f"{OPENID_PROVIDER_URL}/oauth/authorize?{urlencode(params)}"
78
 
 
115
 
116
  # Get user info
117
  access_token = token_data.get("access_token")
118
+ if not access_token:
119
+ raise HTTPException(
120
+ status_code=500,
121
+ detail="Token exchange succeeded but no access_token was returned.",
122
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
+ # Fetch user info (optional β€” failure is not fatal)
125
+ async with httpx.AsyncClient() as client:
126
+ try:
127
+ userinfo_response = await client.get(
128
+ f"{OPENID_PROVIDER_URL}/oauth/userinfo",
129
+ headers={"Authorization": f"Bearer {access_token}"},
130
+ )
131
+ userinfo_response.raise_for_status()
132
+ except httpx.HTTPError:
133
+ pass # user_info not required for auth flow
134
+
135
+ # Set access token as HttpOnly cookie (not in URL β€” avoids leaks via
136
+ # Referrer headers, browser history, and server logs)
137
+ is_production = bool(os.environ.get("SPACE_HOST"))
138
+ response = RedirectResponse(url="/", status_code=302)
139
+ response.set_cookie(
140
+ key="hf_access_token",
141
+ value=access_token,
142
+ httponly=True,
143
+ secure=is_production, # Secure flag only in production (HTTPS)
144
+ samesite="lax",
145
+ max_age=3600 * 24, # 24 hours
146
+ path="/",
147
+ )
148
+ return response
149
 
150
 
151
  @router.get("/logout")
152
  async def logout() -> RedirectResponse:
153
+ """Log out the user by clearing the auth cookie."""
154
+ response = RedirectResponse(url="/")
155
+ response.delete_cookie(key="hf_access_token", path="/")
156
+ return response
157
 
158
 
159
+ @router.get("/status")
160
+ async def auth_status() -> dict:
161
+ """Check if OAuth is enabled on this instance."""
162
+ return {"auth_enabled": AUTH_ENABLED}
 
 
163
 
 
164
 
165
+ @router.get("/me")
166
+ async def get_me(user: dict = Depends(get_current_user)) -> dict:
167
+ """Get current user info. Returns the authenticated user or dev user.
168
+
169
+ Uses the shared auth dependency which handles cookie + Bearer token.
170
+ """
171
+ return user
 
 
 
 
 
 
 
 
 
backend/session_manager.py CHANGED
@@ -48,11 +48,28 @@ class AgentSession:
48
  session: Session
49
  tool_router: ToolRouter
50
  submission_queue: asyncio.Queue
 
 
51
  task: asyncio.Task | None = None
52
  created_at: datetime = field(default_factory=datetime.utcnow)
53
  is_active: bool = True
54
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  class SessionManager:
57
  """Manages multiple concurrent agent sessions."""
58
 
@@ -61,19 +78,69 @@ class SessionManager:
61
  self.sessions: dict[str, AgentSession] = {}
62
  self._lock = asyncio.Lock()
63
 
64
- async def create_session(self) -> str:
65
- """Create a new agent session and return its ID."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  session_id = str(uuid.uuid4())
67
 
68
  # Create queues for this session
69
  submission_queue: asyncio.Queue = asyncio.Queue()
70
  event_queue: asyncio.Queue = asyncio.Queue()
71
 
72
- # Create tool router
73
- tool_router = ToolRouter(self.config.mcpServers)
 
 
 
 
 
 
 
 
 
 
74
 
75
- # Create the agent session
76
- session = Session(event_queue, config=self.config, tool_router=tool_router)
 
 
77
 
78
  # Create wrapper
79
  agent_session = AgentSession(
@@ -81,6 +148,8 @@ class SessionManager:
81
  session=session,
82
  tool_router=tool_router,
83
  submission_queue=submission_queue,
 
 
84
  )
85
 
86
  async with self._lock:
@@ -92,7 +161,7 @@ class SessionManager:
92
  )
93
  agent_session.task = task
94
 
95
- logger.info(f"Created session {session_id}")
96
  return session_id
97
 
98
  async def _run_session(
@@ -245,6 +314,27 @@ class SessionManager:
245
 
246
  return True
247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  def get_session_info(self, session_id: str) -> dict[str, Any] | None:
249
  """Get information about a session."""
250
  agent_session = self.sessions.get(session_id)
@@ -256,15 +346,25 @@ class SessionManager:
256
  "created_at": agent_session.created_at.isoformat(),
257
  "is_active": agent_session.is_active,
258
  "message_count": len(agent_session.session.context_manager.items),
 
259
  }
260
 
261
- def list_sessions(self) -> list[dict[str, Any]]:
262
- """List all sessions."""
263
- return [
264
- self.get_session_info(sid)
265
- for sid in self.sessions
266
- if self.get_session_info(sid)
267
- ]
 
 
 
 
 
 
 
 
 
268
 
269
  @property
270
  def active_session_count(self) -> int:
 
48
  session: Session
49
  tool_router: ToolRouter
50
  submission_queue: asyncio.Queue
51
+ user_id: str = "dev" # Owner of this session
52
+ hf_token: str | None = None # User's HF OAuth token for tool execution
53
  task: asyncio.Task | None = None
54
  created_at: datetime = field(default_factory=datetime.utcnow)
55
  is_active: bool = True
56
 
57
 
58
+ class SessionCapacityError(Exception):
59
+ """Raised when no more sessions can be created."""
60
+
61
+ def __init__(self, message: str, error_type: str = "global") -> None:
62
+ super().__init__(message)
63
+ self.error_type = error_type # "global" or "per_user"
64
+
65
+
66
+ # ── Capacity limits ─────────────────────────────────────────────────
67
+ # Estimated for HF Spaces cpu-basic (2 vCPU, 16 GB RAM).
68
+ # Each session uses ~10-20 MB (context, tools, queues, task).
69
+ MAX_SESSIONS: int = 50
70
+ MAX_SESSIONS_PER_USER: int = 10
71
+
72
+
73
  class SessionManager:
74
  """Manages multiple concurrent agent sessions."""
75
 
 
78
  self.sessions: dict[str, AgentSession] = {}
79
  self._lock = asyncio.Lock()
80
 
81
+ def _count_user_sessions(self, user_id: str) -> int:
82
+ """Count active sessions owned by a specific user."""
83
+ return sum(
84
+ 1
85
+ for s in self.sessions.values()
86
+ if s.user_id == user_id and s.is_active
87
+ )
88
+
89
+ async def create_session(self, user_id: str = "dev", hf_token: str | None = None) -> str:
90
+ """Create a new agent session and return its ID.
91
+
92
+ Session() and ToolRouter() constructors contain blocking I/O
93
+ (e.g. HfApi().whoami(), litellm.get_max_tokens()) so they are
94
+ executed in a thread pool to avoid freezing the async event loop.
95
+
96
+ Args:
97
+ user_id: The ID of the user who owns this session.
98
+
99
+ Raises:
100
+ SessionCapacityError: If the server or user has reached the
101
+ maximum number of concurrent sessions.
102
+ """
103
+ # ── Capacity checks ──────────────────────────────────────────
104
+ async with self._lock:
105
+ active_count = self.active_session_count
106
+ if active_count >= MAX_SESSIONS:
107
+ raise SessionCapacityError(
108
+ f"Server is at capacity ({active_count}/{MAX_SESSIONS} sessions). "
109
+ "Please try again later.",
110
+ error_type="global",
111
+ )
112
+ if user_id != "dev":
113
+ user_count = self._count_user_sessions(user_id)
114
+ if user_count >= MAX_SESSIONS_PER_USER:
115
+ raise SessionCapacityError(
116
+ f"You have reached the maximum of {MAX_SESSIONS_PER_USER} "
117
+ "concurrent sessions. Please close an existing session first.",
118
+ error_type="per_user",
119
+ )
120
+
121
  session_id = str(uuid.uuid4())
122
 
123
  # Create queues for this session
124
  submission_queue: asyncio.Queue = asyncio.Queue()
125
  event_queue: asyncio.Queue = asyncio.Queue()
126
 
127
+ # Run blocking constructors in a thread to keep the event loop responsive.
128
+ # Without this, Session.__init__ β†’ ContextManager β†’ litellm.get_max_tokens()
129
+ # blocks all HTTP/WebSocket handling.
130
+ import time as _time
131
+
132
+ def _create_session_sync():
133
+ t0 = _time.monotonic()
134
+ tool_router = ToolRouter(self.config.mcpServers)
135
+ session = Session(event_queue, config=self.config, tool_router=tool_router)
136
+ t1 = _time.monotonic()
137
+ logger.info(f"Session initialized in {t1 - t0:.2f}s")
138
+ return tool_router, session
139
 
140
+ tool_router, session = await asyncio.to_thread(_create_session_sync)
141
+
142
+ # Store user's HF token on the session so tools can use it
143
+ session.hf_token = hf_token
144
 
145
  # Create wrapper
146
  agent_session = AgentSession(
 
148
  session=session,
149
  tool_router=tool_router,
150
  submission_queue=submission_queue,
151
+ user_id=user_id,
152
+ hf_token=hf_token,
153
  )
154
 
155
  async with self._lock:
 
161
  )
162
  agent_session.task = task
163
 
164
+ logger.info(f"Created session {session_id} for user {user_id}")
165
  return session_id
166
 
167
  async def _run_session(
 
314
 
315
  return True
316
 
317
+ def get_session_owner(self, session_id: str) -> str | None:
318
+ """Get the user_id that owns a session, or None if session doesn't exist."""
319
+ agent_session = self.sessions.get(session_id)
320
+ if not agent_session:
321
+ return None
322
+ return agent_session.user_id
323
+
324
+ def verify_session_access(self, session_id: str, user_id: str) -> bool:
325
+ """Check if a user has access to a session.
326
+
327
+ Returns True if:
328
+ - The session exists AND the user owns it
329
+ - The user_id is "dev" (dev mode bypass)
330
+ """
331
+ owner = self.get_session_owner(session_id)
332
+ if owner is None:
333
+ return False
334
+ if user_id == "dev" or owner == "dev":
335
+ return True
336
+ return owner == user_id
337
+
338
  def get_session_info(self, session_id: str) -> dict[str, Any] | None:
339
  """Get information about a session."""
340
  agent_session = self.sessions.get(session_id)
 
346
  "created_at": agent_session.created_at.isoformat(),
347
  "is_active": agent_session.is_active,
348
  "message_count": len(agent_session.session.context_manager.items),
349
+ "user_id": agent_session.user_id,
350
  }
351
 
352
+ def list_sessions(self, user_id: str | None = None) -> list[dict[str, Any]]:
353
+ """List sessions, optionally filtered by user.
354
+
355
+ Args:
356
+ user_id: If provided, only return sessions owned by this user.
357
+ If "dev", return all sessions (dev mode).
358
+ """
359
+ results = []
360
+ for sid in self.sessions:
361
+ info = self.get_session_info(sid)
362
+ if not info:
363
+ continue
364
+ if user_id and user_id != "dev" and info.get("user_id") != user_id:
365
+ continue
366
+ results.append(info)
367
+ return results
368
 
369
  @property
370
  def active_session_count(self) -> int:
backend/websocket.py CHANGED
@@ -1,6 +1,5 @@
1
  """WebSocket connection manager for real-time communication."""
2
 
3
- import asyncio
4
  import logging
5
  from typing import Any
6
 
@@ -15,23 +14,18 @@ class ConnectionManager:
15
  def __init__(self) -> None:
16
  # session_id -> WebSocket
17
  self.active_connections: dict[str, WebSocket] = {}
18
- # session_id -> asyncio.Queue for outgoing messages
19
- self.message_queues: dict[str, asyncio.Queue] = {}
20
 
21
  async def connect(self, websocket: WebSocket, session_id: str) -> None:
22
  """Accept a WebSocket connection and register it."""
23
  logger.info(f"Attempting to accept WebSocket for session {session_id}")
24
  await websocket.accept()
25
  self.active_connections[session_id] = websocket
26
- self.message_queues[session_id] = asyncio.Queue()
27
  logger.info(f"WebSocket connected and registered for session {session_id}")
28
 
29
  def disconnect(self, session_id: str) -> None:
30
  """Remove a WebSocket connection."""
31
  if session_id in self.active_connections:
32
  del self.active_connections[session_id]
33
- if session_id in self.message_queues:
34
- del self.message_queues[session_id]
35
  logger.info(f"WebSocket disconnected for session {session_id}")
36
 
37
  async def send_event(
@@ -63,10 +57,6 @@ class ConnectionManager:
63
  """Check if a session has an active WebSocket connection."""
64
  return session_id in self.active_connections
65
 
66
- def get_queue(self, session_id: str) -> asyncio.Queue | None:
67
- """Get the message queue for a session."""
68
- return self.message_queues.get(session_id)
69
-
70
 
71
  # Global connection manager instance
72
  manager = ConnectionManager()
 
1
  """WebSocket connection manager for real-time communication."""
2
 
 
3
  import logging
4
  from typing import Any
5
 
 
14
  def __init__(self) -> None:
15
  # session_id -> WebSocket
16
  self.active_connections: dict[str, WebSocket] = {}
 
 
17
 
18
  async def connect(self, websocket: WebSocket, session_id: str) -> None:
19
  """Accept a WebSocket connection and register it."""
20
  logger.info(f"Attempting to accept WebSocket for session {session_id}")
21
  await websocket.accept()
22
  self.active_connections[session_id] = websocket
 
23
  logger.info(f"WebSocket connected and registered for session {session_id}")
24
 
25
  def disconnect(self, session_id: str) -> None:
26
  """Remove a WebSocket connection."""
27
  if session_id in self.active_connections:
28
  del self.active_connections[session_id]
 
 
29
  logger.info(f"WebSocket disconnected for session {session_id}")
30
 
31
  async def send_event(
 
57
  """Check if a session has an active WebSocket connection."""
58
  return session_id in self.active_connections
59
 
 
 
 
 
60
 
61
  # Global connection manager instance
62
  manager = ConnectionManager()
configs/main_agent_config.json CHANGED
@@ -1,9 +1,9 @@
1
  {
2
- "model_name": "anthropic/claude-opus-4-5-20251101",
3
  "save_sessions": true,
4
  "session_dataset_repo": "akseljoonas/hf-agent-sessions",
5
  "yolo_mode": false,
6
- "confirm_cpu_jobs": false,
7
  "auto_file_upload": true,
8
  "mcpServers": {
9
  "hf-mcp-server": {
 
1
  {
2
+ "model_name": "huggingface/novita/moonshotai/kimi-k2.5",
3
  "save_sessions": true,
4
  "session_dataset_repo": "akseljoonas/hf-agent-sessions",
5
  "yolo_mode": false,
6
+ "confirm_cpu_jobs": true,
7
  "auto_file_upload": true,
8
  "mcpServers": {
9
  "hf-mcp-server": {
frontend/package-lock.json CHANGED
@@ -8,10 +8,12 @@
8
  "name": "hf-agent-frontend",
9
  "version": "1.0.0",
10
  "dependencies": {
 
11
  "@emotion/react": "^11.13.0",
12
  "@emotion/styled": "^11.13.0",
13
  "@mui/icons-material": "^6.1.0",
14
  "@mui/material": "^6.1.0",
 
15
  "react": "^18.3.1",
16
  "react-dom": "^18.3.1",
17
  "react-markdown": "^9.0.1",
@@ -34,6 +36,70 @@
34
  "vite": "^5.4.10"
35
  }
36
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  "node_modules/@babel/code-frame": {
38
  "version": "7.28.6",
39
  "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.28.6.tgz",
@@ -1348,6 +1414,15 @@
1348
  }
1349
  }
1350
  },
 
 
 
 
 
 
 
 
 
1351
  "node_modules/@popperjs/core": {
1352
  "version": "2.11.8",
1353
  "resolved": "https://registry.npmjs.org/@popperjs/core/-/core-2.11.8.tgz",
@@ -1715,6 +1790,12 @@
1715
  "win32"
1716
  ]
1717
  },
 
 
 
 
 
 
1718
  "node_modules/@types/babel__core": {
1719
  "version": "7.20.5",
1720
  "resolved": "https://registry.npmjs.org/@types/babel__core/-/babel__core-7.20.5.tgz",
@@ -2155,6 +2236,15 @@
2155
  "integrity": "sha512-WmoN8qaIAo7WTYWbAZuG8PYEhn5fkz7dZrqTBZ7dtt//lL2Gwms1IcnQ5yHqjDfX8Ft5j4YzDM23f87zBfDe9g==",
2156
  "license": "ISC"
2157
  },
 
 
 
 
 
 
 
 
 
2158
  "node_modules/@vitejs/plugin-react": {
2159
  "version": "4.7.0",
2160
  "resolved": "https://registry.npmjs.org/@vitejs/plugin-react/-/plugin-react-4.7.0.tgz",
@@ -2200,6 +2290,24 @@
2200
  "acorn": "^6.0.0 || ^7.0.0 || ^8.0.0"
2201
  }
2202
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2203
  "node_modules/ajv": {
2204
  "version": "6.12.6",
2205
  "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz",
@@ -2848,6 +2956,15 @@
2848
  "node": ">=0.10.0"
2849
  }
2850
  },
 
 
 
 
 
 
 
 
 
2851
  "node_modules/extend": {
2852
  "version": "3.0.2",
2853
  "resolved": "https://registry.npmjs.org/extend/-/extend-3.0.2.tgz",
@@ -3356,6 +3473,12 @@
3356
  "integrity": "sha512-xyFwyhro/JEof6Ghe2iz2NcXoj2sloNsWr/XsERDK/oiPCfaNhl5ONfp+jQdAZRQQ0IJWNzH9zIZF7li91kh2w==",
3357
  "license": "MIT"
3358
  },
 
 
 
 
 
 
3359
  "node_modules/json-schema-traverse": {
3360
  "version": "0.4.1",
3361
  "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-0.4.1.tgz",
@@ -5052,6 +5175,31 @@
5052
  "url": "https://github.com/sponsors/ljharb"
5053
  }
5054
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5055
  "node_modules/tinyglobby": {
5056
  "version": "0.2.15",
5057
  "resolved": "https://registry.npmjs.org/tinyglobby/-/tinyglobby-0.2.15.tgz",
@@ -5282,6 +5430,16 @@
5282
  "punycode": "^2.1.0"
5283
  }
5284
  },
 
 
 
 
 
 
 
 
 
 
5285
  "node_modules/vfile": {
5286
  "version": "6.0.3",
5287
  "resolved": "https://registry.npmjs.org/vfile/-/vfile-6.0.3.tgz",
@@ -5426,6 +5584,16 @@
5426
  "url": "https://github.com/sponsors/sindresorhus"
5427
  }
5428
  },
 
 
 
 
 
 
 
 
 
 
5429
  "node_modules/zustand": {
5430
  "version": "5.0.10",
5431
  "resolved": "https://registry.npmjs.org/zustand/-/zustand-5.0.10.tgz",
 
8
  "name": "hf-agent-frontend",
9
  "version": "1.0.0",
10
  "dependencies": {
11
+ "@ai-sdk/react": "^3.0.93",
12
  "@emotion/react": "^11.13.0",
13
  "@emotion/styled": "^11.13.0",
14
  "@mui/icons-material": "^6.1.0",
15
  "@mui/material": "^6.1.0",
16
+ "ai": "^6.0.91",
17
  "react": "^18.3.1",
18
  "react-dom": "^18.3.1",
19
  "react-markdown": "^9.0.1",
 
36
  "vite": "^5.4.10"
37
  }
38
  },
39
+ "node_modules/@ai-sdk/gateway": {
40
+ "version": "3.0.50",
41
+ "resolved": "https://registry.npmjs.org/@ai-sdk/gateway/-/gateway-3.0.50.tgz",
42
+ "integrity": "sha512-Jdd1a8VgbD7l7r+COj0h5SuaYRfPvOJ/AO6l0OrmTPEcI2MUQPr3C4JttfpNkcheEN+gOdy0CtZWuG17bW2fjw==",
43
+ "license": "Apache-2.0",
44
+ "dependencies": {
45
+ "@ai-sdk/provider": "3.0.8",
46
+ "@ai-sdk/provider-utils": "4.0.15",
47
+ "@vercel/oidc": "3.1.0"
48
+ },
49
+ "engines": {
50
+ "node": ">=18"
51
+ },
52
+ "peerDependencies": {
53
+ "zod": "^3.25.76 || ^4.1.8"
54
+ }
55
+ },
56
+ "node_modules/@ai-sdk/provider": {
57
+ "version": "3.0.8",
58
+ "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-3.0.8.tgz",
59
+ "integrity": "sha512-oGMAgGoQdBXbZqNG0Ze56CHjDZ1IDYOwGYxYjO5KLSlz5HiNQ9udIXsPZ61VWaHGZ5XW/jyjmr6t2xz2jGVwbQ==",
60
+ "license": "Apache-2.0",
61
+ "dependencies": {
62
+ "json-schema": "^0.4.0"
63
+ },
64
+ "engines": {
65
+ "node": ">=18"
66
+ }
67
+ },
68
+ "node_modules/@ai-sdk/provider-utils": {
69
+ "version": "4.0.15",
70
+ "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-4.0.15.tgz",
71
+ "integrity": "sha512-8XiKWbemmCbvNN0CLR9u3PQiet4gtEVIrX4zzLxnCj06AwsEDJwJVBbKrEI4t6qE8XRSIvU2irka0dcpziKW6w==",
72
+ "license": "Apache-2.0",
73
+ "dependencies": {
74
+ "@ai-sdk/provider": "3.0.8",
75
+ "@standard-schema/spec": "^1.1.0",
76
+ "eventsource-parser": "^3.0.6"
77
+ },
78
+ "engines": {
79
+ "node": ">=18"
80
+ },
81
+ "peerDependencies": {
82
+ "zod": "^3.25.76 || ^4.1.8"
83
+ }
84
+ },
85
+ "node_modules/@ai-sdk/react": {
86
+ "version": "3.0.93",
87
+ "resolved": "https://registry.npmjs.org/@ai-sdk/react/-/react-3.0.93.tgz",
88
+ "integrity": "sha512-FY1HmeAfCpiAGLhIZh2QR8QFzHFZfhjMmkA9D5KC/O3eGqPeY7CwBABLkzRH+5Gkf+MfxXnEm4VF0MpmvDMjpg==",
89
+ "license": "Apache-2.0",
90
+ "dependencies": {
91
+ "@ai-sdk/provider-utils": "4.0.15",
92
+ "ai": "6.0.91",
93
+ "swr": "^2.2.5",
94
+ "throttleit": "2.1.0"
95
+ },
96
+ "engines": {
97
+ "node": ">=18"
98
+ },
99
+ "peerDependencies": {
100
+ "react": "^18 || ~19.0.1 || ~19.1.2 || ^19.2.1"
101
+ }
102
+ },
103
  "node_modules/@babel/code-frame": {
104
  "version": "7.28.6",
105
  "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.28.6.tgz",
 
1414
  }
1415
  }
1416
  },
1417
+ "node_modules/@opentelemetry/api": {
1418
+ "version": "1.9.0",
1419
+ "resolved": "https://registry.npmjs.org/@opentelemetry/api/-/api-1.9.0.tgz",
1420
+ "integrity": "sha512-3giAOQvZiH5F9bMlMiv8+GSPMeqg0dbaeo58/0SlA9sxSqZhnUtxzX9/2FzyhS9sWQf5S0GJE0AKBrFqjpeYcg==",
1421
+ "license": "Apache-2.0",
1422
+ "engines": {
1423
+ "node": ">=8.0.0"
1424
+ }
1425
+ },
1426
  "node_modules/@popperjs/core": {
1427
  "version": "2.11.8",
1428
  "resolved": "https://registry.npmjs.org/@popperjs/core/-/core-2.11.8.tgz",
 
1790
  "win32"
1791
  ]
1792
  },
1793
+ "node_modules/@standard-schema/spec": {
1794
+ "version": "1.1.0",
1795
+ "resolved": "https://registry.npmjs.org/@standard-schema/spec/-/spec-1.1.0.tgz",
1796
+ "integrity": "sha512-l2aFy5jALhniG5HgqrD6jXLi/rUWrKvqN/qJx6yoJsgKhblVd+iqqU4RCXavm/jPityDo5TCvKMnpjKnOriy0w==",
1797
+ "license": "MIT"
1798
+ },
1799
  "node_modules/@types/babel__core": {
1800
  "version": "7.20.5",
1801
  "resolved": "https://registry.npmjs.org/@types/babel__core/-/babel__core-7.20.5.tgz",
 
2236
  "integrity": "sha512-WmoN8qaIAo7WTYWbAZuG8PYEhn5fkz7dZrqTBZ7dtt//lL2Gwms1IcnQ5yHqjDfX8Ft5j4YzDM23f87zBfDe9g==",
2237
  "license": "ISC"
2238
  },
2239
+ "node_modules/@vercel/oidc": {
2240
+ "version": "3.1.0",
2241
+ "resolved": "https://registry.npmjs.org/@vercel/oidc/-/oidc-3.1.0.tgz",
2242
+ "integrity": "sha512-Fw28YZpRnA3cAHHDlkt7xQHiJ0fcL+NRcIqsocZQUSmbzeIKRpwttJjik5ZGanXP+vlA4SbTg+AbA3bP363l+w==",
2243
+ "license": "Apache-2.0",
2244
+ "engines": {
2245
+ "node": ">= 20"
2246
+ }
2247
+ },
2248
  "node_modules/@vitejs/plugin-react": {
2249
  "version": "4.7.0",
2250
  "resolved": "https://registry.npmjs.org/@vitejs/plugin-react/-/plugin-react-4.7.0.tgz",
 
2290
  "acorn": "^6.0.0 || ^7.0.0 || ^8.0.0"
2291
  }
2292
  },
2293
+ "node_modules/ai": {
2294
+ "version": "6.0.91",
2295
+ "resolved": "https://registry.npmjs.org/ai/-/ai-6.0.91.tgz",
2296
+ "integrity": "sha512-k1/8BusZMhYVxxLZt0BUZzm9HVDCCh117nyWfWUx5xjR2+tWisJbXgysL7EBMq2lgyHwgpA1jDR3tVjWSdWZXw==",
2297
+ "license": "Apache-2.0",
2298
+ "dependencies": {
2299
+ "@ai-sdk/gateway": "3.0.50",
2300
+ "@ai-sdk/provider": "3.0.8",
2301
+ "@ai-sdk/provider-utils": "4.0.15",
2302
+ "@opentelemetry/api": "1.9.0"
2303
+ },
2304
+ "engines": {
2305
+ "node": ">=18"
2306
+ },
2307
+ "peerDependencies": {
2308
+ "zod": "^3.25.76 || ^4.1.8"
2309
+ }
2310
+ },
2311
  "node_modules/ajv": {
2312
  "version": "6.12.6",
2313
  "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz",
 
2956
  "node": ">=0.10.0"
2957
  }
2958
  },
2959
+ "node_modules/eventsource-parser": {
2960
+ "version": "3.0.6",
2961
+ "resolved": "https://registry.npmjs.org/eventsource-parser/-/eventsource-parser-3.0.6.tgz",
2962
+ "integrity": "sha512-Vo1ab+QXPzZ4tCa8SwIHJFaSzy4R6SHf7BY79rFBDf0idraZWAkYrDjDj8uWaSm3S2TK+hJ7/t1CEmZ7jXw+pg==",
2963
+ "license": "MIT",
2964
+ "engines": {
2965
+ "node": ">=18.0.0"
2966
+ }
2967
+ },
2968
  "node_modules/extend": {
2969
  "version": "3.0.2",
2970
  "resolved": "https://registry.npmjs.org/extend/-/extend-3.0.2.tgz",
 
3473
  "integrity": "sha512-xyFwyhro/JEof6Ghe2iz2NcXoj2sloNsWr/XsERDK/oiPCfaNhl5ONfp+jQdAZRQQ0IJWNzH9zIZF7li91kh2w==",
3474
  "license": "MIT"
3475
  },
3476
+ "node_modules/json-schema": {
3477
+ "version": "0.4.0",
3478
+ "resolved": "https://registry.npmjs.org/json-schema/-/json-schema-0.4.0.tgz",
3479
+ "integrity": "sha512-es94M3nTIfsEPisRafak+HDLfHXnKBhV3vU5eqPcS3flIWqcxJWgXHXiey3YrpaNsanY5ei1VoYEbOzijuq9BA==",
3480
+ "license": "(AFL-2.1 OR BSD-3-Clause)"
3481
+ },
3482
  "node_modules/json-schema-traverse": {
3483
  "version": "0.4.1",
3484
  "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-0.4.1.tgz",
 
5175
  "url": "https://github.com/sponsors/ljharb"
5176
  }
5177
  },
5178
+ "node_modules/swr": {
5179
+ "version": "2.4.0",
5180
+ "resolved": "https://registry.npmjs.org/swr/-/swr-2.4.0.tgz",
5181
+ "integrity": "sha512-sUlC20T8EOt1pHmDiqueUWMmRRX03W7w5YxovWX7VR2KHEPCTMly85x05vpkP5i6Bu4h44ePSMD9Tc+G2MItFw==",
5182
+ "license": "MIT",
5183
+ "dependencies": {
5184
+ "dequal": "^2.0.3",
5185
+ "use-sync-external-store": "^1.6.0"
5186
+ },
5187
+ "peerDependencies": {
5188
+ "react": "^16.11.0 || ^17.0.0 || ^18.0.0 || ^19.0.0"
5189
+ }
5190
+ },
5191
+ "node_modules/throttleit": {
5192
+ "version": "2.1.0",
5193
+ "resolved": "https://registry.npmjs.org/throttleit/-/throttleit-2.1.0.tgz",
5194
+ "integrity": "sha512-nt6AMGKW1p/70DF/hGBdJB57B8Tspmbp5gfJ8ilhLnt7kkr2ye7hzD6NVG8GGErk2HWF34igrL2CXmNIkzKqKw==",
5195
+ "license": "MIT",
5196
+ "engines": {
5197
+ "node": ">=18"
5198
+ },
5199
+ "funding": {
5200
+ "url": "https://github.com/sponsors/sindresorhus"
5201
+ }
5202
+ },
5203
  "node_modules/tinyglobby": {
5204
  "version": "0.2.15",
5205
  "resolved": "https://registry.npmjs.org/tinyglobby/-/tinyglobby-0.2.15.tgz",
 
5430
  "punycode": "^2.1.0"
5431
  }
5432
  },
5433
+ "node_modules/use-sync-external-store": {
5434
+ "version": "1.6.0",
5435
+ "resolved": "https://registry.npmjs.org/use-sync-external-store/-/use-sync-external-store-1.6.0.tgz",
5436
+ "integrity": "sha512-Pp6GSwGP/NrPIrxVFAIkOQeyw8lFenOHijQWkUTrDvrF4ALqylP2C/KCkeS9dpUM3KvYRQhna5vt7IL95+ZQ9w==",
5437
+ "license": "MIT",
5438
+ "peer": true,
5439
+ "peerDependencies": {
5440
+ "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0"
5441
+ }
5442
+ },
5443
  "node_modules/vfile": {
5444
  "version": "6.0.3",
5445
  "resolved": "https://registry.npmjs.org/vfile/-/vfile-6.0.3.tgz",
 
5584
  "url": "https://github.com/sponsors/sindresorhus"
5585
  }
5586
  },
5587
+ "node_modules/zod": {
5588
+ "version": "4.3.6",
5589
+ "resolved": "https://registry.npmjs.org/zod/-/zod-4.3.6.tgz",
5590
+ "integrity": "sha512-rftlrkhHZOcjDwkGlnUtZZkvaPHCsDATp4pGpuOOMDaTdDDXF91wuVDJoWoPsKX/3YPQ5fHuF3STjcYyKr+Qhg==",
5591
+ "license": "MIT",
5592
+ "peer": true,
5593
+ "funding": {
5594
+ "url": "https://github.com/sponsors/colinhacks"
5595
+ }
5596
+ },
5597
  "node_modules/zustand": {
5598
  "version": "5.0.10",
5599
  "resolved": "https://registry.npmjs.org/zustand/-/zustand-5.0.10.tgz",
frontend/package.json CHANGED
@@ -10,10 +10,12 @@
10
  "preview": "vite preview"
11
  },
12
  "dependencies": {
 
13
  "@emotion/react": "^11.13.0",
14
  "@emotion/styled": "^11.13.0",
15
  "@mui/icons-material": "^6.1.0",
16
  "@mui/material": "^6.1.0",
 
17
  "react": "^18.3.1",
18
  "react-dom": "^18.3.1",
19
  "react-markdown": "^9.0.1",
 
10
  "preview": "vite preview"
11
  },
12
  "dependencies": {
13
+ "@ai-sdk/react": "^3.0.93",
14
  "@emotion/react": "^11.13.0",
15
  "@emotion/styled": "^11.13.0",
16
  "@mui/icons-material": "^6.1.0",
17
  "@mui/material": "^6.1.0",
18
+ "ai": "^6.0.91",
19
  "react": "^18.3.1",
20
  "react-dom": "^18.3.1",
21
  "react-markdown": "^9.0.1",
frontend/src/App.tsx CHANGED
@@ -1,7 +1,12 @@
1
  import { Box } from '@mui/material';
2
  import AppLayout from '@/components/Layout/AppLayout';
 
3
 
4
  function App() {
 
 
 
 
5
  return (
6
  <Box sx={{ height: '100vh', display: 'flex' }}>
7
  <AppLayout />
 
1
  import { Box } from '@mui/material';
2
  import AppLayout from '@/components/Layout/AppLayout';
3
+ import { useAuth } from '@/hooks/useAuth';
4
 
5
  function App() {
6
+ // Non-blocking auth check β€” fires in background, updates store when done.
7
+ // If auth fails later, apiFetch redirects to /auth/login.
8
+ useAuth();
9
+
10
  return (
11
  <Box sx={{ height: '100vh', display: 'flex' }}>
12
  <AppLayout />
frontend/src/components/ApprovalModal/ApprovalModal.tsx DELETED
@@ -1,208 +0,0 @@
1
- import { useState, useCallback } from 'react';
2
- import {
3
- Dialog,
4
- DialogTitle,
5
- DialogContent,
6
- DialogActions,
7
- Button,
8
- Box,
9
- Typography,
10
- Checkbox,
11
- FormControlLabel,
12
- Accordion,
13
- AccordionSummary,
14
- AccordionDetails,
15
- TextField,
16
- Chip,
17
- } from '@mui/material';
18
- import ExpandMoreIcon from '@mui/icons-material/ExpandMore';
19
- import WarningIcon from '@mui/icons-material/Warning';
20
- import { useAgentStore } from '@/store/agentStore';
21
-
22
- interface ApprovalModalProps {
23
- sessionId: string | null;
24
- }
25
-
26
- interface ApprovalState {
27
- [toolCallId: string]: {
28
- approved: boolean;
29
- feedback: string;
30
- };
31
- }
32
-
33
- export default function ApprovalModal({ sessionId }: ApprovalModalProps) {
34
- const { pendingApprovals, setPendingApprovals } = useAgentStore();
35
- const [approvalState, setApprovalState] = useState<ApprovalState>({});
36
-
37
- const isOpen = pendingApprovals !== null && pendingApprovals.tools.length > 0;
38
-
39
- const handleApprovalChange = useCallback(
40
- (toolCallId: string, approved: boolean) => {
41
- setApprovalState((prev) => ({
42
- ...prev,
43
- [toolCallId]: {
44
- ...prev[toolCallId],
45
- approved,
46
- feedback: prev[toolCallId]?.feedback || '',
47
- },
48
- }));
49
- },
50
- []
51
- );
52
-
53
- const handleFeedbackChange = useCallback(
54
- (toolCallId: string, feedback: string) => {
55
- setApprovalState((prev) => ({
56
- ...prev,
57
- [toolCallId]: {
58
- ...prev[toolCallId],
59
- feedback,
60
- },
61
- }));
62
- },
63
- []
64
- );
65
-
66
- const handleSubmit = useCallback(async () => {
67
- if (!sessionId || !pendingApprovals) return;
68
-
69
- const approvals = pendingApprovals.tools.map((tool) => ({
70
- tool_call_id: tool.tool_call_id,
71
- approved: approvalState[tool.tool_call_id]?.approved ?? false,
72
- feedback: approvalState[tool.tool_call_id]?.feedback || null,
73
- }));
74
-
75
- try {
76
- await fetch('/api/approve', {
77
- method: 'POST',
78
- headers: { 'Content-Type': 'application/json' },
79
- body: JSON.stringify({
80
- session_id: sessionId,
81
- approvals,
82
- }),
83
- });
84
- setPendingApprovals(null);
85
- setApprovalState({});
86
- } catch (e) {
87
- console.error('Approval submission failed:', e);
88
- }
89
- }, [sessionId, pendingApprovals, approvalState, setPendingApprovals]);
90
-
91
- const handleApproveAll = useCallback(() => {
92
- if (!pendingApprovals) return;
93
- const newState: ApprovalState = {};
94
- pendingApprovals.tools.forEach((tool) => {
95
- newState[tool.tool_call_id] = { approved: true, feedback: '' };
96
- });
97
- setApprovalState(newState);
98
- }, [pendingApprovals]);
99
-
100
- const handleRejectAll = useCallback(() => {
101
- if (!pendingApprovals) return;
102
- const newState: ApprovalState = {};
103
- pendingApprovals.tools.forEach((tool) => {
104
- newState[tool.tool_call_id] = { approved: false, feedback: '' };
105
- });
106
- setApprovalState(newState);
107
- }, [pendingApprovals]);
108
-
109
- if (!isOpen || !pendingApprovals) return null;
110
-
111
- const approvedCount = Object.values(approvalState).filter((s) => s.approved).length;
112
-
113
- return (
114
- <Dialog
115
- open={isOpen}
116
- maxWidth="md"
117
- fullWidth
118
- PaperProps={{
119
- sx: { bgcolor: 'background.paper' },
120
- }}
121
- >
122
- <DialogTitle sx={{ display: 'flex', alignItems: 'center', gap: 1 }}>
123
- <WarningIcon color="warning" />
124
- Approval Required
125
- <Chip
126
- label={`${pendingApprovals.count} tool${pendingApprovals.count > 1 ? 's' : ''}`}
127
- size="small"
128
- sx={{ ml: 1 }}
129
- />
130
- </DialogTitle>
131
- <DialogContent dividers>
132
- <Typography variant="body2" color="text.secondary" sx={{ mb: 2 }}>
133
- The following tool calls require your approval before execution:
134
- </Typography>
135
- {pendingApprovals.tools.map((tool, index) => (
136
- <Accordion key={tool.tool_call_id} defaultExpanded={index === 0}>
137
- <AccordionSummary expandIcon={<ExpandMoreIcon />}>
138
- <Box sx={{ display: 'flex', alignItems: 'center', gap: 2, width: '100%' }}>
139
- <FormControlLabel
140
- control={
141
- <Checkbox
142
- checked={approvalState[tool.tool_call_id]?.approved ?? false}
143
- onChange={(e) => {
144
- e.stopPropagation();
145
- handleApprovalChange(tool.tool_call_id, e.target.checked);
146
- }}
147
- onClick={(e) => e.stopPropagation()}
148
- />
149
- }
150
- label=""
151
- sx={{ m: 0 }}
152
- />
153
- <Chip label={tool.tool} size="small" color="primary" variant="outlined" />
154
- <Typography variant="body2" color="text.secondary" sx={{ ml: 'auto' }}>
155
- {approvalState[tool.tool_call_id]?.approved ? 'Approved' : 'Pending'}
156
- </Typography>
157
- </Box>
158
- </AccordionSummary>
159
- <AccordionDetails>
160
- <Typography variant="subtitle2" gutterBottom>
161
- Arguments:
162
- </Typography>
163
- <Box
164
- component="pre"
165
- sx={{
166
- bgcolor: 'background.default',
167
- p: 1.5,
168
- borderRadius: 1,
169
- overflow: 'auto',
170
- fontSize: '0.8rem',
171
- maxHeight: 200,
172
- }}
173
- >
174
- {JSON.stringify(tool.arguments, null, 2)}
175
- </Box>
176
- {!approvalState[tool.tool_call_id]?.approved && (
177
- <TextField
178
- fullWidth
179
- size="small"
180
- label="Feedback (optional)"
181
- placeholder="Explain why you're rejecting this..."
182
- value={approvalState[tool.tool_call_id]?.feedback || ''}
183
- onChange={(e) => handleFeedbackChange(tool.tool_call_id, e.target.value)}
184
- sx={{ mt: 2 }}
185
- />
186
- )}
187
- </AccordionDetails>
188
- </Accordion>
189
- ))}
190
- </DialogContent>
191
- <DialogActions sx={{ px: 3, py: 2 }}>
192
- <Button onClick={handleRejectAll} color="error" variant="outlined">
193
- Reject All
194
- </Button>
195
- <Button onClick={handleApproveAll} color="success" variant="outlined">
196
- Approve All
197
- </Button>
198
- <Box sx={{ flex: 1 }} />
199
- <Typography variant="body2" color="text.secondary" sx={{ mr: 2 }}>
200
- {approvedCount} of {pendingApprovals.count} approved
201
- </Typography>
202
- <Button onClick={handleSubmit} variant="contained" color="primary">
203
- Submit
204
- </Button>
205
- </DialogActions>
206
- </Dialog>
207
- );
208
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
frontend/src/components/Chat/ActivityStatusBar.tsx ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { Box, Typography } from '@mui/material';
2
+ import { keyframes } from '@mui/system';
3
+ import { useAgentStore, type ActivityStatus } from '@/store/agentStore';
4
+
5
+ const shimmer = keyframes`
6
+ 0% { background-position: -100% center; }
7
+ 50% { background-position: 200% center; }
8
+ 100% { background-position: -100% center; }
9
+ `;
10
+
11
+ const TOOL_LABELS: Record<string, string> = {
12
+ hf_jobs: 'Running job',
13
+ hf_repo_files: 'Uploading file',
14
+ hf_repo_git: 'Git operation',
15
+ hf_inspect_dataset: 'Inspecting dataset',
16
+ hf_search: 'Searching',
17
+ plan_tool: 'Planning',
18
+ };
19
+
20
+ function statusLabel(status: ActivityStatus): string {
21
+ switch (status.type) {
22
+ case 'thinking': return 'Thinking';
23
+ case 'streaming': return 'Writing';
24
+ case 'tool': return TOOL_LABELS[status.toolName] || `Running ${status.toolName}`;
25
+ case 'waiting-approval': return 'Waiting for approval';
26
+ default: return '';
27
+ }
28
+ }
29
+
30
+ export default function ActivityStatusBar() {
31
+ const activityStatus = useAgentStore(s => s.activityStatus);
32
+
33
+ if (activityStatus.type === 'idle') return null;
34
+
35
+ const label = statusLabel(activityStatus);
36
+
37
+ return (
38
+ <Box sx={{ px: 2, py: 0.5, minHeight: 28, display: 'flex', alignItems: 'center' }}>
39
+ <Typography
40
+ sx={{
41
+ fontFamily: 'monospace',
42
+ fontSize: '0.72rem',
43
+ fontWeight: 500,
44
+ letterSpacing: '0.02em',
45
+ background: 'linear-gradient(90deg, var(--muted-text) 30%, var(--text) 50%, var(--muted-text) 70%)',
46
+ backgroundSize: '250% 100%',
47
+ backgroundClip: 'text',
48
+ WebkitBackgroundClip: 'text',
49
+ WebkitTextFillColor: 'transparent',
50
+ animation: `${shimmer} 4s ease-in-out infinite`,
51
+ }}
52
+ >
53
+ {label}…
54
+ </Typography>
55
+ </Box>
56
+ );
57
+ }
frontend/src/components/Chat/ApprovalFlow.tsx DELETED
@@ -1,515 +0,0 @@
1
- import { useState, useCallback, useEffect } from 'react';
2
- import { Box, Typography, Button, TextField, IconButton, Link } from '@mui/material';
3
- import SendIcon from '@mui/icons-material/Send';
4
- import OpenInNewIcon from '@mui/icons-material/OpenInNew';
5
- import CheckCircleIcon from '@mui/icons-material/CheckCircle';
6
- import CancelIcon from '@mui/icons-material/Cancel';
7
- import LaunchIcon from '@mui/icons-material/Launch';
8
- import { useAgentStore } from '@/store/agentStore';
9
- import { useLayoutStore } from '@/store/layoutStore';
10
- import { useSessionStore } from '@/store/sessionStore';
11
- import type { Message, ToolApproval } from '@/types/agent';
12
-
13
- interface ApprovalFlowProps {
14
- message: Message;
15
- }
16
-
17
- export default function ApprovalFlow({ message }: ApprovalFlowProps) {
18
- const { setPanelContent, setPanelTab, setActivePanelTab, clearPanelTabs, updateMessage } = useAgentStore();
19
- const { setRightPanelOpen, setLeftSidebarOpen } = useLayoutStore();
20
- const { activeSessionId } = useSessionStore();
21
- const [currentIndex, setCurrentIndex] = useState(0);
22
- const [feedback, setFeedback] = useState('');
23
- const [decisions, setDecisions] = useState<ToolApproval[]>([]);
24
-
25
- const approvalData = message.approval;
26
-
27
- if (!approvalData) return null;
28
-
29
- const { batch, status } = approvalData;
30
-
31
- // Parse toolOutput to extract job info (URL, status, logs, errors)
32
- let logsContent = '';
33
- let showLogsButton = false;
34
- let jobUrl = '';
35
- let jobStatus = '';
36
- let jobFailed = false;
37
- let errorMessage = '';
38
-
39
- if (message.toolOutput) {
40
- const output = message.toolOutput;
41
-
42
- // Extract job URL: **View at:** https://...
43
- const urlMatch = output.match(/\*\*View at:\*\*\s*(https:\/\/[^\s\n]+)/);
44
- if (urlMatch) {
45
- jobUrl = urlMatch[1];
46
- }
47
-
48
- // Extract job status: **Final Status:** ...
49
- const statusMatch = output.match(/\*\*Final Status:\*\*\s*([^\n]+)/);
50
- if (statusMatch) {
51
- jobStatus = statusMatch[1].trim();
52
- jobFailed = jobStatus.toLowerCase().includes('error') || jobStatus.toLowerCase().includes('failed');
53
- }
54
-
55
- // Extract logs
56
- if (output.includes('**Logs:**')) {
57
- const parts = output.split('**Logs:**');
58
- if (parts.length > 1) {
59
- const logsPart = parts[1].trim();
60
- const codeBlockMatch = logsPart.match(/```([\s\S]*?)```/);
61
- if (codeBlockMatch) {
62
- logsContent = codeBlockMatch[1].trim();
63
- showLogsButton = true;
64
- }
65
- }
66
- }
67
-
68
- // Detect errors - if output exists but doesn't have the expected job completion format
69
- // This catches early failures (validation errors, API errors, etc.)
70
- const isExpectedFormat = output.includes('**Job ID:**') || output.includes('**View at:**');
71
- const looksLikeError = output.toLowerCase().includes('error') ||
72
- output.toLowerCase().includes('failed') ||
73
- output.toLowerCase().includes('exception') ||
74
- output.includes('Traceback');
75
-
76
- if (!isExpectedFormat || (looksLikeError && !logsContent)) {
77
- // This is likely an error message - show it
78
- errorMessage = output;
79
- jobFailed = true;
80
- }
81
- }
82
-
83
- // Sync right panel with current tool
84
- useEffect(() => {
85
- if (!batch || currentIndex >= batch.tools.length) return;
86
-
87
- // Only auto-open panel if pending
88
- if (status !== 'pending') return;
89
-
90
- const tool = batch.tools[currentIndex];
91
- const args = tool.arguments as any;
92
-
93
- if (tool.tool === 'hf_jobs' && (args.operation === 'run' || args.operation === 'scheduled run') && args.script) {
94
- setPanelContent({
95
- title: 'Compute Job Script',
96
- content: args.script,
97
- language: 'python',
98
- parameters: args
99
- });
100
- // Don't auto-open if already resolved
101
- } else if (tool.tool === 'hf_repo_files' && args.operation === 'upload' && args.content) {
102
- setPanelContent({
103
- title: `File Upload: ${args.path || 'unnamed'}`,
104
- content: args.content,
105
- parameters: args
106
- });
107
- }
108
- }, [currentIndex, batch, status, setPanelContent]);
109
-
110
- const handleResolve = useCallback(async (approved: boolean) => {
111
- if (!batch || !activeSessionId) return;
112
-
113
- const currentTool = batch.tools[currentIndex];
114
- const newDecisions = [
115
- ...decisions,
116
- {
117
- tool_call_id: currentTool.tool_call_id,
118
- approved,
119
- feedback: approved ? null : feedback || 'Rejected by user',
120
- },
121
- ];
122
-
123
- if (currentIndex < batch.tools.length - 1) {
124
- setDecisions(newDecisions);
125
- setCurrentIndex(currentIndex + 1);
126
- setFeedback('');
127
- } else {
128
- // All tools in batch resolved
129
- try {
130
- await fetch('/api/approve', {
131
- method: 'POST',
132
- headers: { 'Content-Type': 'application/json' },
133
- body: JSON.stringify({
134
- session_id: activeSessionId,
135
- approvals: newDecisions,
136
- }),
137
- });
138
-
139
- // Update message status
140
- updateMessage(activeSessionId, message.id, {
141
- approval: {
142
- ...approvalData!,
143
- status: approved ? 'approved' : 'rejected',
144
- decisions: newDecisions
145
- }
146
- });
147
-
148
- } catch (e) {
149
- console.error('Approval submission failed:', e);
150
- }
151
- }
152
- }, [activeSessionId, message.id, batch, currentIndex, feedback, decisions, approvalData, updateMessage]);
153
-
154
- if (!batch || currentIndex >= batch.tools.length) return null;
155
-
156
- const currentTool = batch.tools[currentIndex];
157
-
158
- // Check if script contains push_to_hub or upload_file
159
- const args = currentTool.arguments as any;
160
- const containsPushToHub = currentTool.tool === 'hf_jobs' && args.script && (args.script.includes('push_to_hub') || args.script.includes('upload_file'));
161
-
162
- const getToolDescription = (toolName: string, args: any) => {
163
- if (toolName === 'hf_jobs') {
164
- return (
165
- <Box sx={{ flex: 1 }}>
166
- <Typography variant="body2" sx={{ color: 'var(--muted-text)' }}>
167
- The agent wants to execute <Box component="span" sx={{ color: 'var(--accent-yellow)', fontWeight: 500 }}>hf_jobs</Box> on{' '}
168
- <Box component="span" sx={{ fontWeight: 500, color: 'var(--text)' }}>{args.hardware_flavor || 'default'}</Box> with a timeout of{' '}
169
- <Box component="span" sx={{ fontWeight: 500, color: 'var(--text)' }}>{args.timeout || '30m'}</Box>
170
- </Typography>
171
- </Box>
172
- );
173
- }
174
- return (
175
- <Typography variant="body2" sx={{ color: 'var(--muted-text)', flex: 1 }}>
176
- The agent wants to execute <Box component="span" sx={{ color: 'var(--accent-yellow)', fontWeight: 500 }}>{toolName}</Box>
177
- </Typography>
178
- );
179
- };
180
-
181
- const showCode = () => {
182
- const args = currentTool.arguments as any;
183
- if (currentTool.tool === 'hf_jobs' && args.script) {
184
- // Clear existing tabs and set up script tab (and logs if available)
185
- clearPanelTabs();
186
- setPanelTab({
187
- id: 'script',
188
- title: 'Script',
189
- content: args.script,
190
- language: 'python',
191
- parameters: args
192
- });
193
- // If logs are available (job completed), also add logs tab
194
- if (logsContent) {
195
- setPanelTab({
196
- id: 'logs',
197
- title: 'Logs',
198
- content: logsContent,
199
- language: 'text'
200
- });
201
- }
202
- setActivePanelTab('script');
203
- setRightPanelOpen(true);
204
- setLeftSidebarOpen(false);
205
- } else {
206
- setPanelContent({
207
- title: `Tool: ${currentTool.tool}`,
208
- content: JSON.stringify(args, null, 2),
209
- language: 'json',
210
- parameters: args
211
- });
212
- setRightPanelOpen(true);
213
- setLeftSidebarOpen(false);
214
- }
215
- };
216
-
217
- const handleViewLogs = (e: React.MouseEvent) => {
218
- e.stopPropagation();
219
- const args = currentTool.arguments as any;
220
- // Set up both tabs so user can switch between script and logs
221
- clearPanelTabs();
222
- if (currentTool.tool === 'hf_jobs' && args.script) {
223
- setPanelTab({
224
- id: 'script',
225
- title: 'Script',
226
- content: args.script,
227
- language: 'python',
228
- parameters: args
229
- });
230
- }
231
- setPanelTab({
232
- id: 'logs',
233
- title: 'Logs',
234
- content: logsContent,
235
- language: 'text'
236
- });
237
- setActivePanelTab('logs');
238
- setRightPanelOpen(true);
239
- setLeftSidebarOpen(false);
240
- };
241
-
242
- return (
243
- <Box
244
- className="action-card"
245
- sx={{
246
- width: '100%',
247
- padding: '18px',
248
- borderRadius: 'var(--radius-md)',
249
- background: 'linear-gradient(180deg, rgba(255,255,255,0.015), transparent)',
250
- border: '1px solid rgba(255,255,255,0.03)',
251
- display: 'flex',
252
- flexDirection: 'column',
253
- gap: '12px',
254
- opacity: status !== 'pending' && !showLogsButton ? 0.8 : 1
255
- }}
256
- >
257
- <Box sx={{ display: 'flex', alignItems: 'center', gap: 1 }}>
258
- <Typography variant="subtitle2" sx={{ fontWeight: 600, color: 'var(--text)' }}>
259
- {status === 'pending' ? 'Approval Required' : status === 'approved' ? 'Approved' : 'Rejected'}
260
- </Typography>
261
- <Typography variant="caption" sx={{ color: 'var(--muted-text)' }}>
262
- ({currentIndex + 1}/{batch.count})
263
- </Typography>
264
- {status === 'approved' && <CheckCircleIcon sx={{ fontSize: 18, color: 'var(--accent-green)' }} />}
265
- {status === 'rejected' && <CancelIcon sx={{ fontSize: 18, color: 'var(--accent-red)' }} />}
266
- </Box>
267
-
268
- <Box
269
- onClick={showCode}
270
- sx={{
271
- display: 'flex',
272
- alignItems: 'center',
273
- gap: 1,
274
- cursor: 'pointer',
275
- p: 1.5,
276
- borderRadius: '8px',
277
- bgcolor: 'rgba(0,0,0,0.2)',
278
- border: '1px solid rgba(255,255,255,0.05)',
279
- transition: 'all 0.2s',
280
- '&:hover': {
281
- bgcolor: 'rgba(255,255,255,0.03)',
282
- borderColor: 'var(--accent-primary)',
283
- }
284
- }}
285
- >
286
- {getToolDescription(currentTool.tool, currentTool.arguments)}
287
- <OpenInNewIcon sx={{ fontSize: 16, color: 'var(--muted-text)', opacity: 0.7 }} />
288
- </Box>
289
-
290
- {/* Script/Logs buttons for hf_jobs - always show when we have a script */}
291
- {currentTool.tool === 'hf_jobs' && args.script && (
292
- <Box sx={{ display: 'flex', flexDirection: 'column', gap: 1 }}>
293
- <Box sx={{ display: 'flex', gap: 1, flexWrap: 'wrap' }}>
294
- <Button
295
- variant="outlined"
296
- size="small"
297
- onClick={showCode}
298
- sx={{
299
- textTransform: 'none',
300
- borderColor: 'rgba(255,255,255,0.1)',
301
- color: 'var(--muted-text)',
302
- fontSize: '0.75rem',
303
- py: 0.5,
304
- '&:hover': {
305
- borderColor: 'var(--accent-primary)',
306
- color: 'var(--accent-primary)',
307
- bgcolor: 'rgba(255,255,255,0.03)'
308
- }
309
- }}
310
- >
311
- View Script
312
- </Button>
313
- <Button
314
- variant="outlined"
315
- size="small"
316
- onClick={handleViewLogs}
317
- disabled={!logsContent && status === 'pending'}
318
- sx={{
319
- textTransform: 'none',
320
- borderColor: 'rgba(255,255,255,0.1)',
321
- color: logsContent ? 'var(--accent-primary)' : 'var(--muted-text)',
322
- fontSize: '0.75rem',
323
- py: 0.5,
324
- '&:hover': {
325
- borderColor: 'var(--accent-primary)',
326
- bgcolor: 'rgba(255,255,255,0.03)'
327
- },
328
- '&.Mui-disabled': {
329
- color: 'rgba(255,255,255,0.3)',
330
- borderColor: 'rgba(255,255,255,0.05)',
331
- }
332
- }}
333
- >
334
- {logsContent ? 'View Logs' : 'Logs (waiting for job...)'}
335
- </Button>
336
- </Box>
337
-
338
- {/* Job URL - only show when we have a specific URL */}
339
- {jobUrl && (
340
- <Link
341
- href={jobUrl}
342
- target="_blank"
343
- rel="noopener noreferrer"
344
- sx={{
345
- display: 'flex',
346
- alignItems: 'center',
347
- gap: 0.5,
348
- color: 'var(--accent-primary)',
349
- fontSize: '0.75rem',
350
- textDecoration: 'none',
351
- opacity: 0.9,
352
- '&:hover': {
353
- opacity: 1,
354
- textDecoration: 'underline',
355
- }
356
- }}
357
- >
358
- <LaunchIcon sx={{ fontSize: 14 }} />
359
- View Job on Hugging Face
360
- </Link>
361
- )}
362
-
363
- {/* Show job status if available */}
364
- {jobStatus && (
365
- <Typography
366
- variant="caption"
367
- sx={{
368
- color: jobFailed ? 'var(--accent-red)' : 'var(--accent-green)',
369
- fontSize: '0.75rem',
370
- fontWeight: 500,
371
- }}
372
- >
373
- Status: {jobStatus}
374
- </Typography>
375
- )}
376
- </Box>
377
- )}
378
-
379
- {containsPushToHub && (
380
- <Typography variant="caption" sx={{ color: 'var(--accent-green)', fontSize: '0.75rem', opacity: 0.8, px: 0.5 }}>
381
- We've detected the result will be pushed to hub.
382
- </Typography>
383
- )}
384
-
385
- {/* Show error message if job failed */}
386
- {errorMessage && status !== 'pending' && (
387
- <Box
388
- sx={{
389
- p: 1.5,
390
- borderRadius: '8px',
391
- bgcolor: 'rgba(224, 90, 79, 0.1)',
392
- border: '1px solid rgba(224, 90, 79, 0.3)',
393
- }}
394
- >
395
- <Typography
396
- variant="caption"
397
- sx={{
398
- color: 'var(--accent-red)',
399
- fontWeight: 600,
400
- display: 'block',
401
- mb: 0.5,
402
- }}
403
- >
404
- Error
405
- </Typography>
406
- <Typography
407
- component="pre"
408
- sx={{
409
- color: 'var(--text)',
410
- fontSize: '0.75rem',
411
- fontFamily: 'ui-monospace, SFMono-Regular, Menlo, Monaco, monospace',
412
- whiteSpace: 'pre-wrap',
413
- wordBreak: 'break-word',
414
- m: 0,
415
- maxHeight: '150px',
416
- overflow: 'auto',
417
- }}
418
- >
419
- {errorMessage.length > 500 ? errorMessage.substring(0, 500) + '...' : errorMessage}
420
- </Typography>
421
- </Box>
422
- )}
423
-
424
-
425
- {status === 'pending' && (
426
- <Box sx={{ display: 'flex', flexDirection: 'column', gap: 2 }}>
427
- <Box sx={{ display: 'flex', gap: 1 }}>
428
- <TextField
429
- fullWidth
430
- size="small"
431
- placeholder="Feedback (optional)"
432
- value={feedback}
433
- onChange={(e) => setFeedback(e.target.value)}
434
- variant="outlined"
435
- sx={{
436
- '& .MuiOutlinedInput-root': {
437
- bgcolor: 'rgba(0,0,0,0.2)',
438
- fontFamily: 'inherit',
439
- fontSize: '0.9rem'
440
- }
441
- }}
442
- />
443
- <IconButton
444
- onClick={() => handleResolve(false)}
445
- disabled={!feedback}
446
- title="Reject with feedback"
447
- sx={{
448
- color: 'var(--accent-red)',
449
- border: '1px solid rgba(255,255,255,0.05)',
450
- borderRadius: '8px',
451
- width: 40,
452
- height: 40,
453
- '&:hover': {
454
- bgcolor: 'rgba(224, 90, 79, 0.1)',
455
- borderColor: 'var(--accent-red)',
456
- },
457
- '&.Mui-disabled': {
458
- color: 'rgba(255,255,255,0.1)',
459
- borderColor: 'rgba(255,255,255,0.02)'
460
- }
461
- }}
462
- >
463
- <SendIcon fontSize="small" />
464
- </IconButton>
465
- </Box>
466
-
467
- <Box className="action-buttons" sx={{ display: 'flex', gap: '10px' }}>
468
- <Button
469
- className="btn-reject"
470
- onClick={() => handleResolve(false)}
471
- sx={{
472
- flex: 1,
473
- background: 'transparent',
474
- border: '1px solid rgba(255,255,255,0.05)',
475
- color: 'var(--accent-red)',
476
- padding: '10px 14px',
477
- borderRadius: '10px',
478
- '&:hover': {
479
- bgcolor: 'rgba(224, 90, 79, 0.05)',
480
- borderColor: 'var(--accent-red)',
481
- }
482
- }}
483
- >
484
- Reject
485
- </Button>
486
- <Button
487
- className="btn-approve"
488
- onClick={() => handleResolve(true)}
489
- sx={{
490
- flex: 1,
491
- background: 'transparent',
492
- border: '1px solid rgba(255,255,255,0.05)',
493
- color: 'var(--accent-green)',
494
- padding: '10px 14px',
495
- borderRadius: '10px',
496
- '&:hover': {
497
- bgcolor: 'rgba(47, 204, 113, 0.05)',
498
- borderColor: 'var(--accent-green)',
499
- }
500
- }}
501
- >
502
- Approve
503
- </Button>
504
- </Box>
505
- </Box>
506
- )}
507
-
508
- {status === 'rejected' && decisions.some(d => d.feedback) && (
509
- <Typography variant="body2" sx={{ color: 'var(--accent-red)', mt: 1 }}>
510
- Feedback: {decisions.find(d => d.feedback)?.feedback}
511
- </Typography>
512
- )}
513
- </Box>
514
- );
515
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
frontend/src/components/Chat/AssistantMessage.tsx ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useMemo } from 'react';
2
+ import { Box, Stack, Typography } from '@mui/material';
3
+ import MarkdownContent from './MarkdownContent';
4
+ import ToolCallGroup from './ToolCallGroup';
5
+ import type { UIMessage } from 'ai';
6
+ import type { MessageMeta } from '@/types/agent';
7
+
8
+ interface AssistantMessageProps {
9
+ message: UIMessage;
10
+ isStreaming?: boolean;
11
+ approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null }>) => Promise<boolean>;
12
+ }
13
+
14
+ /**
15
+ * Groups consecutive tool parts together so they render as a single
16
+ * ToolCallGroup (visually identical to the old segments approach).
17
+ */
18
+ type DynamicToolPart = Extract<UIMessage['parts'][number], { type: 'dynamic-tool' }>;
19
+
20
+ function groupParts(parts: UIMessage['parts']) {
21
+ const groups: Array<
22
+ | { kind: 'text'; text: string; idx: number }
23
+ | { kind: 'tools'; tools: DynamicToolPart[]; idx: number }
24
+ > = [];
25
+
26
+ for (let i = 0; i < parts.length; i++) {
27
+ const part = parts[i];
28
+
29
+ if (part.type === 'text') {
30
+ groups.push({ kind: 'text', text: part.text, idx: i });
31
+ } else if (part.type === 'dynamic-tool') {
32
+ const toolPart = part as DynamicToolPart;
33
+ const last = groups[groups.length - 1];
34
+ if (last?.kind === 'tools') {
35
+ last.tools.push(toolPart);
36
+ } else {
37
+ groups.push({ kind: 'tools', tools: [toolPart], idx: i });
38
+ }
39
+ }
40
+ // step-start, step-end, etc. are ignored visually
41
+ }
42
+
43
+ return groups;
44
+ }
45
+
46
+ export default function AssistantMessage({ message, isStreaming = false, approveTools }: AssistantMessageProps) {
47
+ const groups = useMemo(() => groupParts(message.parts), [message.parts]);
48
+
49
+ // Find the last text group index for streaming cursor
50
+ let lastTextIdx = -1;
51
+ for (let i = groups.length - 1; i >= 0; i--) {
52
+ if (groups[i].kind === 'text') { lastTextIdx = i; break; }
53
+ }
54
+
55
+ const meta = message.metadata as MessageMeta | undefined;
56
+ const timeStr = meta?.createdAt
57
+ ? new Date(meta.createdAt).toLocaleTimeString([], { hour: '2-digit', minute: '2-digit' })
58
+ : null;
59
+
60
+ if (groups.length === 0) return null;
61
+
62
+ return (
63
+ <Box sx={{ minWidth: 0 }}>
64
+ <Stack direction="row" alignItems="baseline" spacing={1} sx={{ mb: 0.5 }}>
65
+ <Typography
66
+ variant="caption"
67
+ sx={{
68
+ fontWeight: 700,
69
+ fontSize: '0.72rem',
70
+ color: 'var(--muted-text)',
71
+ textTransform: 'uppercase',
72
+ letterSpacing: '0.04em',
73
+ }}
74
+ >
75
+ Assistant
76
+ </Typography>
77
+ {timeStr && (
78
+ <Typography variant="caption" sx={{ color: 'var(--muted-text)', fontSize: '0.7rem' }}>
79
+ {timeStr}
80
+ </Typography>
81
+ )}
82
+ </Stack>
83
+
84
+ <Box
85
+ sx={{
86
+ maxWidth: { xs: '95%', md: '85%' },
87
+ bgcolor: 'var(--surface)',
88
+ borderRadius: 1.5,
89
+ borderTopLeftRadius: 4,
90
+ px: { xs: 1.5, md: 2.5 },
91
+ py: 1.5,
92
+ border: '1px solid var(--border)',
93
+ }}
94
+ >
95
+ {groups.map((group, i) => {
96
+ if (group.kind === 'text' && group.text) {
97
+ return (
98
+ <MarkdownContent
99
+ key={group.idx}
100
+ content={group.text}
101
+ isStreaming={isStreaming && i === lastTextIdx}
102
+ />
103
+ );
104
+ }
105
+ if (group.kind === 'tools' && group.tools.length > 0) {
106
+ return (
107
+ <ToolCallGroup
108
+ key={group.idx}
109
+ tools={group.tools}
110
+ approveTools={approveTools}
111
+ />
112
+ );
113
+ }
114
+ return null;
115
+ })}
116
+ </Box>
117
+ </Box>
118
+ );
119
+ }
frontend/src/components/Chat/ChatInput.tsx CHANGED
@@ -1,14 +1,103 @@
1
- import { useState, useCallback, KeyboardEvent } from 'react';
2
- import { Box, TextField, IconButton, CircularProgress, Typography } from '@mui/material';
3
  import ArrowUpwardIcon from '@mui/icons-material/ArrowUpward';
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  interface ChatInputProps {
6
  onSend: (text: string) => void;
7
  disabled?: boolean;
 
8
  }
9
 
10
- export default function ChatInput({ onSend, disabled = false }: ChatInputProps) {
11
  const [input, setInput] = useState('');
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  const handleSend = useCallback(() => {
14
  if (input.trim() && !disabled) {
@@ -27,26 +116,48 @@ export default function ChatInput({ onSend, disabled = false }: ChatInputProps)
27
  [handleSend]
28
  );
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  return (
31
  <Box
32
  sx={{
33
- pb: 4,
34
- pt: 2,
35
  position: 'relative',
36
  zIndex: 10,
37
  }}
38
  >
39
- <Box sx={{ maxWidth: '880px', mx: 'auto', width: '100%', px: 2 }}>
40
  <Box
41
  className="composer"
42
  sx={{
43
  display: 'flex',
44
  gap: '10px',
45
  alignItems: 'flex-start',
46
- bgcolor: 'rgba(255,255,255,0.01)',
47
  borderRadius: 'var(--radius-md)',
48
  p: '12px',
49
- border: '1px solid rgba(255,255,255,0.03)',
50
  transition: 'box-shadow 0.2s ease, border-color 0.2s ease',
51
  '&:focus-within': {
52
  borderColor: 'var(--accent-yellow)',
@@ -61,9 +172,10 @@ export default function ChatInput({ onSend, disabled = false }: ChatInputProps)
61
  value={input}
62
  onChange={(e) => setInput(e.target.value)}
63
  onKeyDown={handleKeyDown}
64
- placeholder="Ask anything..."
65
  disabled={disabled}
66
  variant="standard"
 
67
  InputProps={{
68
  disableUnderline: true,
69
  sx: {
@@ -72,7 +184,7 @@ export default function ChatInput({ onSend, disabled = false }: ChatInputProps)
72
  fontFamily: 'inherit',
73
  padding: 0,
74
  lineHeight: 1.5,
75
- minHeight: '56px',
76
  alignItems: 'flex-start',
77
  }
78
  }}
@@ -99,7 +211,7 @@ export default function ChatInput({ onSend, disabled = false }: ChatInputProps)
99
  transition: 'all 0.2s',
100
  '&:hover': {
101
  color: 'var(--accent-yellow)',
102
- bgcolor: 'rgba(255,255,255,0.05)',
103
  },
104
  '&.Mui-disabled': {
105
  opacity: 0.3,
@@ -109,17 +221,108 @@ export default function ChatInput({ onSend, disabled = false }: ChatInputProps)
109
  {disabled ? <CircularProgress size={20} color="inherit" /> : <ArrowUpwardIcon fontSize="small" />}
110
  </IconButton>
111
  </Box>
112
-
113
  {/* Powered By Badge */}
114
- <Box sx={{ display: 'flex', alignItems: 'center', justifyContent: 'center', mt: 1.5, gap: 0.8, opacity: 0.5 }}>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  <Typography variant="caption" sx={{ fontSize: '10px', color: 'var(--muted-text)', textTransform: 'uppercase', letterSpacing: '0.05em', fontWeight: 500 }}>
116
  powered by
117
  </Typography>
118
- <img src="/claude-logo.png" alt="Claude" style={{ height: '12px', objectFit: 'contain' }} />
 
 
 
 
119
  <Typography variant="caption" sx={{ fontSize: '10px', color: 'var(--text)', fontWeight: 600, letterSpacing: '0.02em' }}>
120
- claude-opus-4-5-20251101
121
  </Typography>
 
122
  </Box>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  </Box>
124
  </Box>
125
  );
 
1
+ import { useState, useCallback, useEffect, useRef, KeyboardEvent } from 'react';
2
+ import { Box, TextField, IconButton, CircularProgress, Typography, Menu, MenuItem, ListItemIcon, ListItemText, Chip } from '@mui/material';
3
  import ArrowUpwardIcon from '@mui/icons-material/ArrowUpward';
4
+ import ArrowDropDownIcon from '@mui/icons-material/ArrowDropDown';
5
+ import { apiFetch } from '@/utils/api';
6
+
7
+ // Model configuration
8
+ interface ModelOption {
9
+ id: string;
10
+ name: string;
11
+ description: string;
12
+ modelPath: string;
13
+ avatarUrl: string;
14
+ recommended?: boolean;
15
+ }
16
+
17
+ const getHfAvatarUrl = (modelId: string) => {
18
+ const org = modelId.split('/')[0];
19
+ return `https://huggingface.co/api/avatars/${org}`;
20
+ };
21
+
22
+ const MODEL_OPTIONS: ModelOption[] = [
23
+ {
24
+ id: 'minimax-m2.1',
25
+ name: 'MiniMax M2.1',
26
+ description: 'Via Novita',
27
+ modelPath: 'huggingface/novita/minimax/minimax-m2.1',
28
+ avatarUrl: getHfAvatarUrl('MiniMaxAI/MiniMax-M2.1'),
29
+ recommended: true,
30
+ },
31
+ {
32
+ id: 'claude-opus',
33
+ name: 'Claude Opus 4.5',
34
+ description: 'Anthropic',
35
+ modelPath: 'anthropic/claude-opus-4-5-20251101',
36
+ avatarUrl: 'https://huggingface.co/api/avatars/Anthropic',
37
+ recommended: true,
38
+ },
39
+ {
40
+ id: 'kimi-k2.5',
41
+ name: 'Kimi K2.5',
42
+ description: 'Via Novita',
43
+ modelPath: 'huggingface/novita/moonshotai/kimi-k2.5',
44
+ avatarUrl: getHfAvatarUrl('moonshotai/Kimi-K2.5'),
45
+ },
46
+ {
47
+ id: 'glm-5',
48
+ name: 'GLM 5',
49
+ description: 'Via Novita',
50
+ modelPath: 'huggingface/novita/zai-org/glm-5',
51
+ avatarUrl: getHfAvatarUrl('zai-org/GLM-5'),
52
+ },
53
+ ];
54
+
55
+ const findModelByPath = (path: string): ModelOption | undefined => {
56
+ return MODEL_OPTIONS.find(m => m.modelPath === path || path?.includes(m.id));
57
+ };
58
 
59
  interface ChatInputProps {
60
  onSend: (text: string) => void;
61
  disabled?: boolean;
62
+ placeholder?: string;
63
  }
64
 
65
+ export default function ChatInput({ onSend, disabled = false, placeholder = 'Ask anything...' }: ChatInputProps) {
66
  const [input, setInput] = useState('');
67
+ const inputRef = useRef<HTMLTextAreaElement>(null);
68
+ const [selectedModelId, setSelectedModelId] = useState<string>(() => {
69
+ try {
70
+ const stored = localStorage.getItem('hf-agent-model');
71
+ if (stored && MODEL_OPTIONS.some(m => m.id === stored)) return stored;
72
+ } catch { /* localStorage unavailable */ }
73
+ return MODEL_OPTIONS[0].id;
74
+ });
75
+ const [modelAnchorEl, setModelAnchorEl] = useState<null | HTMLElement>(null);
76
+
77
+ // Sync with backend on mount (backend is source of truth, localStorage is just a cache)
78
+ useEffect(() => {
79
+ fetch('/api/config/model')
80
+ .then((res) => (res.ok ? res.json() : null))
81
+ .then((data) => {
82
+ if (data?.current) {
83
+ const model = findModelByPath(data.current);
84
+ if (model) {
85
+ setSelectedModelId(model.id);
86
+ try { localStorage.setItem('hf-agent-model', model.id); } catch { /* ignore */ }
87
+ }
88
+ }
89
+ })
90
+ .catch(() => { /* ignore */ });
91
+ }, []);
92
+
93
+ const selectedModel = MODEL_OPTIONS.find(m => m.id === selectedModelId) || MODEL_OPTIONS[0];
94
+
95
+ // Auto-focus the textarea when the session becomes ready (disabled -> false)
96
+ useEffect(() => {
97
+ if (!disabled && inputRef.current) {
98
+ inputRef.current.focus();
99
+ }
100
+ }, [disabled]);
101
 
102
  const handleSend = useCallback(() => {
103
  if (input.trim() && !disabled) {
 
116
  [handleSend]
117
  );
118
 
119
+ const handleModelClick = (event: React.MouseEvent<HTMLElement>) => {
120
+ setModelAnchorEl(event.currentTarget);
121
+ };
122
+
123
+ const handleModelClose = () => {
124
+ setModelAnchorEl(null);
125
+ };
126
+
127
+ const handleSelectModel = async (model: ModelOption) => {
128
+ handleModelClose();
129
+ try {
130
+ const res = await apiFetch('/api/config/model', {
131
+ method: 'POST',
132
+ body: JSON.stringify({ model: model.modelPath }),
133
+ });
134
+ if (res.ok) {
135
+ setSelectedModelId(model.id);
136
+ try { localStorage.setItem('hf-agent-model', model.id); } catch { /* ignore */ }
137
+ }
138
+ } catch { /* ignore */ }
139
+ };
140
+
141
  return (
142
  <Box
143
  sx={{
144
+ pb: { xs: 2, md: 4 },
145
+ pt: { xs: 1, md: 2 },
146
  position: 'relative',
147
  zIndex: 10,
148
  }}
149
  >
150
+ <Box sx={{ maxWidth: '880px', mx: 'auto', width: '100%', px: { xs: 0, sm: 1, md: 2 } }}>
151
  <Box
152
  className="composer"
153
  sx={{
154
  display: 'flex',
155
  gap: '10px',
156
  alignItems: 'flex-start',
157
+ bgcolor: 'var(--composer-bg)',
158
  borderRadius: 'var(--radius-md)',
159
  p: '12px',
160
+ border: '1px solid var(--border)',
161
  transition: 'box-shadow 0.2s ease, border-color 0.2s ease',
162
  '&:focus-within': {
163
  borderColor: 'var(--accent-yellow)',
 
172
  value={input}
173
  onChange={(e) => setInput(e.target.value)}
174
  onKeyDown={handleKeyDown}
175
+ placeholder={placeholder}
176
  disabled={disabled}
177
  variant="standard"
178
+ inputRef={inputRef}
179
  InputProps={{
180
  disableUnderline: true,
181
  sx: {
 
184
  fontFamily: 'inherit',
185
  padding: 0,
186
  lineHeight: 1.5,
187
+ minHeight: { xs: '44px', md: '56px' },
188
  alignItems: 'flex-start',
189
  }
190
  }}
 
211
  transition: 'all 0.2s',
212
  '&:hover': {
213
  color: 'var(--accent-yellow)',
214
+ bgcolor: 'var(--hover-bg)',
215
  },
216
  '&.Mui-disabled': {
217
  opacity: 0.3,
 
221
  {disabled ? <CircularProgress size={20} color="inherit" /> : <ArrowUpwardIcon fontSize="small" />}
222
  </IconButton>
223
  </Box>
224
+
225
  {/* Powered By Badge */}
226
+ <Box
227
+ onClick={handleModelClick}
228
+ sx={{
229
+ display: 'flex',
230
+ alignItems: 'center',
231
+ justifyContent: 'center',
232
+ mt: 1.5,
233
+ gap: 0.8,
234
+ opacity: 0.6,
235
+ cursor: 'pointer',
236
+ transition: 'opacity 0.2s',
237
+ '&:hover': {
238
+ opacity: 1
239
+ }
240
+ }}
241
+ >
242
  <Typography variant="caption" sx={{ fontSize: '10px', color: 'var(--muted-text)', textTransform: 'uppercase', letterSpacing: '0.05em', fontWeight: 500 }}>
243
  powered by
244
  </Typography>
245
+ <img
246
+ src={selectedModel.avatarUrl}
247
+ alt={selectedModel.name}
248
+ style={{ height: '14px', width: '14px', objectFit: 'contain', borderRadius: '2px' }}
249
+ />
250
  <Typography variant="caption" sx={{ fontSize: '10px', color: 'var(--text)', fontWeight: 600, letterSpacing: '0.02em' }}>
251
+ {selectedModel.name}
252
  </Typography>
253
+ <ArrowDropDownIcon sx={{ fontSize: '14px', color: 'var(--muted-text)' }} />
254
  </Box>
255
+
256
+ {/* Model Selection Menu */}
257
+ <Menu
258
+ anchorEl={modelAnchorEl}
259
+ open={Boolean(modelAnchorEl)}
260
+ onClose={handleModelClose}
261
+ anchorOrigin={{
262
+ vertical: 'top',
263
+ horizontal: 'center',
264
+ }}
265
+ transformOrigin={{
266
+ vertical: 'bottom',
267
+ horizontal: 'center',
268
+ }}
269
+ slotProps={{
270
+ paper: {
271
+ sx: {
272
+ bgcolor: 'var(--panel)',
273
+ border: '1px solid var(--divider)',
274
+ mb: 1,
275
+ maxHeight: '400px',
276
+ }
277
+ }
278
+ }}
279
+ >
280
+ {MODEL_OPTIONS.map((model) => (
281
+ <MenuItem
282
+ key={model.id}
283
+ onClick={() => handleSelectModel(model)}
284
+ selected={selectedModelId === model.id}
285
+ sx={{
286
+ py: 1.5,
287
+ '&.Mui-selected': {
288
+ bgcolor: 'rgba(255,255,255,0.05)',
289
+ }
290
+ }}
291
+ >
292
+ <ListItemIcon>
293
+ <img
294
+ src={model.avatarUrl}
295
+ alt={model.name}
296
+ style={{ width: 24, height: 24, borderRadius: '4px', objectFit: 'cover' }}
297
+ />
298
+ </ListItemIcon>
299
+ <ListItemText
300
+ primary={
301
+ <Box sx={{ display: 'flex', alignItems: 'center', gap: 1 }}>
302
+ {model.name}
303
+ {model.recommended && (
304
+ <Chip
305
+ label="Recommended"
306
+ size="small"
307
+ sx={{
308
+ height: '18px',
309
+ fontSize: '10px',
310
+ bgcolor: 'var(--accent-yellow)',
311
+ color: '#000',
312
+ fontWeight: 600,
313
+ }}
314
+ />
315
+ )}
316
+ </Box>
317
+ }
318
+ secondary={model.description}
319
+ secondaryTypographyProps={{
320
+ sx: { fontSize: '12px', color: 'var(--muted-text)' }
321
+ }}
322
+ />
323
+ </MenuItem>
324
+ ))}
325
+ </Menu>
326
  </Box>
327
  </Box>
328
  );
frontend/src/components/Chat/MarkdownContent.tsx ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useMemo, useRef, useState, useEffect } from 'react';
2
+ import { Box } from '@mui/material';
3
+ import ReactMarkdown from 'react-markdown';
4
+ import remarkGfm from 'remark-gfm';
5
+ import type { SxProps, Theme } from '@mui/material/styles';
6
+
7
+ interface MarkdownContentProps {
8
+ content: string;
9
+ sx?: SxProps<Theme>;
10
+ /** When true, shows a blinking cursor and throttles renders. */
11
+ isStreaming?: boolean;
12
+ }
13
+
14
+ /** Shared markdown styles β€” adapts to light/dark via CSS variables. */
15
+ const markdownSx: SxProps<Theme> = {
16
+ fontSize: '0.925rem',
17
+ lineHeight: 1.7,
18
+ color: 'var(--text)',
19
+ wordBreak: 'break-word',
20
+
21
+ '& p': { m: 0, mb: 1.5, '&:last-child': { mb: 0 } },
22
+
23
+ '& h1, & h2, & h3, & h4': { mt: 2.5, mb: 1, fontWeight: 600, lineHeight: 1.3 },
24
+ '& h1': { fontSize: '1.35rem' },
25
+ '& h2': { fontSize: '1.15rem' },
26
+ '& h3': { fontSize: '1.05rem' },
27
+
28
+ '& pre': {
29
+ bgcolor: 'var(--code-bg)',
30
+ p: 2,
31
+ borderRadius: 2,
32
+ overflow: 'auto',
33
+ fontSize: '0.82rem',
34
+ lineHeight: 1.6,
35
+ border: '1px solid var(--tool-border)',
36
+ my: 2,
37
+ },
38
+ '& code': {
39
+ bgcolor: 'var(--hover-bg)',
40
+ px: 0.75,
41
+ py: 0.25,
42
+ borderRadius: 0.5,
43
+ fontSize: '0.84rem',
44
+ fontFamily: '"JetBrains Mono", ui-monospace, SFMono-Regular, Menlo, Monaco, monospace',
45
+ },
46
+ '& pre code': { bgcolor: 'transparent', p: 0 },
47
+
48
+ '& a': {
49
+ color: 'var(--accent-yellow)',
50
+ textDecoration: 'none',
51
+ fontWeight: 500,
52
+ '&:hover': { textDecoration: 'underline' },
53
+ },
54
+
55
+ '& ul, & ol': { pl: 3, my: 1 },
56
+ '& li': { mb: 0.5 },
57
+ '& li::marker': { color: 'var(--muted-text)' },
58
+
59
+ '& blockquote': {
60
+ borderLeft: '3px solid var(--accent-yellow)',
61
+ pl: 2,
62
+ ml: 0,
63
+ my: 1.5,
64
+ color: 'var(--muted-text)',
65
+ fontStyle: 'italic',
66
+ },
67
+
68
+ '& table': {
69
+ borderCollapse: 'collapse',
70
+ width: '100%',
71
+ my: 2,
72
+ fontSize: '0.85rem',
73
+ },
74
+ '& th': {
75
+ borderBottom: '2px solid var(--border-hover)',
76
+ textAlign: 'left',
77
+ p: 1,
78
+ fontWeight: 600,
79
+ },
80
+ '& td': {
81
+ borderBottom: '1px solid var(--tool-border)',
82
+ p: 1,
83
+ },
84
+
85
+ '& hr': {
86
+ border: 'none',
87
+ borderTop: '1px solid var(--border)',
88
+ my: 2,
89
+ },
90
+
91
+ '& img': {
92
+ maxWidth: '100%',
93
+ borderRadius: 2,
94
+ },
95
+ };
96
+
97
+ /**
98
+ * Throttled content for streaming: render the full markdown through
99
+ * ReactMarkdown but only re-parse every ~80ms to avoid layout thrashing.
100
+ * This is the Claude approach β€” always render as markdown, never split
101
+ * into raw text. The parser handles incomplete tables gracefully.
102
+ */
103
+ function useThrottledValue(value: string, isStreaming: boolean, intervalMs = 80): string {
104
+ const [throttled, setThrottled] = useState(value);
105
+ const lastUpdate = useRef(0);
106
+ const pending = useRef<ReturnType<typeof setTimeout> | null>(null);
107
+ const latestValue = useRef(value);
108
+ latestValue.current = value;
109
+
110
+ useEffect(() => {
111
+ if (!isStreaming) {
112
+ // Not streaming β€” always use latest value immediately
113
+ setThrottled(value);
114
+ return;
115
+ }
116
+
117
+ const now = Date.now();
118
+ const elapsed = now - lastUpdate.current;
119
+
120
+ if (elapsed >= intervalMs) {
121
+ // Enough time passed β€” update immediately
122
+ setThrottled(value);
123
+ lastUpdate.current = now;
124
+ } else {
125
+ // Schedule an update for the remaining time
126
+ if (pending.current) clearTimeout(pending.current);
127
+ pending.current = setTimeout(() => {
128
+ setThrottled(latestValue.current);
129
+ lastUpdate.current = Date.now();
130
+ pending.current = null;
131
+ }, intervalMs - elapsed);
132
+ }
133
+
134
+ return () => {
135
+ if (pending.current) clearTimeout(pending.current);
136
+ };
137
+ }, [value, isStreaming, intervalMs]);
138
+
139
+ // When streaming ends, flush immediately
140
+ useEffect(() => {
141
+ if (!isStreaming) {
142
+ setThrottled(latestValue.current);
143
+ }
144
+ }, [isStreaming]);
145
+
146
+ return throttled;
147
+ }
148
+
149
+ export default function MarkdownContent({ content, sx, isStreaming = false }: MarkdownContentProps) {
150
+ // Throttle re-parses during streaming to ~12fps (every 80ms)
151
+ const displayContent = useThrottledValue(content, isStreaming);
152
+
153
+ const remarkPlugins = useMemo(() => [remarkGfm], []);
154
+
155
+ return (
156
+ <Box sx={[markdownSx, ...(Array.isArray(sx) ? sx : sx ? [sx] : [])]}>
157
+ <ReactMarkdown remarkPlugins={remarkPlugins}>{displayContent}</ReactMarkdown>
158
+ </Box>
159
+ );
160
+ }
frontend/src/components/Chat/MessageBubble.tsx CHANGED
@@ -1,215 +1,44 @@
1
- import { Box, Paper, Typography } from '@mui/material';
2
- import ReactMarkdown from 'react-markdown';
3
- import remarkGfm from 'remark-gfm';
4
- import ApprovalFlow from './ApprovalFlow';
5
- import type { Message, TraceLog } from '@/types/agent';
6
- import { useAgentStore } from '@/store/agentStore';
7
- import { useLayoutStore } from '@/store/layoutStore';
8
 
9
  interface MessageBubbleProps {
10
- message: Message;
 
 
 
 
 
11
  }
12
 
13
- // Render a tools segment with clickable tool calls
14
- function ToolsSegment({ tools }: { tools: TraceLog[] }) {
15
- const { showToolOutput } = useAgentStore();
16
- const { setRightPanelOpen } = useLayoutStore();
17
-
18
- const handleToolClick = (log: TraceLog) => {
19
- if (log.completed && log.output) {
20
- showToolOutput(log);
21
- setRightPanelOpen(true);
22
- }
23
- };
24
-
25
- return (
26
- <Box
27
- sx={{
28
- bgcolor: 'rgba(0,0,0,0.3)',
29
- borderRadius: 1,
30
- p: 1.5,
31
- border: 1,
32
- borderColor: 'rgba(255,255,255,0.05)',
33
- my: 1.5,
34
- }}
35
- >
36
- <Box sx={{ display: 'flex', flexDirection: 'column', gap: 0.5 }}>
37
- {tools.map((log) => {
38
- const isClickable = log.completed && log.output;
39
- return (
40
- <Typography
41
- key={log.id}
42
- variant="caption"
43
- component="div"
44
- onClick={() => handleToolClick(log)}
45
- sx={{
46
- color: 'var(--muted-text)',
47
- fontFamily: 'ui-monospace, SFMono-Regular, Menlo, Monaco, monospace',
48
- fontSize: '0.75rem',
49
- display: 'flex',
50
- alignItems: 'center',
51
- gap: 0.5,
52
- cursor: isClickable ? 'pointer' : 'default',
53
- borderRadius: 0.5,
54
- px: 0.5,
55
- mx: -0.5,
56
- transition: 'background-color 0.15s ease',
57
- '&:hover': isClickable ? {
58
- bgcolor: 'rgba(255,255,255,0.05)',
59
- } : {},
60
- }}
61
- >
62
- <span style={{
63
- color: log.completed
64
- ? (log.success === false ? '#F87171' : '#FDB022')
65
- : 'inherit',
66
- fontSize: '0.85rem',
67
- }}>
68
- {log.completed ? (log.success === false ? 'βœ—' : 'βœ“') : 'β€’'}
69
- </span>
70
- <span style={{
71
- fontWeight: 600,
72
- color: isClickable ? 'rgba(255, 255, 255, 0.9)' : 'inherit',
73
- textDecoration: isClickable ? 'underline' : 'none',
74
- textDecorationColor: 'rgba(255,255,255,0.3)',
75
- textUnderlineOffset: '2px',
76
- }}>
77
- {log.tool}
78
- </span>
79
- {!log.completed && <span style={{ opacity: 0.6 }}>...</span>}
80
- {isClickable && (
81
- <span style={{
82
- opacity: 0.4,
83
- fontSize: '0.65rem',
84
- marginLeft: 'auto',
85
- }}>
86
- click to view
87
- </span>
88
- )}
89
- </Typography>
90
- );
91
- })}
92
- </Box>
93
- </Box>
94
- );
95
- }
96
-
97
- // Markdown styles
98
- const markdownStyles = {
99
- '& p': { m: 0, mb: 1, '&:last-child': { mb: 0 } },
100
- '& pre': {
101
- bgcolor: 'rgba(0,0,0,0.5)',
102
- p: 1.5,
103
- borderRadius: 1,
104
- overflow: 'auto',
105
- fontSize: '0.85rem',
106
- border: '1px solid rgba(255,255,255,0.05)',
107
- },
108
- '& code': {
109
- bgcolor: 'rgba(255,255,255,0.05)',
110
- px: 0.5,
111
- py: 0.25,
112
- borderRadius: 0.5,
113
- fontSize: '0.85rem',
114
- fontFamily: '"JetBrains Mono", monospace',
115
- },
116
- '& pre code': { bgcolor: 'transparent', p: 0 },
117
- '& a': {
118
- color: 'var(--accent-yellow)',
119
- textDecoration: 'none',
120
- '&:hover': { textDecoration: 'underline' },
121
- },
122
- '& ul, & ol': { pl: 2, my: 1 },
123
- '& table': {
124
- borderCollapse: 'collapse',
125
- width: '100%',
126
- my: 2,
127
- fontSize: '0.875rem',
128
- },
129
- '& th': {
130
- borderBottom: '1px solid rgba(255,255,255,0.1)',
131
- textAlign: 'left',
132
- p: 1,
133
- bgcolor: 'rgba(255,255,255,0.02)',
134
- },
135
- '& td': {
136
- borderBottom: '1px solid rgba(255,255,255,0.05)',
137
- p: 1,
138
- },
139
- };
140
-
141
- export default function MessageBubble({ message }: MessageBubbleProps) {
142
- const isUser = message.role === 'user';
143
- const isAssistant = message.role === 'assistant';
144
-
145
- if (message.approval) {
146
  return (
147
- <Box sx={{ width: '100%', maxWidth: '880px', mx: 'auto', my: 2 }}>
148
- <ApprovalFlow message={message} />
149
- </Box>
 
 
 
150
  );
151
  }
152
 
153
- // Render segments chronologically if available, otherwise fall back to content
154
- const renderContent = () => {
155
- if (message.segments && message.segments.length > 0) {
156
- return message.segments.map((segment, idx) => {
157
- if (segment.type === 'text' && segment.content) {
158
- return (
159
- <Box key={idx} sx={markdownStyles}>
160
- <ReactMarkdown remarkPlugins={[remarkGfm]}>{segment.content}</ReactMarkdown>
161
- </Box>
162
- );
163
- }
164
- if (segment.type === 'tools' && segment.tools && segment.tools.length > 0) {
165
- return <ToolsSegment key={idx} tools={segment.tools} />;
166
- }
167
- return null;
168
- });
169
- }
170
- // Fallback: just render content
171
  return (
172
- <Box sx={markdownStyles}>
173
- <ReactMarkdown remarkPlugins={[remarkGfm]}>{message.content}</ReactMarkdown>
174
- </Box>
 
 
175
  );
176
- };
177
-
178
- return (
179
- <Box
180
- sx={{
181
- display: 'flex',
182
- justifyContent: isUser ? 'flex-end' : 'flex-start',
183
- width: '100%',
184
- maxWidth: '880px',
185
- mx: 'auto',
186
- }}
187
- >
188
- <Paper
189
- elevation={0}
190
- className={`message ${isUser ? 'user' : isAssistant ? 'assistant' : ''}`}
191
- sx={{
192
- p: '14px 18px',
193
- margin: '10px 0',
194
- maxWidth: '100%',
195
- borderRadius: 'var(--radius-lg)',
196
- borderTopLeftRadius: isAssistant ? '6px' : undefined,
197
- lineHeight: 1.45,
198
- boxShadow: 'var(--shadow-1)',
199
- border: '1px solid rgba(255,255,255,0.03)',
200
- background: 'linear-gradient(180deg, rgba(255,255,255,0.015), transparent)',
201
- }}
202
- >
203
- {renderContent()}
204
 
205
- <Typography
206
- className="meta"
207
- variant="caption"
208
- sx={{ display: 'block', textAlign: 'right', mt: 1, fontSize: '11px', opacity: 0.5 }}
209
- >
210
- {new Date(message.timestamp).toLocaleTimeString()}
211
- </Typography>
212
- </Paper>
213
- </Box>
214
- );
215
  }
 
1
+ import UserMessage from './UserMessage';
2
+ import AssistantMessage from './AssistantMessage';
3
+ import type { UIMessage } from 'ai';
 
 
 
 
4
 
5
  interface MessageBubbleProps {
6
+ message: UIMessage;
7
+ isLastTurn?: boolean;
8
+ onUndoTurn?: () => void;
9
+ isProcessing?: boolean;
10
+ isStreaming?: boolean;
11
+ approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null }>) => Promise<boolean>;
12
  }
13
 
14
+ export default function MessageBubble({
15
+ message,
16
+ isLastTurn = false,
17
+ onUndoTurn,
18
+ isProcessing = false,
19
+ isStreaming = false,
20
+ approveTools,
21
+ }: MessageBubbleProps) {
22
+ if (message.role === 'user') {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  return (
24
+ <UserMessage
25
+ message={message}
26
+ isLastTurn={isLastTurn}
27
+ onUndoTurn={onUndoTurn}
28
+ isProcessing={isProcessing}
29
+ />
30
  );
31
  }
32
 
33
+ if (message.role === 'assistant') {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  return (
35
+ <AssistantMessage
36
+ message={message}
37
+ isStreaming={isStreaming}
38
+ approveTools={approveTools}
39
+ />
40
  );
41
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ return null;
 
 
 
 
 
 
 
 
 
44
  }
frontend/src/components/Chat/MessageList.tsx CHANGED
@@ -1,100 +1,151 @@
1
- import { useEffect, useRef } from 'react';
2
- import { Box, Typography } from '@mui/material';
3
- import { useSessionStore } from '@/store/sessionStore';
4
  import MessageBubble from './MessageBubble';
5
- import type { Message } from '@/types/agent';
 
 
6
 
7
  interface MessageListProps {
8
- messages: Message[];
9
  isProcessing: boolean;
 
 
10
  }
11
 
12
- const TechnicalIndicator = () => (
13
- <Box
14
- component="span"
15
- sx={{
16
- color: 'primary.main',
17
- fontFamily: 'monospace',
18
- fontWeight: 'bold',
19
- fontSize: '1.2rem',
20
- lineHeight: 0,
21
- display: 'inline-block',
22
- verticalAlign: 'middle',
23
- width: '1em',
24
- letterSpacing: '-3px',
25
- transform: 'scale(0.6) translateY(-2px)',
26
- '&::after': {
27
- content: '""',
28
- animation: 'dots 2s steps(4, end) infinite',
29
- },
30
- '@keyframes dots': {
31
- '0%': { content: '""' },
32
- '25%': { content: '"."' },
33
- '50%': { content: '".."' },
34
- '75%, 100%': { content: '"..."' },
35
- },
36
- }}
37
- />
38
- );
39
 
40
- export default function MessageList({ messages, isProcessing }: MessageListProps) {
41
- const bottomRef = useRef<HTMLDivElement>(null);
42
- const { activeSessionId } = useSessionStore();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- // Auto-scroll to bottom when new messages arrive
45
  useEffect(() => {
46
- bottomRef.current?.scrollIntoView({ behavior: 'smooth' });
47
- }, [messages, isProcessing]);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  return (
50
  <Box
 
51
  sx={{
52
  flex: 1,
53
  overflow: 'auto',
54
- p: 2,
 
55
  display: 'flex',
56
  flexDirection: 'column',
57
  }}
58
  >
59
- <Box sx={{ maxWidth: 'md', mx: 'auto', width: '100%', display: 'flex', flexDirection: 'column', gap: 2 }}>
 
 
 
 
 
 
 
 
60
  {messages.length === 0 && !isProcessing ? (
61
- <Box
62
- sx={{
63
- flex: 1,
64
- display: 'flex',
65
- alignItems: 'center',
66
- justifyContent: 'center',
67
- py: 8,
68
- }}
69
- >
70
- <Typography color="text.secondary" sx={{ fontFamily: 'monospace' }}>
71
- Awaiting input…
72
- </Typography>
73
- </Box>
74
  ) : (
75
- messages.map((message) => (
76
- <MessageBubble key={message.id} message={message} />
 
 
 
 
 
 
 
 
77
  ))
78
  )}
79
-
80
- {isProcessing && (
81
- <Box sx={{ width: '100%', mb: 2 }}>
82
- <Box sx={{ display: 'flex', alignItems: 'center', gap: 1, mb: 1, px: 0.5 }}>
83
- <Typography variant="caption" color="text.secondary" sx={{ fontFamily: 'monospace', fontWeight: 600 }}>
84
- Thinking
85
- </Typography>
86
- <TechnicalIndicator />
87
- </Box>
88
- </Box>
89
- )}
90
 
91
- {activeSessionId && (
92
- // ApprovalFlow is now handled within messages
93
- null
94
- )}
95
-
96
- <div ref={bottomRef} />
97
- </Box>
98
  </Box>
99
  );
100
- }
 
1
+ import { useCallback, useEffect, useRef, useMemo } from 'react';
2
+ import { Box, Stack, Typography } from '@mui/material';
 
3
  import MessageBubble from './MessageBubble';
4
+ import ActivityStatusBar from './ActivityStatusBar';
5
+ import { useAgentStore } from '@/store/agentStore';
6
+ import type { UIMessage } from 'ai';
7
 
8
  interface MessageListProps {
9
+ messages: UIMessage[];
10
  isProcessing: boolean;
11
+ approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null }>) => Promise<boolean>;
12
+ onUndoLastTurn: () => void | Promise<void>;
13
  }
14
 
15
+ function getGreeting(): string {
16
+ const h = new Date().getHours();
17
+ if (h < 12) return 'Morning';
18
+ if (h < 17) return 'Afternoon';
19
+ return 'Evening';
20
+ }
21
+
22
+ function WelcomeGreeting() {
23
+ const { user } = useAgentStore();
24
+ const firstName = user?.name?.split(' ')[0] || user?.username;
25
+ const greeting = firstName ? `${getGreeting()}, ${firstName}` : getGreeting();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ return (
28
+ <Box
29
+ sx={{
30
+ flex: 1,
31
+ display: 'flex',
32
+ flexDirection: 'column',
33
+ alignItems: 'center',
34
+ justifyContent: 'center',
35
+ py: 8,
36
+ gap: 1.5,
37
+ }}
38
+ >
39
+ <Typography
40
+ sx={{
41
+ fontFamily: 'monospace',
42
+ fontSize: '1.6rem',
43
+ color: 'var(--text)',
44
+ fontWeight: 600,
45
+ }}
46
+ >
47
+ {greeting}
48
+ </Typography>
49
+ <Typography
50
+ color="text.secondary"
51
+ sx={{ fontFamily: 'monospace', fontSize: '0.9rem' }}
52
+ >
53
+ Let's build something impressive?
54
+ </Typography>
55
+ </Box>
56
+ );
57
+ }
58
+
59
+ export default function MessageList({ messages, isProcessing, approveTools, onUndoLastTurn }: MessageListProps) {
60
+ const scrollContainerRef = useRef<HTMLDivElement>(null);
61
+ const stickToBottom = useRef(true);
62
+
63
+ const scrollToBottom = useCallback(() => {
64
+ const el = scrollContainerRef.current;
65
+ if (el) el.scrollTop = el.scrollHeight;
66
+ }, []);
67
+
68
+ useEffect(() => {
69
+ const el = scrollContainerRef.current;
70
+ if (!el) return;
71
+ const onScroll = () => {
72
+ const distFromBottom = el.scrollHeight - el.scrollTop - el.clientHeight;
73
+ stickToBottom.current = distFromBottom < 80;
74
+ };
75
+ el.addEventListener('scroll', onScroll, { passive: true });
76
+ return () => el.removeEventListener('scroll', onScroll);
77
+ }, []);
78
+
79
+ useEffect(() => {
80
+ if (stickToBottom.current) scrollToBottom();
81
+ }, [messages, isProcessing, scrollToBottom]);
82
 
 
83
  useEffect(() => {
84
+ const el = scrollContainerRef.current;
85
+ if (!el) return;
86
+ const observer = new MutationObserver(() => {
87
+ if (stickToBottom.current) el.scrollTop = el.scrollHeight;
88
+ });
89
+ observer.observe(el, { childList: true, subtree: true, characterData: true });
90
+ return () => observer.disconnect();
91
+ }, []);
92
+
93
+ const lastUserMsgId = useMemo(() => {
94
+ for (let i = messages.length - 1; i >= 0; i--) {
95
+ if (messages[i].role === 'user') return messages[i].id;
96
+ }
97
+ return null;
98
+ }, [messages]);
99
+
100
+ // The last assistant message is "streaming" when we're processing
101
+ const lastAssistantId = useMemo(() => {
102
+ for (let i = messages.length - 1; i >= 0; i--) {
103
+ if (messages[i].role === 'assistant') return messages[i].id;
104
+ }
105
+ return null;
106
+ }, [messages]);
107
 
108
  return (
109
  <Box
110
+ ref={scrollContainerRef}
111
  sx={{
112
  flex: 1,
113
  overflow: 'auto',
114
+ px: { xs: 0.5, sm: 1, md: 2 },
115
+ py: { xs: 2, md: 3 },
116
  display: 'flex',
117
  flexDirection: 'column',
118
  }}
119
  >
120
+ <Stack
121
+ spacing={3}
122
+ sx={{
123
+ maxWidth: 880,
124
+ mx: 'auto',
125
+ width: '100%',
126
+ flex: messages.length === 0 && !isProcessing ? 1 : undefined,
127
+ }}
128
+ >
129
  {messages.length === 0 && !isProcessing ? (
130
+ <WelcomeGreeting />
 
 
 
 
 
 
 
 
 
 
 
 
131
  ) : (
132
+ messages.map((msg) => (
133
+ <MessageBubble
134
+ key={msg.id}
135
+ message={msg}
136
+ isLastTurn={msg.id === lastUserMsgId}
137
+ onUndoTurn={onUndoLastTurn}
138
+ isProcessing={isProcessing}
139
+ isStreaming={isProcessing && msg.id === lastAssistantId}
140
+ approveTools={approveTools}
141
+ />
142
  ))
143
  )}
 
 
 
 
 
 
 
 
 
 
 
144
 
145
+ <ActivityStatusBar />
146
+
147
+ <div />
148
+ </Stack>
 
 
 
149
  </Box>
150
  );
151
+ }
frontend/src/components/Chat/ThinkingIndicator.tsx ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { Box, Typography } from '@mui/material';
2
+
3
+ /** Pulsing dots shown while the agent is processing. */
4
+ export default function ThinkingIndicator() {
5
+ return (
6
+ <Box sx={{ pt: 0.75 }}>
7
+ <Typography
8
+ variant="caption"
9
+ sx={{
10
+ fontWeight: 700,
11
+ fontSize: '0.72rem',
12
+ color: 'var(--muted-text)',
13
+ textTransform: 'uppercase',
14
+ letterSpacing: '0.04em',
15
+ display: 'flex',
16
+ alignItems: 'center',
17
+ gap: 0.75,
18
+ }}
19
+ >
20
+ Thinking
21
+ <Box
22
+ component="span"
23
+ sx={{
24
+ display: 'inline-flex',
25
+ gap: '3px',
26
+ '& span': {
27
+ width: 4,
28
+ height: 4,
29
+ borderRadius: '50%',
30
+ bgcolor: 'primary.main',
31
+ animation: 'dotPulse 1.4s ease-in-out infinite',
32
+ },
33
+ '& span:nth-of-type(2)': { animationDelay: '0.2s' },
34
+ '& span:nth-of-type(3)': { animationDelay: '0.4s' },
35
+ '@keyframes dotPulse': {
36
+ '0%, 80%, 100%': { opacity: 0.25, transform: 'scale(0.8)' },
37
+ '40%': { opacity: 1, transform: 'scale(1)' },
38
+ },
39
+ }}
40
+ >
41
+ <span />
42
+ <span />
43
+ <span />
44
+ </Box>
45
+ </Typography>
46
+ </Box>
47
+ );
48
+ }
frontend/src/components/Chat/ToolCallGroup.tsx ADDED
@@ -0,0 +1,655 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useCallback, useMemo, useRef, useState } from 'react';
2
+ import { Box, Stack, Typography, Chip, Button, TextField, IconButton, Link, CircularProgress } from '@mui/material';
3
+ import CheckCircleOutlineIcon from '@mui/icons-material/CheckCircleOutline';
4
+ import ErrorOutlineIcon from '@mui/icons-material/ErrorOutline';
5
+ import OpenInNewIcon from '@mui/icons-material/OpenInNew';
6
+ import HourglassEmptyIcon from '@mui/icons-material/HourglassEmpty';
7
+ import LaunchIcon from '@mui/icons-material/Launch';
8
+ import SendIcon from '@mui/icons-material/Send';
9
+ import BlockIcon from '@mui/icons-material/Block';
10
+ import { useAgentStore } from '@/store/agentStore';
11
+ import { useLayoutStore } from '@/store/layoutStore';
12
+ import { logger } from '@/utils/logger';
13
+ import type { UIMessage } from 'ai';
14
+
15
+ // ---------------------------------------------------------------------------
16
+ // Type helpers β€” extract the dynamic-tool part type from UIMessage
17
+ // ---------------------------------------------------------------------------
18
+ type DynamicToolPart = Extract<UIMessage['parts'][number], { type: 'dynamic-tool' }>;
19
+
20
+ type ToolPartState = DynamicToolPart['state'];
21
+
22
+ interface ToolCallGroupProps {
23
+ tools: DynamicToolPart[];
24
+ approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null; edited_script?: string | null }>) => Promise<boolean>;
25
+ }
26
+
27
+ // ---------------------------------------------------------------------------
28
+ // Visual helpers
29
+ // ---------------------------------------------------------------------------
30
+
31
+ function StatusIcon({ state }: { state: ToolPartState }) {
32
+ switch (state) {
33
+ case 'approval-requested':
34
+ return <HourglassEmptyIcon sx={{ fontSize: 16, color: 'var(--accent-yellow)' }} />;
35
+ case 'output-available':
36
+ return <CheckCircleOutlineIcon sx={{ fontSize: 16, color: 'success.main' }} />;
37
+ case 'output-error':
38
+ return <ErrorOutlineIcon sx={{ fontSize: 16, color: 'error.main' }} />;
39
+ case 'output-denied':
40
+ return <BlockIcon sx={{ fontSize: 16, color: 'var(--muted-text)' }} />;
41
+ case 'input-streaming':
42
+ case 'input-available':
43
+ default:
44
+ return <CircularProgress size={14} thickness={5} sx={{ color: 'var(--accent-yellow)' }} />;
45
+ }
46
+ }
47
+
48
+ function statusLabel(state: ToolPartState): string | null {
49
+ switch (state) {
50
+ case 'approval-requested': return 'awaiting approval';
51
+ case 'input-streaming':
52
+ case 'input-available': return 'running';
53
+ case 'output-denied': return 'denied';
54
+ case 'output-error': return 'error';
55
+ default: return null;
56
+ }
57
+ }
58
+
59
+ function statusColor(state: ToolPartState): string {
60
+ switch (state) {
61
+ case 'approval-requested': return 'var(--accent-yellow)';
62
+ case 'output-available': return 'var(--accent-green)';
63
+ case 'output-error': return 'var(--accent-red)';
64
+ case 'output-denied': return 'var(--muted-text)';
65
+ default: return 'var(--accent-yellow)';
66
+ }
67
+ }
68
+
69
+ // ---------------------------------------------------------------------------
70
+ // Inline approval UI (per-tool)
71
+ // ---------------------------------------------------------------------------
72
+
73
+ function InlineApproval({
74
+ toolCallId,
75
+ toolName,
76
+ input,
77
+ scriptLabel,
78
+ onResolve,
79
+ }: {
80
+ toolCallId: string;
81
+ toolName: string;
82
+ input: unknown;
83
+ scriptLabel: string;
84
+ onResolve: (toolCallId: string, approved: boolean, feedback?: string) => void;
85
+ }) {
86
+ const [feedback, setFeedback] = useState('');
87
+ const args = input as Record<string, unknown> | undefined;
88
+ const { setPanel, getEditedScript } = useAgentStore();
89
+ const { setRightPanelOpen, setLeftSidebarOpen } = useLayoutStore();
90
+ const hasEditedScript = !!getEditedScript(toolCallId);
91
+
92
+ const handleScriptClick = useCallback(() => {
93
+ if (toolName === 'hf_jobs' && args?.script) {
94
+ const scriptContent = getEditedScript(toolCallId) || String(args.script);
95
+ setPanel(
96
+ { title: scriptLabel, script: { content: scriptContent, language: 'python' }, parameters: { tool_call_id: toolCallId } },
97
+ 'script',
98
+ true,
99
+ );
100
+ setRightPanelOpen(true);
101
+ setLeftSidebarOpen(false);
102
+ }
103
+ }, [toolCallId, toolName, args, scriptLabel, setPanel, getEditedScript, setRightPanelOpen, setLeftSidebarOpen]);
104
+
105
+ return (
106
+ <Box sx={{ px: 1.5, py: 1.5, borderTop: '1px solid var(--tool-border)' }}>
107
+ {toolName === 'hf_jobs' && args && (
108
+ <Box sx={{ mb: 1.5 }}>
109
+ <Typography variant="body2" sx={{ color: 'var(--muted-text)', fontSize: '0.75rem', mb: 1 }}>
110
+ Execute <Box component="span" sx={{ color: 'var(--accent-yellow)', fontWeight: 500 }}>{scriptLabel.replace('Script', 'Job')}</Box> on{' '}
111
+ <Box component="span" sx={{ fontWeight: 500, color: 'var(--text)' }}>
112
+ {String(args.hardware_flavor || 'default')}
113
+ </Box>
114
+ {!!args.timeout && (
115
+ <> with timeout <Box component="span" sx={{ fontWeight: 500, color: 'var(--text)' }}>
116
+ {String(args.timeout)}
117
+ </Box></>
118
+ )}
119
+ </Typography>
120
+ {typeof args.script === 'string' && args.script && (
121
+ <Box
122
+ onClick={handleScriptClick}
123
+ sx={{
124
+ mt: 0.5,
125
+ p: 1.5,
126
+ bgcolor: 'var(--code-panel-bg)',
127
+ border: '1px solid var(--tool-border)',
128
+ borderRadius: '8px',
129
+ cursor: 'pointer',
130
+ transition: 'border-color 0.15s ease',
131
+ '&:hover': { borderColor: 'var(--accent-yellow)' },
132
+ }}
133
+ >
134
+ <Box
135
+ component="pre"
136
+ sx={{
137
+ m: 0,
138
+ fontFamily: '"JetBrains Mono", ui-monospace, SFMono-Regular, monospace',
139
+ fontSize: '0.7rem',
140
+ lineHeight: 1.5,
141
+ color: 'var(--text)',
142
+ overflow: 'hidden',
143
+ display: '-webkit-box',
144
+ WebkitLineClamp: 3,
145
+ WebkitBoxOrient: 'vertical',
146
+ whiteSpace: 'pre-wrap',
147
+ wordBreak: 'break-all',
148
+ }}
149
+ >
150
+ {String(args.script).trim()}
151
+ </Box>
152
+ <Typography
153
+ variant="caption"
154
+ sx={{
155
+ display: 'flex',
156
+ alignItems: 'center',
157
+ gap: 0.5,
158
+ mt: 1,
159
+ fontSize: '0.65rem',
160
+ color: 'var(--muted-text)',
161
+ '&:hover': { color: 'var(--accent-yellow)' },
162
+ }}
163
+ >
164
+ Click to view & edit
165
+ </Typography>
166
+ </Box>
167
+ )}
168
+ </Box>
169
+ )}
170
+
171
+ <Box sx={{ display: 'flex', gap: 1, mb: 1 }}>
172
+ <TextField
173
+ fullWidth
174
+ size="small"
175
+ placeholder="Feedback (optional)"
176
+ value={feedback}
177
+ onChange={(e) => setFeedback(e.target.value)}
178
+ variant="outlined"
179
+ sx={{
180
+ '& .MuiOutlinedInput-root': {
181
+ bgcolor: 'var(--hover-bg)',
182
+ fontFamily: 'inherit',
183
+ fontSize: '0.8rem',
184
+ '& fieldset': { borderColor: 'var(--tool-border)' },
185
+ '&:hover fieldset': { borderColor: 'var(--border-hover)' },
186
+ '&.Mui-focused fieldset': { borderColor: 'var(--accent-yellow)' },
187
+ },
188
+ '& .MuiOutlinedInput-input': {
189
+ color: 'var(--text)',
190
+ '&::placeholder': { color: 'var(--muted-text)', opacity: 0.7 },
191
+ },
192
+ }}
193
+ />
194
+ <IconButton
195
+ onClick={() => onResolve(toolCallId, false, feedback || 'Rejected by user')}
196
+ disabled={!feedback}
197
+ size="small"
198
+ sx={{
199
+ color: 'var(--accent-red)',
200
+ border: '1px solid var(--tool-border)',
201
+ borderRadius: '6px',
202
+ '&:hover': { bgcolor: 'rgba(224,90,79,0.1)', borderColor: 'var(--accent-red)' },
203
+ '&.Mui-disabled': { color: 'var(--muted-text)', opacity: 0.3 },
204
+ }}
205
+ >
206
+ <SendIcon sx={{ fontSize: 14 }} />
207
+ </IconButton>
208
+ </Box>
209
+
210
+ <Box sx={{ display: 'flex', gap: 1 }}>
211
+ <Button
212
+ size="small"
213
+ onClick={() => onResolve(toolCallId, false, feedback || 'Rejected by user')}
214
+ sx={{
215
+ flex: 1,
216
+ textTransform: 'none',
217
+ border: '1px solid rgba(255,255,255,0.05)',
218
+ color: 'var(--accent-red)',
219
+ fontSize: '0.75rem',
220
+ py: 0.75,
221
+ borderRadius: '8px',
222
+ '&:hover': { bgcolor: 'rgba(224,90,79,0.05)', borderColor: 'var(--accent-red)' },
223
+ }}
224
+ >
225
+ Reject
226
+ </Button>
227
+ <Button
228
+ size="small"
229
+ onClick={() => onResolve(toolCallId, true)}
230
+ sx={{
231
+ flex: 1,
232
+ textTransform: 'none',
233
+ border: hasEditedScript ? '1px solid var(--accent-green)' : '1px solid rgba(255,255,255,0.05)',
234
+ color: 'var(--accent-green)',
235
+ fontSize: '0.75rem',
236
+ py: 0.75,
237
+ borderRadius: '8px',
238
+ bgcolor: hasEditedScript ? 'rgba(47,204,113,0.08)' : 'transparent',
239
+ '&:hover': { bgcolor: 'rgba(47,204,113,0.05)', borderColor: 'var(--accent-green)' },
240
+ }}
241
+ >
242
+ {hasEditedScript ? 'Approve (edited)' : 'Approve'}
243
+ </Button>
244
+ </Box>
245
+ </Box>
246
+ );
247
+ }
248
+
249
+ // ---------------------------------------------------------------------------
250
+ // Main component
251
+ // ---------------------------------------------------------------------------
252
+
253
+ export default function ToolCallGroup({ tools, approveTools }: ToolCallGroupProps) {
254
+ const { setPanel, lockPanel, getJobUrl, getEditedScript } = useAgentStore();
255
+ const { setRightPanelOpen, setLeftSidebarOpen } = useLayoutStore();
256
+
257
+ // ── Batch approval state ─────────────────��────────────────────────
258
+ const pendingTools = useMemo(
259
+ () => tools.filter(t => t.state === 'approval-requested'),
260
+ [tools],
261
+ );
262
+
263
+ const [decisions, setDecisions] = useState<Record<string, { approved: boolean; feedback?: string }>>({});
264
+ const [isSubmitting, setIsSubmitting] = useState(false);
265
+ const submittingRef = useRef(false);
266
+
267
+ const { scriptLabelMap, toolDisplayMap } = useMemo(() => {
268
+ const hfJobs = tools.filter(t => t.toolName === 'hf_jobs' && (t.input as Record<string, unknown>)?.script);
269
+ const scriptMap: Record<string, string> = {};
270
+ const displayMap: Record<string, string> = {};
271
+ for (let i = 0; i < hfJobs.length; i++) {
272
+ const id = hfJobs[i].toolCallId;
273
+ if (hfJobs.length > 1) {
274
+ scriptMap[id] = `Script ${i + 1}`;
275
+ displayMap[id] = `hf_jobs #${i + 1}`;
276
+ } else {
277
+ scriptMap[id] = 'Script';
278
+ displayMap[id] = 'hf_jobs';
279
+ }
280
+ }
281
+ return { scriptLabelMap: scriptMap, toolDisplayMap: displayMap };
282
+ }, [tools]);
283
+
284
+ // ── Send all decisions as a single batch ──────────────────────────
285
+ const sendBatch = useCallback(
286
+ async (batch: Record<string, { approved: boolean; feedback?: string }>) => {
287
+ if (submittingRef.current) return;
288
+ submittingRef.current = true;
289
+ setIsSubmitting(true);
290
+
291
+ const approvals = Object.entries(batch).map(([toolCallId, d]) => {
292
+ const editedScript = d.approved ? (getEditedScript(toolCallId) ?? null) : null;
293
+ if (editedScript) {
294
+ logger.log(`Sending edited script for ${toolCallId} (${editedScript.length} chars)`);
295
+ }
296
+ return {
297
+ tool_call_id: toolCallId,
298
+ approved: d.approved,
299
+ feedback: d.approved ? null : (d.feedback || 'Rejected by user'),
300
+ edited_script: editedScript,
301
+ };
302
+ });
303
+
304
+ const ok = await approveTools(approvals);
305
+ if (ok) {
306
+ lockPanel();
307
+ } else {
308
+ logger.error('Batch approval failed');
309
+ submittingRef.current = false;
310
+ setIsSubmitting(false);
311
+ }
312
+ },
313
+ [approveTools, lockPanel, getEditedScript],
314
+ );
315
+
316
+ const handleApproveAll = useCallback(() => {
317
+ const batch: Record<string, { approved: boolean }> = {};
318
+ for (const t of pendingTools) batch[t.toolCallId] = { approved: true };
319
+ sendBatch(batch);
320
+ }, [pendingTools, sendBatch]);
321
+
322
+ const handleRejectAll = useCallback(() => {
323
+ const batch: Record<string, { approved: boolean }> = {};
324
+ for (const t of pendingTools) batch[t.toolCallId] = { approved: false };
325
+ sendBatch(batch);
326
+ }, [pendingTools, sendBatch]);
327
+
328
+ const handleIndividualDecision = useCallback(
329
+ (toolCallId: string, approved: boolean, feedback?: string) => {
330
+ setDecisions(prev => {
331
+ const next = { ...prev, [toolCallId]: { approved, feedback } };
332
+ if (pendingTools.every(t => next[t.toolCallId])) {
333
+ queueMicrotask(() => sendBatch(next));
334
+ }
335
+ return next;
336
+ });
337
+ },
338
+ [pendingTools, sendBatch],
339
+ );
340
+
341
+ const undoDecision = useCallback((toolCallId: string) => {
342
+ setDecisions(prev => {
343
+ const next = { ...prev };
344
+ delete next[toolCallId];
345
+ return next;
346
+ });
347
+ }, []);
348
+
349
+ // ── Panel click handler ───────────────────────────────────────────
350
+ const handleClick = useCallback(
351
+ (tool: DynamicToolPart) => {
352
+ const args = tool.input as Record<string, unknown> | undefined;
353
+ const displayName = toolDisplayMap[tool.toolCallId] || tool.toolName;
354
+
355
+ if (tool.toolName === 'hf_jobs' && args?.script) {
356
+ const hasOutput = (tool.state === 'output-available' || tool.state === 'output-error') && tool.output;
357
+ const scriptContent = getEditedScript(tool.toolCallId) || String(args.script);
358
+ setPanel(
359
+ {
360
+ title: displayName,
361
+ script: { content: scriptContent, language: 'python' },
362
+ ...(hasOutput ? { output: { content: String(tool.output), language: 'markdown' } } : {}),
363
+ parameters: { tool_call_id: tool.toolCallId },
364
+ },
365
+ hasOutput ? 'output' : 'script',
366
+ );
367
+ setRightPanelOpen(true);
368
+ setLeftSidebarOpen(false);
369
+ return;
370
+ }
371
+
372
+ if ((tool.state === 'output-available' || tool.state === 'output-error') && tool.output) {
373
+ let language = 'text';
374
+ const content = String(tool.output);
375
+ if (content.trim().startsWith('{') || content.trim().startsWith('[')) language = 'json';
376
+ else if (content.includes('```')) language = 'markdown';
377
+
378
+ setPanel({ title: displayName, output: { content, language } }, 'output');
379
+ setRightPanelOpen(true);
380
+ } else if (args) {
381
+ const content = JSON.stringify(args, null, 2);
382
+ setPanel({ title: displayName, output: { content, language: 'json' } }, 'output');
383
+ setRightPanelOpen(true);
384
+ }
385
+ },
386
+ [toolDisplayMap, setPanel, getEditedScript, setRightPanelOpen, setLeftSidebarOpen],
387
+ );
388
+
389
+ // ── Parse hf_jobs metadata from output ────────────────────────────
390
+ function parseJobMeta(output: unknown): { jobUrl?: string; jobStatus?: string } {
391
+ if (typeof output !== 'string') return {};
392
+ const urlMatch = output.match(/\*\*View at:\*\*\s*(https:\/\/[^\s\n]+)/);
393
+ const statusMatch = output.match(/\*\*Final Status:\*\*\s*([^\n]+)/);
394
+ return {
395
+ jobUrl: urlMatch?.[1],
396
+ jobStatus: statusMatch?.[1]?.trim(),
397
+ };
398
+ }
399
+
400
+ // ── Render ────────────────────────────────────────────────────────
401
+ const decidedCount = pendingTools.filter(t => decisions[t.toolCallId]).length;
402
+
403
+ return (
404
+ <Box
405
+ sx={{
406
+ borderRadius: 2,
407
+ border: '1px solid var(--tool-border)',
408
+ bgcolor: 'var(--tool-bg)',
409
+ overflow: 'hidden',
410
+ my: 1,
411
+ }}
412
+ >
413
+ {/* Batch approval header β€” hidden once user starts deciding individually */}
414
+ {pendingTools.length > 1 && !isSubmitting && decidedCount === 0 && (
415
+ <Box
416
+ sx={{
417
+ display: 'flex',
418
+ alignItems: 'center',
419
+ gap: 1,
420
+ px: 1.5,
421
+ py: 1,
422
+ borderBottom: '1px solid var(--tool-border)',
423
+ }}
424
+ >
425
+ <Typography
426
+ variant="body2"
427
+ sx={{ fontSize: '0.72rem', color: 'var(--muted-text)', mr: 'auto', whiteSpace: 'nowrap' }}
428
+ >
429
+ {`${pendingTools.length} tool${pendingTools.length > 1 ? 's' : ''} pending`}
430
+ </Typography>
431
+ <Button
432
+ size="small"
433
+ onClick={handleRejectAll}
434
+ sx={{
435
+ textTransform: 'none',
436
+ color: 'var(--accent-red)',
437
+ border: '1px solid rgba(255,255,255,0.05)',
438
+ fontSize: '0.72rem',
439
+ py: 0.5,
440
+ px: 1.5,
441
+ borderRadius: '8px',
442
+ '&:hover': { bgcolor: 'rgba(224,90,79,0.05)', borderColor: 'var(--accent-red)' },
443
+ }}
444
+ >
445
+ Reject all
446
+ </Button>
447
+ <Button
448
+ size="small"
449
+ onClick={handleApproveAll}
450
+ sx={{
451
+ textTransform: 'none',
452
+ color: 'var(--accent-green)',
453
+ border: '1px solid var(--accent-green)',
454
+ fontSize: '0.72rem',
455
+ fontWeight: 600,
456
+ py: 0.5,
457
+ px: 1.5,
458
+ borderRadius: '8px',
459
+ '&:hover': { bgcolor: 'rgba(47,204,113,0.1)' },
460
+ }}
461
+ >
462
+ Approve all{pendingTools.length > 1 ? ` (${pendingTools.length})` : ''}
463
+ </Button>
464
+ </Box>
465
+ )}
466
+
467
+ {/* Tool list */}
468
+ <Stack divider={<Box sx={{ borderBottom: '1px solid var(--tool-border)' }} />}>
469
+ {tools.map((tool) => {
470
+ const state = tool.state;
471
+ const isPending = state === 'approval-requested';
472
+ const clickable =
473
+ state === 'output-available' ||
474
+ state === 'output-error' ||
475
+ !!tool.input;
476
+ const localDecision = decisions[tool.toolCallId];
477
+
478
+ const displayState = isPending && localDecision
479
+ ? (localDecision.approved ? 'input-available' : 'output-denied')
480
+ : state;
481
+ const label = statusLabel(displayState as ToolPartState);
482
+
483
+ // Parse job metadata from hf_jobs output and store
484
+ const jobUrlFromStore = tool.toolName === 'hf_jobs' ? getJobUrl(tool.toolCallId) : undefined;
485
+ const jobMetaFromOutput = tool.toolName === 'hf_jobs' && tool.state === 'output-available'
486
+ ? parseJobMeta(tool.output)
487
+ : {};
488
+
489
+ // Combine job URL from store (available immediately) with output metadata (available at completion)
490
+ const jobMeta = {
491
+ jobUrl: jobUrlFromStore || jobMetaFromOutput.jobUrl,
492
+ jobStatus: jobMetaFromOutput.jobStatus,
493
+ };
494
+
495
+ return (
496
+ <Box key={tool.toolCallId}>
497
+ {/* Main tool row */}
498
+ <Stack
499
+ direction="row"
500
+ alignItems="center"
501
+ spacing={1}
502
+ onClick={() => !isPending && handleClick(tool)}
503
+ sx={{
504
+ px: 1.5,
505
+ py: 1,
506
+ cursor: isPending ? 'default' : clickable ? 'pointer' : 'default',
507
+ transition: 'background-color 0.15s',
508
+ '&:hover': clickable && !isPending ? { bgcolor: 'var(--hover-bg)' } : {},
509
+ }}
510
+ >
511
+ <StatusIcon state={
512
+ (tool.toolName === 'hf_jobs' && jobMeta.jobStatus && ['ERROR', 'FAILED', 'CANCELLED'].includes(jobMeta.jobStatus) && displayState === 'output-available')
513
+ ? 'output-error'
514
+ : displayState as ToolPartState
515
+ } />
516
+
517
+ <Typography
518
+ variant="body2"
519
+ sx={{
520
+ fontFamily: '"JetBrains Mono", ui-monospace, SFMono-Regular, monospace',
521
+ fontWeight: 600,
522
+ fontSize: '0.78rem',
523
+ color: 'var(--text)',
524
+ flex: 1,
525
+ minWidth: 0,
526
+ overflow: 'hidden',
527
+ textOverflow: 'ellipsis',
528
+ whiteSpace: 'nowrap',
529
+ }}
530
+ >
531
+ {toolDisplayMap[tool.toolCallId] || tool.toolName}
532
+ </Typography>
533
+
534
+ {/* Status chip (non hf_jobs, or hf_jobs without final status) */}
535
+ {label && !(tool.toolName === 'hf_jobs' && jobMeta.jobStatus) && (
536
+ <Chip
537
+ label={label}
538
+ size="small"
539
+ sx={{
540
+ height: 20,
541
+ fontSize: '0.65rem',
542
+ fontWeight: 600,
543
+ bgcolor: displayState === 'output-error' ? 'rgba(224,90,79,0.12)'
544
+ : displayState === 'output-denied' ? 'rgba(255,255,255,0.05)'
545
+ : 'var(--accent-yellow-weak)',
546
+ color: statusColor(displayState as ToolPartState),
547
+ letterSpacing: '0.03em',
548
+ }}
549
+ />
550
+ )}
551
+
552
+ {/* HF Jobs: final status chip from job metadata */}
553
+ {tool.toolName === 'hf_jobs' && jobMeta.jobStatus && (
554
+ <Chip
555
+ label={jobMeta.jobStatus}
556
+ size="small"
557
+ sx={{
558
+ height: 20,
559
+ fontSize: '0.65rem',
560
+ fontWeight: 600,
561
+ bgcolor: jobMeta.jobStatus === 'COMPLETED'
562
+ ? 'rgba(47,204,113,0.12)'
563
+ : ['ERROR', 'FAILED', 'CANCELLED'].includes(jobMeta.jobStatus!)
564
+ ? 'rgba(224,90,79,0.12)'
565
+ : 'rgba(255,193,59,0.12)',
566
+ color: jobMeta.jobStatus === 'COMPLETED'
567
+ ? 'var(--accent-green)'
568
+ : ['ERROR', 'FAILED', 'CANCELLED'].includes(jobMeta.jobStatus!)
569
+ ? 'var(--accent-red)'
570
+ : 'var(--accent-yellow)',
571
+ letterSpacing: '0.03em',
572
+ }}
573
+ />
574
+ )}
575
+
576
+ {/* View on HF link β€” single place, shown whenever URL is available */}
577
+ {tool.toolName === 'hf_jobs' && jobMeta.jobUrl && (
578
+ <Link
579
+ href={jobMeta.jobUrl}
580
+ target="_blank"
581
+ rel="noopener noreferrer"
582
+ onClick={(e) => e.stopPropagation()}
583
+ sx={{
584
+ display: 'inline-flex',
585
+ alignItems: 'center',
586
+ gap: 0.5,
587
+ color: 'var(--accent-yellow)',
588
+ fontSize: '0.68rem',
589
+ textDecoration: 'none',
590
+ ml: 0.5,
591
+ '&:hover': { textDecoration: 'underline' },
592
+ }}
593
+ >
594
+ <LaunchIcon sx={{ fontSize: 12 }} />
595
+ View on HF
596
+ </Link>
597
+ )}
598
+
599
+ {clickable && !isPending && (
600
+ <OpenInNewIcon sx={{ fontSize: 14, color: 'var(--muted-text)', opacity: 0.6 }} />
601
+ )}
602
+ </Stack>
603
+
604
+
605
+ {/* Per-tool approval: undecided */}
606
+ {isPending && !localDecision && !isSubmitting && (
607
+ <InlineApproval
608
+ toolCallId={tool.toolCallId}
609
+ toolName={tool.toolName}
610
+ input={tool.input}
611
+ scriptLabel={scriptLabelMap[tool.toolCallId] || 'Script'}
612
+ onResolve={handleIndividualDecision}
613
+ />
614
+ )}
615
+
616
+ {/* Per-tool approval: locally decided (undo available) */}
617
+ {isPending && localDecision && !isSubmitting && (
618
+ <Box
619
+ sx={{
620
+ display: 'flex',
621
+ alignItems: 'center',
622
+ justifyContent: 'space-between',
623
+ px: 1.5,
624
+ py: 0.75,
625
+ borderTop: '1px solid var(--tool-border)',
626
+ }}
627
+ >
628
+ <Typography variant="body2" sx={{ fontSize: '0.72rem', color: 'var(--muted-text)' }}>
629
+ {localDecision.approved
630
+ ? 'Marked for approval'
631
+ : `Marked for rejection${localDecision.feedback ? `: ${localDecision.feedback}` : ''}`}
632
+ </Typography>
633
+ <Button
634
+ size="small"
635
+ onClick={() => undoDecision(tool.toolCallId)}
636
+ sx={{
637
+ textTransform: 'none',
638
+ fontSize: '0.7rem',
639
+ color: 'var(--muted-text)',
640
+ minWidth: 'auto',
641
+ px: 1,
642
+ '&:hover': { color: 'var(--text)' },
643
+ }}
644
+ >
645
+ Undo
646
+ </Button>
647
+ </Box>
648
+ )}
649
+ </Box>
650
+ );
651
+ })}
652
+ </Stack>
653
+ </Box>
654
+ );
655
+ }
frontend/src/components/Chat/UserMessage.tsx ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { Box, Stack, Typography, IconButton, Tooltip } from '@mui/material';
2
+ import CloseIcon from '@mui/icons-material/Close';
3
+ import type { UIMessage } from 'ai';
4
+ import type { MessageMeta } from '@/types/agent';
5
+
6
+ interface UserMessageProps {
7
+ message: UIMessage;
8
+ isLastTurn?: boolean;
9
+ onUndoTurn?: () => void;
10
+ isProcessing?: boolean;
11
+ }
12
+
13
+ function extractText(message: UIMessage): string {
14
+ return message.parts
15
+ .filter((p): p is Extract<typeof p, { type: 'text' }> => p.type === 'text')
16
+ .map(p => p.text)
17
+ .join('');
18
+ }
19
+
20
+ export default function UserMessage({
21
+ message,
22
+ isLastTurn = false,
23
+ onUndoTurn,
24
+ isProcessing = false,
25
+ }: UserMessageProps) {
26
+ const showUndo = isLastTurn && !isProcessing && !!onUndoTurn;
27
+ const text = extractText(message);
28
+ const meta = message.metadata as MessageMeta | undefined;
29
+ const timeStr = meta?.createdAt
30
+ ? new Date(meta.createdAt).toLocaleTimeString([], { hour: '2-digit', minute: '2-digit' })
31
+ : null;
32
+ return (
33
+ <Stack
34
+ direction="row"
35
+ spacing={1.5}
36
+ justifyContent="flex-end"
37
+ alignItems="flex-start"
38
+ sx={{
39
+ '& .undo-btn': {
40
+ opacity: 0,
41
+ transition: 'opacity 0.15s ease',
42
+ },
43
+ '&:hover .undo-btn': {
44
+ opacity: 1,
45
+ },
46
+ }}
47
+ >
48
+ {showUndo && (
49
+ <Box className="undo-btn" sx={{ display: 'flex', alignItems: 'center', mt: 0.75 }}>
50
+ <Tooltip title="Remove this turn" placement="left">
51
+ <IconButton
52
+ onClick={onUndoTurn}
53
+ size="small"
54
+ sx={{
55
+ width: 24,
56
+ height: 24,
57
+ color: 'var(--muted-text)',
58
+ '&:hover': {
59
+ color: 'var(--accent-red)',
60
+ bgcolor: 'rgba(244,67,54,0.08)',
61
+ },
62
+ }}
63
+ >
64
+ <CloseIcon sx={{ fontSize: 14 }} />
65
+ </IconButton>
66
+ </Tooltip>
67
+ </Box>
68
+ )}
69
+
70
+ <Box
71
+ sx={{
72
+ maxWidth: { xs: '88%', md: '72%' },
73
+ bgcolor: 'var(--surface)',
74
+ borderRadius: 1.5,
75
+ borderTopRightRadius: 4,
76
+ px: { xs: 1.5, md: 2.5 },
77
+ py: 1.5,
78
+ border: '1px solid var(--border)',
79
+ }}
80
+ >
81
+ <Typography
82
+ variant="body1"
83
+ sx={{
84
+ fontSize: '0.925rem',
85
+ lineHeight: 1.65,
86
+ color: 'var(--text)',
87
+ whiteSpace: 'pre-wrap',
88
+ wordBreak: 'break-word',
89
+ }}
90
+ >
91
+ {text}
92
+ </Typography>
93
+
94
+ {timeStr && (
95
+ <Typography
96
+ variant="caption"
97
+ sx={{ color: 'var(--muted-text)', mt: 0.5, display: 'block', textAlign: 'right', fontSize: '0.7rem' }}
98
+ >
99
+ {timeStr}
100
+ </Typography>
101
+ )}
102
+ </Box>
103
+ </Stack>
104
+ );
105
+ }
frontend/src/components/CodePanel/CodePanel.tsx CHANGED
@@ -1,138 +1,463 @@
1
- import { useRef, useEffect, useMemo } from 'react';
2
- import { Box, Typography, IconButton } from '@mui/material';
3
  import CloseIcon from '@mui/icons-material/Close';
4
  import RadioButtonUncheckedIcon from '@mui/icons-material/RadioButtonUnchecked';
5
  import CheckCircleIcon from '@mui/icons-material/CheckCircle';
6
  import PlayCircleOutlineIcon from '@mui/icons-material/PlayCircleOutline';
7
  import CodeIcon from '@mui/icons-material/Code';
8
- import TerminalIcon from '@mui/icons-material/Terminal';
9
  import ArticleIcon from '@mui/icons-material/Article';
 
 
 
 
10
  import { Prism as SyntaxHighlighter } from 'react-syntax-highlighter';
11
- import { vscDarkPlus } from 'react-syntax-highlighter/dist/esm/styles/prism';
12
  import ReactMarkdown from 'react-markdown';
13
  import remarkGfm from 'remark-gfm';
14
  import { useAgentStore } from '@/store/agentStore';
15
  import { useLayoutStore } from '@/store/layoutStore';
16
  import { processLogs } from '@/utils/logProcessor';
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  export default function CodePanel() {
19
- const { panelContent, panelTabs, activePanelTab, setActivePanelTab, removePanelTab, plan } = useAgentStore();
20
- const { setRightPanelOpen } = useLayoutStore();
 
21
  const scrollRef = useRef<HTMLDivElement>(null);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- // Get the active tab content, or fall back to panelContent for backwards compatibility
24
- const activeTab = panelTabs.find(t => t.id === activePanelTab);
25
- const currentContent = activeTab || panelContent;
 
 
 
 
 
 
 
 
 
26
 
27
  const displayContent = useMemo(() => {
28
- if (!currentContent?.content) return '';
29
- // Apply log processing only for text/logs, not for code/json
30
- if (!currentContent.language || currentContent.language === 'text') {
31
- return processLogs(currentContent.content);
32
  }
33
- return currentContent.content;
34
- }, [currentContent?.content, currentContent?.language]);
35
 
36
  useEffect(() => {
37
- // Auto-scroll only for logs tab
38
- if (scrollRef.current && activePanelTab === 'logs') {
39
  scrollRef.current.scrollTop = scrollRef.current.scrollHeight;
40
  }
41
- }, [displayContent, activePanelTab]);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- const hasTabs = panelTabs.length > 0;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  return (
46
  <Box sx={{ height: '100%', display: 'flex', flexDirection: 'column', bgcolor: 'var(--panel)' }}>
47
- {/* Header - Fixed 60px to align */}
48
- <Box sx={{
49
- height: '60px',
50
- display: 'flex',
51
- alignItems: 'center',
52
- justifyContent: 'space-between',
53
- px: 2,
54
- borderBottom: '1px solid rgba(255,255,255,0.03)'
55
- }}>
56
- {hasTabs ? (
57
- <Box sx={{ display: 'flex', alignItems: 'center', gap: 0.5, flexWrap: 'wrap' }}>
58
- {panelTabs.map((tab) => {
59
- const isActive = activePanelTab === tab.id;
60
- // Choose icon based on tab type
61
- let icon = <TerminalIcon sx={{ fontSize: 14 }} />;
62
- if (tab.id === 'script' || tab.language === 'python') {
63
- icon = <CodeIcon sx={{ fontSize: 14 }} />;
64
- } else if (tab.id === 'tool_output' || tab.language === 'markdown' || tab.language === 'json') {
65
- icon = <ArticleIcon sx={{ fontSize: 14 }} />;
66
- }
67
- return (
68
- <Box
69
- key={tab.id}
70
- onClick={() => setActivePanelTab(tab.id)}
71
- sx={{
72
- display: 'flex',
73
- alignItems: 'center',
74
- gap: 0.5,
75
- px: 1.5,
76
- py: 0.75,
77
- borderRadius: 1,
78
- cursor: 'pointer',
79
- fontSize: '0.7rem',
80
- fontWeight: 600,
81
- textTransform: 'uppercase',
82
- letterSpacing: '0.05em',
83
- color: isActive ? 'var(--text)' : 'var(--muted-text)',
84
- bgcolor: isActive ? 'rgba(255,255,255,0.08)' : 'transparent',
85
- border: '1px solid',
86
- borderColor: isActive ? 'rgba(255,255,255,0.1)' : 'transparent',
87
- transition: 'all 0.15s ease',
88
- '&:hover': {
89
- bgcolor: 'rgba(255,255,255,0.05)',
90
- },
91
- }}
92
- >
93
- {icon}
94
- <span>{tab.title}</span>
95
- <Box
96
- component="span"
97
- onClick={(e) => {
98
- e.stopPropagation();
99
- removePanelTab(tab.id);
100
- }}
101
- sx={{
102
- display: 'flex',
103
- alignItems: 'center',
104
- justifyContent: 'center',
105
- ml: 0.5,
106
- width: 16,
107
- height: 16,
108
- borderRadius: '50%',
109
- fontSize: '0.65rem',
110
- opacity: 0.5,
111
- '&:hover': {
112
- opacity: 1,
113
- bgcolor: 'rgba(255,255,255,0.1)',
114
- },
115
- }}
116
- >
117
- βœ•
118
- </Box>
119
  </Box>
120
- );
121
- })}
122
- </Box>
123
- ) : (
124
- <Typography variant="caption" sx={{ fontWeight: 600, color: 'var(--muted-text)', textTransform: 'uppercase', letterSpacing: '0.05em' }}>
125
- {currentContent?.title || 'Code Panel'}
126
- </Typography>
127
- )}
128
- <IconButton size="small" onClick={() => setRightPanelOpen(false)} sx={{ color: 'var(--muted-text)' }}>
129
- <CloseIcon fontSize="small" />
130
- </IconButton>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  </Box>
132
 
133
- {/* Main Content Area */}
134
  <Box sx={{ flex: 1, overflow: 'hidden', display: 'flex', flexDirection: 'column' }}>
135
- {!currentContent ? (
136
  <Box sx={{ flex: 1, display: 'flex', alignItems: 'center', justifyContent: 'center', p: 4 }}>
137
  <Typography variant="body2" color="text.secondary" sx={{ opacity: 0.5 }}>
138
  NO DATA LOADED
@@ -144,174 +469,72 @@ export default function CodePanel() {
144
  ref={scrollRef}
145
  className="code-panel"
146
  sx={{
147
- background: '#0A0B0C',
148
  borderRadius: 'var(--radius-md)',
149
- padding: '18px',
150
- border: '1px solid rgba(255,255,255,0.03)',
151
- fontFamily: 'ui-monospace, SFMono-Regular, Menlo, Monaco, "Roboto Mono", monospace',
152
  fontSize: '13px',
153
  lineHeight: 1.55,
154
  height: '100%',
155
  overflow: 'auto',
156
  }}
157
  >
158
- {currentContent.content ? (
159
- currentContent.language === 'python' ? (
160
- <SyntaxHighlighter
161
- language="python"
162
- style={vscDarkPlus}
163
- customStyle={{
164
- margin: 0,
165
- padding: 0,
166
- background: 'transparent',
167
- fontSize: '13px',
168
- fontFamily: 'inherit',
169
- }}
170
- wrapLines={true}
171
- wrapLongLines={true}
172
- >
173
- {displayContent}
174
- </SyntaxHighlighter>
175
- ) : currentContent.language === 'json' ? (
176
- <SyntaxHighlighter
177
- language="json"
178
- style={vscDarkPlus}
179
- customStyle={{
180
- margin: 0,
181
- padding: 0,
182
- background: 'transparent',
183
- fontSize: '13px',
184
- fontFamily: 'inherit',
185
- }}
186
- wrapLines={true}
187
- wrapLongLines={true}
188
- >
189
- {displayContent}
190
- </SyntaxHighlighter>
191
- ) : currentContent.language === 'markdown' ? (
192
- <Box sx={{
193
- color: 'var(--text)',
194
- fontSize: '13px',
195
- lineHeight: 1.6,
196
- '& p': { m: 0, mb: 1.5, '&:last-child': { mb: 0 } },
197
- '& pre': {
198
- bgcolor: 'rgba(0,0,0,0.4)',
199
- p: 1.5,
200
- borderRadius: 1,
201
- overflow: 'auto',
202
- fontSize: '12px',
203
- border: '1px solid rgba(255,255,255,0.05)',
204
- },
205
- '& code': {
206
- bgcolor: 'rgba(255,255,255,0.05)',
207
- px: 0.5,
208
- py: 0.25,
209
- borderRadius: 0.5,
210
- fontSize: '12px',
211
- fontFamily: 'ui-monospace, SFMono-Regular, Menlo, Monaco, monospace',
212
- },
213
- '& pre code': { bgcolor: 'transparent', p: 0 },
214
- '& a': {
215
- color: 'var(--accent-yellow)',
216
- textDecoration: 'none',
217
- '&:hover': { textDecoration: 'underline' },
218
- },
219
- '& ul, & ol': { pl: 2.5, my: 1 },
220
- '& li': { mb: 0.5 },
221
- '& table': {
222
- borderCollapse: 'collapse',
223
- width: '100%',
224
- my: 2,
225
- fontSize: '12px',
226
- fontFamily: 'ui-monospace, SFMono-Regular, Menlo, Monaco, monospace',
227
- },
228
- '& th': {
229
- borderBottom: '2px solid rgba(255,255,255,0.15)',
230
- textAlign: 'left',
231
- p: 1,
232
- fontWeight: 600,
233
- },
234
- '& td': {
235
- borderBottom: '1px solid rgba(255,255,255,0.05)',
236
- p: 1,
237
- },
238
- '& h1, & h2, & h3, & h4': {
239
- mt: 2,
240
- mb: 1,
241
- fontWeight: 600,
242
- },
243
- '& h1': { fontSize: '1.25rem' },
244
- '& h2': { fontSize: '1.1rem' },
245
- '& h3': { fontSize: '1rem' },
246
- '& blockquote': {
247
- borderLeft: '3px solid rgba(255,255,255,0.2)',
248
- pl: 2,
249
- ml: 0,
250
- color: 'var(--muted-text)',
251
- },
252
- }}>
253
- <ReactMarkdown remarkPlugins={[remarkGfm]}>{displayContent}</ReactMarkdown>
254
- </Box>
255
- ) : (
256
- <Box component="pre" sx={{
257
- m: 0,
258
- fontFamily: 'inherit',
259
- color: 'var(--text)',
260
- whiteSpace: 'pre-wrap',
261
- wordBreak: 'break-all'
262
- }}>
263
- <code>{displayContent}</code>
264
- </Box>
265
- )
266
- ) : (
267
- <Box sx={{ display: 'flex', alignItems: 'center', justifyContent: 'center', height: '100%', opacity: 0.5 }}>
268
- <Typography variant="caption">
269
- NO CONTENT TO DISPLAY
270
- </Typography>
271
- </Box>
272
- )}
273
  </Box>
274
  </Box>
275
  )}
276
  </Box>
277
 
278
- {/* Plan Display at Bottom */}
279
  {plan && plan.length > 0 && (
280
- <Box sx={{
281
- borderTop: '1px solid rgba(255,255,255,0.03)',
282
- bgcolor: 'rgba(0,0,0,0.2)',
 
283
  maxHeight: '30%',
284
  display: 'flex',
285
- flexDirection: 'column'
286
- }}>
287
- <Box sx={{ p: 1.5, borderBottom: '1px solid rgba(255,255,255,0.03)', display: 'flex', alignItems: 'center', gap: 1 }}>
288
- <Typography variant="caption" sx={{ fontWeight: 600, color: 'var(--muted-text)', textTransform: 'uppercase', letterSpacing: '0.05em' }}>
289
- CURRENT PLAN
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  </Typography>
291
- </Box>
292
- <Box sx={{ p: 2, overflow: 'auto', display: 'flex', flexDirection: 'column', gap: 1 }}>
293
- {plan.map((item) => (
294
- <Box key={item.id} sx={{ display: 'flex', alignItems: 'flex-start', gap: 1.5 }}>
295
- <Box sx={{ mt: 0.2 }}>
296
- {item.status === 'completed' && <CheckCircleIcon sx={{ fontSize: 16, color: 'var(--accent-green)' }} />}
297
- {item.status === 'in_progress' && <PlayCircleOutlineIcon sx={{ fontSize: 16, color: 'var(--accent-yellow)' }} />}
298
- {item.status === 'pending' && <RadioButtonUncheckedIcon sx={{ fontSize: 16, color: 'var(--muted-text)', opacity: 0.5 }} />}
299
- </Box>
300
- <Typography
301
- variant="body2"
302
- sx={{
303
- fontSize: '13px',
304
- fontFamily: 'ui-monospace, SFMono-Regular, Menlo, Monaco, monospace',
305
- color: item.status === 'completed' ? 'var(--muted-text)' : 'var(--text)',
306
- textDecoration: item.status === 'completed' ? 'line-through' : 'none',
307
- opacity: item.status === 'pending' ? 0.7 : 1
308
- }}
309
- >
310
- {item.content}
311
- </Typography>
312
- </Box>
313
- ))}
314
- </Box>
315
  </Box>
316
  )}
317
  </Box>
 
1
+ import { useRef, useEffect, useMemo, useState, useCallback } from 'react';
2
+ import { Box, Stack, Typography, IconButton, Button, Tooltip } from '@mui/material';
3
  import CloseIcon from '@mui/icons-material/Close';
4
  import RadioButtonUncheckedIcon from '@mui/icons-material/RadioButtonUnchecked';
5
  import CheckCircleIcon from '@mui/icons-material/CheckCircle';
6
  import PlayCircleOutlineIcon from '@mui/icons-material/PlayCircleOutline';
7
  import CodeIcon from '@mui/icons-material/Code';
 
8
  import ArticleIcon from '@mui/icons-material/Article';
9
+ import EditIcon from '@mui/icons-material/Edit';
10
+ import UndoIcon from '@mui/icons-material/Undo';
11
+ import ContentCopyIcon from '@mui/icons-material/ContentCopy';
12
+ import CheckIcon from '@mui/icons-material/Check';
13
  import { Prism as SyntaxHighlighter } from 'react-syntax-highlighter';
14
+ import { vscDarkPlus, vs } from 'react-syntax-highlighter/dist/esm/styles/prism';
15
  import ReactMarkdown from 'react-markdown';
16
  import remarkGfm from 'remark-gfm';
17
  import { useAgentStore } from '@/store/agentStore';
18
  import { useLayoutStore } from '@/store/layoutStore';
19
  import { processLogs } from '@/utils/logProcessor';
20
+ import type { PanelView } from '@/store/agentStore';
21
+
22
+ // ── Helpers ──────────────────────────────────────────────────────
23
+
24
+ function PlanStatusIcon({ status }: { status: string }) {
25
+ if (status === 'completed') return <CheckCircleIcon sx={{ fontSize: 16, color: 'var(--accent-green)' }} />;
26
+ if (status === 'in_progress') return <PlayCircleOutlineIcon sx={{ fontSize: 16, color: 'var(--accent-yellow)' }} />;
27
+ return <RadioButtonUncheckedIcon sx={{ fontSize: 16, color: 'var(--muted-text)', opacity: 0.5 }} />;
28
+ }
29
+
30
+ // ── Markdown styles (adapts via CSS vars) ────────────────────────
31
+ const markdownSx = {
32
+ color: 'var(--text)',
33
+ fontSize: '13px',
34
+ lineHeight: 1.6,
35
+ '& p': { m: 0, mb: 1.5, '&:last-child': { mb: 0 } },
36
+ '& pre': {
37
+ bgcolor: 'var(--code-bg)',
38
+ p: 1.5,
39
+ borderRadius: 1,
40
+ overflow: 'auto',
41
+ fontSize: '12px',
42
+ border: '1px solid var(--tool-border)',
43
+ },
44
+ '& code': {
45
+ bgcolor: 'var(--hover-bg)',
46
+ px: 0.5,
47
+ py: 0.25,
48
+ borderRadius: 0.5,
49
+ fontSize: '12px',
50
+ fontFamily: 'ui-monospace, SFMono-Regular, Menlo, Monaco, monospace',
51
+ },
52
+ '& pre code': { bgcolor: 'transparent', p: 0 },
53
+ '& a': {
54
+ color: 'var(--accent-yellow)',
55
+ textDecoration: 'none',
56
+ '&:hover': { textDecoration: 'underline' },
57
+ },
58
+ '& ul, & ol': { pl: 2.5, my: 1 },
59
+ '& li': { mb: 0.5 },
60
+ '& table': {
61
+ borderCollapse: 'collapse',
62
+ width: '100%',
63
+ my: 2,
64
+ fontSize: '12px',
65
+ fontFamily: 'ui-monospace, SFMono-Regular, Menlo, Monaco, monospace',
66
+ },
67
+ '& th': {
68
+ borderBottom: '2px solid var(--border-hover)',
69
+ textAlign: 'left',
70
+ p: 1,
71
+ fontWeight: 600,
72
+ },
73
+ '& td': {
74
+ borderBottom: '1px solid var(--tool-border)',
75
+ p: 1,
76
+ },
77
+ '& h1, & h2, & h3, & h4': { mt: 2, mb: 1, fontWeight: 600 },
78
+ '& h1': { fontSize: '1.25rem' },
79
+ '& h2': { fontSize: '1.1rem' },
80
+ '& h3': { fontSize: '1rem' },
81
+ '& blockquote': {
82
+ borderLeft: '3px solid var(--accent-yellow)',
83
+ pl: 2,
84
+ ml: 0,
85
+ color: 'var(--muted-text)',
86
+ },
87
+ } as const;
88
+
89
+ // ── View toggle button ──────────────────────────────────────────
90
+
91
+ function ViewToggle({ view, icon, label, isActive, onClick }: {
92
+ view: PanelView;
93
+ icon: React.ReactNode;
94
+ label: string;
95
+ isActive: boolean;
96
+ onClick: (v: PanelView) => void;
97
+ }) {
98
+ return (
99
+ <Box
100
+ onClick={() => onClick(view)}
101
+ sx={{
102
+ display: 'flex',
103
+ alignItems: 'center',
104
+ gap: 0.5,
105
+ px: 1.5,
106
+ py: 0.75,
107
+ borderRadius: 1,
108
+ cursor: 'pointer',
109
+ fontSize: '0.7rem',
110
+ fontWeight: 600,
111
+ textTransform: 'uppercase',
112
+ letterSpacing: '0.05em',
113
+ whiteSpace: 'nowrap',
114
+ color: isActive ? 'var(--text)' : 'var(--muted-text)',
115
+ bgcolor: isActive ? 'var(--tab-active-bg)' : 'transparent',
116
+ border: '1px solid',
117
+ borderColor: isActive ? 'var(--tab-active-border)' : 'transparent',
118
+ transition: 'all 0.15s ease',
119
+ '&:hover': { bgcolor: 'var(--tab-hover-bg)' },
120
+ }}
121
+ >
122
+ {icon}
123
+ <span>{label}</span>
124
+ </Box>
125
+ );
126
+ }
127
+
128
+ // ── Component ────────────────────────────────────────────────────
129
 
130
  export default function CodePanel() {
131
+ const { panelData, panelView, panelEditable, setPanelView, updatePanelScript, setEditedScript, plan } =
132
+ useAgentStore();
133
+ const { setRightPanelOpen, themeMode } = useLayoutStore();
134
  const scrollRef = useRef<HTMLDivElement>(null);
135
+ const textareaRef = useRef<HTMLTextAreaElement>(null);
136
+ const [isEditing, setIsEditing] = useState(false);
137
+ const [editedContent, setEditedContent] = useState('');
138
+ const [originalContent, setOriginalContent] = useState('');
139
+ const [copied, setCopied] = useState(false);
140
+
141
+ const isDark = themeMode === 'dark';
142
+ const syntaxTheme = isDark ? vscDarkPlus : vs;
143
+
144
+ const activeSection = panelView === 'script' ? panelData?.script : panelData?.output;
145
+ const hasScript = !!panelData?.script;
146
+ const hasOutput = !!panelData?.output;
147
+ const hasBothViews = hasScript && hasOutput;
148
+
149
+ const isEditableScript = panelView === 'script' && panelEditable;
150
+ const hasUnsavedChanges = isEditing && editedContent !== originalContent;
151
+
152
+ // Sync edited content when panel data changes
153
+ useEffect(() => {
154
+ if (panelData?.script?.content && panelView === 'script' && panelEditable) {
155
+ setOriginalContent(panelData.script.content);
156
+ if (!isEditing) {
157
+ setEditedContent(panelData.script.content);
158
+ }
159
+ }
160
+ }, [panelData?.script?.content, panelView, panelEditable, isEditing]);
161
+
162
+ // Exit editing when switching away from script view or losing editable
163
+ useEffect(() => {
164
+ if (!isEditableScript && isEditing) {
165
+ setIsEditing(false);
166
+ }
167
+ }, [isEditableScript, isEditing]);
168
+
169
+ const handleStartEdit = useCallback(() => {
170
+ if (panelData?.script?.content) {
171
+ setEditedContent(panelData.script.content);
172
+ setOriginalContent(panelData.script.content);
173
+ setIsEditing(true);
174
+ setTimeout(() => textareaRef.current?.focus(), 0);
175
+ }
176
+ }, [panelData?.script?.content]);
177
+
178
+ const handleCancelEdit = useCallback(() => {
179
+ setEditedContent(originalContent);
180
+ setIsEditing(false);
181
+ }, [originalContent]);
182
+
183
+ const handleSaveEdit = useCallback(() => {
184
+ if (editedContent !== originalContent) {
185
+ updatePanelScript(editedContent);
186
+ const toolCallId = panelData?.parameters?.tool_call_id as string | undefined;
187
+ if (toolCallId) {
188
+ setEditedScript(toolCallId, editedContent);
189
+ }
190
+ setOriginalContent(editedContent);
191
+ }
192
+ setIsEditing(false);
193
+ }, [panelData?.parameters?.tool_call_id, editedContent, originalContent, updatePanelScript, setEditedScript]);
194
 
195
+ const handleCopy = useCallback(async () => {
196
+ const contentToCopy = isEditing ? editedContent : (activeSection?.content || '');
197
+ if (contentToCopy) {
198
+ try {
199
+ await navigator.clipboard.writeText(contentToCopy);
200
+ setCopied(true);
201
+ setTimeout(() => setCopied(false), 2000);
202
+ } catch (err) {
203
+ console.error('Failed to copy:', err);
204
+ }
205
+ }
206
+ }, [isEditing, editedContent, activeSection?.content]);
207
 
208
  const displayContent = useMemo(() => {
209
+ if (!activeSection?.content) return '';
210
+ if (!activeSection.language || activeSection.language === 'text') {
211
+ return processLogs(activeSection.content);
 
212
  }
213
+ return activeSection.content;
214
+ }, [activeSection?.content, activeSection?.language]);
215
 
216
  useEffect(() => {
217
+ if (scrollRef.current && panelView === 'output') {
 
218
  scrollRef.current.scrollTop = scrollRef.current.scrollHeight;
219
  }
220
+ }, [displayContent, panelView]);
221
+
222
+ // ── Syntax-highlighted code block (DRY) ────────────────────────
223
+ const renderSyntaxBlock = (language: string) => (
224
+ <SyntaxHighlighter
225
+ language={language}
226
+ style={syntaxTheme}
227
+ customStyle={{
228
+ margin: 0,
229
+ padding: 0,
230
+ background: 'transparent',
231
+ fontSize: '13px',
232
+ fontFamily: 'inherit',
233
+ }}
234
+ wrapLines
235
+ wrapLongLines
236
+ >
237
+ {displayContent}
238
+ </SyntaxHighlighter>
239
+ );
240
+
241
+ // ── Content renderer ───────────────────────────────────────────
242
+ const renderContent = () => {
243
+ if (!activeSection?.content) {
244
+ return (
245
+ <Box sx={{ display: 'flex', alignItems: 'center', justifyContent: 'center', height: '100%', opacity: 0.5 }}>
246
+ <Typography variant="caption">NO CONTENT TO DISPLAY</Typography>
247
+ </Box>
248
+ );
249
+ }
250
+
251
+ if (isEditing && isEditableScript) {
252
+ return (
253
+ <Box sx={{ position: 'relative', width: '100%', height: '100%' }}>
254
+ <SyntaxHighlighter
255
+ language={activeSection?.language === 'python' ? 'python' : activeSection?.language === 'json' ? 'json' : 'text'}
256
+ style={syntaxTheme}
257
+ customStyle={{
258
+ margin: 0,
259
+ padding: 0,
260
+ background: 'transparent',
261
+ fontSize: '13px',
262
+ fontFamily: '"JetBrains Mono", ui-monospace, SFMono-Regular, Menlo, Monaco, monospace',
263
+ lineHeight: 1.55,
264
+ pointerEvents: 'none',
265
+ }}
266
+ wrapLines
267
+ wrapLongLines
268
+ >
269
+ {editedContent || ' '}
270
+ </SyntaxHighlighter>
271
+ <textarea
272
+ ref={textareaRef}
273
+ value={editedContent}
274
+ onChange={(e) => setEditedContent(e.target.value)}
275
+ spellCheck={false}
276
+ style={{
277
+ position: 'absolute',
278
+ top: 0,
279
+ left: 0,
280
+ width: '100%',
281
+ height: '100%',
282
+ background: 'transparent',
283
+ border: 'none',
284
+ outline: 'none',
285
+ resize: 'none',
286
+ color: 'transparent',
287
+ caretColor: 'var(--text)',
288
+ fontFamily: '"JetBrains Mono", ui-monospace, SFMono-Regular, Menlo, Monaco, monospace',
289
+ fontSize: '13px',
290
+ lineHeight: 1.55,
291
+ overflow: 'hidden',
292
+ }}
293
+ />
294
+ </Box>
295
+ );
296
+ }
297
 
298
+ const lang = activeSection.language;
299
+ if (lang === 'python') return renderSyntaxBlock('python');
300
+ if (lang === 'json') return renderSyntaxBlock('json');
301
+
302
+ if (lang === 'markdown') {
303
+ return (
304
+ <Box sx={markdownSx}>
305
+ <ReactMarkdown remarkPlugins={[remarkGfm]}>{displayContent}</ReactMarkdown>
306
+ </Box>
307
+ );
308
+ }
309
+
310
+ return (
311
+ <Box
312
+ component="pre"
313
+ sx={{ m: 0, fontFamily: 'inherit', color: 'var(--text)', whiteSpace: 'pre-wrap', wordBreak: 'break-all' }}
314
+ >
315
+ <code>{displayContent}</code>
316
+ </Box>
317
+ );
318
+ };
319
 
320
  return (
321
  <Box sx={{ height: '100%', display: 'flex', flexDirection: 'column', bgcolor: 'var(--panel)' }}>
322
+ {/* ── Header ─────────────────────────────────────────────── */}
323
+ <Box
324
+ sx={{
325
+ height: 60,
326
+ display: 'flex',
327
+ alignItems: 'center',
328
+ justifyContent: 'space-between',
329
+ px: 2,
330
+ borderBottom: '1px solid var(--border)',
331
+ flexShrink: 0,
332
+ }}
333
+ >
334
+ <Box sx={{ display: 'flex', alignItems: 'center', gap: 1, flex: 1, minWidth: 0 }}>
335
+ {panelData ? (
336
+ <>
337
+ <Typography
338
+ variant="caption"
339
+ sx={{
340
+ fontWeight: 600,
341
+ color: 'var(--muted-text)',
342
+ textTransform: 'uppercase',
343
+ letterSpacing: '0.05em',
344
+ fontSize: '0.7rem',
345
+ flexShrink: 0,
346
+ }}
347
+ >
348
+ {panelData.title}
349
+ </Typography>
350
+ {hasBothViews && (
351
+ <Box sx={{ display: 'flex', gap: 0.5, ml: 1 }}>
352
+ <ViewToggle
353
+ view="script"
354
+ icon={<CodeIcon sx={{ fontSize: 14 }} />}
355
+ label="Script"
356
+ isActive={panelView === 'script'}
357
+ onClick={setPanelView}
358
+ />
359
+ <ViewToggle
360
+ view="output"
361
+ icon={<ArticleIcon sx={{ fontSize: 14 }} />}
362
+ label="Result"
363
+ isActive={panelView === 'output'}
364
+ onClick={setPanelView}
365
+ />
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  </Box>
367
+ )}
368
+ </>
369
+ ) : (
370
+ <Typography
371
+ variant="caption"
372
+ sx={{ fontWeight: 600, color: 'var(--muted-text)', textTransform: 'uppercase', letterSpacing: '0.05em' }}
373
+ >
374
+ Code Panel
375
+ </Typography>
376
+ )}
377
+ </Box>
378
+
379
+ <Box sx={{ display: 'flex', alignItems: 'center', gap: 1 }}>
380
+ {activeSection?.content && (
381
+ <Tooltip title={copied ? 'Copied!' : 'Copy'} placement="top">
382
+ <IconButton
383
+ size="small"
384
+ onClick={handleCopy}
385
+ sx={{
386
+ color: copied ? 'var(--accent-green)' : 'var(--muted-text)',
387
+ '&:hover': { color: 'var(--accent-yellow)', bgcolor: 'var(--hover-bg)' },
388
+ }}
389
+ >
390
+ {copied ? <CheckIcon sx={{ fontSize: 18 }} /> : <ContentCopyIcon sx={{ fontSize: 18 }} />}
391
+ </IconButton>
392
+ </Tooltip>
393
+ )}
394
+ {isEditableScript && !isEditing && (
395
+ <Button
396
+ size="small"
397
+ startIcon={<EditIcon sx={{ fontSize: 14 }} />}
398
+ onClick={handleStartEdit}
399
+ sx={{
400
+ textTransform: 'none',
401
+ color: 'var(--muted-text)',
402
+ fontSize: '0.75rem',
403
+ py: 0.5,
404
+ '&:hover': { color: 'var(--accent-yellow)', bgcolor: 'var(--hover-bg)' },
405
+ }}
406
+ >
407
+ Edit
408
+ </Button>
409
+ )}
410
+ {isEditing && (
411
+ <>
412
+ <Button
413
+ size="small"
414
+ startIcon={<UndoIcon sx={{ fontSize: 14 }} />}
415
+ onClick={handleCancelEdit}
416
+ sx={{
417
+ textTransform: 'none',
418
+ color: 'var(--muted-text)',
419
+ fontSize: '0.75rem',
420
+ py: 0.5,
421
+ '&:hover': { color: 'var(--accent-red)', bgcolor: 'var(--hover-bg)' },
422
+ }}
423
+ >
424
+ Cancel
425
+ </Button>
426
+ <Button
427
+ size="small"
428
+ variant="contained"
429
+ onClick={handleSaveEdit}
430
+ disabled={!hasUnsavedChanges}
431
+ sx={{
432
+ textTransform: 'none',
433
+ fontSize: '0.75rem',
434
+ py: 0.5,
435
+ bgcolor: hasUnsavedChanges ? 'var(--accent-yellow)' : 'var(--hover-bg)',
436
+ color: hasUnsavedChanges ? '#000' : 'var(--muted-text)',
437
+ '&:hover': {
438
+ bgcolor: hasUnsavedChanges ? 'var(--accent-yellow)' : 'var(--hover-bg)',
439
+ opacity: 0.9,
440
+ },
441
+ '&.Mui-disabled': {
442
+ bgcolor: 'var(--hover-bg)',
443
+ color: 'var(--muted-text)',
444
+ opacity: 0.5,
445
+ },
446
+ }}
447
+ >
448
+ Save
449
+ </Button>
450
+ </>
451
+ )}
452
+ <IconButton size="small" onClick={() => setRightPanelOpen(false)} sx={{ color: 'var(--muted-text)' }}>
453
+ <CloseIcon fontSize="small" />
454
+ </IconButton>
455
+ </Box>
456
  </Box>
457
 
458
+ {/* ── Main content area ─────────────────────────────────── */}
459
  <Box sx={{ flex: 1, overflow: 'hidden', display: 'flex', flexDirection: 'column' }}>
460
+ {!panelData ? (
461
  <Box sx={{ flex: 1, display: 'flex', alignItems: 'center', justifyContent: 'center', p: 4 }}>
462
  <Typography variant="body2" color="text.secondary" sx={{ opacity: 0.5 }}>
463
  NO DATA LOADED
 
469
  ref={scrollRef}
470
  className="code-panel"
471
  sx={{
472
+ bgcolor: 'var(--code-panel-bg)',
473
  borderRadius: 'var(--radius-md)',
474
+ p: '18px',
475
+ border: '1px solid var(--border)',
476
+ fontFamily: '"JetBrains Mono", ui-monospace, SFMono-Regular, Menlo, Monaco, monospace',
477
  fontSize: '13px',
478
  lineHeight: 1.55,
479
  height: '100%',
480
  overflow: 'auto',
481
  }}
482
  >
483
+ {renderContent()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
484
  </Box>
485
  </Box>
486
  )}
487
  </Box>
488
 
489
+ {/* ── Plan display (bottom) ─────────────────────────────── */}
490
  {plan && plan.length > 0 && (
491
+ <Box
492
+ sx={{
493
+ borderTop: '1px solid var(--border)',
494
+ bgcolor: 'var(--plan-bg)',
495
  maxHeight: '30%',
496
  display: 'flex',
497
+ flexDirection: 'column',
498
+ }}
499
+ >
500
+ <Box
501
+ sx={{
502
+ p: 1.5,
503
+ borderBottom: '1px solid var(--border)',
504
+ display: 'flex',
505
+ alignItems: 'center',
506
+ gap: 1,
507
+ }}
508
+ >
509
+ <Typography
510
+ variant="caption"
511
+ sx={{ fontWeight: 600, color: 'var(--muted-text)', textTransform: 'uppercase', letterSpacing: '0.05em' }}
512
+ >
513
+ CURRENT PLAN
514
+ </Typography>
515
+ </Box>
516
+
517
+ <Stack spacing={1} sx={{ p: 2, overflow: 'auto' }}>
518
+ {plan.map((item) => (
519
+ <Stack key={item.id} direction="row" alignItems="flex-start" spacing={1.5}>
520
+ <Box sx={{ mt: 0.2 }}>
521
+ <PlanStatusIcon status={item.status} />
522
+ </Box>
523
+ <Typography
524
+ variant="body2"
525
+ sx={{
526
+ fontSize: '13px',
527
+ fontFamily: 'ui-monospace, SFMono-Regular, Menlo, Monaco, monospace',
528
+ color: item.status === 'completed' ? 'var(--muted-text)' : 'var(--text)',
529
+ textDecoration: item.status === 'completed' ? 'line-through' : 'none',
530
+ opacity: item.status === 'pending' ? 0.7 : 1,
531
+ }}
532
+ >
533
+ {item.content}
534
  </Typography>
535
+ </Stack>
536
+ ))}
537
+ </Stack>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
538
  </Box>
539
  )}
540
  </Box>
frontend/src/components/Layout/AppLayout.tsx CHANGED
@@ -1,65 +1,83 @@
1
- import { useCallback, useRef, useEffect } from 'react';
2
  import {
 
3
  Box,
4
  Drawer,
5
  Typography,
6
  IconButton,
 
 
 
 
 
7
  } from '@mui/material';
8
  import MenuIcon from '@mui/icons-material/Menu';
9
  import ChevronLeftIcon from '@mui/icons-material/ChevronLeft';
10
  import DragIndicatorIcon from '@mui/icons-material/DragIndicator';
 
 
 
11
 
12
  import { useSessionStore } from '@/store/sessionStore';
13
  import { useAgentStore } from '@/store/agentStore';
14
  import { useLayoutStore } from '@/store/layoutStore';
15
- import { useAgentWebSocket } from '@/hooks/useAgentWebSocket';
16
  import SessionSidebar from '@/components/SessionSidebar/SessionSidebar';
17
  import CodePanel from '@/components/CodePanel/CodePanel';
18
  import ChatInput from '@/components/Chat/ChatInput';
19
  import MessageList from '@/components/Chat/MessageList';
20
- import type { Message } from '@/types/agent';
 
21
 
22
  const DRAWER_WIDTH = 260;
23
 
24
  export default function AppLayout() {
25
- const { activeSessionId } = useSessionStore();
26
- const { isConnected, isProcessing, getMessages, addMessage } = useAgentStore();
27
  const {
28
  isLeftSidebarOpen,
29
  isRightPanelOpen,
30
  rightPanelWidth,
 
31
  setRightPanelWidth,
 
32
  toggleLeftSidebar,
33
- toggleRightPanel
34
  } = useLayoutStore();
35
 
36
- const isResizing = useRef(false);
 
37
 
38
- const startResizing = useCallback((e: React.MouseEvent) => {
39
- e.preventDefault();
40
- isResizing.current = true;
41
- document.addEventListener('mousemove', handleMouseMove);
42
- document.addEventListener('mouseup', stopResizing);
43
- document.body.style.cursor = 'col-resize';
44
- }, []);
45
 
46
- const stopResizing = useCallback(() => {
47
- isResizing.current = false;
48
- document.removeEventListener('mousemove', handleMouseMove);
49
- document.removeEventListener('mouseup', stopResizing);
50
- document.body.style.cursor = 'default';
51
- }, []);
52
 
53
  const handleMouseMove = useCallback((e: MouseEvent) => {
54
  if (!isResizing.current) return;
55
  const newWidth = window.innerWidth - e.clientX;
56
- const maxWidth = window.innerWidth * 0.8;
57
  const minWidth = 300;
58
  if (newWidth > minWidth && newWidth < maxWidth) {
59
  setRightPanelWidth(newWidth);
60
  }
61
  }, [setRightPanelWidth]);
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  useEffect(() => {
64
  return () => {
65
  document.removeEventListener('mousemove', handleMouseMove);
@@ -67,75 +85,157 @@ export default function AppLayout() {
67
  };
68
  }, [handleMouseMove, stopResizing]);
69
 
70
- const messages = activeSessionId ? getMessages(activeSessionId) : [];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
- useAgentWebSocket({
73
  sessionId: activeSessionId,
74
- onReady: () => console.log('Agent ready'),
75
- onError: (error) => console.error('Agent error:', error),
 
 
 
 
76
  });
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  const handleSendMessage = useCallback(
79
  async (text: string) => {
80
- if (!activeSessionId || !text.trim()) return;
81
-
82
- const userMsg: Message = {
83
- id: `user_${Date.now()}`,
84
- role: 'user',
85
- content: text.trim(),
86
- timestamp: new Date().toISOString(),
87
- };
88
- addMessage(activeSessionId, userMsg);
89
 
90
- try {
91
- await fetch('/api/submit', {
 
 
 
 
 
 
92
  method: 'POST',
93
- headers: { 'Content-Type': 'application/json' },
94
- body: JSON.stringify({
95
- session_id: activeSessionId,
96
- text: text.trim(),
97
- }),
98
- });
99
- } catch (e) {
100
- console.error('Send failed:', e);
 
 
101
  }
102
  },
103
- [activeSessionId, addMessage]
104
  );
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  return (
107
  <Box sx={{ display: 'flex', width: '100%', height: '100%' }}>
108
- {/* Left Sidebar Drawer */}
109
- <Box
110
- component="nav"
111
- sx={{
112
- width: { md: isLeftSidebarOpen ? DRAWER_WIDTH : 0 },
113
- flexShrink: { md: 0 },
114
- transition: isResizing.current ? 'none' : 'width 0.2s',
115
- overflow: 'hidden',
116
- }}
117
- >
118
- <Drawer
119
- variant="persistent"
120
  sx={{
121
- display: { xs: 'none', md: 'block' },
122
- '& .MuiDrawer-paper': {
123
- boxSizing: 'border-box',
124
- width: DRAWER_WIDTH,
125
- borderRight: '1px solid',
126
- borderColor: 'divider',
127
- top: 0,
128
- height: '100%',
129
- bgcolor: 'var(--panel)', // Ensure correct background matches sidebar
130
- },
131
  }}
132
- open={isLeftSidebarOpen}
133
  >
134
- <SessionSidebar />
135
- </Drawer>
136
- </Box>
137
 
138
- {/* Main Content Area */}
139
  <Box
140
  sx={{
141
  flexGrow: 1,
@@ -143,142 +243,226 @@ export default function AppLayout() {
143
  display: 'flex',
144
  flexDirection: 'column',
145
  transition: isResizing.current ? 'none' : 'width 0.2s',
146
- position: 'relative',
147
  overflow: 'hidden',
 
148
  }}
149
  >
150
- {/* Top Header Bar (Fixed) */}
151
  <Box sx={{
152
- height: '60px',
153
- px: 1,
154
  display: 'flex',
155
  alignItems: 'center',
156
  borderBottom: 1,
157
  borderColor: 'divider',
158
  bgcolor: 'background.default',
159
  zIndex: 1200,
 
160
  }}>
161
  <IconButton onClick={toggleLeftSidebar} size="small">
162
- {isLeftSidebarOpen ? <ChevronLeftIcon /> : <MenuIcon />}
163
  </IconButton>
164
 
165
- <Box sx={{ flex: 1, display: 'flex', justifyContent: 'center' }}>
166
- <img
167
- src="/hf-logo-white.png"
168
- alt="Hugging Face"
169
- style={{ height: '40px', objectFit: 'contain' }}
 
170
  />
 
 
 
 
 
 
 
 
 
 
 
171
  </Box>
172
 
173
- <IconButton
174
- onClick={toggleRightPanel}
175
- size="small"
176
- sx={{ visibility: isRightPanelOpen ? 'hidden' : 'visible' }}
177
- >
178
- <MenuIcon />
179
- </IconButton>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  </Box>
181
 
 
182
  <Box
183
- component="main"
184
- className="chat-pane"
185
  sx={{
186
  flexGrow: 1,
187
  display: 'flex',
188
- flexDirection: 'column',
189
  overflow: 'hidden',
190
- background: 'linear-gradient(180deg, var(--bg), var(--panel))',
191
- padding: '24px',
192
  }}
193
  >
194
- {activeSessionId ? (
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  <>
196
- <MessageList messages={messages} isProcessing={isProcessing} />
197
- <ChatInput
198
- onSend={handleSendMessage}
199
- disabled={isProcessing || !isConnected}
200
- />
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  </>
202
- ) : (
203
- <Box
204
- sx={{
205
- flex: 1,
206
- display: 'flex',
207
- alignItems: 'center',
208
- justifyContent: 'center',
209
- flexDirection: 'column',
210
- gap: 2,
211
- }}
212
- >
213
- <Typography variant="h5" color="text.secondary" sx={{ fontFamily: 'monospace' }}>
214
- NO SESSION SELECTED
215
- </Typography>
216
- <Typography variant="body2" color="text.secondary" sx={{ fontFamily: 'monospace' }}>
217
- Initialize a session via the sidebar
218
- </Typography>
219
- </Box>
220
  )}
221
  </Box>
222
  </Box>
223
 
224
- {/* Resize Handle */}
225
- {isRightPanelOpen && (
226
- <Box
227
- onMouseDown={startResizing}
228
- sx={{
229
- width: '4px',
230
- cursor: 'col-resize',
231
- bgcolor: 'divider',
232
- display: 'flex',
233
- alignItems: 'center',
234
- justifyContent: 'center',
235
- transition: 'background-color 0.2s',
236
- zIndex: 1300,
237
- overflow: 'hidden',
238
- '&:hover': {
239
- bgcolor: 'primary.main',
240
- },
241
- }}
242
- >
243
- <DragIndicatorIcon
244
- sx={{
245
- fontSize: '0.8rem',
246
- color: 'text.secondary',
247
- pointerEvents: 'none',
248
- }}
249
- />
250
- </Box>
251
- )}
252
-
253
- {/* Right Panel Drawer */}
254
- <Box
255
- component="nav"
256
- sx={{
257
- width: { md: isRightPanelOpen ? rightPanelWidth : 0 },
258
- flexShrink: { md: 0 },
259
- transition: isResizing.current ? 'none' : 'width 0.2s',
260
- overflow: 'hidden',
261
- }}
262
- >
263
  <Drawer
264
- anchor="right"
265
- variant="persistent"
 
266
  sx={{
267
- display: { xs: 'none', md: 'block' },
268
  '& .MuiDrawer-paper': {
269
- boxSizing: 'border-box',
270
- width: rightPanelWidth,
271
- borderLeft: 'none',
272
- top: 0,
273
- height: '100%',
274
  bgcolor: 'var(--panel)',
275
  },
276
  }}
277
- open={isRightPanelOpen}
278
  >
279
  <CodePanel />
280
  </Drawer>
281
- </Box>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
  </Box>
283
  );
284
  }
 
1
+ import { useCallback, useRef, useEffect, useState } from 'react';
2
  import {
3
+ Avatar,
4
  Box,
5
  Drawer,
6
  Typography,
7
  IconButton,
8
+ Alert,
9
+ AlertTitle,
10
+ Snackbar,
11
+ useMediaQuery,
12
+ useTheme,
13
  } from '@mui/material';
14
  import MenuIcon from '@mui/icons-material/Menu';
15
  import ChevronLeftIcon from '@mui/icons-material/ChevronLeft';
16
  import DragIndicatorIcon from '@mui/icons-material/DragIndicator';
17
+ import DarkModeOutlinedIcon from '@mui/icons-material/DarkModeOutlined';
18
+ import LightModeOutlinedIcon from '@mui/icons-material/LightModeOutlined';
19
+ import { logger } from '@/utils/logger';
20
 
21
  import { useSessionStore } from '@/store/sessionStore';
22
  import { useAgentStore } from '@/store/agentStore';
23
  import { useLayoutStore } from '@/store/layoutStore';
24
+ import { useAgentChat } from '@/hooks/useAgentChat';
25
  import SessionSidebar from '@/components/SessionSidebar/SessionSidebar';
26
  import CodePanel from '@/components/CodePanel/CodePanel';
27
  import ChatInput from '@/components/Chat/ChatInput';
28
  import MessageList from '@/components/Chat/MessageList';
29
+ import WelcomeScreen from '@/components/WelcomeScreen/WelcomeScreen';
30
+ import { apiFetch } from '@/utils/api';
31
 
32
  const DRAWER_WIDTH = 260;
33
 
34
  export default function AppLayout() {
35
+ const { sessions, activeSessionId, deleteSession, updateSessionTitle } = useSessionStore();
36
+ const { isConnected, isProcessing, setProcessing, activityStatus, llmHealthError, setLlmHealthError, user } = useAgentStore();
37
  const {
38
  isLeftSidebarOpen,
39
  isRightPanelOpen,
40
  rightPanelWidth,
41
+ themeMode,
42
  setRightPanelWidth,
43
+ setLeftSidebarOpen,
44
  toggleLeftSidebar,
45
+ toggleTheme,
46
  } = useLayoutStore();
47
 
48
+ const theme = useTheme();
49
+ const isMobile = useMediaQuery(theme.breakpoints.down('md'));
50
 
51
+ const [showExpiredToast, setShowExpiredToast] = useState(false);
52
+ const disconnectTimer = useRef<ReturnType<typeof setTimeout> | null>(null);
 
 
 
 
 
53
 
54
+ const isResizing = useRef(false);
 
 
 
 
 
55
 
56
  const handleMouseMove = useCallback((e: MouseEvent) => {
57
  if (!isResizing.current) return;
58
  const newWidth = window.innerWidth - e.clientX;
59
+ const maxWidth = window.innerWidth * 0.6;
60
  const minWidth = 300;
61
  if (newWidth > minWidth && newWidth < maxWidth) {
62
  setRightPanelWidth(newWidth);
63
  }
64
  }, [setRightPanelWidth]);
65
 
66
+ const stopResizing = useCallback(() => {
67
+ isResizing.current = false;
68
+ document.removeEventListener('mousemove', handleMouseMove);
69
+ document.removeEventListener('mouseup', stopResizing);
70
+ document.body.style.cursor = 'default';
71
+ }, [handleMouseMove]);
72
+
73
+ const startResizing = useCallback((e: React.MouseEvent) => {
74
+ e.preventDefault();
75
+ isResizing.current = true;
76
+ document.addEventListener('mousemove', handleMouseMove);
77
+ document.addEventListener('mouseup', stopResizing);
78
+ document.body.style.cursor = 'col-resize';
79
+ }, [handleMouseMove, stopResizing]);
80
+
81
  useEffect(() => {
82
  return () => {
83
  document.removeEventListener('mousemove', handleMouseMove);
 
85
  };
86
  }, [handleMouseMove, stopResizing]);
87
 
88
+ // ── LLM health check on mount ───────────────────────────────────
89
+ useEffect(() => {
90
+ let cancelled = false;
91
+ (async () => {
92
+ try {
93
+ const res = await apiFetch('/api/health/llm');
94
+ const data = await res.json();
95
+ if (!cancelled && data.status === 'error') {
96
+ setLlmHealthError({
97
+ error: data.error || 'Unknown LLM error',
98
+ errorType: data.error_type || 'unknown',
99
+ model: data.model,
100
+ });
101
+ } else if (!cancelled) {
102
+ setLlmHealthError(null);
103
+ }
104
+ } catch {
105
+ // Backend unreachable β€” not an LLM issue, ignore
106
+ }
107
+ })();
108
+ return () => { cancelled = true; };
109
+ }, []); // eslint-disable-line react-hooks/exhaustive-deps
110
+
111
+ const hasAnySessions = sessions.length > 0;
112
 
113
+ const { messages, sendMessage, undoLastTurn, approveTools } = useAgentChat({
114
  sessionId: activeSessionId,
115
+ onReady: () => logger.log('Agent ready'),
116
+ onError: (error) => logger.error('Agent error:', error),
117
+ onSessionDead: (deadSessionId) => {
118
+ logger.log('Removing dead session:', deadSessionId);
119
+ deleteSession(deadSessionId);
120
+ },
121
  });
122
 
123
+ // Debounced "session expired" toast β€” only fires after 2s of sustained disconnect
124
+ useEffect(() => {
125
+ if (!isConnected && messages.length > 0 && activeSessionId) {
126
+ disconnectTimer.current = setTimeout(() => setShowExpiredToast(true), 2000);
127
+ } else {
128
+ if (disconnectTimer.current) clearTimeout(disconnectTimer.current);
129
+ disconnectTimer.current = null;
130
+ setShowExpiredToast(false);
131
+ }
132
+ return () => {
133
+ if (disconnectTimer.current) clearTimeout(disconnectTimer.current);
134
+ };
135
+ }, [isConnected, messages.length, activeSessionId]);
136
+
137
  const handleSendMessage = useCallback(
138
  async (text: string) => {
139
+ if (!activeSessionId || !text.trim() || isProcessing) return;
 
 
 
 
 
 
 
 
140
 
141
+ setProcessing(true);
142
+ sendMessage({ text: text.trim(), metadata: { createdAt: new Date().toISOString() } });
143
+
144
+ // Auto-title the session from the first user message (async, non-blocking)
145
+ const isFirstMessage = messages.filter((m) => m.role === 'user').length <= 1;
146
+ if (isFirstMessage) {
147
+ const sessionId = activeSessionId;
148
+ apiFetch('/api/title', {
149
  method: 'POST',
150
+ body: JSON.stringify({ session_id: sessionId, text: text.trim() }),
151
+ })
152
+ .then((res) => res.json())
153
+ .then((data) => {
154
+ if (data.title) updateSessionTitle(sessionId, data.title);
155
+ })
156
+ .catch(() => {
157
+ const raw = text.trim();
158
+ updateSessionTitle(sessionId, raw.length > 40 ? raw.slice(0, 40) + '…' : raw);
159
+ });
160
  }
161
  },
162
+ [activeSessionId, sendMessage, messages, updateSessionTitle, isProcessing, setProcessing],
163
  );
164
 
165
+ // Close sidebar on mobile after selecting a session
166
+ const handleSidebarClose = useCallback(() => {
167
+ if (isMobile) setLeftSidebarOpen(false);
168
+ }, [isMobile, setLeftSidebarOpen]);
169
+
170
+ // ── LLM error toast helper ──────────────────────────────────────────
171
+ const llmErrorTitle = llmHealthError
172
+ ? llmHealthError.errorType === 'credits'
173
+ ? 'API Credits Exhausted'
174
+ : llmHealthError.errorType === 'auth'
175
+ ? 'Invalid API Key'
176
+ : llmHealthError.errorType === 'rate_limit'
177
+ ? 'Rate Limited'
178
+ : llmHealthError.errorType === 'network'
179
+ ? 'LLM Provider Unreachable'
180
+ : 'LLM Error'
181
+ : '';
182
+
183
+ // ── Welcome screen: no sessions at all ────────────────────────────
184
+ if (!hasAnySessions) {
185
+ return (
186
+ <Box sx={{ width: '100%', height: '100%', display: 'flex', flexDirection: 'column' }}>
187
+ <WelcomeScreen />
188
+ </Box>
189
+ );
190
+ }
191
+
192
+ // ── Sidebar drawer ────────────────────────────────────────────────
193
+ const sidebarDrawer = (
194
+ <Drawer
195
+ variant={isMobile ? 'temporary' : 'persistent'}
196
+ anchor="left"
197
+ open={isLeftSidebarOpen}
198
+ onClose={() => setLeftSidebarOpen(false)}
199
+ ModalProps={{ keepMounted: true }} // Better mobile perf
200
+ sx={{
201
+ '& .MuiDrawer-paper': {
202
+ boxSizing: 'border-box',
203
+ width: DRAWER_WIDTH,
204
+ borderRight: '1px solid',
205
+ borderColor: 'divider',
206
+ top: 0,
207
+ height: '100%',
208
+ bgcolor: 'var(--panel)',
209
+ },
210
+ }}
211
+ >
212
+ <SessionSidebar onClose={handleSidebarClose} />
213
+ </Drawer>
214
+ );
215
+
216
+ // ── Main chat interface ───────────────────────────────────────────
217
  return (
218
  <Box sx={{ display: 'flex', width: '100%', height: '100%' }}>
219
+ {/* ── Left Sidebar ─────────────────────────────────────────── */}
220
+ {isMobile ? (
221
+ // Mobile: temporary overlay drawer (no reserved width)
222
+ sidebarDrawer
223
+ ) : (
224
+ // Desktop: persistent drawer with reserved width
225
+ <Box
226
+ component="nav"
 
 
 
 
227
  sx={{
228
+ width: isLeftSidebarOpen ? DRAWER_WIDTH : 0,
229
+ flexShrink: 0,
230
+ transition: isResizing.current ? 'none' : 'width 0.2s',
231
+ overflow: 'hidden',
 
 
 
 
 
 
232
  }}
 
233
  >
234
+ {sidebarDrawer}
235
+ </Box>
236
+ )}
237
 
238
+ {/* ── Main Content (header + chat + code panel) ────────────── */}
239
  <Box
240
  sx={{
241
  flexGrow: 1,
 
243
  display: 'flex',
244
  flexDirection: 'column',
245
  transition: isResizing.current ? 'none' : 'width 0.2s',
 
246
  overflow: 'hidden',
247
+ minWidth: 0,
248
  }}
249
  >
250
+ {/* ── Top Header Bar ─────────────────────────────────────── */}
251
  <Box sx={{
252
+ height: { xs: 52, md: 60 },
253
+ px: { xs: 1, md: 2 },
254
  display: 'flex',
255
  alignItems: 'center',
256
  borderBottom: 1,
257
  borderColor: 'divider',
258
  bgcolor: 'background.default',
259
  zIndex: 1200,
260
+ flexShrink: 0,
261
  }}>
262
  <IconButton onClick={toggleLeftSidebar} size="small">
263
+ {isLeftSidebarOpen && !isMobile ? <ChevronLeftIcon /> : <MenuIcon />}
264
  </IconButton>
265
 
266
+ <Box sx={{ flex: 1, display: 'flex', justifyContent: 'center', alignItems: 'center', gap: 0.75 }}>
267
+ <Box
268
+ component="img"
269
+ src="https://huggingface.co/front/assets/huggingface_logo-noborder.svg"
270
+ alt="HF"
271
+ sx={{ width: { xs: 20, md: 22 }, height: { xs: 20, md: 22 } }}
272
  />
273
+ <Typography
274
+ variant="subtitle1"
275
+ sx={{
276
+ fontWeight: 700,
277
+ color: 'var(--text)',
278
+ letterSpacing: '-0.01em',
279
+ fontSize: { xs: '0.88rem', md: '0.95rem' },
280
+ }}
281
+ >
282
+ HF Agent
283
+ </Typography>
284
  </Box>
285
 
286
+ <Box sx={{ display: 'flex', alignItems: 'center', gap: 0.5 }}>
287
+ <IconButton
288
+ onClick={toggleTheme}
289
+ size="small"
290
+ sx={{
291
+ color: 'text.secondary',
292
+ '&:hover': { color: 'primary.main' },
293
+ }}
294
+ >
295
+ {themeMode === 'dark' ? <LightModeOutlinedIcon fontSize="small" /> : <DarkModeOutlinedIcon fontSize="small" />}
296
+ </IconButton>
297
+
298
+ {user?.picture ? (
299
+ <Avatar
300
+ src={user.picture}
301
+ alt={user.username || 'User'}
302
+ sx={{ width: 28, height: 28, ml: 0.5 }}
303
+ />
304
+ ) : user?.username ? (
305
+ <Avatar
306
+ sx={{
307
+ width: 28,
308
+ height: 28,
309
+ ml: 0.5,
310
+ bgcolor: 'primary.main',
311
+ fontSize: '0.75rem',
312
+ fontWeight: 700,
313
+ }}
314
+ >
315
+ {user.username[0].toUpperCase()}
316
+ </Avatar>
317
+ ) : null}
318
+ </Box>
319
  </Box>
320
 
321
+ {/* ── Chat + Code Panel ─────────────────────────���────────── */}
322
  <Box
 
 
323
  sx={{
324
  flexGrow: 1,
325
  display: 'flex',
 
326
  overflow: 'hidden',
 
 
327
  }}
328
  >
329
+ {/* Chat area */}
330
+ <Box
331
+ component="main"
332
+ className="chat-pane"
333
+ sx={{
334
+ flexGrow: 1,
335
+ display: 'flex',
336
+ flexDirection: 'column',
337
+ overflow: 'hidden',
338
+ background: 'var(--body-gradient)',
339
+ p: { xs: 1.5, sm: 2, md: 3 },
340
+ minWidth: 0,
341
+ }}
342
+ >
343
+ {activeSessionId ? (
344
+ <>
345
+ <MessageList messages={messages} isProcessing={isProcessing} approveTools={approveTools} onUndoLastTurn={undoLastTurn} />
346
+ <ChatInput
347
+ onSend={handleSendMessage}
348
+ disabled={isProcessing || !isConnected || activityStatus.type === 'waiting-approval'}
349
+ placeholder={activityStatus.type === 'waiting-approval' ? 'Approve or reject pending tools first...' : undefined}
350
+ />
351
+ </>
352
+ ) : (
353
+ <Box
354
+ sx={{
355
+ flex: 1,
356
+ display: 'flex',
357
+ alignItems: 'center',
358
+ justifyContent: 'center',
359
+ flexDirection: 'column',
360
+ gap: 2,
361
+ px: 2,
362
+ }}
363
+ >
364
+ <Typography variant="h5" color="text.secondary" sx={{ fontFamily: 'monospace', fontSize: { xs: '1rem', md: '1.5rem' } }}>
365
+ NO SESSION SELECTED
366
+ </Typography>
367
+ <Typography variant="body2" color="text.secondary" sx={{ fontFamily: 'monospace', fontSize: { xs: '0.75rem', md: '0.875rem' } }}>
368
+ Initialize a session via the sidebar
369
+ </Typography>
370
+ </Box>
371
+ )}
372
+ </Box>
373
+
374
+ {/* Code panel β€” inline on desktop, overlay drawer on mobile */}
375
+ {isRightPanelOpen && !isMobile && (
376
  <>
377
+ <Box
378
+ onMouseDown={startResizing}
379
+ sx={{
380
+ width: '4px',
381
+ cursor: 'col-resize',
382
+ bgcolor: 'divider',
383
+ display: 'flex',
384
+ alignItems: 'center',
385
+ justifyContent: 'center',
386
+ transition: 'background-color 0.2s',
387
+ flexShrink: 0,
388
+ '&:hover': { bgcolor: 'primary.main' },
389
+ }}
390
+ >
391
+ <DragIndicatorIcon
392
+ sx={{ fontSize: '0.8rem', color: 'text.secondary', pointerEvents: 'none' }}
393
+ />
394
+ </Box>
395
+ <Box
396
+ sx={{
397
+ width: rightPanelWidth,
398
+ flexShrink: 0,
399
+ height: '100%',
400
+ overflow: 'hidden',
401
+ borderLeft: '1px solid',
402
+ borderColor: 'divider',
403
+ bgcolor: 'var(--panel)',
404
+ }}
405
+ >
406
+ <CodePanel />
407
+ </Box>
408
  </>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
  )}
410
  </Box>
411
  </Box>
412
 
413
+ {/* Code panel β€” drawer overlay on mobile */}
414
+ {isMobile && (
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
  <Drawer
416
+ anchor="bottom"
417
+ open={isRightPanelOpen}
418
+ onClose={() => useLayoutStore.getState().setRightPanelOpen(false)}
419
  sx={{
 
420
  '& .MuiDrawer-paper': {
421
+ height: '75vh',
422
+ borderTopLeftRadius: 16,
423
+ borderTopRightRadius: 16,
 
 
424
  bgcolor: 'var(--panel)',
425
  },
426
  }}
 
427
  >
428
  <CodePanel />
429
  </Drawer>
430
+ )}
431
+ <Snackbar
432
+ open={showExpiredToast}
433
+ anchorOrigin={{ vertical: 'bottom', horizontal: 'center' }}
434
+ onClose={() => setShowExpiredToast(false)}
435
+ >
436
+ <Alert
437
+ severity="warning"
438
+ variant="filled"
439
+ onClose={() => setShowExpiredToast(false)}
440
+ sx={{ fontFamily: 'monospace', fontSize: '0.8rem' }}
441
+ >
442
+ Session expired β€” create a new session to continue.
443
+ </Alert>
444
+ </Snackbar>
445
+ <Snackbar
446
+ open={!!llmHealthError}
447
+ anchorOrigin={{ vertical: 'top', horizontal: 'center' }}
448
+ onClose={() => setLlmHealthError(null)}
449
+ >
450
+ <Alert
451
+ severity="error"
452
+ variant="filled"
453
+ onClose={() => setLlmHealthError(null)}
454
+ sx={{ fontSize: '0.8rem', maxWidth: 480 }}
455
+ >
456
+ <AlertTitle sx={{ fontWeight: 700, fontSize: '0.85rem' }}>
457
+ {llmErrorTitle}
458
+ </AlertTitle>
459
+ {llmHealthError && (
460
+ <Typography variant="body2" sx={{ fontSize: '0.78rem', opacity: 0.9 }}>
461
+ {llmHealthError.model} β€” {llmHealthError.error.slice(0, 150)}
462
+ </Typography>
463
+ )}
464
+ </Alert>
465
+ </Snackbar>
466
  </Box>
467
  );
468
  }
frontend/src/components/SessionSidebar/SessionSidebar.tsx CHANGED
@@ -1,246 +1,344 @@
1
- import { useCallback } from 'react';
2
  import {
 
3
  Box,
4
- List,
5
- ListItem,
6
  IconButton,
7
  Typography,
8
- Button,
9
- Tooltip,
10
  } from '@mui/material';
11
- import DeleteIcon from '@mui/icons-material/Delete';
12
- import UndoIcon from '@mui/icons-material/Undo';
 
13
  import { useSessionStore } from '@/store/sessionStore';
14
  import { useAgentStore } from '@/store/agentStore';
 
15
 
16
  interface SessionSidebarProps {
17
  onClose?: () => void;
18
  }
19
 
20
- const StatusDiode = ({ connected }: { connected: boolean }) => (
 
21
  <Box
22
  sx={{
23
- width: 10,
24
- height: 10,
25
  borderRadius: '50%',
26
- bgcolor: connected ? 'var(--accent-green)' : 'var(--accent-red)', // Use green/red for connection status
27
- boxShadow: connected ? '0 0 6px rgba(47, 204, 113, 0.4)' : 'none',
28
- transition: 'all 0.3s ease',
29
  }}
30
  />
31
  );
32
 
33
- const RunningIndicator = () => (
34
- <Box
35
- className="running-indicator"
36
- sx={{
37
- width: 10,
38
- height: 10,
39
- borderRadius: '50%',
40
- bgcolor: 'var(--accent-yellow)',
41
- boxShadow: '0 0 6px rgba(199,165,0,0.18)',
42
- }}
43
- />
44
- );
45
-
46
  export default function SessionSidebar({ onClose }: SessionSidebarProps) {
47
  const { sessions, activeSessionId, createSession, deleteSession, switchSession } =
48
  useSessionStore();
49
- const { clearMessages, isConnected, isProcessing, setPlan, setPanelContent } = useAgentStore();
 
 
 
 
 
50
 
51
  const handleNewSession = useCallback(async () => {
 
 
 
52
  try {
53
- const response = await fetch('/api/session', { method: 'POST' });
 
 
 
 
 
54
  const data = await response.json();
55
  createSession(data.session_id);
56
- // Clear plan and code panel for new session
57
  setPlan([]);
58
- setPanelContent(null);
59
  onClose?.();
60
- } catch (e) {
61
- console.error('Failed to create session:', e);
 
 
62
  }
63
- }, [createSession, setPlan, setPanelContent, onClose]);
64
 
65
- const handleDeleteSession = useCallback(
66
  async (sessionId: string, e: React.MouseEvent) => {
67
  e.stopPropagation();
68
  try {
69
- await fetch(`/api/session/${sessionId}`, { method: 'DELETE' });
 
 
 
70
  deleteSession(sessionId);
71
- clearMessages(sessionId);
72
- } catch (e) {
73
- console.error('Failed to delete session:', e);
74
  }
75
  },
76
- [deleteSession, clearMessages]
77
  );
78
 
79
- const handleSelectSession = useCallback(
80
  (sessionId: string) => {
81
  switchSession(sessionId);
82
- // Clear plan and code panel when switching sessions
83
  setPlan([]);
84
- setPanelContent(null);
85
  onClose?.();
86
  },
87
- [switchSession, setPlan, setPanelContent, onClose]
88
  );
89
 
90
- const handleUndo = useCallback(async () => {
91
- if (!activeSessionId) return;
92
- try {
93
- await fetch(`/api/undo/${activeSessionId}`, { method: 'POST' });
94
- } catch (e) {
95
- console.error('Undo failed:', e);
96
- }
97
- }, [activeSessionId]);
98
 
99
- const formatTime = (dateString: string) => {
100
- return new Date(dateString).toLocaleTimeString([], { hour: '2-digit', minute: '2-digit' });
101
- };
102
 
103
  return (
104
- <Box className="sidebar" sx={{ height: '100%', display: 'flex', flexDirection: 'column', bgcolor: 'var(--panel)' }}>
105
- {/* Header - Aligned with AppLayout 60px */}
106
- <Box sx={{
107
- height: '60px',
108
- display: 'flex',
109
- alignItems: 'center',
110
- px: 2,
111
- borderBottom: '1px solid rgba(255,255,255,0.03)'
112
- }}>
113
- <Box className="brand-logo" sx={{ display: 'flex' }}>
114
- <img
115
- src="/hf-log-only-white.png"
116
- alt="HF Agent"
117
- style={{ height: '24px', objectFit: 'contain' }}
118
- />
119
- </Box>
 
 
 
 
 
 
120
  </Box>
121
 
122
- {/* Content */}
123
- <Box sx={{ flex: 1, display: 'flex', flexDirection: 'column', p: 2, overflow: 'hidden' }}>
124
- {/* System Info / Status */}
125
- <Box sx={{ mb: 2, display: 'flex', alignItems: 'center', gap: 1 }}>
126
- <StatusDiode connected={isConnected} />
127
- <Typography variant="caption" sx={{ color: 'var(--muted-text)', fontFamily: 'inherit' }}>
128
- {isConnected ? 'System Online' : 'Disconnected'}
129
- </Typography>
130
- </Box>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
- <Button
133
- fullWidth
134
- className="create-session"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  onClick={handleNewSession}
 
136
  sx={{
137
  display: 'inline-flex',
138
  alignItems: 'center',
139
- justifyContent: 'flex-start',
140
- gap: '10px',
141
- padding: '10px 14px',
142
- borderRadius: 'var(--radius-md)',
143
- border: '1px solid rgba(255,255,255,0.06)',
144
- bgcolor: 'transparent',
145
- color: 'var(--text)',
146
- fontWeight: 600,
147
- textTransform: 'none',
148
- mb: 3,
 
 
 
149
  '&:hover': {
150
- bgcolor: 'rgba(255,255,255,0.02)',
151
- border: '1px solid rgba(255,255,255,0.1)',
 
 
 
152
  },
153
- '&::before': {
154
- content: '""',
155
- width: '4px',
156
- height: '20px',
157
- background: 'linear-gradient(180deg, var(--accent-yellow), rgba(199,165,0,0.9))',
158
- borderRadius: '4px',
159
- }
160
  }}
161
  >
162
- New Session
163
- </Button>
164
-
165
- {/* Session List */}
166
- <Box sx={{ flex: 1, overflow: 'auto', mx: -1, px: 1 }}>
167
- <List disablePadding sx={{ display: 'flex', flexDirection: 'column', gap: 1 }}>
168
- {[...sessions].reverse().map((session, index) => {
169
- const sessionNumber = sessions.length - index;
170
- const isSelected = session.id === activeSessionId;
171
- return (
172
- <ListItem
173
- key={session.id}
174
- disablePadding
175
- className="session-item"
176
- onClick={() => handleSelectSession(session.id)}
177
- sx={{
178
- display: 'flex',
179
- alignItems: 'center',
180
- gap: '12px',
181
- padding: '10px',
182
- borderRadius: 'var(--radius-md)',
183
- bgcolor: isSelected ? 'rgba(255,255,255,0.05)' : 'transparent',
184
- cursor: 'pointer',
185
- transition: 'background 0.18s ease, transform 0.08s ease',
186
- '&:hover': {
187
- bgcolor: 'rgba(255,255,255,0.02)',
188
- transform: 'translateY(-1px)',
189
- },
190
- '& .delete-btn': {
191
- opacity: 0,
192
- transition: 'opacity 0.2s',
193
- },
194
- '&:hover .delete-btn': {
195
- opacity: 1,
196
- }
197
- }}
198
- >
199
- <Box sx={{ flex: 1, overflow: 'hidden' }}>
200
- <Typography variant="body2" sx={{ fontWeight: 500, color: 'var(--text)', whiteSpace: 'nowrap', overflow: 'hidden', textOverflow: 'ellipsis' }}>
201
- Session {String(sessionNumber).padStart(2, '0')}
202
- </Typography>
203
- <Box sx={{ display: 'flex', alignItems: 'center', gap: 1, mt: 0.5 }}>
204
- {session.isActive && <RunningIndicator />}
205
- <Typography className="time" variant="caption" sx={{ fontSize: '12px', color: 'var(--muted-text)' }}>
206
- {formatTime(session.createdAt)}
207
- </Typography>
208
- </Box>
209
- </Box>
210
-
211
- <IconButton
212
- className="delete-btn"
213
- size="small"
214
- onClick={(e) => handleDeleteSession(session.id, e)}
215
- sx={{ color: 'var(--muted-text)', '&:hover': { color: 'var(--accent-red)' } }}
216
- >
217
- <DeleteIcon fontSize="small" />
218
- </IconButton>
219
- </ListItem>
220
- );
221
- })}
222
- </List>
223
  </Box>
224
- </Box>
225
 
226
- {/* Footer */}
227
- <Box sx={{ p: 2, borderTop: '1px solid rgba(255,255,255,0.03)' }}>
228
- <Box sx={{ display: 'flex', alignItems: 'center', justifyContent: 'space-between' }}>
229
- <Typography variant="caption" className="small-note" sx={{ fontSize: '12px', color: 'var(--muted-text)' }}>
230
- {sessions.length} active
231
- </Typography>
232
- <Tooltip title="Undo last turn">
233
- <span>
234
- <IconButton
235
- onClick={handleUndo}
236
- disabled={!activeSessionId || isProcessing}
237
- size="small"
238
- sx={{ color: 'var(--muted-text)', '&:hover': { color: 'var(--text)' } }}
239
- >
240
- <UndoIcon fontSize="small" />
241
- </IconButton>
242
- </span>
243
- </Tooltip>
244
  </Box>
245
  </Box>
246
  </Box>
 
1
+ import { useCallback, useState } from 'react';
2
  import {
3
+ Alert,
4
  Box,
 
 
5
  IconButton,
6
  Typography,
7
+ CircularProgress,
8
+ Divider,
9
  } from '@mui/material';
10
+ import AddIcon from '@mui/icons-material/Add';
11
+ import DeleteOutlineIcon from '@mui/icons-material/DeleteOutline';
12
+ import ChatBubbleOutlineIcon from '@mui/icons-material/ChatBubbleOutline';
13
  import { useSessionStore } from '@/store/sessionStore';
14
  import { useAgentStore } from '@/store/agentStore';
15
+ import { apiFetch } from '@/utils/api';
16
 
17
  interface SessionSidebarProps {
18
  onClose?: () => void;
19
  }
20
 
21
+ /** Small coloured dot for connection status */
22
+ const StatusDot = ({ connected }: { connected: boolean }) => (
23
  <Box
24
  sx={{
25
+ width: 6,
26
+ height: 6,
27
  borderRadius: '50%',
28
+ bgcolor: connected ? 'var(--accent-green)' : 'var(--accent-red)',
29
+ boxShadow: connected ? '0 0 4px rgba(76,175,80,0.4)' : 'none',
30
+ flexShrink: 0,
31
  }}
32
  />
33
  );
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  export default function SessionSidebar({ onClose }: SessionSidebarProps) {
36
  const { sessions, activeSessionId, createSession, deleteSession, switchSession } =
37
  useSessionStore();
38
+ const { isConnected, setPlan, clearPanel } =
39
+ useAgentStore();
40
+ const [isCreatingSession, setIsCreatingSession] = useState(false);
41
+ const [capacityError, setCapacityError] = useState<string | null>(null);
42
+
43
+ // ── Handlers ──────────────────────────────────────────────────────
44
 
45
  const handleNewSession = useCallback(async () => {
46
+ if (isCreatingSession) return;
47
+ setIsCreatingSession(true);
48
+ setCapacityError(null);
49
  try {
50
+ const response = await apiFetch('/api/session', { method: 'POST' });
51
+ if (response.status === 503) {
52
+ const data = await response.json();
53
+ setCapacityError(data.detail || 'Server is at capacity.');
54
+ return;
55
+ }
56
  const data = await response.json();
57
  createSession(data.session_id);
 
58
  setPlan([]);
59
+ clearPanel();
60
  onClose?.();
61
+ } catch {
62
+ setCapacityError('Failed to create session.');
63
+ } finally {
64
+ setIsCreatingSession(false);
65
  }
66
+ }, [isCreatingSession, createSession, setPlan, clearPanel, onClose]);
67
 
68
+ const handleDelete = useCallback(
69
  async (sessionId: string, e: React.MouseEvent) => {
70
  e.stopPropagation();
71
  try {
72
+ await apiFetch(`/api/session/${sessionId}`, { method: 'DELETE' });
73
+ deleteSession(sessionId);
74
+ } catch {
75
+ // Delete locally even if backend fails (session may already be gone)
76
  deleteSession(sessionId);
 
 
 
77
  }
78
  },
79
+ [deleteSession],
80
  );
81
 
82
+ const handleSelect = useCallback(
83
  (sessionId: string) => {
84
  switchSession(sessionId);
 
85
  setPlan([]);
86
+ clearPanel();
87
  onClose?.();
88
  },
89
+ [switchSession, setPlan, clearPanel, onClose],
90
  );
91
 
92
+ const formatTime = (d: string) =>
93
+ new Date(d).toLocaleTimeString([], { hour: '2-digit', minute: '2-digit' });
 
 
 
 
 
 
94
 
95
+ // ── Render ────────────────────────────────────────────────────────
 
 
96
 
97
  return (
98
+ <Box
99
+ sx={{
100
+ height: '100%',
101
+ display: 'flex',
102
+ flexDirection: 'column',
103
+ bgcolor: 'var(--panel)',
104
+ }}
105
+ >
106
+ {/* ── Header ─────────────────────────────────────────────────── */}
107
+ <Box sx={{ px: 1.75, pt: 2, pb: 0 }}>
108
+ <Typography
109
+ variant="caption"
110
+ sx={{
111
+ color: 'var(--muted-text)',
112
+ fontSize: '0.65rem',
113
+ fontWeight: 600,
114
+ textTransform: 'uppercase',
115
+ letterSpacing: '0.08em',
116
+ }}
117
+ >
118
+ Recent chats
119
+ </Typography>
120
  </Box>
121
 
122
+ {/* ── Capacity error ─────────────────────────────────────────── */}
123
+ {capacityError && (
124
+ <Alert
125
+ severity="warning"
126
+ variant="outlined"
127
+ onClose={() => setCapacityError(null)}
128
+ sx={{
129
+ m: 1,
130
+ fontSize: '0.7rem',
131
+ py: 0.25,
132
+ '& .MuiAlert-message': { py: 0 },
133
+ borderColor: '#FF9D00',
134
+ color: 'var(--text)',
135
+ }}
136
+ >
137
+ {capacityError}
138
+ </Alert>
139
+ )}
140
+
141
+ {/* ── Session list ───────────────────────────────────────────── */}
142
+ <Box
143
+ sx={{
144
+ flex: 1,
145
+ overflow: 'auto',
146
+ py: 1,
147
+ // Thinner scrollbar
148
+ '&::-webkit-scrollbar': { width: 4 },
149
+ '&::-webkit-scrollbar-thumb': {
150
+ bgcolor: 'var(--scrollbar-thumb)',
151
+ borderRadius: 2,
152
+ },
153
+ }}
154
+ >
155
+ {sessions.length === 0 ? (
156
+ <Box
157
+ sx={{
158
+ display: 'flex',
159
+ flexDirection: 'column',
160
+ alignItems: 'center',
161
+ justifyContent: 'center',
162
+ py: 8,
163
+ px: 3,
164
+ gap: 1.5,
165
+ }}
166
+ >
167
+ <ChatBubbleOutlineIcon
168
+ sx={{ fontSize: 28, color: 'var(--muted-text)', opacity: 0.25 }}
169
+ />
170
+ <Typography
171
+ variant="caption"
172
+ sx={{
173
+ color: 'var(--muted-text)',
174
+ opacity: 0.5,
175
+ textAlign: 'center',
176
+ lineHeight: 1.5,
177
+ fontSize: '0.72rem',
178
+ }}
179
+ >
180
+ No sessions yet
181
+ </Typography>
182
+ </Box>
183
+ ) : (
184
+ [...sessions].reverse().map((session, index) => {
185
+ const num = sessions.length - index;
186
+ const isSelected = session.id === activeSessionId;
187
 
188
+ return (
189
+ <Box
190
+ key={session.id}
191
+ onClick={() => handleSelect(session.id)}
192
+ sx={{
193
+ display: 'flex',
194
+ alignItems: 'center',
195
+ gap: 1,
196
+ px: 1.5,
197
+ py: 0.875,
198
+ mx: 0.75,
199
+ borderRadius: '10px',
200
+ cursor: 'pointer',
201
+ transition: 'background-color 0.12s ease',
202
+ bgcolor: isSelected
203
+ ? 'var(--hover-bg)'
204
+ : 'transparent',
205
+ '&:hover': {
206
+ bgcolor: 'var(--hover-bg)',
207
+ },
208
+ '& .delete-btn': {
209
+ opacity: 0,
210
+ transition: 'opacity 0.12s',
211
+ },
212
+ '&:hover .delete-btn': {
213
+ opacity: 1,
214
+ },
215
+ }}
216
+ >
217
+ <ChatBubbleOutlineIcon
218
+ sx={{
219
+ fontSize: 15,
220
+ color: isSelected ? 'var(--text)' : 'var(--muted-text)',
221
+ opacity: isSelected ? 0.8 : 0.4,
222
+ flexShrink: 0,
223
+ }}
224
+ />
225
+
226
+ <Box sx={{ flex: 1, minWidth: 0 }}>
227
+ <Typography
228
+ variant="body2"
229
+ sx={{
230
+ fontWeight: isSelected ? 600 : 400,
231
+ color: 'var(--text)',
232
+ fontSize: '0.84rem',
233
+ lineHeight: 1.4,
234
+ whiteSpace: 'nowrap',
235
+ overflow: 'hidden',
236
+ textOverflow: 'ellipsis',
237
+ }}
238
+ >
239
+ {session.title.startsWith('Chat ') ? `Session ${String(num).padStart(2, '0')}` : session.title}
240
+ </Typography>
241
+ <Typography
242
+ variant="caption"
243
+ sx={{
244
+ color: 'var(--muted-text)',
245
+ fontSize: '0.65rem',
246
+ lineHeight: 1.2,
247
+ }}
248
+ >
249
+ {formatTime(session.createdAt)}
250
+ </Typography>
251
+ </Box>
252
+
253
+ <IconButton
254
+ className="delete-btn"
255
+ size="small"
256
+ onClick={(e) => handleDelete(session.id, e)}
257
+ sx={{
258
+ color: 'var(--muted-text)',
259
+ width: 26,
260
+ height: 26,
261
+ flexShrink: 0,
262
+ '&:hover': { color: 'var(--accent-red)', bgcolor: 'rgba(244,67,54,0.08)' },
263
+ }}
264
+ >
265
+ <DeleteOutlineIcon sx={{ fontSize: 15 }} />
266
+ </IconButton>
267
+ </Box>
268
+ );
269
+ })
270
+ )}
271
+ </Box>
272
+
273
+ {/* ── Footer: New Session + status ──────────────────────────── */}
274
+ <Divider sx={{ opacity: 0.5 }} />
275
+ <Box
276
+ sx={{
277
+ px: 1.5,
278
+ py: 1.5,
279
+ display: 'flex',
280
+ flexDirection: 'column',
281
+ gap: 1,
282
+ flexShrink: 0,
283
+ }}
284
+ >
285
+ <Box
286
+ component="button"
287
  onClick={handleNewSession}
288
+ disabled={isCreatingSession}
289
  sx={{
290
  display: 'inline-flex',
291
  alignItems: 'center',
292
+ justifyContent: 'center',
293
+ gap: 0.75,
294
+ width: '100%',
295
+ px: 1.5,
296
+ py: 1.25,
297
+ border: 'none',
298
+ borderRadius: '10px',
299
+ bgcolor: '#FF9D00',
300
+ color: '#000',
301
+ fontSize: '0.85rem',
302
+ fontWeight: 700,
303
+ cursor: 'pointer',
304
+ transition: 'all 0.12s ease',
305
  '&:hover': {
306
+ bgcolor: '#FFB340',
307
+ },
308
+ '&:disabled': {
309
+ opacity: 0.5,
310
+ cursor: 'not-allowed',
311
  },
 
 
 
 
 
 
 
312
  }}
313
  >
314
+ {isCreatingSession ? (
315
+ <>
316
+ <CircularProgress size={12} sx={{ color: '#000' }} />
317
+ Creating...
318
+ </>
319
+ ) : (
320
+ <>
321
+ <AddIcon sx={{ fontSize: 16 }} />
322
+ New Session
323
+ </>
324
+ )}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  </Box>
 
326
 
327
+ <Box
328
+ sx={{
329
+ display: 'flex',
330
+ alignItems: 'center',
331
+ justifyContent: 'center',
332
+ gap: 0.5,
333
+ }}
334
+ >
335
+ <StatusDot connected={isConnected} />
336
+ <Typography
337
+ variant="caption"
338
+ sx={{ color: 'var(--muted-text)', fontSize: '0.62rem', letterSpacing: '0.02em' }}
339
+ >
340
+ {sessions.length} session{sessions.length !== 1 ? 's' : ''} &middot; Backend {isConnected ? 'online' : 'offline'}
341
+ </Typography>
 
 
 
342
  </Box>
343
  </Box>
344
  </Box>
frontend/src/components/WelcomeScreen/WelcomeScreen.tsx ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useState, useCallback } from 'react';
2
+ import {
3
+ Box,
4
+ Typography,
5
+ Button,
6
+ CircularProgress,
7
+ Alert,
8
+ } from '@mui/material';
9
+ import OpenInNewIcon from '@mui/icons-material/OpenInNew';
10
+ import { useSessionStore } from '@/store/sessionStore';
11
+ import { useAgentStore } from '@/store/agentStore';
12
+ import { apiFetch } from '@/utils/api';
13
+ import { isInIframe, triggerLogin } from '@/hooks/useAuth';
14
+
15
+ /** HF brand orange */
16
+ const HF_ORANGE = '#FF9D00';
17
+
18
+ export default function WelcomeScreen() {
19
+ const { createSession } = useSessionStore();
20
+ const { setPlan, clearPanel, user } = useAgentStore();
21
+ const [isCreating, setIsCreating] = useState(false);
22
+ const [error, setError] = useState<string | null>(null);
23
+
24
+ const inIframe = isInIframe();
25
+ const isAuthenticated = user?.authenticated;
26
+ const isDevUser = user?.username === 'dev';
27
+
28
+ const handleStart = useCallback(async () => {
29
+ if (isCreating) return;
30
+
31
+ // Not authenticated and not dev β†’ need to login
32
+ if (!isAuthenticated && !isDevUser) {
33
+ // In iframe: can't redirect (cookies blocked) β€” user needs to open in new tab
34
+ // This shouldn't happen because we show a different button in iframe
35
+ // But just in case:
36
+ if (inIframe) return;
37
+ triggerLogin();
38
+ return;
39
+ }
40
+
41
+ setIsCreating(true);
42
+ setError(null);
43
+
44
+ try {
45
+ const response = await apiFetch('/api/session', { method: 'POST' });
46
+ if (response.status === 503) {
47
+ const data = await response.json();
48
+ setError(data.detail || 'Server is at capacity. Please try again later.');
49
+ return;
50
+ }
51
+ if (response.status === 401) {
52
+ triggerLogin();
53
+ return;
54
+ }
55
+ if (!response.ok) {
56
+ setError('Failed to create session. Please try again.');
57
+ return;
58
+ }
59
+ const data = await response.json();
60
+ createSession(data.session_id);
61
+ setPlan([]);
62
+ clearPanel();
63
+ } catch {
64
+ // Redirect may throw β€” ignore
65
+ } finally {
66
+ setIsCreating(false);
67
+ }
68
+ }, [isCreating, createSession, setPlan, clearPanel, isAuthenticated, isDevUser, inIframe]);
69
+
70
+ // Build the direct Space URL for the "open in new tab" link
71
+ const spaceHost = typeof window !== 'undefined'
72
+ ? window.location.hostname.includes('.hf.space')
73
+ ? window.location.origin
74
+ : `https://smolagents-ml-agent.hf.space`
75
+ : '';
76
+
77
+ return (
78
+ <Box
79
+ sx={{
80
+ width: '100%',
81
+ height: '100%',
82
+ display: 'flex',
83
+ flexDirection: 'column',
84
+ alignItems: 'center',
85
+ justifyContent: 'center',
86
+ background: 'var(--body-gradient)',
87
+ py: 8,
88
+ }}
89
+ >
90
+ {/* HF Logo */}
91
+ <Box
92
+ component="img"
93
+ src="https://huggingface.co/front/assets/huggingface_logo-noborder.svg"
94
+ alt="Hugging Face"
95
+ sx={{ width: 96, height: 96, mb: 3, display: 'block' }}
96
+ />
97
+
98
+ {/* Title */}
99
+ <Typography
100
+ variant="h2"
101
+ sx={{
102
+ fontWeight: 800,
103
+ color: 'var(--text)',
104
+ mb: 1.5,
105
+ letterSpacing: '-0.02em',
106
+ fontSize: { xs: '2rem', md: '2.8rem' },
107
+ }}
108
+ >
109
+ HF Agent
110
+ </Typography>
111
+
112
+ {/* Description */}
113
+ <Typography
114
+ variant="body1"
115
+ sx={{
116
+ color: 'var(--muted-text)',
117
+ maxWidth: 520,
118
+ mb: 5,
119
+ lineHeight: 1.8,
120
+ fontSize: '0.95rem',
121
+ textAlign: 'center',
122
+ px: 2,
123
+ '& strong': { color: 'var(--text)', fontWeight: 600 },
124
+ }}
125
+ >
126
+ A general-purpose AI agent for <strong>machine learning engineering</strong>.
127
+ It browses <strong>Hugging Face documentation</strong>, manages{' '}
128
+ <strong>repositories</strong>, launches <strong>training jobs</strong>,
129
+ and explores <strong>datasets</strong> β€” all through natural conversation.
130
+ </Typography>
131
+
132
+ {/* Action button β€” depends on context */}
133
+ {inIframe && !isAuthenticated && !isDevUser ? (
134
+ // In iframe + not logged in β†’ link to open Space directly
135
+ <Button
136
+ variant="contained"
137
+ size="large"
138
+ component="a"
139
+ href={spaceHost}
140
+ target="_blank"
141
+ rel="noopener noreferrer"
142
+ endIcon={<OpenInNewIcon />}
143
+ sx={{
144
+ px: 5,
145
+ py: 1.5,
146
+ fontSize: '1rem',
147
+ fontWeight: 700,
148
+ textTransform: 'none',
149
+ borderRadius: '12px',
150
+ bgcolor: HF_ORANGE,
151
+ color: '#000',
152
+ boxShadow: '0 4px 24px rgba(255, 157, 0, 0.3)',
153
+ textDecoration: 'none',
154
+ '&:hover': {
155
+ bgcolor: '#FFB340',
156
+ boxShadow: '0 6px 32px rgba(255, 157, 0, 0.45)',
157
+ },
158
+ }}
159
+ >
160
+ Open HF Agent
161
+ </Button>
162
+ ) : !isAuthenticated && !isDevUser ? (
163
+ // Direct access + not logged in β†’ sign in button
164
+ <Button
165
+ variant="contained"
166
+ size="large"
167
+ onClick={() => triggerLogin()}
168
+ sx={{
169
+ px: 5,
170
+ py: 1.5,
171
+ fontSize: '1rem',
172
+ fontWeight: 700,
173
+ textTransform: 'none',
174
+ borderRadius: '12px',
175
+ bgcolor: HF_ORANGE,
176
+ color: '#000',
177
+ boxShadow: '0 4px 24px rgba(255, 157, 0, 0.3)',
178
+ '&:hover': {
179
+ bgcolor: '#FFB340',
180
+ boxShadow: '0 6px 32px rgba(255, 157, 0, 0.45)',
181
+ },
182
+ }}
183
+ >
184
+ Sign in with Hugging Face
185
+ </Button>
186
+ ) : (
187
+ // Authenticated or dev β†’ start session
188
+ <Button
189
+ variant="contained"
190
+ size="large"
191
+ onClick={handleStart}
192
+ disabled={isCreating}
193
+ startIcon={
194
+ isCreating ? <CircularProgress size={20} color="inherit" /> : null
195
+ }
196
+ sx={{
197
+ px: 5,
198
+ py: 1.5,
199
+ fontSize: '1rem',
200
+ fontWeight: 700,
201
+ textTransform: 'none',
202
+ borderRadius: '12px',
203
+ bgcolor: HF_ORANGE,
204
+ color: '#000',
205
+ boxShadow: '0 4px 24px rgba(255, 157, 0, 0.3)',
206
+ '&:hover': {
207
+ bgcolor: '#FFB340',
208
+ boxShadow: '0 6px 32px rgba(255, 157, 0, 0.45)',
209
+ },
210
+ '&.Mui-disabled': {
211
+ bgcolor: 'rgba(255, 157, 0, 0.35)',
212
+ color: 'rgba(0,0,0,0.45)',
213
+ },
214
+ }}
215
+ >
216
+ {isCreating ? 'Initializing...' : 'Start Session'}
217
+ </Button>
218
+ )}
219
+
220
+ {/* Error */}
221
+ {error && (
222
+ <Alert
223
+ severity="warning"
224
+ variant="outlined"
225
+ onClose={() => setError(null)}
226
+ sx={{
227
+ mt: 3,
228
+ maxWidth: 400,
229
+ fontSize: '0.8rem',
230
+ borderColor: HF_ORANGE,
231
+ color: 'var(--text)',
232
+ }}
233
+ >
234
+ {error}
235
+ </Alert>
236
+ )}
237
+
238
+ {/* Footnote */}
239
+ <Typography
240
+ variant="caption"
241
+ sx={{ mt: 5, color: 'var(--muted-text)', opacity: 0.5, fontSize: '0.7rem' }}
242
+ >
243
+ Conversations are stored locally in your browser.
244
+ </Typography>
245
+ </Box>
246
+ );
247
+ }
frontend/src/hooks/useAgentChat.ts ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Central hook wiring the Vercel AI SDK's useChat with our custom
3
+ * WebSocketChatTransport. Replaces the old useAgentWebSocket + agentStore
4
+ * message management.
5
+ */
6
+ import { useCallback, useEffect, useMemo, useRef } from 'react';
7
+ import { useChat } from '@ai-sdk/react';
8
+ import type { UIMessage } from 'ai';
9
+ import { WebSocketChatTransport, type SideChannelCallbacks } from '@/lib/ws-chat-transport';
10
+ import { loadMessages, saveMessages } from '@/lib/chat-message-store';
11
+ import { apiFetch } from '@/utils/api';
12
+ import { useAgentStore } from '@/store/agentStore';
13
+ import { useSessionStore } from '@/store/sessionStore';
14
+ import { useLayoutStore } from '@/store/layoutStore';
15
+ import { logger } from '@/utils/logger';
16
+
17
+ interface UseAgentChatOptions {
18
+ sessionId: string | null;
19
+ onReady?: () => void;
20
+ onError?: (error: string) => void;
21
+ onSessionDead?: (sessionId: string) => void;
22
+ }
23
+
24
+ export function useAgentChat({ sessionId, onReady, onError, onSessionDead }: UseAgentChatOptions) {
25
+ const callbacksRef = useRef({ onReady, onError, onSessionDead });
26
+ callbacksRef.current = { onReady, onError, onSessionDead };
27
+
28
+ const {
29
+ setProcessing,
30
+ setConnected,
31
+ setActivityStatus,
32
+ setError,
33
+ setPanel,
34
+ setPanelOutput,
35
+ } = useAgentStore();
36
+
37
+ const { setRightPanelOpen, setLeftSidebarOpen } = useLayoutStore();
38
+ const { setSessionActive } = useSessionStore();
39
+
40
+ // ── Build side-channel callbacks (stable ref) ────────────────────
41
+ const sideChannel = useMemo<SideChannelCallbacks>(
42
+ () => ({
43
+ onReady: () => {
44
+ setConnected(true);
45
+ setProcessing(false);
46
+ if (sessionId) setSessionActive(sessionId, true);
47
+ callbacksRef.current.onReady?.();
48
+ },
49
+ onShutdown: () => {
50
+ setConnected(false);
51
+ setProcessing(false);
52
+ },
53
+ onError: (error: string) => {
54
+ setError(error);
55
+ setProcessing(false);
56
+ callbacksRef.current.onError?.(error);
57
+ },
58
+ onProcessing: () => {
59
+ setProcessing(true);
60
+ setActivityStatus({ type: 'thinking' });
61
+ },
62
+ onProcessingDone: () => {
63
+ setProcessing(false);
64
+ },
65
+ onUndoComplete: () => {
66
+ setProcessing(false);
67
+ // Remove the last turn (user msg + assistant response) from useChat state
68
+ const setMsgs = chatActionsRef.current.setMessages;
69
+ const msgs = chatActionsRef.current.messages;
70
+ if (setMsgs && msgs.length > 0) {
71
+ let lastUserIdx = -1;
72
+ for (let i = msgs.length - 1; i >= 0; i--) {
73
+ if (msgs[i].role === 'user') { lastUserIdx = i; break; }
74
+ }
75
+ const updated = lastUserIdx > 0 ? msgs.slice(0, lastUserIdx) : [];
76
+ setMsgs(updated);
77
+ if (sessionId) saveMessages(sessionId, updated);
78
+ }
79
+ },
80
+ onCompacted: (oldTokens: number, newTokens: number) => {
81
+ logger.log(`Context compacted: ${oldTokens} β†’ ${newTokens} tokens`);
82
+ },
83
+ onPlanUpdate: (plan) => {
84
+ useAgentStore.getState().setPlan(plan as Array<{ id: string; content: string; status: 'pending' | 'in_progress' | 'completed' }>);
85
+ if (!useLayoutStore.getState().isRightPanelOpen) {
86
+ setRightPanelOpen(true);
87
+ }
88
+ },
89
+ onToolLog: (tool: string, log: string) => {
90
+ if (tool === 'hf_jobs') {
91
+ const state = useAgentStore.getState();
92
+ const existingOutput = state.panelData?.output?.content || '';
93
+ const newContent = existingOutput
94
+ ? existingOutput + '\n' + log
95
+ : '--- Job execution started ---\n' + log;
96
+
97
+ setPanelOutput({ content: newContent, language: 'text' });
98
+
99
+ if (!useLayoutStore.getState().isRightPanelOpen) {
100
+ setRightPanelOpen(true);
101
+ }
102
+ }
103
+ },
104
+ onConnectionChange: (connected: boolean) => {
105
+ setConnected(connected);
106
+ },
107
+ onSessionDead: (deadSessionId: string) => {
108
+ logger.warn(`Session ${deadSessionId} dead, removing`);
109
+ callbacksRef.current.onSessionDead?.(deadSessionId);
110
+ },
111
+ onApprovalRequired: (tools) => {
112
+ if (!tools.length) return;
113
+ setActivityStatus({ type: 'waiting-approval' });
114
+ const firstTool = tools[0];
115
+ const args = firstTool.arguments as Record<string, string | undefined>;
116
+
117
+ if (firstTool.tool === 'hf_jobs' && args.script) {
118
+ setPanel(
119
+ { title: 'Script', script: { content: args.script, language: 'python' }, parameters: firstTool.arguments as Record<string, unknown> },
120
+ 'script',
121
+ true,
122
+ );
123
+ } else if (firstTool.tool === 'hf_repo_files' && args.content) {
124
+ const filename = args.path || 'file';
125
+ setPanel({
126
+ title: filename.split('/').pop() || 'Content',
127
+ script: { content: args.content, language: filename.endsWith('.py') ? 'python' : 'text' },
128
+ parameters: firstTool.arguments as Record<string, unknown>,
129
+ });
130
+ } else {
131
+ setPanel({
132
+ title: firstTool.tool,
133
+ output: { content: JSON.stringify(firstTool.arguments, null, 2), language: 'json' },
134
+ }, 'output');
135
+ }
136
+
137
+ setRightPanelOpen(true);
138
+ setLeftSidebarOpen(false);
139
+ },
140
+ onToolCallPanel: (toolName: string, args: Record<string, unknown>) => {
141
+ if (toolName === 'hf_jobs' && args.operation && args.script) {
142
+ setPanel(
143
+ { title: 'Script', script: { content: String(args.script), language: 'python' }, parameters: args },
144
+ 'script',
145
+ );
146
+ setRightPanelOpen(true);
147
+ setLeftSidebarOpen(false);
148
+ } else if (toolName === 'hf_repo_files' && args.operation === 'upload' && args.content) {
149
+ setPanel({
150
+ title: `File Upload: ${String(args.path || 'unnamed')}`,
151
+ script: { content: String(args.content), language: String(args.path || '').endsWith('.py') ? 'python' : 'text' },
152
+ parameters: args,
153
+ });
154
+ setRightPanelOpen(true);
155
+ setLeftSidebarOpen(false);
156
+ }
157
+ },
158
+ onToolOutputPanel: (toolName: string, _toolCallId: string, output: string, success: boolean) => {
159
+ if (toolName === 'hf_jobs' && output) {
160
+ setPanelOutput({ content: output, language: 'markdown' });
161
+ if (!success) useAgentStore.getState().setPanelView('output');
162
+ }
163
+ },
164
+ onStreaming: () => {
165
+ setActivityStatus({ type: 'streaming' });
166
+ },
167
+ onToolRunning: (toolName: string) => {
168
+ setActivityStatus({ type: 'tool', toolName });
169
+ },
170
+ }),
171
+ // Zustand setters are stable
172
+ // eslint-disable-next-line react-hooks/exhaustive-deps
173
+ [sessionId],
174
+ );
175
+
176
+ // ── Create transport (single stable instance for the lifetime of this hook) ──
177
+ const transportRef = useRef<WebSocketChatTransport | null>(null);
178
+ if (!transportRef.current) {
179
+ transportRef.current = new WebSocketChatTransport({ sideChannel });
180
+ }
181
+
182
+ // Keep side-channel callbacks in sync (they capture sessionId)
183
+ useEffect(() => {
184
+ transportRef.current?.updateSideChannel(sideChannel);
185
+ }, [sideChannel]);
186
+
187
+ // Connect / disconnect WebSocket when session changes
188
+ useEffect(() => {
189
+ transportRef.current?.connectToSession(sessionId);
190
+ return () => {
191
+ transportRef.current?.connectToSession(null);
192
+ };
193
+ }, [sessionId]);
194
+
195
+ // ── Restore persisted messages for this session ─────────────────
196
+ const initialMessages = useMemo(
197
+ () => (sessionId ? loadMessages(sessionId) : []),
198
+ [sessionId],
199
+ );
200
+
201
+ // ── Ref for chat actions (used by sideChannel callbacks created before chat) ──
202
+ const chatActionsRef = useRef<{
203
+ setMessages: ((msgs: UIMessage[]) => void) | null;
204
+ messages: UIMessage[];
205
+ }>({ setMessages: null, messages: [] });
206
+
207
+ // ── useChat from Vercel AI SDK ───────────────────────────────────
208
+ const chat = useChat({
209
+ id: sessionId || '__no_session__',
210
+ messages: initialMessages,
211
+ transport: transportRef.current!,
212
+ experimental_throttle: 80,
213
+ onFinish: ({ messages, isAbort, isError }) => {
214
+ if (isAbort || isError) return;
215
+ if (sessionId && messages.length > 0) {
216
+ saveMessages(sessionId, messages);
217
+ }
218
+ },
219
+ onError: (error) => {
220
+ logger.error('useChat error:', error);
221
+ setError(error.message);
222
+ setProcessing(false);
223
+ },
224
+ });
225
+
226
+ // Keep chatActionsRef in sync every render
227
+ chatActionsRef.current.setMessages = chat.setMessages;
228
+ chatActionsRef.current.messages = chat.messages;
229
+
230
+ // ── Persist messages on every user send (onFinish covers assistant turns) ──
231
+ const prevLenRef = useRef(initialMessages.length);
232
+ useEffect(() => {
233
+ if (!sessionId || chat.messages.length === 0) return;
234
+ if (chat.messages.length !== prevLenRef.current) {
235
+ prevLenRef.current = chat.messages.length;
236
+ saveMessages(sessionId, chat.messages);
237
+ }
238
+ }, [sessionId, chat.messages]);
239
+
240
+ // ── Undo last turn (calls backend + syncs useChat + localStorage) ──
241
+ const undoLastTurn = useCallback(async () => {
242
+ if (!sessionId) return;
243
+ try {
244
+ const res = await apiFetch(`/api/undo/${sessionId}`, { method: 'POST' });
245
+ if (!res.ok) {
246
+ logger.error('Undo API returned', res.status);
247
+ return;
248
+ }
249
+ } catch (e) {
250
+ logger.error('Undo failed:', e);
251
+ }
252
+ // Backend will also send undo_complete, but we apply optimistically
253
+ // so the UI updates immediately.
254
+ }, [sessionId]);
255
+
256
+ // ── Convenience: approve tools via transport ─────────────────────
257
+ const approveTools = useCallback(
258
+ async (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null; edited_script?: string | null }>) => {
259
+ if (!sessionId || !transportRef.current) return false;
260
+ const ok = await transportRef.current.approveTools(sessionId, approvals);
261
+ if (ok) {
262
+ const hasApproved = approvals.some(a => a.approved);
263
+ if (hasApproved) setProcessing(true);
264
+ }
265
+ return ok;
266
+ },
267
+ [sessionId, setProcessing],
268
+ );
269
+
270
+ return {
271
+ messages: chat.messages,
272
+ sendMessage: chat.sendMessage,
273
+ status: chat.status,
274
+ undoLastTurn,
275
+ approveTools,
276
+ transport: transportRef.current,
277
+ };
278
+ }
frontend/src/hooks/useAgentWebSocket.ts DELETED
@@ -1,503 +0,0 @@
1
- import { useCallback, useEffect, useRef } from 'react';
2
- import { useAgentStore } from '@/store/agentStore';
3
- import { useSessionStore } from '@/store/sessionStore';
4
- import { useLayoutStore } from '@/store/layoutStore';
5
- import type { AgentEvent } from '@/types/events';
6
- import type { Message, TraceLog } from '@/types/agent';
7
-
8
- const WS_RECONNECT_DELAY = 1000;
9
- const WS_MAX_RECONNECT_DELAY = 30000;
10
-
11
- interface UseAgentWebSocketOptions {
12
- sessionId: string | null;
13
- onReady?: () => void;
14
- onError?: (error: string) => void;
15
- }
16
-
17
- export function useAgentWebSocket({
18
- sessionId,
19
- onReady,
20
- onError,
21
- }: UseAgentWebSocketOptions) {
22
- const wsRef = useRef<WebSocket | null>(null);
23
- const reconnectTimeoutRef = useRef<number | null>(null);
24
- const reconnectDelayRef = useRef(WS_RECONNECT_DELAY);
25
-
26
- const {
27
- addMessage,
28
- updateMessage,
29
- setProcessing,
30
- setConnected,
31
- setPendingApprovals,
32
- setError,
33
- addTraceLog,
34
- updateTraceLog,
35
- clearTraceLogs,
36
- setPanelContent,
37
- setPanelTab,
38
- setActivePanelTab,
39
- clearPanelTabs,
40
- setPlan,
41
- setCurrentTurnMessageId,
42
- updateCurrentTurnTrace,
43
- } = useAgentStore();
44
-
45
- const { setRightPanelOpen, setLeftSidebarOpen } = useLayoutStore();
46
-
47
- const { setSessionActive } = useSessionStore();
48
-
49
- const handleEvent = useCallback(
50
- (event: AgentEvent) => {
51
- if (!sessionId) return;
52
-
53
- switch (event.event_type) {
54
- case 'ready':
55
- setConnected(true);
56
- setProcessing(false);
57
- setSessionActive(sessionId, true);
58
- onReady?.();
59
- break;
60
-
61
- case 'processing':
62
- setProcessing(true);
63
- clearTraceLogs();
64
- // Don't clear panel tabs here - they should persist during approval flow
65
- // Tabs will be cleared when a new tool_call sets up new content
66
- setCurrentTurnMessageId(null); // Start a new turn
67
- break;
68
-
69
- case 'assistant_message': {
70
- const content = (event.data?.content as string) || '';
71
- const currentTrace = useAgentStore.getState().traceLogs;
72
- const currentTurnMsgId = useAgentStore.getState().currentTurnMessageId;
73
-
74
- if (currentTurnMsgId) {
75
- // Update existing message - add segments chronologically
76
- const messages = useAgentStore.getState().getMessages(sessionId);
77
- const existingMsg = messages.find(m => m.id === currentTurnMsgId);
78
-
79
- if (existingMsg) {
80
- const segments = existingMsg.segments ? [...existingMsg.segments] : [];
81
-
82
- // If there are pending traces, add them as a tools segment first
83
- if (currentTrace.length > 0) {
84
- segments.push({ type: 'tools', tools: [...currentTrace] });
85
- clearTraceLogs();
86
- }
87
-
88
- // Add the new text segment
89
- if (content) {
90
- segments.push({ type: 'text', content });
91
- }
92
-
93
- updateMessage(sessionId, currentTurnMsgId, {
94
- content: existingMsg.content + '\n\n' + content,
95
- segments,
96
- });
97
- }
98
- } else {
99
- // Create new message
100
- const messageId = `msg_${Date.now()}`;
101
- const segments: Array<{ type: 'text' | 'tools'; content?: string; tools?: typeof currentTrace }> = [];
102
-
103
- // Add any pending traces first
104
- if (currentTrace.length > 0) {
105
- segments.push({ type: 'tools', tools: [...currentTrace] });
106
- clearTraceLogs();
107
- }
108
-
109
- // Add the text
110
- if (content) {
111
- segments.push({ type: 'text', content });
112
- }
113
-
114
- const message: Message = {
115
- id: messageId,
116
- role: 'assistant',
117
- content,
118
- timestamp: new Date().toISOString(),
119
- segments,
120
- };
121
- addMessage(sessionId, message);
122
- setCurrentTurnMessageId(messageId);
123
- }
124
- break;
125
- }
126
-
127
- case 'tool_call': {
128
- const toolName = (event.data?.tool as string) || 'unknown';
129
- const args = (event.data?.arguments as Record<string, any>) || {};
130
-
131
- // Don't display plan_tool in trace logs (it shows up elsewhere in the UI)
132
- if (toolName !== 'plan_tool') {
133
- const log: TraceLog = {
134
- id: `tool_${Date.now()}`,
135
- type: 'call',
136
- text: `Agent is executing ${toolName}...`,
137
- tool: toolName,
138
- timestamp: new Date().toISOString(),
139
- completed: false,
140
- // Store args for auto-exec message creation later
141
- args: toolName === 'hf_jobs' ? args : undefined,
142
- };
143
- addTraceLog(log);
144
- // Update the current turn message's trace in real-time
145
- updateCurrentTurnTrace(sessionId);
146
- }
147
-
148
- // Auto-expand Right Panel for specific tools
149
- if (toolName === 'hf_jobs' && (args.operation === 'run' || args.operation === 'scheduled run') && args.script) {
150
- // Clear any existing tabs from previous jobs before setting new script
151
- clearPanelTabs();
152
- // Use tab system for jobs - add script tab immediately
153
- setPanelTab({
154
- id: 'script',
155
- title: 'Script',
156
- content: args.script,
157
- language: 'python',
158
- parameters: args
159
- });
160
- setActivePanelTab('script');
161
- setRightPanelOpen(true);
162
- setLeftSidebarOpen(false);
163
- } else if (toolName === 'hf_repo_files' && args.operation === 'upload' && args.content) {
164
- setPanelContent({
165
- title: `File Upload: ${args.path || 'unnamed'}`,
166
- content: args.content,
167
- parameters: args,
168
- language: args.path?.endsWith('.py') ? 'python' : undefined
169
- });
170
- setRightPanelOpen(true);
171
- setLeftSidebarOpen(false);
172
- }
173
-
174
- console.log('Tool call:', toolName, args);
175
- break;
176
- }
177
-
178
- case 'tool_output': {
179
- const toolName = (event.data?.tool as string) || 'unknown';
180
- const output = (event.data?.output as string) || '';
181
- const success = event.data?.success as boolean;
182
-
183
- // Mark the corresponding trace log as completed and store the output
184
- updateTraceLog(toolName, { completed: true, output, success });
185
- // Update the current turn message's trace in real-time
186
- updateCurrentTurnTrace(sessionId);
187
-
188
- // Special handling for hf_jobs - update or create job message with output
189
- if (toolName === 'hf_jobs') {
190
- const messages = useAgentStore.getState().getMessages(sessionId);
191
- const traceLogs = useAgentStore.getState().traceLogs;
192
-
193
- // Find existing approval message for this job
194
- let jobMsg = [...messages].reverse().find(m => m.approval);
195
-
196
- if (!jobMsg) {
197
- // No approval message exists - this was an auto-executed job
198
- // Create a job execution message so user can see results
199
- const jobTrace = [...traceLogs].reverse().find(t => t.tool === 'hf_jobs');
200
- const args = jobTrace?.args || {};
201
-
202
- const autoExecMessage: Message = {
203
- id: `msg_auto_${Date.now()}`,
204
- role: 'assistant',
205
- content: '',
206
- timestamp: new Date().toISOString(),
207
- approval: {
208
- status: 'approved', // Auto-approved (no user action needed)
209
- batch: {
210
- tools: [{
211
- tool: toolName,
212
- arguments: args,
213
- tool_call_id: `auto_${Date.now()}`
214
- }],
215
- count: 1
216
- }
217
- },
218
- toolOutput: output
219
- };
220
- addMessage(sessionId, autoExecMessage);
221
- console.log('Created auto-exec message with tool output:', toolName);
222
- } else {
223
- // Update existing approval message
224
- const currentOutput = jobMsg.toolOutput || '';
225
- const newOutput = currentOutput ? currentOutput + '\n\n' + output : output;
226
-
227
- useAgentStore.getState().updateMessage(sessionId, jobMsg.id, {
228
- toolOutput: newOutput
229
- });
230
- console.log('Updated job message with tool output:', toolName);
231
- }
232
- }
233
-
234
- // Don't create message bubbles for tool outputs - they only show in trace logs
235
- console.log('Tool output:', toolName, success);
236
- break;
237
- }
238
-
239
- case 'tool_log': {
240
- const toolName = (event.data?.tool as string) || 'unknown';
241
- const log = (event.data?.log as string) || '';
242
-
243
- if (toolName === 'hf_jobs') {
244
- const currentTabs = useAgentStore.getState().panelTabs;
245
- const logsTab = currentTabs.find(t => t.id === 'logs');
246
-
247
- // Append to existing logs tab or create new one
248
- const newContent = logsTab
249
- ? logsTab.content + '\n' + log
250
- : '--- Job execution started ---\n' + log;
251
-
252
- setPanelTab({
253
- id: 'logs',
254
- title: 'Logs',
255
- content: newContent,
256
- language: 'text'
257
- });
258
-
259
- // Auto-switch to logs tab when logs start streaming
260
- setActivePanelTab('logs');
261
-
262
- if (!useLayoutStore.getState().isRightPanelOpen) {
263
- setRightPanelOpen(true);
264
- }
265
- }
266
- break;
267
- }
268
-
269
- case 'plan_update': {
270
- const plan = (event.data?.plan as any[]) || [];
271
- setPlan(plan);
272
- if (!useLayoutStore.getState().isRightPanelOpen) {
273
- setRightPanelOpen(true);
274
- }
275
- break;
276
- }
277
-
278
- case 'approval_required': {
279
- const tools = event.data?.tools as Array<{
280
- tool: string;
281
- arguments: Record<string, unknown>;
282
- tool_call_id: string;
283
- }>;
284
- const count = (event.data?.count as number) || 0;
285
-
286
- // Create a persistent message for the approval request
287
- const message: Message = {
288
- id: `msg_approval_${Date.now()}`,
289
- role: 'assistant',
290
- content: '', // Content is handled by the approval UI
291
- timestamp: new Date().toISOString(),
292
- approval: {
293
- status: 'pending',
294
- batch: { tools, count }
295
- }
296
- };
297
- addMessage(sessionId, message);
298
-
299
- // Show the first tool's content in the panel so users see what they're approving
300
- if (tools && tools.length > 0) {
301
- const firstTool = tools[0];
302
- const args = firstTool.arguments as Record<string, any>;
303
-
304
- clearPanelTabs();
305
-
306
- if (firstTool.tool === 'hf_jobs' && args.script) {
307
- setPanelTab({
308
- id: 'script',
309
- title: 'Script',
310
- content: args.script,
311
- language: 'python',
312
- parameters: args
313
- });
314
- setActivePanelTab('script');
315
- } else if (firstTool.tool === 'hf_repo_files' && args.content) {
316
- const filename = args.path || 'file';
317
- const isPython = filename.endsWith('.py');
318
- setPanelTab({
319
- id: 'content',
320
- title: filename.split('/').pop() || 'Content',
321
- content: args.content,
322
- language: isPython ? 'python' : 'text',
323
- parameters: args
324
- });
325
- setActivePanelTab('content');
326
- } else {
327
- // For other tools, show args as JSON
328
- setPanelTab({
329
- id: 'args',
330
- title: firstTool.tool,
331
- content: JSON.stringify(args, null, 2),
332
- language: 'json',
333
- parameters: args
334
- });
335
- setActivePanelTab('args');
336
- }
337
-
338
- setRightPanelOpen(true);
339
- setLeftSidebarOpen(false);
340
- }
341
-
342
- // Clear currentTurnMessageId so subsequent assistant_message events create a new message below the approval
343
- setCurrentTurnMessageId(null);
344
-
345
- // We don't set pendingApprovals in the global store anymore as the message handles the UI
346
- setPendingApprovals(null);
347
- setProcessing(false);
348
- break;
349
- }
350
-
351
- case 'turn_complete':
352
- setProcessing(false);
353
- setCurrentTurnMessageId(null); // Clear the current turn
354
- break;
355
-
356
- case 'compacted': {
357
- const oldTokens = event.data?.old_tokens as number;
358
- const newTokens = event.data?.new_tokens as number;
359
- console.log(`Context compacted: ${oldTokens} -> ${newTokens} tokens`);
360
- break;
361
- }
362
-
363
- case 'error': {
364
- const errorMsg = (event.data?.error as string) || 'Unknown error';
365
- setError(errorMsg);
366
- setProcessing(false);
367
- onError?.(errorMsg);
368
- break;
369
- }
370
-
371
- case 'shutdown':
372
- setConnected(false);
373
- setProcessing(false);
374
- break;
375
-
376
- case 'interrupted':
377
- setProcessing(false);
378
- break;
379
-
380
- case 'undo_complete':
381
- // Could remove last messages from store
382
- break;
383
-
384
- default:
385
- console.log('Unknown event:', event);
386
- }
387
- },
388
- // Zustand setters are stable, so we don't need them in deps
389
- // eslint-disable-next-line react-hooks/exhaustive-deps
390
- [sessionId, onReady, onError]
391
- );
392
-
393
- const connect = useCallback(() => {
394
- if (!sessionId) return;
395
-
396
- // Don't connect if already connected or connecting
397
- if (wsRef.current?.readyState === WebSocket.OPEN ||
398
- wsRef.current?.readyState === WebSocket.CONNECTING) {
399
- return;
400
- }
401
-
402
- // Connect directly to backend (Vite doesn't proxy WebSockets)
403
- const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
404
- // In development, connect directly to backend port 7860
405
- // In production, use the same host
406
- const isDev = import.meta.env.DEV;
407
- const host = isDev ? '127.0.0.1:7860' : window.location.host;
408
- const wsUrl = `${protocol}//${host}/api/ws/${sessionId}`;
409
-
410
- console.log('Connecting to WebSocket:', wsUrl);
411
- const ws = new WebSocket(wsUrl);
412
-
413
- ws.onopen = () => {
414
- console.log('WebSocket connected');
415
- setConnected(true);
416
- reconnectDelayRef.current = WS_RECONNECT_DELAY;
417
- };
418
-
419
- ws.onmessage = (event) => {
420
- try {
421
- const data = JSON.parse(event.data) as AgentEvent;
422
- handleEvent(data);
423
- } catch (e) {
424
- console.error('Failed to parse WebSocket message:', e);
425
- }
426
- };
427
-
428
- ws.onerror = (error) => {
429
- console.error('WebSocket error:', error);
430
- };
431
-
432
- ws.onclose = (event) => {
433
- console.log('WebSocket closed', event.code, event.reason);
434
- setConnected(false);
435
-
436
- // Only reconnect if it wasn't a normal closure and session still exists
437
- if (event.code !== 1000 && sessionId) {
438
- // Attempt to reconnect with exponential backoff
439
- if (reconnectTimeoutRef.current) {
440
- clearTimeout(reconnectTimeoutRef.current);
441
- }
442
- reconnectTimeoutRef.current = window.setTimeout(() => {
443
- reconnectDelayRef.current = Math.min(
444
- reconnectDelayRef.current * 2,
445
- WS_MAX_RECONNECT_DELAY
446
- );
447
- connect();
448
- }, reconnectDelayRef.current);
449
- }
450
- };
451
-
452
- wsRef.current = ws;
453
- }, [sessionId, handleEvent]);
454
-
455
- const disconnect = useCallback(() => {
456
- if (reconnectTimeoutRef.current) {
457
- clearTimeout(reconnectTimeoutRef.current);
458
- reconnectTimeoutRef.current = null;
459
- }
460
- if (wsRef.current) {
461
- wsRef.current.close();
462
- wsRef.current = null;
463
- }
464
- setConnected(false);
465
- }, []);
466
-
467
- const sendPing = useCallback(() => {
468
- if (wsRef.current?.readyState === WebSocket.OPEN) {
469
- wsRef.current.send(JSON.stringify({ type: 'ping' }));
470
- }
471
- }, []);
472
-
473
- // Connect when sessionId changes (with a small delay to ensure session is ready)
474
- useEffect(() => {
475
- if (!sessionId) {
476
- disconnect();
477
- return;
478
- }
479
-
480
- // Small delay to ensure session is fully created on backend
481
- const timeoutId = setTimeout(() => {
482
- connect();
483
- }, 100);
484
-
485
- return () => {
486
- clearTimeout(timeoutId);
487
- disconnect();
488
- };
489
- // eslint-disable-next-line react-hooks/exhaustive-deps
490
- }, [sessionId]);
491
-
492
- // Heartbeat
493
- useEffect(() => {
494
- const interval = setInterval(sendPing, 30000);
495
- return () => clearInterval(interval);
496
- }, [sendPing]);
497
-
498
- return {
499
- isConnected: wsRef.current?.readyState === WebSocket.OPEN,
500
- connect,
501
- disconnect,
502
- };
503
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
frontend/src/hooks/useAuth.ts ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Authentication hook β€” simple server-side OAuth.
3
+ *
4
+ * - Hors iframe: /auth/login redirect (cookies work fine)
5
+ * - Dans iframe: show "Open in full page" link
6
+ *
7
+ * Token is stored via HttpOnly cookie by the backend.
8
+ * In dev mode (no OAUTH_CLIENT_ID), auth is bypassed.
9
+ */
10
+
11
+ import { useEffect } from 'react';
12
+ import { useAgentStore } from '@/store/agentStore';
13
+ import { logger } from '@/utils/logger';
14
+
15
+ /** Check if we're running inside an iframe. */
16
+ export function isInIframe(): boolean {
17
+ try {
18
+ return window.top !== window.self;
19
+ } catch {
20
+ return true; // SecurityError = cross-origin iframe
21
+ }
22
+ }
23
+
24
+ /** Redirect to the server-side OAuth login. */
25
+ export function triggerLogin(): void {
26
+ window.location.href = '/auth/login';
27
+ }
28
+
29
+ /**
30
+ * Hook: on mount, check if user is authenticated.
31
+ * Sets user in the agent store.
32
+ */
33
+ export function useAuth() {
34
+ const setUser = useAgentStore((s) => s.setUser);
35
+
36
+ useEffect(() => {
37
+ let cancelled = false;
38
+
39
+ async function checkAuth() {
40
+ try {
41
+ // Check if user is already authenticated (cookie-based)
42
+ const response = await fetch('/auth/me', { credentials: 'include' });
43
+ if (response.ok) {
44
+ const data = await response.json();
45
+ if (!cancelled && data.authenticated) {
46
+ setUser({
47
+ authenticated: true,
48
+ username: data.username,
49
+ name: data.name,
50
+ picture: data.picture,
51
+ });
52
+ logger.log('Authenticated as', data.username);
53
+ return;
54
+ }
55
+ }
56
+
57
+ // Not authenticated β€” check if auth is enabled
58
+ const statusRes = await fetch('/auth/status', { credentials: 'include' });
59
+ const statusData = await statusRes.json();
60
+ if (!statusData.auth_enabled) {
61
+ // Dev mode β€” no OAuth configured
62
+ if (!cancelled) setUser({ authenticated: true, username: 'dev' });
63
+ return;
64
+ }
65
+
66
+ // Auth enabled but not logged in β€” welcome screen will handle it
67
+ if (!cancelled) setUser(null);
68
+ } catch {
69
+ // Backend unreachable β€” assume dev mode
70
+ if (!cancelled) setUser({ authenticated: true, username: 'dev' });
71
+ }
72
+ }
73
+
74
+ checkAuth();
75
+ return () => { cancelled = true; };
76
+ }, [setUser]);
77
+ }
frontend/src/lib/chat-message-store.ts ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Lightweight localStorage persistence for UIMessage arrays,
3
+ * keyed by session ID.
4
+ *
5
+ * Uses the same storage namespace (`hf-agent-messages`) that the
6
+ * old Zustand-based store used, so existing data is compatible.
7
+ */
8
+ import type { UIMessage } from 'ai';
9
+ import { logger } from '@/utils/logger';
10
+
11
+ const STORAGE_KEY = 'hf-agent-messages';
12
+ const MAX_SESSIONS = 50;
13
+
14
+ type MessagesMap = Record<string, UIMessage[]>;
15
+
16
+ function readAll(): MessagesMap {
17
+ try {
18
+ const raw = localStorage.getItem(STORAGE_KEY);
19
+ if (!raw) return {};
20
+ const parsed = JSON.parse(raw);
21
+ // Legacy format was { messagesBySession: {...} }
22
+ if (parsed.messagesBySession) return parsed.messagesBySession;
23
+ // New flat format
24
+ if (typeof parsed === 'object' && !Array.isArray(parsed)) return parsed;
25
+ return {};
26
+ } catch {
27
+ return {};
28
+ }
29
+ }
30
+
31
+ function writeAll(map: MessagesMap): void {
32
+ try {
33
+ localStorage.setItem(STORAGE_KEY, JSON.stringify(map));
34
+ } catch (e) {
35
+ logger.warn('Failed to persist messages:', e);
36
+ }
37
+ }
38
+
39
+ export function loadMessages(sessionId: string): UIMessage[] {
40
+ const map = readAll();
41
+ return map[sessionId] ?? [];
42
+ }
43
+
44
+ export function saveMessages(sessionId: string, messages: UIMessage[]): void {
45
+ const map = readAll();
46
+ map[sessionId] = messages;
47
+
48
+ // Evict oldest sessions if we exceed the cap
49
+ const keys = Object.keys(map);
50
+ if (keys.length > MAX_SESSIONS) {
51
+ const toRemove = keys.slice(0, keys.length - MAX_SESSIONS);
52
+ for (const k of toRemove) delete map[k];
53
+ }
54
+
55
+ writeAll(map);
56
+ }
57
+
58
+ export function deleteMessages(sessionId: string): void {
59
+ const map = readAll();
60
+ delete map[sessionId];
61
+ writeAll(map);
62
+ }
frontend/src/lib/ws-chat-transport.ts ADDED
@@ -0,0 +1,593 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Custom ChatTransport that bridges our WebSocket-based backend protocol
3
+ * to the Vercel AI SDK's UIMessageChunk streaming interface.
4
+ *
5
+ * The backend stays unchanged β€” this adapter translates WebSocket events
6
+ * into the chunk types that useChat() expects.
7
+ */
8
+ import type { ChatTransport, UIMessage, UIMessageChunk, ChatRequestOptions } from 'ai';
9
+ import { apiFetch, getWebSocketUrl } from '@/utils/api';
10
+ import { logger } from '@/utils/logger';
11
+ import type { AgentEvent } from '@/types/events';
12
+ import { useAgentStore } from '@/store/agentStore';
13
+
14
+ // ---------------------------------------------------------------------------
15
+ // Side-channel callback interface (non-chat events forwarded to the store)
16
+ // ---------------------------------------------------------------------------
17
+ export interface SideChannelCallbacks {
18
+ onReady: () => void;
19
+ onShutdown: () => void;
20
+ onError: (error: string) => void;
21
+ onProcessing: () => void;
22
+ onProcessingDone: () => void;
23
+ onUndoComplete: () => void;
24
+ onCompacted: (oldTokens: number, newTokens: number) => void;
25
+ onPlanUpdate: (plan: Array<{ id: string; content: string; status: string }>) => void;
26
+ onToolLog: (tool: string, log: string) => void;
27
+ onConnectionChange: (connected: boolean) => void;
28
+ onSessionDead: (sessionId: string) => void;
29
+ /** Called when approval_required arrives β€” lets the store manage panels */
30
+ onApprovalRequired: (tools: Array<{ tool: string; arguments: Record<string, unknown>; tool_call_id: string }>) => void;
31
+ /** Called when a tool_call arrives with panel-relevant args */
32
+ onToolCallPanel: (tool: string, args: Record<string, unknown>) => void;
33
+ /** Called when tool_output arrives with panel-relevant data */
34
+ onToolOutputPanel: (tool: string, toolCallId: string, output: string, success: boolean) => void;
35
+ /** Called when assistant text starts streaming */
36
+ onStreaming: () => void;
37
+ /** Called when a tool starts running (non-plan) */
38
+ onToolRunning: (toolName: string) => void;
39
+ }
40
+
41
+ // ---------------------------------------------------------------------------
42
+ // Transport options
43
+ // ---------------------------------------------------------------------------
44
+ export interface WebSocketChatTransportOptions {
45
+ sideChannel: SideChannelCallbacks;
46
+ }
47
+
48
+ // ---------------------------------------------------------------------------
49
+ // Constants
50
+ // ---------------------------------------------------------------------------
51
+ const WS_RECONNECT_DELAY = 1000;
52
+ const WS_MAX_RECONNECT_DELAY = 30000;
53
+ const WS_MAX_RETRIES = 5;
54
+ const WS_PING_INTERVAL = 30000;
55
+
56
+ let partIdCounter = 0;
57
+ function nextPartId(prefix: string): string {
58
+ return `${prefix}-${Date.now()}-${++partIdCounter}`;
59
+ }
60
+
61
+ // ---------------------------------------------------------------------------
62
+ // Transport implementation
63
+ // ---------------------------------------------------------------------------
64
+ export class WebSocketChatTransport implements ChatTransport<UIMessage> {
65
+ private ws: WebSocket | null = null;
66
+ private currentSessionId: string | null = null;
67
+ private sideChannel: SideChannelCallbacks;
68
+
69
+ private streamController: ReadableStreamDefaultController<UIMessageChunk> | null = null;
70
+ private streamGeneration = 0;
71
+ private abortedGeneration = 0;
72
+ private textPartId: string | null = null;
73
+ private awaitingProcessing = false;
74
+
75
+ private connectTimeout: ReturnType<typeof setTimeout> | null = null;
76
+ private reconnectTimeout: ReturnType<typeof setTimeout> | null = null;
77
+ private reconnectDelay = WS_RECONNECT_DELAY;
78
+ private retries = 0;
79
+ private pingInterval: ReturnType<typeof setInterval> | null = null;
80
+ private boundVisibilityHandler: (() => void) | null = null;
81
+ private wasHidden = false;
82
+
83
+ constructor({ sideChannel }: WebSocketChatTransportOptions) {
84
+ this.sideChannel = sideChannel;
85
+ this.setupVisibilityHandler();
86
+ }
87
+
88
+ private setupVisibilityHandler(): void {
89
+ this.boundVisibilityHandler = () => {
90
+ if (document.visibilityState === 'hidden') {
91
+ this.wasHidden = true;
92
+ return;
93
+ }
94
+
95
+ if (document.visibilityState === 'visible' && this.currentSessionId) {
96
+ const wsState = this.ws?.readyState;
97
+ if (wsState !== WebSocket.OPEN && wsState !== WebSocket.CONNECTING) {
98
+ logger.log('Tab visible: WS is dead, reconnecting immediately');
99
+ this.retries = 0;
100
+ this.reconnectDelay = WS_RECONNECT_DELAY;
101
+ this.createWebSocket(this.currentSessionId);
102
+
103
+ if (this.wasHidden) {
104
+ const store = useAgentStore.getState();
105
+ if (store.isProcessing) {
106
+ logger.log('Tab visible after WS drop: resetting stale processing state');
107
+ store.setProcessing(false);
108
+ this.closeActiveStream();
109
+ }
110
+ }
111
+ } else if (wsState === WebSocket.OPEN) {
112
+ this.ws!.send(JSON.stringify({ type: 'ping' }));
113
+ }
114
+ this.wasHidden = false;
115
+ }
116
+ };
117
+ document.addEventListener('visibilitychange', this.boundVisibilityHandler);
118
+ }
119
+
120
+ /** Update side-channel callbacks (e.g. when sessionId changes). */
121
+ updateSideChannel(sideChannel: SideChannelCallbacks): void {
122
+ this.sideChannel = sideChannel;
123
+ }
124
+
125
+ // ── Public API ──────────────────────────────────────────────────────
126
+
127
+ /** Connect (or reconnect) to a session's WebSocket. */
128
+ connectToSession(sessionId: string | null): void {
129
+ if (this.connectTimeout) {
130
+ clearTimeout(this.connectTimeout);
131
+ this.connectTimeout = null;
132
+ }
133
+ this.disconnectWebSocket();
134
+ this.currentSessionId = sessionId;
135
+ if (sessionId) {
136
+ this.retries = 0;
137
+ this.reconnectDelay = WS_RECONNECT_DELAY;
138
+ this.connectTimeout = setTimeout(() => {
139
+ this.connectTimeout = null;
140
+ if (this.currentSessionId === sessionId) {
141
+ this.createWebSocket(sessionId);
142
+ }
143
+ }, 100);
144
+ }
145
+ }
146
+
147
+ /** Approve / reject tools. Called directly from the UI. */
148
+ async approveTools(
149
+ sessionId: string,
150
+ approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null; edited_script?: string | null }>,
151
+ ): Promise<boolean> {
152
+ try {
153
+ const res = await apiFetch('/api/approve', {
154
+ method: 'POST',
155
+ body: JSON.stringify({ session_id: sessionId, approvals }),
156
+ });
157
+ return res.ok;
158
+ } catch (e) {
159
+ logger.error('Approval request failed:', e);
160
+ return false;
161
+ }
162
+ }
163
+
164
+ /** Clean up everything. */
165
+ destroy(): void {
166
+ if (this.connectTimeout) {
167
+ clearTimeout(this.connectTimeout);
168
+ this.connectTimeout = null;
169
+ }
170
+ if (this.boundVisibilityHandler) {
171
+ document.removeEventListener('visibilitychange', this.boundVisibilityHandler);
172
+ this.boundVisibilityHandler = null;
173
+ }
174
+ this.disconnectWebSocket();
175
+ this.closeActiveStream();
176
+ }
177
+
178
+ // ── ChatTransport interface ─────────────────────────────────────────
179
+
180
+ async sendMessages(
181
+ options: {
182
+ trigger: 'submit-message' | 'regenerate-message';
183
+ chatId: string;
184
+ messageId: string | undefined;
185
+ messages: UIMessage[];
186
+ abortSignal: AbortSignal | undefined;
187
+ } & ChatRequestOptions,
188
+ ): Promise<ReadableStream<UIMessageChunk>> {
189
+ const sessionId = options.chatId;
190
+
191
+ // Close any previously active stream (e.g. user sent new msg during approval)
192
+ this.closeActiveStream();
193
+
194
+ // Track generation to protect against late cancel from a stale stream
195
+ const gen = ++this.streamGeneration;
196
+ logger.log(`sendMessages: gen=${gen}, awaitingProcessing=${this.awaitingProcessing}, abortedGen=${this.abortedGeneration}`);
197
+
198
+ // Wire up abort signal to interrupt the backend and close the stream
199
+ if (options.abortSignal) {
200
+ const onAbort = () => {
201
+ if (this.streamGeneration !== gen) return;
202
+ logger.log(`Stream aborted by user (gen=${gen})`);
203
+ this.interruptBackend(sessionId);
204
+ this.endTextPart();
205
+ if (this.streamController) {
206
+ this.enqueue({ type: 'finish-step' });
207
+ this.enqueue({ type: 'finish', finishReason: 'stop' });
208
+ this.closeActiveStream();
209
+ }
210
+ this.awaitingProcessing = true;
211
+ this.abortedGeneration = this.streamGeneration;
212
+ logger.log(`Abort complete: awaitingProcessing=true, abortedGen=${this.abortedGeneration}`);
213
+ this.sideChannel.onProcessingDone();
214
+ };
215
+ if (options.abortSignal.aborted) {
216
+ onAbort();
217
+ } else {
218
+ options.abortSignal.addEventListener('abort', onAbort, { once: true });
219
+ }
220
+ }
221
+
222
+ // Create the stream BEFORE the POST so WebSocket events arriving
223
+ // while the HTTP request is in-flight are captured immediately.
224
+ const stream = new ReadableStream<UIMessageChunk>({
225
+ start: (controller) => {
226
+ this.streamController = controller;
227
+ this.textPartId = null;
228
+ },
229
+ cancel: () => {
230
+ if (this.streamGeneration === gen) {
231
+ this.streamController = null;
232
+ this.textPartId = null;
233
+ }
234
+ },
235
+ });
236
+
237
+ // Extract the latest user text from the messages array
238
+ const lastUserMsg = [...options.messages].reverse().find(m => m.role === 'user');
239
+ const text = lastUserMsg
240
+ ? lastUserMsg.parts
241
+ .filter((p): p is Extract<typeof p, { type: 'text' }> => p.type === 'text')
242
+ .map(p => p.text)
243
+ .join('')
244
+ : '';
245
+
246
+ // POST to the existing backend endpoint
247
+ try {
248
+ await apiFetch('/api/submit', {
249
+ method: 'POST',
250
+ body: JSON.stringify({ session_id: sessionId, text }),
251
+ });
252
+ } catch (e) {
253
+ logger.error('Submit failed:', e);
254
+ this.enqueue({ type: 'error', errorText: 'Failed to send message' });
255
+ this.closeActiveStream();
256
+ }
257
+
258
+ return stream;
259
+ }
260
+
261
+ async reconnectToStream(): Promise<ReadableStream<UIMessageChunk> | null> {
262
+ return null;
263
+ }
264
+
265
+ /** Ask the backend to interrupt the current generation. Fire-and-forget. */
266
+ private interruptBackend(sessionId: string): void {
267
+ apiFetch(`/api/interrupt/${sessionId}`, { method: 'POST' }).catch((e) =>
268
+ logger.warn('Interrupt request failed:', e),
269
+ );
270
+ }
271
+
272
+ // ── WebSocket lifecycle ─────────────────────────────────────────────
273
+
274
+ private createWebSocket(sessionId: string): void {
275
+ if (this.ws?.readyState === WebSocket.OPEN || this.ws?.readyState === WebSocket.CONNECTING) {
276
+ return;
277
+ }
278
+
279
+ const wsUrl = getWebSocketUrl(sessionId);
280
+ logger.log('WS transport connecting:', wsUrl);
281
+ const ws = new WebSocket(wsUrl);
282
+
283
+ ws.onopen = () => {
284
+ logger.log('WS transport connected');
285
+ this.sideChannel.onConnectionChange(true);
286
+ this.reconnectDelay = WS_RECONNECT_DELAY;
287
+ this.retries = 0;
288
+ this.startPing();
289
+ };
290
+
291
+ ws.onmessage = (evt) => {
292
+ try {
293
+ const raw = JSON.parse(evt.data);
294
+ if (raw.type === 'pong') return;
295
+ this.handleEvent(raw as AgentEvent);
296
+ } catch (e) {
297
+ logger.error('WS parse error:', e);
298
+ }
299
+ };
300
+
301
+ ws.onerror = (err) => logger.error('WS error:', err);
302
+
303
+ ws.onclose = (evt) => {
304
+ logger.log('WS closed', evt.code, evt.reason);
305
+ this.sideChannel.onConnectionChange(false);
306
+ this.stopPing();
307
+
308
+ const noRetry = [1000, 4001, 4003, 4004];
309
+ if (evt.code === 4004 && sessionId) {
310
+ this.sideChannel.onSessionDead(sessionId);
311
+ return;
312
+ }
313
+ if (!noRetry.includes(evt.code) && this.currentSessionId === sessionId) {
314
+ this.retries += 1;
315
+ if (this.retries > WS_MAX_RETRIES) {
316
+ logger.warn('WS max retries reached');
317
+ this.sideChannel.onSessionDead(sessionId);
318
+ return;
319
+ }
320
+ this.reconnectTimeout = setTimeout(() => {
321
+ this.reconnectDelay = Math.min(this.reconnectDelay * 2, WS_MAX_RECONNECT_DELAY);
322
+ this.createWebSocket(sessionId);
323
+ }, this.reconnectDelay);
324
+ }
325
+ };
326
+
327
+ this.ws = ws;
328
+ }
329
+
330
+ private disconnectWebSocket(): void {
331
+ if (this.reconnectTimeout) {
332
+ clearTimeout(this.reconnectTimeout);
333
+ this.reconnectTimeout = null;
334
+ }
335
+ this.stopPing();
336
+ if (this.ws) {
337
+ this.ws.close();
338
+ this.ws = null;
339
+ }
340
+ this.sideChannel.onConnectionChange(false);
341
+ }
342
+
343
+ private startPing(): void {
344
+ this.stopPing();
345
+ this.pingInterval = setInterval(() => {
346
+ if (this.ws?.readyState === WebSocket.OPEN) {
347
+ this.ws.send(JSON.stringify({ type: 'ping' }));
348
+ }
349
+ }, WS_PING_INTERVAL);
350
+ }
351
+
352
+ private stopPing(): void {
353
+ if (this.pingInterval) {
354
+ clearInterval(this.pingInterval);
355
+ this.pingInterval = null;
356
+ }
357
+ }
358
+
359
+ // ── Stream helpers ──────────────────────────────────────────────────
360
+
361
+ private closeActiveStream(): void {
362
+ if (this.streamController) {
363
+ try {
364
+ this.streamController.close();
365
+ } catch {
366
+ // already closed
367
+ }
368
+ this.streamController = null;
369
+ this.textPartId = null;
370
+ }
371
+ }
372
+
373
+ private enqueue(chunk: UIMessageChunk): void {
374
+ try {
375
+ this.streamController?.enqueue(chunk);
376
+ } catch {
377
+ // stream already closed
378
+ }
379
+ }
380
+
381
+ private endTextPart(): void {
382
+ if (this.textPartId) {
383
+ this.enqueue({ type: 'text-end', id: this.textPartId });
384
+ this.textPartId = null;
385
+ }
386
+ }
387
+
388
+ // ── Event β†’ UIMessageChunk mapping ──────────────────────────────────
389
+
390
+ private static readonly STREAM_EVENTS = new Set([
391
+ 'assistant_chunk', 'assistant_stream_end', 'assistant_message',
392
+ 'tool_call', 'tool_output', 'approval_required', 'tool_state_change',
393
+ 'turn_complete', 'error',
394
+ ]);
395
+
396
+ private handleEvent(event: AgentEvent): void {
397
+ // After an abort, ignore stale stream events until the next 'processing'
398
+ if (this.awaitingProcessing && WebSocketChatTransport.STREAM_EVENTS.has(event.event_type)) {
399
+ logger.log(`Filtering stale "${event.event_type}" (gen=${this.streamGeneration}, aborted=${this.abortedGeneration})`);
400
+ return;
401
+ }
402
+
403
+ switch (event.event_type) {
404
+ // ── Side-channel only events ────────────────────────────────
405
+ case 'ready':
406
+ this.sideChannel.onReady();
407
+ break;
408
+
409
+ case 'shutdown':
410
+ this.sideChannel.onShutdown();
411
+ this.closeActiveStream();
412
+ break;
413
+
414
+ case 'interrupted':
415
+ // Don't close the stream here β€” the abort handler already did, and
416
+ // a new stream for the next user message may already exist.
417
+ // Closing here would destroy the NEWER stream, causing the next
418
+ // response to be silently dropped.
419
+ this.sideChannel.onProcessingDone();
420
+ break;
421
+
422
+ case 'undo_complete':
423
+ this.endTextPart();
424
+ this.closeActiveStream();
425
+ this.sideChannel.onUndoComplete();
426
+ break;
427
+
428
+ case 'compacted':
429
+ this.sideChannel.onCompacted(
430
+ (event.data?.old_tokens as number) || 0,
431
+ (event.data?.new_tokens as number) || 0,
432
+ );
433
+ break;
434
+
435
+ case 'plan_update':
436
+ this.sideChannel.onPlanUpdate(
437
+ (event.data?.plan as Array<{ id: string; content: string; status: string }>) || [],
438
+ );
439
+ break;
440
+
441
+ case 'tool_log':
442
+ this.sideChannel.onToolLog(
443
+ (event.data?.tool as string) || '',
444
+ (event.data?.log as string) || '',
445
+ );
446
+ break;
447
+
448
+ // ── Chat stream events ──────────────────────────────────────
449
+ case 'processing':
450
+ if (this.awaitingProcessing) {
451
+ if (this.streamGeneration <= this.abortedGeneration) {
452
+ logger.log(`Ignoring stale "processing" (gen=${this.streamGeneration} <= aborted=${this.abortedGeneration})`);
453
+ break;
454
+ }
455
+ logger.log(`Accepting "processing" for new generation (gen=${this.streamGeneration}, aborted=${this.abortedGeneration})`);
456
+ this.awaitingProcessing = false;
457
+ }
458
+ this.sideChannel.onProcessing();
459
+ if (this.streamController) {
460
+ this.enqueue({
461
+ type: 'start',
462
+ messageMetadata: { createdAt: new Date().toISOString() },
463
+ });
464
+ this.enqueue({ type: 'start-step' });
465
+ }
466
+ break;
467
+
468
+ case 'assistant_chunk': {
469
+ const delta = (event.data?.content as string) || '';
470
+ if (!delta || !this.streamController) break;
471
+
472
+ if (!this.textPartId) {
473
+ this.textPartId = nextPartId('text');
474
+ this.enqueue({ type: 'text-start', id: this.textPartId });
475
+ this.sideChannel.onStreaming();
476
+ }
477
+ this.enqueue({ type: 'text-delta', id: this.textPartId, delta });
478
+ break;
479
+ }
480
+
481
+ case 'assistant_stream_end':
482
+ this.endTextPart();
483
+ break;
484
+
485
+ case 'assistant_message': {
486
+ const content = (event.data?.content as string) || '';
487
+ if (!content || !this.streamController) break;
488
+ const id = nextPartId('text');
489
+ this.enqueue({ type: 'text-start', id });
490
+ this.enqueue({ type: 'text-delta', id, delta: content });
491
+ this.enqueue({ type: 'text-end', id });
492
+ break;
493
+ }
494
+
495
+ case 'tool_call': {
496
+ if (!this.streamController) break;
497
+ const toolName = (event.data?.tool as string) || 'unknown';
498
+ const toolCallId = (event.data?.tool_call_id as string) || '';
499
+ const args = (event.data?.arguments as Record<string, unknown>) || {};
500
+
501
+ if (toolName === 'plan_tool') break;
502
+
503
+ this.endTextPart();
504
+ this.enqueue({ type: 'tool-input-start', toolCallId, toolName, dynamic: true });
505
+ this.enqueue({ type: 'tool-input-available', toolCallId, toolName, input: args, dynamic: true });
506
+
507
+ this.sideChannel.onToolRunning(toolName);
508
+ this.sideChannel.onToolCallPanel(toolName, args as Record<string, unknown>);
509
+ break;
510
+ }
511
+
512
+ case 'tool_output': {
513
+ if (!this.streamController) break;
514
+ const toolCallId = (event.data?.tool_call_id as string) || '';
515
+ const output = (event.data?.output as string) || '';
516
+ const success = event.data?.success as boolean;
517
+ const toolName = (event.data?.tool as string) || '';
518
+
519
+ if (toolName === 'plan_tool' || toolCallId.startsWith('plan_tool')) break;
520
+
521
+ if (success) {
522
+ this.enqueue({ type: 'tool-output-available', toolCallId, output, dynamic: true });
523
+ } else {
524
+ this.enqueue({ type: 'tool-output-error', toolCallId, errorText: output, dynamic: true });
525
+ }
526
+
527
+ this.sideChannel.onToolOutputPanel(toolName, toolCallId, output, success);
528
+ break;
529
+ }
530
+
531
+ case 'approval_required': {
532
+ const tools = event.data?.tools as Array<{
533
+ tool: string;
534
+ arguments: Record<string, unknown>;
535
+ tool_call_id: string;
536
+ }>;
537
+ if (!tools || !this.streamController) break;
538
+
539
+ this.endTextPart();
540
+
541
+ for (const t of tools) {
542
+ this.enqueue({ type: 'tool-input-start', toolCallId: t.tool_call_id, toolName: t.tool, dynamic: true });
543
+ this.enqueue({ type: 'tool-input-available', toolCallId: t.tool_call_id, toolName: t.tool, input: t.arguments, dynamic: true });
544
+ this.enqueue({ type: 'tool-approval-request', approvalId: `approval-${t.tool_call_id}`, toolCallId: t.tool_call_id });
545
+ }
546
+
547
+ this.sideChannel.onApprovalRequired(tools);
548
+ this.sideChannel.onProcessingDone();
549
+ break;
550
+ }
551
+
552
+ case 'tool_state_change': {
553
+ const tcId = (event.data?.tool_call_id as string) || '';
554
+ const state = (event.data?.state as string) || '';
555
+ const jobUrl = (event.data?.jobUrl as string) || undefined;
556
+
557
+ if (tcId.startsWith('plan_tool')) break;
558
+
559
+ if (jobUrl && tcId) {
560
+ useAgentStore.getState().setJobUrl(tcId, jobUrl);
561
+ }
562
+
563
+ if (this.streamController && (state === 'rejected' || state === 'abandoned')) {
564
+ this.enqueue({ type: 'tool-output-denied', toolCallId: tcId });
565
+ }
566
+ break;
567
+ }
568
+
569
+ case 'turn_complete':
570
+ this.endTextPart();
571
+ if (this.streamController) {
572
+ this.enqueue({ type: 'finish-step' });
573
+ this.enqueue({ type: 'finish', finishReason: 'stop' });
574
+ this.closeActiveStream();
575
+ }
576
+ this.sideChannel.onProcessingDone();
577
+ break;
578
+
579
+ case 'error': {
580
+ const errorMsg = (event.data?.error as string) || 'Unknown error';
581
+ this.sideChannel.onError(errorMsg);
582
+ if (this.streamController) {
583
+ this.enqueue({ type: 'error', errorText: errorMsg });
584
+ }
585
+ this.sideChannel.onProcessingDone();
586
+ break;
587
+ }
588
+
589
+ default:
590
+ logger.log('WS transport: unknown event', event);
591
+ }
592
+ }
593
+ }
frontend/src/main.tsx CHANGED
@@ -3,13 +3,23 @@ import { createRoot } from 'react-dom/client';
3
  import { ThemeProvider } from '@mui/material/styles';
4
  import CssBaseline from '@mui/material/CssBaseline';
5
  import App from './App';
6
- import theme from './theme';
 
7
 
8
- createRoot(document.getElementById('root')!).render(
9
- <StrictMode>
 
 
 
10
  <ThemeProvider theme={theme}>
11
  <CssBaseline />
12
  <App />
13
  </ThemeProvider>
 
 
 
 
 
 
14
  </StrictMode>
15
  );
 
3
  import { ThemeProvider } from '@mui/material/styles';
4
  import CssBaseline from '@mui/material/CssBaseline';
5
  import App from './App';
6
+ import { darkTheme, lightTheme } from './theme';
7
+ import { useLayoutStore } from './store/layoutStore';
8
 
9
+ function Root() {
10
+ const themeMode = useLayoutStore((s) => s.themeMode);
11
+ const theme = themeMode === 'light' ? lightTheme : darkTheme;
12
+
13
+ return (
14
  <ThemeProvider theme={theme}>
15
  <CssBaseline />
16
  <App />
17
  </ThemeProvider>
18
+ );
19
+ }
20
+
21
+ createRoot(document.getElementById('root')!).render(
22
+ <StrictMode>
23
+ <Root />
24
  </StrictMode>
25
  );
frontend/src/store/agentStore.ts CHANGED
@@ -1,5 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
1
  import { create } from 'zustand';
2
- import type { Message, ApprovalBatch, User, TraceLog } from '@/types/agent';
3
 
4
  export interface PlanItem {
5
  id: string;
@@ -7,254 +18,158 @@ export interface PlanItem {
7
  status: 'pending' | 'in_progress' | 'completed';
8
  }
9
 
10
- interface PanelTab {
11
- id: string;
12
- title: string;
13
  content: string;
14
- language?: string;
15
- parameters?: any;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  }
17
 
 
 
 
 
 
 
 
18
  interface AgentStore {
19
- // State per session (keyed by session ID)
20
- messagesBySession: Record<string, Message[]>;
21
  isProcessing: boolean;
22
  isConnected: boolean;
23
- pendingApprovals: ApprovalBatch | null;
24
  user: User | null;
25
  error: string | null;
26
- traceLogs: TraceLog[];
27
- panelContent: { title: string; content: string; language?: string; parameters?: any } | null;
28
- panelTabs: PanelTab[];
29
- activePanelTab: string | null;
 
 
 
 
30
  plan: PlanItem[];
31
- currentTurnMessageId: string | null; // Track the current turn's assistant message
 
 
 
 
 
32
 
33
  // Actions
34
- addMessage: (sessionId: string, message: Message) => void;
35
- updateMessage: (sessionId: string, messageId: string, updates: Partial<Message>) => void;
36
- clearMessages: (sessionId: string) => void;
37
  setProcessing: (isProcessing: boolean) => void;
38
  setConnected: (isConnected: boolean) => void;
39
- setPendingApprovals: (approvals: ApprovalBatch | null) => void;
40
  setUser: (user: User | null) => void;
41
  setError: (error: string | null) => void;
42
- getMessages: (sessionId: string) => Message[];
43
- addTraceLog: (log: TraceLog) => void;
44
- updateTraceLog: (toolName: string, updates: Partial<TraceLog>) => void;
45
- clearTraceLogs: () => void;
46
- setPanelContent: (content: { title: string; content: string; language?: string; parameters?: any } | null) => void;
47
- setPanelTab: (tab: PanelTab) => void;
48
- setActivePanelTab: (tabId: string) => void;
49
- clearPanelTabs: () => void;
50
- removePanelTab: (tabId: string) => void;
51
  setPlan: (plan: PlanItem[]) => void;
52
- setCurrentTurnMessageId: (id: string | null) => void;
53
- updateCurrentTurnTrace: (sessionId: string) => void;
54
- showToolOutput: (log: TraceLog) => void;
 
 
 
 
55
  }
56
 
57
- export const useAgentStore = create<AgentStore>((set, get) => ({
58
- messagesBySession: {},
59
  isProcessing: false,
60
  isConnected: false,
61
- pendingApprovals: null,
62
  user: null,
63
  error: null,
64
- traceLogs: [],
65
- panelContent: null,
66
- panelTabs: [],
67
- activePanelTab: null,
68
- plan: [],
69
- currentTurnMessageId: null,
70
-
71
- addMessage: (sessionId: string, message: Message) => {
72
- set((state) => {
73
- const currentMessages = state.messagesBySession[sessionId] || [];
74
- return {
75
- messagesBySession: {
76
- ...state.messagesBySession,
77
- [sessionId]: [...currentMessages, message],
78
- },
79
- };
80
- });
81
- },
82
 
83
- updateMessage: (sessionId: string, messageId: string, updates: Partial<Message>) => {
84
- set((state) => {
85
- const currentMessages = state.messagesBySession[sessionId] || [];
86
- const updatedMessages = currentMessages.map((msg) =>
87
- msg.id === messageId ? { ...msg, ...updates } : msg
88
- );
89
- return {
90
- messagesBySession: {
91
- ...state.messagesBySession,
92
- [sessionId]: updatedMessages,
93
- },
94
- };
95
- });
96
- },
97
 
98
- clearMessages: (sessionId: string) => {
99
- set((state) => ({
100
- messagesBySession: {
101
- ...state.messagesBySession,
102
- [sessionId]: [],
103
- },
104
- }));
105
- },
106
 
107
- setProcessing: (isProcessing: boolean) => {
108
- set({ isProcessing });
109
- },
110
 
111
- setConnected: (isConnected: boolean) => {
112
- set({ isConnected });
113
- },
114
 
115
- setPendingApprovals: (approvals: ApprovalBatch | null) => {
116
- set({ pendingApprovals: approvals });
 
 
117
  },
 
 
 
 
 
118
 
119
- setUser: (user: User | null) => {
120
- set({ user });
121
- },
122
 
123
- setError: (error: string | null) => {
124
- set({ error });
125
- },
 
 
126
 
127
- getMessages: (sessionId: string) => {
128
- return get().messagesBySession[sessionId] || [];
129
- },
130
 
131
- addTraceLog: (log: TraceLog) => {
132
- set((state) => ({
133
- traceLogs: [...state.traceLogs, log],
134
- }));
135
- },
136
 
137
- updateTraceLog: (toolName: string, updates: Partial<TraceLog>) => {
138
- set((state) => {
139
- // Find the last trace log with this tool name and update it
140
- const traceLogs = [...state.traceLogs];
141
- for (let i = traceLogs.length - 1; i >= 0; i--) {
142
- if (traceLogs[i].tool === toolName && traceLogs[i].type === 'call') {
143
- traceLogs[i] = { ...traceLogs[i], ...updates };
144
- break;
145
- }
146
- }
147
- return { traceLogs };
148
- });
149
- },
150
 
151
- clearTraceLogs: () => {
152
- set({ traceLogs: [] });
153
- },
154
 
155
- setPanelContent: (content) => {
156
- set({ panelContent: content });
157
- },
158
 
159
- setPanelTab: (tab: PanelTab) => {
160
- set((state) => {
161
- const existingIndex = state.panelTabs.findIndex(t => t.id === tab.id);
162
- let newTabs: PanelTab[];
163
- if (existingIndex >= 0) {
164
- // Update existing tab
165
- newTabs = [...state.panelTabs];
166
- newTabs[existingIndex] = tab;
167
- } else {
168
- // Add new tab
169
- newTabs = [...state.panelTabs, tab];
170
- }
171
- return {
172
- panelTabs: newTabs,
173
- activePanelTab: state.activePanelTab || tab.id, // Auto-select first tab
174
- };
175
- });
176
- },
177
 
178
- setActivePanelTab: (tabId: string) => {
179
- set({ activePanelTab: tabId });
180
- },
181
 
182
- clearPanelTabs: () => {
183
- set({ panelTabs: [], activePanelTab: null });
184
- },
185
 
186
- removePanelTab: (tabId: string) => {
187
- set((state) => {
188
- const newTabs = state.panelTabs.filter(t => t.id !== tabId);
189
- // If we removed the active tab, switch to another tab or null
190
- let newActiveTab = state.activePanelTab;
191
- if (state.activePanelTab === tabId) {
192
- newActiveTab = newTabs.length > 0 ? newTabs[newTabs.length - 1].id : null;
193
- }
194
- return {
195
- panelTabs: newTabs,
196
- activePanelTab: newActiveTab,
197
- };
198
- });
199
  },
200
 
201
- setPlan: (plan: PlanItem[]) => {
202
- set({ plan });
203
- },
204
 
205
- setCurrentTurnMessageId: (id: string | null) => {
206
- set({ currentTurnMessageId: id });
207
- },
208
 
209
- updateCurrentTurnTrace: (sessionId: string) => {
210
- const state = get();
211
- if (state.currentTurnMessageId) {
212
- const currentMessages = state.messagesBySession[sessionId] || [];
213
- const updatedMessages = currentMessages.map((msg) =>
214
- msg.id === state.currentTurnMessageId
215
- ? { ...msg, trace: state.traceLogs.length > 0 ? [...state.traceLogs] : undefined }
216
- : msg
217
- );
218
- set({
219
- messagesBySession: {
220
- ...state.messagesBySession,
221
- [sessionId]: updatedMessages,
222
- },
223
- });
224
- }
225
- },
226
 
227
- showToolOutput: (log: TraceLog) => {
228
- // Show tool output in the right panel - only ONE tool output tab at a time
229
- const state = get();
230
-
231
- // Determine language based on content
232
- let language = 'text';
233
- const content = log.output || '';
234
-
235
- // Check if content looks like JSON
236
- if (content.trim().startsWith('{') || content.trim().startsWith('[') || content.includes('```json')) {
237
- language = 'json';
238
- }
239
- // Check if content has markdown tables or formatting
240
- else if (content.includes('|') && content.includes('---') || content.includes('```')) {
241
- language = 'markdown';
242
- }
243
-
244
- // Remove any existing tool output tab (only keep one)
245
- const otherTabs = state.panelTabs.filter(t => t.id !== 'tool_output');
246
-
247
- // Create/replace the single tool output tab
248
- const newTab = {
249
- id: 'tool_output',
250
- title: log.tool,
251
- content: content || 'No output available',
252
- language,
253
- };
254
-
255
- set({
256
- panelTabs: [...otherTabs, newTab],
257
- activePanelTab: 'tool_output',
258
- });
259
  },
 
 
260
  }));
 
1
+ /**
2
+ * Agent store β€” manages UI state that is NOT handled by the Vercel AI SDK.
3
+ *
4
+ * Message state (messages, streaming, tool calls) is now managed by useChat().
5
+ * This store only handles:
6
+ * - Connection / processing flags
7
+ * - Panel state (right panel β€” single-artifact pattern)
8
+ * - Plan state
9
+ * - User info / error banners
10
+ * - Edited scripts (for hf_jobs code editing)
11
+ */
12
  import { create } from 'zustand';
13
+ import type { User } from '@/types/agent';
14
 
15
  export interface PlanItem {
16
  id: string;
 
18
  status: 'pending' | 'in_progress' | 'completed';
19
  }
20
 
21
+ export interface PanelSection {
 
 
22
  content: string;
23
+ language: string;
24
+ }
25
+
26
+ export interface PanelData {
27
+ title: string;
28
+ script?: PanelSection;
29
+ output?: PanelSection;
30
+ parameters?: Record<string, unknown>;
31
+ }
32
+
33
+ export type PanelView = 'script' | 'output';
34
+
35
+ export interface LLMHealthError {
36
+ error: string;
37
+ errorType: 'auth' | 'credits' | 'rate_limit' | 'network' | 'unknown';
38
+ model: string;
39
  }
40
 
41
+ export type ActivityStatus =
42
+ | { type: 'idle' }
43
+ | { type: 'thinking' }
44
+ | { type: 'tool'; toolName: string }
45
+ | { type: 'waiting-approval' }
46
+ | { type: 'streaming' };
47
+
48
  interface AgentStore {
49
+ // Global UI flags
 
50
  isProcessing: boolean;
51
  isConnected: boolean;
52
+ activityStatus: ActivityStatus;
53
  user: User | null;
54
  error: string | null;
55
+ llmHealthError: LLMHealthError | null;
56
+
57
+ // Right panel (single-artifact pattern)
58
+ panelData: PanelData | null;
59
+ panelView: PanelView;
60
+ panelEditable: boolean;
61
+
62
+ // Plan
63
  plan: PlanItem[];
64
+
65
+ // Edited scripts (tool_call_id -> edited content)
66
+ editedScripts: Record<string, string>;
67
+
68
+ // Job URLs (tool_call_id -> job URL) for HF jobs
69
+ jobUrls: Record<string, string>;
70
 
71
  // Actions
 
 
 
72
  setProcessing: (isProcessing: boolean) => void;
73
  setConnected: (isConnected: boolean) => void;
74
+ setActivityStatus: (status: ActivityStatus) => void;
75
  setUser: (user: User | null) => void;
76
  setError: (error: string | null) => void;
77
+ setLlmHealthError: (error: LLMHealthError | null) => void;
78
+
79
+ setPanel: (data: PanelData, view?: PanelView, editable?: boolean) => void;
80
+ setPanelView: (view: PanelView) => void;
81
+ setPanelOutput: (output: PanelSection) => void;
82
+ updatePanelScript: (content: string) => void;
83
+ lockPanel: () => void;
84
+ clearPanel: () => void;
85
+
86
  setPlan: (plan: PlanItem[]) => void;
87
+
88
+ setEditedScript: (toolCallId: string, content: string) => void;
89
+ getEditedScript: (toolCallId: string) => string | undefined;
90
+ clearEditedScripts: () => void;
91
+
92
+ setJobUrl: (toolCallId: string, jobUrl: string) => void;
93
+ getJobUrl: (toolCallId: string) => string | undefined;
94
  }
95
 
96
+ export const useAgentStore = create<AgentStore>()((set, get) => ({
 
97
  isProcessing: false,
98
  isConnected: false,
99
+ activityStatus: { type: 'idle' },
100
  user: null,
101
  error: null,
102
+ llmHealthError: null,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
+ panelData: null,
105
+ panelView: 'script',
106
+ panelEditable: false,
 
 
 
 
 
 
 
 
 
 
 
107
 
108
+ plan: [],
 
 
 
 
 
 
 
109
 
110
+ editedScripts: {},
111
+ jobUrls: {},
 
112
 
113
+ // ── Global flags ──────────────────────────────────────────────────
 
 
114
 
115
+ setProcessing: (isProcessing) => {
116
+ const current = get().activityStatus;
117
+ const preserveStatus = current.type === 'waiting-approval';
118
+ set({ isProcessing, ...(!isProcessing && !preserveStatus ? { activityStatus: { type: 'idle' } } : {}) });
119
  },
120
+ setConnected: (isConnected) => set({ isConnected }),
121
+ setActivityStatus: (status) => set({ activityStatus: status }),
122
+ setUser: (user) => set({ user }),
123
+ setError: (error) => set({ error }),
124
+ setLlmHealthError: (error) => set({ llmHealthError: error }),
125
 
126
+ // ── Panel (single-artifact) ───────────────────────────────────────
 
 
127
 
128
+ setPanel: (data, view, editable) => set({
129
+ panelData: data,
130
+ panelView: view ?? (data.script ? 'script' : 'output'),
131
+ panelEditable: editable ?? false,
132
+ }),
133
 
134
+ setPanelView: (view) => set({ panelView: view }),
 
 
135
 
136
+ setPanelOutput: (output) => set((state) => ({
137
+ panelData: state.panelData ? { ...state.panelData, output } : null,
138
+ })),
 
 
139
 
140
+ updatePanelScript: (content) => set((state) => ({
141
+ panelData: state.panelData?.script
142
+ ? { ...state.panelData, script: { ...state.panelData.script, content } }
143
+ : state.panelData,
144
+ })),
 
 
 
 
 
 
 
 
145
 
146
+ lockPanel: () => set({ panelEditable: false }),
 
 
147
 
148
+ clearPanel: () => set({ panelData: null, panelView: 'script', panelEditable: false }),
 
 
149
 
150
+ // ── Plan ──────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
+ setPlan: (plan) => set({ plan }),
 
 
153
 
154
+ // ── Edited scripts ────────────────────────────────────────────────
 
 
155
 
156
+ setEditedScript: (toolCallId, content) => {
157
+ set((state) => ({
158
+ editedScripts: { ...state.editedScripts, [toolCallId]: content },
159
+ }));
 
 
 
 
 
 
 
 
 
160
  },
161
 
162
+ getEditedScript: (toolCallId) => get().editedScripts[toolCallId],
 
 
163
 
164
+ clearEditedScripts: () => set({ editedScripts: {} }),
 
 
165
 
166
+ // ── Job URLs ────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
+ setJobUrl: (toolCallId, jobUrl) => {
169
+ set((state) => ({
170
+ jobUrls: { ...state.jobUrls, [toolCallId]: jobUrl },
171
+ }));
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  },
173
+
174
+ getJobUrl: (toolCallId) => get().jobUrls[toolCallId],
175
  }));