akseljoonas HF Staff commited on
Commit
af6a7ab
Β·
1 Parent(s): 110c57a

feat: add headless CLI mode with local filesystem tools and rich terminal rendering

Browse files

- Add argparse-based headless mode: `python -m agent.main "prompt here"`
- Local tool specs override sandbox /app references for CLI mode
- System prompt injects CLI-specific context (working directory, no sandbox)
- Extract streaming/non-streaming LLM call helpers with LLMResult dataclass
- Add shimmer thinking animation (truecolor gradient sweep)
- Progressive markdown rendering via rich Live display
- Clean MCP shutdown sequence, suppress asyncio teardown noise
- Add rich dependency to agent extras

agent/context_manager/manager.py CHANGED
@@ -79,11 +79,13 @@ class ContextManager:
79
  tool_specs: list[dict[str, Any]] | None = None,
80
  prompt_file_suffix: str = "system_prompt_v3.yaml",
81
  hf_token: str | None = None,
 
82
  ):
83
  self.system_prompt = self._load_system_prompt(
84
  tool_specs or [],
85
  prompt_file_suffix="system_prompt_v3.yaml",
86
  hf_token=hf_token,
 
87
  )
88
  self.max_context = max_context - 10000
89
  self.compact_size = int(max_context * compact_size)
@@ -96,6 +98,7 @@ class ContextManager:
96
  tool_specs: list[dict[str, Any]],
97
  prompt_file_suffix: str = "system_prompt.yaml",
98
  hf_token: str | None = None,
 
99
  ):
100
  """Load and render the system prompt from YAML file with Jinja2"""
101
  prompt_file = Path(__file__).parent.parent / "prompts" / f"{prompt_file_suffix}"
@@ -119,6 +122,23 @@ class ContextManager:
119
  tools=tool_specs,
120
  num_tools=len(tool_specs),
121
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  return (
123
  f"{static_prompt}\n\n"
124
  f"[Session context: Date={current_date}, Time={current_time}, "
 
79
  tool_specs: list[dict[str, Any]] | None = None,
80
  prompt_file_suffix: str = "system_prompt_v3.yaml",
81
  hf_token: str | None = None,
82
+ local_mode: bool = False,
83
  ):
84
  self.system_prompt = self._load_system_prompt(
85
  tool_specs or [],
86
  prompt_file_suffix="system_prompt_v3.yaml",
87
  hf_token=hf_token,
88
+ local_mode=local_mode,
89
  )
90
  self.max_context = max_context - 10000
91
  self.compact_size = int(max_context * compact_size)
 
98
  tool_specs: list[dict[str, Any]],
99
  prompt_file_suffix: str = "system_prompt.yaml",
100
  hf_token: str | None = None,
101
+ local_mode: bool = False,
102
  ):
103
  """Load and render the system prompt from YAML file with Jinja2"""
104
  prompt_file = Path(__file__).parent.parent / "prompts" / f"{prompt_file_suffix}"
 
122
  tools=tool_specs,
123
  num_tools=len(tool_specs),
124
  )
125
+
126
+ # CLI-specific context for local mode
127
+ if local_mode:
128
+ import os
129
+ cwd = os.getcwd()
130
+ local_context = (
131
+ f"\n\n# CLI / Local mode\n\n"
132
+ f"You are running as a local CLI tool on the user's machine. "
133
+ f"There is NO sandbox β€” bash, read, write, and edit operate directly "
134
+ f"on the local filesystem.\n\n"
135
+ f"Working directory: {cwd}\n"
136
+ f"Use absolute paths or paths relative to the working directory. "
137
+ f"Do NOT use /app/ paths β€” that is a sandbox convention that does not apply here.\n"
138
+ f"The sandbox_create tool is NOT available. Run code directly with bash."
139
+ )
140
+ static_prompt += local_context
141
+
142
  return (
143
  f"{static_prompt}\n\n"
144
  f"[Session context: Date={current_date}, Time={current_time}, "
agent/core/agent_loop.py CHANGED
@@ -6,6 +6,7 @@ import asyncio
6
  import json
7
  import logging
8
  import os
 
9
 
10
  from litellm import ChatCompletionMessageToolCall, Message, acompletion
11
  from litellm.exceptions import ContextWindowExceededError
@@ -244,6 +245,164 @@ async def _cleanup_on_cancel(session: Session) -> None:
244
  session._running_job_ids.clear()
245
 
246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  class Handlers:
248
  """Handler functions for each operation type"""
249
 
@@ -345,98 +504,17 @@ class Handlers:
345
  messages = session.context_manager.get_messages()
346
  tools = session.tool_router.get_tool_specs_for_llm()
347
  try:
348
- # ── Stream the LLM response (with retry for transient errors) ──
349
  llm_params = _resolve_hf_router_params(session.config.model_name)
350
- response = None
351
- for _llm_attempt in range(_MAX_LLM_RETRIES):
352
- try:
353
- response = await acompletion(
354
- messages=messages,
355
- tools=tools,
356
- tool_choice="auto",
357
- stream=True,
358
- stream_options={"include_usage": True},
359
- timeout=600,
360
- **llm_params,
361
- )
362
- break
363
- except ContextWindowExceededError:
364
- raise
365
- except Exception as e:
366
- if _llm_attempt < _MAX_LLM_RETRIES - 1 and _is_transient_error(e):
367
- _delay = _LLM_RETRY_DELAYS[_llm_attempt]
368
- logger.warning(
369
- "Transient LLM error (attempt %d/%d): %s β€” retrying in %ds",
370
- _llm_attempt + 1, _MAX_LLM_RETRIES, e, _delay,
371
- )
372
- await session.send_event(Event(
373
- event_type="tool_log",
374
- data={"tool": "system", "log": f"LLM connection error, retrying in {_delay}s..."},
375
- ))
376
- await asyncio.sleep(_delay)
377
- continue
378
- raise
379
-
380
- full_content = ""
381
- tool_calls_acc: dict[int, dict] = {}
382
- token_count = 0
383
- finish_reason = None
384
-
385
- async for chunk in response:
386
- # ── Check cancellation during streaming ──
387
- if session.is_cancelled:
388
- tool_calls_acc.clear()
389
- break
390
-
391
- choice = chunk.choices[0] if chunk.choices else None
392
- if not choice:
393
- # Last chunk may carry only usage info
394
- if hasattr(chunk, "usage") and chunk.usage:
395
- token_count = chunk.usage.total_tokens
396
- continue
397
-
398
- delta = choice.delta
399
- if choice.finish_reason:
400
- finish_reason = choice.finish_reason
401
-
402
- # Stream text deltas to the frontend
403
- if delta.content:
404
- full_content += delta.content
405
- await session.send_event(
406
- Event(
407
- event_type="assistant_chunk",
408
- data={"content": delta.content},
409
- )
410
- )
411
 
412
- # Accumulate tool-call deltas (name + args arrive in pieces)
413
- if delta.tool_calls:
414
- for tc_delta in delta.tool_calls:
415
- idx = tc_delta.index
416
- if idx not in tool_calls_acc:
417
- tool_calls_acc[idx] = {
418
- "id": "",
419
- "type": "function",
420
- "function": {"name": "", "arguments": ""},
421
- }
422
- if tc_delta.id:
423
- tool_calls_acc[idx]["id"] = tc_delta.id
424
- if tc_delta.function:
425
- if tc_delta.function.name:
426
- tool_calls_acc[idx]["function"]["name"] += (
427
- tc_delta.function.name
428
- )
429
- if tc_delta.function.arguments:
430
- tool_calls_acc[idx]["function"]["arguments"] += (
431
- tc_delta.function.arguments
432
- )
433
-
434
- # Capture usage from the final chunk
435
- if hasattr(chunk, "usage") and chunk.usage:
436
- token_count = chunk.usage.total_tokens
437
-
438
- # ── Stream finished β€” reconstruct full message ───────
439
- content = full_content or None
440
 
441
  # If output was truncated, all tool call args are garbage.
442
  # Inject a system hint so the LLM retries with smaller content.
@@ -468,9 +546,10 @@ class Handlers:
468
  session.context_manager.add_message(
469
  Message(role="user", content=f"[SYSTEM: {truncation_hint}]")
470
  )
471
- await session.send_event(
472
- Event(event_type="assistant_stream_end", data={})
473
- )
 
474
  await session.send_event(
475
  Event(
476
  event_type="tool_log",
@@ -496,9 +575,10 @@ class Handlers:
496
  )
497
 
498
  # Signal end of streaming to the frontend
499
- await session.send_event(
500
- Event(event_type="assistant_stream_end", data={})
501
- )
 
502
 
503
  # If no tool calls, add assistant message and we're done
504
  if not tool_calls:
@@ -1043,6 +1123,8 @@ async def submission_loop(
1043
  tool_router: ToolRouter | None = None,
1044
  session_holder: list | None = None,
1045
  hf_token: str | None = None,
 
 
1046
  ) -> None:
1047
  """
1048
  Main agent loop - processes submissions and dispatches to handlers.
@@ -1051,7 +1133,8 @@ async def submission_loop(
1051
 
1052
  # Create session with tool router
1053
  session = Session(
1054
- event_queue, config=config, tool_router=tool_router, hf_token=hf_token
 
1055
  )
1056
  if session_holder is not None:
1057
  session_holder[0] = session
 
6
  import json
7
  import logging
8
  import os
9
+ from dataclasses import dataclass
10
 
11
  from litellm import ChatCompletionMessageToolCall, Message, acompletion
12
  from litellm.exceptions import ContextWindowExceededError
 
245
  session._running_job_ids.clear()
246
 
247
 
248
+ @dataclass
249
+ class LLMResult:
250
+ """Result from an LLM call (streaming or non-streaming)."""
251
+ content: str | None
252
+ tool_calls_acc: dict[int, dict]
253
+ token_count: int
254
+ finish_reason: str | None
255
+
256
+
257
+ async def _call_llm_streaming(session: Session, messages, tools, llm_params) -> LLMResult:
258
+ """Call the LLM with streaming, emitting assistant_chunk events."""
259
+ response = None
260
+ for _llm_attempt in range(_MAX_LLM_RETRIES):
261
+ try:
262
+ response = await acompletion(
263
+ messages=messages,
264
+ tools=tools,
265
+ tool_choice="auto",
266
+ stream=True,
267
+ stream_options={"include_usage": True},
268
+ timeout=600,
269
+ **llm_params,
270
+ )
271
+ break
272
+ except ContextWindowExceededError:
273
+ raise
274
+ except Exception as e:
275
+ if _llm_attempt < _MAX_LLM_RETRIES - 1 and _is_transient_error(e):
276
+ _delay = _LLM_RETRY_DELAYS[_llm_attempt]
277
+ logger.warning(
278
+ "Transient LLM error (attempt %d/%d): %s β€” retrying in %ds",
279
+ _llm_attempt + 1, _MAX_LLM_RETRIES, e, _delay,
280
+ )
281
+ await session.send_event(Event(
282
+ event_type="tool_log",
283
+ data={"tool": "system", "log": f"LLM connection error, retrying in {_delay}s..."},
284
+ ))
285
+ await asyncio.sleep(_delay)
286
+ continue
287
+ raise
288
+
289
+ full_content = ""
290
+ tool_calls_acc: dict[int, dict] = {}
291
+ token_count = 0
292
+ finish_reason = None
293
+
294
+ async for chunk in response:
295
+ if session.is_cancelled:
296
+ tool_calls_acc.clear()
297
+ break
298
+
299
+ choice = chunk.choices[0] if chunk.choices else None
300
+ if not choice:
301
+ if hasattr(chunk, "usage") and chunk.usage:
302
+ token_count = chunk.usage.total_tokens
303
+ continue
304
+
305
+ delta = choice.delta
306
+ if choice.finish_reason:
307
+ finish_reason = choice.finish_reason
308
+
309
+ if delta.content:
310
+ full_content += delta.content
311
+ await session.send_event(
312
+ Event(event_type="assistant_chunk", data={"content": delta.content})
313
+ )
314
+
315
+ if delta.tool_calls:
316
+ for tc_delta in delta.tool_calls:
317
+ idx = tc_delta.index
318
+ if idx not in tool_calls_acc:
319
+ tool_calls_acc[idx] = {
320
+ "id": "", "type": "function",
321
+ "function": {"name": "", "arguments": ""},
322
+ }
323
+ if tc_delta.id:
324
+ tool_calls_acc[idx]["id"] = tc_delta.id
325
+ if tc_delta.function:
326
+ if tc_delta.function.name:
327
+ tool_calls_acc[idx]["function"]["name"] += tc_delta.function.name
328
+ if tc_delta.function.arguments:
329
+ tool_calls_acc[idx]["function"]["arguments"] += tc_delta.function.arguments
330
+
331
+ if hasattr(chunk, "usage") and chunk.usage:
332
+ token_count = chunk.usage.total_tokens
333
+
334
+ return LLMResult(
335
+ content=full_content or None,
336
+ tool_calls_acc=tool_calls_acc,
337
+ token_count=token_count,
338
+ finish_reason=finish_reason,
339
+ )
340
+
341
+
342
+ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params) -> LLMResult:
343
+ """Call the LLM without streaming, emit assistant_message at the end."""
344
+ response = None
345
+ for _llm_attempt in range(_MAX_LLM_RETRIES):
346
+ try:
347
+ response = await acompletion(
348
+ messages=messages,
349
+ tools=tools,
350
+ tool_choice="auto",
351
+ stream=False,
352
+ timeout=600,
353
+ **llm_params,
354
+ )
355
+ break
356
+ except ContextWindowExceededError:
357
+ raise
358
+ except Exception as e:
359
+ if _llm_attempt < _MAX_LLM_RETRIES - 1 and _is_transient_error(e):
360
+ _delay = _LLM_RETRY_DELAYS[_llm_attempt]
361
+ logger.warning(
362
+ "Transient LLM error (attempt %d/%d): %s β€” retrying in %ds",
363
+ _llm_attempt + 1, _MAX_LLM_RETRIES, e, _delay,
364
+ )
365
+ await session.send_event(Event(
366
+ event_type="tool_log",
367
+ data={"tool": "system", "log": f"LLM connection error, retrying in {_delay}s..."},
368
+ ))
369
+ await asyncio.sleep(_delay)
370
+ continue
371
+ raise
372
+
373
+ choice = response.choices[0]
374
+ message = choice.message
375
+ content = message.content or None
376
+ finish_reason = choice.finish_reason
377
+ token_count = response.usage.total_tokens if response.usage else 0
378
+
379
+ # Build tool_calls_acc in the same format as streaming
380
+ tool_calls_acc: dict[int, dict] = {}
381
+ if message.tool_calls:
382
+ for idx, tc in enumerate(message.tool_calls):
383
+ tool_calls_acc[idx] = {
384
+ "id": tc.id,
385
+ "type": "function",
386
+ "function": {
387
+ "name": tc.function.name,
388
+ "arguments": tc.function.arguments,
389
+ },
390
+ }
391
+
392
+ # Emit the full message as a single event
393
+ if content:
394
+ await session.send_event(
395
+ Event(event_type="assistant_message", data={"content": content})
396
+ )
397
+
398
+ return LLMResult(
399
+ content=content,
400
+ tool_calls_acc=tool_calls_acc,
401
+ token_count=token_count,
402
+ finish_reason=finish_reason,
403
+ )
404
+
405
+
406
  class Handlers:
407
  """Handler functions for each operation type"""
408
 
 
504
  messages = session.context_manager.get_messages()
505
  tools = session.tool_router.get_tool_specs_for_llm()
506
  try:
507
+ # ── Call the LLM (streaming or non-streaming) ──
508
  llm_params = _resolve_hf_router_params(session.config.model_name)
509
+ if session.stream:
510
+ llm_result = await _call_llm_streaming(session, messages, tools, llm_params)
511
+ else:
512
+ llm_result = await _call_llm_non_streaming(session, messages, tools, llm_params)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513
 
514
+ content = llm_result.content
515
+ tool_calls_acc = llm_result.tool_calls_acc
516
+ token_count = llm_result.token_count
517
+ finish_reason = llm_result.finish_reason
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
518
 
519
  # If output was truncated, all tool call args are garbage.
520
  # Inject a system hint so the LLM retries with smaller content.
 
546
  session.context_manager.add_message(
547
  Message(role="user", content=f"[SYSTEM: {truncation_hint}]")
548
  )
549
+ if session.stream:
550
+ await session.send_event(
551
+ Event(event_type="assistant_stream_end", data={})
552
+ )
553
  await session.send_event(
554
  Event(
555
  event_type="tool_log",
 
575
  )
576
 
577
  # Signal end of streaming to the frontend
578
+ if session.stream:
579
+ await session.send_event(
580
+ Event(event_type="assistant_stream_end", data={})
581
+ )
582
 
583
  # If no tool calls, add assistant message and we're done
584
  if not tool_calls:
 
1123
  tool_router: ToolRouter | None = None,
1124
  session_holder: list | None = None,
1125
  hf_token: str | None = None,
1126
+ local_mode: bool = False,
1127
+ stream: bool = True,
1128
  ) -> None:
1129
  """
1130
  Main agent loop - processes submissions and dispatches to handlers.
 
1133
 
1134
  # Create session with tool router
1135
  session = Session(
1136
+ event_queue, config=config, tool_router=tool_router, hf_token=hf_token,
1137
+ local_mode=local_mode, stream=stream,
1138
  )
1139
  if session_holder is not None:
1140
  session_holder[0] = session
agent/core/session.py CHANGED
@@ -84,9 +84,12 @@ class Session:
84
  tool_router=None,
85
  context_manager: ContextManager | None = None,
86
  hf_token: str | None = None,
 
 
87
  ):
88
  self.hf_token: Optional[str] = hf_token
89
  self.tool_router = tool_router
 
90
  tool_specs = tool_router.get_tool_specs_for_llm() if tool_router else []
91
  self.context_manager = context_manager or ContextManager(
92
  max_context=_get_max_tokens_safe(config.model_name),
@@ -94,6 +97,7 @@ class Session:
94
  untouched_messages=5,
95
  tool_specs=tool_specs,
96
  hf_token=hf_token,
 
97
  )
98
  self.event_queue = event_queue
99
  self.session_id = str(uuid.uuid4())
 
84
  tool_router=None,
85
  context_manager: ContextManager | None = None,
86
  hf_token: str | None = None,
87
+ local_mode: bool = False,
88
+ stream: bool = True,
89
  ):
90
  self.hf_token: Optional[str] = hf_token
91
  self.tool_router = tool_router
92
+ self.stream = stream
93
  tool_specs = tool_router.get_tool_specs_for_llm() if tool_router else []
94
  self.context_manager = context_manager or ContextManager(
95
  max_context=_get_max_tokens_safe(config.model_name),
 
97
  untouched_messages=5,
98
  tool_specs=tool_specs,
99
  hf_token=hf_token,
100
+ local_mode=local_mode,
101
  )
102
  self.event_queue = event_queue
103
  self.session_id = str(uuid.uuid4())
agent/main.py CHANGED
@@ -1,10 +1,16 @@
1
  """
2
  Interactive CLI chat with the agent
 
 
 
 
3
  """
4
 
 
5
  import asyncio
6
  import json
7
  import os
 
8
  import time
9
  from dataclasses import dataclass
10
  from pathlib import Path
@@ -51,7 +57,7 @@ def _safe_get_args(arguments: dict) -> dict:
51
 
52
 
53
  def _get_hf_token() -> str | None:
54
- """Get HF token from environment or huggingface_hub cached login."""
55
  token = os.environ.get("HF_TOKEN")
56
  if token:
57
  return token
@@ -63,6 +69,12 @@ def _get_hf_token() -> str | None:
63
  return token
64
  except Exception:
65
  pass
 
 
 
 
 
 
66
  return None
67
 
68
 
@@ -123,6 +135,128 @@ class Submission:
123
  operation: Operation
124
 
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  async def event_listener(
127
  event_queue: asyncio.Queue,
128
  submission_queue: asyncio.Queue,
@@ -134,6 +268,9 @@ async def event_listener(
134
  """Background task that listens for events and displays them"""
135
  submission_id = [1000] # Use list to make it mutable in closure
136
  last_tool_name = [None] # Track last tool called
 
 
 
137
 
138
  while True:
139
  try:
@@ -144,16 +281,22 @@ async def event_listener(
144
  print(format_success("\U0001f917 Agent ready"))
145
  ready_event.set()
146
  elif event.event_type == "assistant_message":
 
 
147
  content = event.data.get("content", "") if event.data else ""
148
  if content:
149
- print(f"\nAssistant: {content}")
 
150
  elif event.event_type == "assistant_chunk":
 
151
  content = event.data.get("content", "") if event.data else ""
152
  if content:
153
- print(content, end="", flush=True)
154
  elif event.event_type == "assistant_stream_end":
155
- print() # newline after streaming
156
  elif event.event_type == "tool_call":
 
 
157
  tool_name = event.data.get("tool", "") if event.data else ""
158
  arguments = event.data.get("arguments", {}) if event.data else {}
159
  if tool_name:
@@ -167,7 +310,11 @@ async def event_listener(
167
  # Don't truncate plan_tool output, truncate everything else
168
  should_truncate = last_tool_name[0] != "plan_tool"
169
  print(format_tool_output(output, success, truncate=should_truncate))
 
 
170
  elif event.event_type == "turn_complete":
 
 
171
  print(format_turn_complete())
172
  # Display plan after turn complete
173
  plan_display = format_plan_display()
@@ -175,6 +322,8 @@ async def event_listener(
175
  print(plan_display)
176
  turn_complete_event.set()
177
  elif event.event_type == "interrupted":
 
 
178
  print("\n(interrupted)")
179
  turn_complete_event.set()
180
  elif event.event_type == "undo_complete":
@@ -191,6 +340,8 @@ async def event_listener(
191
  if state in ("approved", "rejected", "running"):
192
  print(f" {tool}: {state}")
193
  elif event.event_type == "error":
 
 
194
  error = (
195
  event.data.get("error", "Unknown error")
196
  if event.data
@@ -199,9 +350,11 @@ async def event_listener(
199
  print(format_error(error))
200
  turn_complete_event.set()
201
  elif event.event_type == "shutdown":
 
 
202
  break
203
  elif event.event_type == "processing":
204
- pass # print("Processing...", flush=True)
205
  elif event.event_type == "compacted":
206
  old_tokens = event.data.get("old_tokens", 0) if event.data else 0
207
  new_tokens = event.data.get("new_tokens", 0) if event.data else 0
@@ -670,6 +823,8 @@ async def main():
670
  tool_router=tool_router,
671
  session_holder=session_holder,
672
  hf_token=hf_token,
 
 
673
  )
674
  )
675
 
@@ -762,17 +917,167 @@ async def main():
762
  )
763
  await submission_queue.put(shutdown_submission)
764
 
 
 
765
  try:
766
- await asyncio.wait_for(agent_task, timeout=5.0)
767
  except asyncio.TimeoutError:
768
  agent_task.cancel()
 
 
 
 
769
  listener_task.cancel()
770
 
771
  print("Goodbye!\n")
772
 
773
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
774
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
775
  try:
776
- asyncio.run(main())
 
 
 
777
  except KeyboardInterrupt:
778
  print("\n\nGoodbye!")
 
1
  """
2
  Interactive CLI chat with the agent
3
+
4
+ Supports two modes:
5
+ Interactive: python -m agent.main
6
+ Headless: python -m agent.main "find me bird datasets"
7
  """
8
 
9
+ import argparse
10
  import asyncio
11
  import json
12
  import os
13
+ import sys
14
  import time
15
  from dataclasses import dataclass
16
  from pathlib import Path
 
57
 
58
 
59
  def _get_hf_token() -> str | None:
60
+ """Get HF token from environment, huggingface_hub API, or cached token file."""
61
  token = os.environ.get("HF_TOKEN")
62
  if token:
63
  return token
 
69
  return token
70
  except Exception:
71
  pass
72
+ # Fallback: read the cached token file directly
73
+ token_path = Path.home() / ".cache" / "huggingface" / "token"
74
+ if token_path.exists():
75
+ token = token_path.read_text().strip()
76
+ if token:
77
+ return token
78
  return None
79
 
80
 
 
135
  operation: Operation
136
 
137
 
138
+ def _create_rich_console():
139
+ """Create a rich Console for markdown rendering."""
140
+ from rich.console import Console
141
+ return Console(highlight=False)
142
+
143
+
144
+ def _render_markdown(console, text: str) -> None:
145
+ """Render markdown text to the terminal via rich."""
146
+ from rich.markdown import Markdown
147
+ console.print(Markdown(text))
148
+
149
+
150
+ class _ThinkingShimmer:
151
+ """Animated shiny/shimmer thinking indicator β€” a bright gradient sweeps across the text."""
152
+
153
+ _BASE = (90, 90, 110) # dim base color
154
+ _HIGHLIGHT = (255, 200, 80) # bright shimmer highlight (warm gold)
155
+ _WIDTH = 5 # shimmer width in characters
156
+ _FPS = 24
157
+
158
+ def __init__(self, console):
159
+ self._console = console
160
+ self._task = None
161
+ self._running = False
162
+
163
+ def start(self):
164
+ if self._running:
165
+ return
166
+ self._running = True
167
+ self._task = asyncio.ensure_future(self._animate())
168
+
169
+ def stop(self):
170
+ self._running = False
171
+ if self._task:
172
+ self._task.cancel()
173
+ self._task = None
174
+ # Clear the shimmer line
175
+ self._console.file.write("\r\033[K")
176
+ self._console.file.flush()
177
+
178
+ def _render_frame(self, text: str, offset: float) -> str:
179
+ """Render one frame: a bright spot sweeps left-to-right across `text`."""
180
+ out = []
181
+ n = len(text)
182
+ for i, ch in enumerate(text):
183
+ # Distance from the shimmer center (wraps around)
184
+ dist = abs(i - offset)
185
+ wrap_dist = abs(i - offset + n + self._WIDTH)
186
+ dist = min(dist, wrap_dist, abs(i - offset - n - self._WIDTH))
187
+ # Blend factor: 1.0 at center, 0.0 beyond _WIDTH
188
+ t = max(0.0, 1.0 - dist / self._WIDTH)
189
+ t = t * t * (3 - 2 * t) # smoothstep
190
+ r = int(self._BASE[0] + (self._HIGHLIGHT[0] - self._BASE[0]) * t)
191
+ g = int(self._BASE[1] + (self._HIGHLIGHT[1] - self._BASE[1]) * t)
192
+ b = int(self._BASE[2] + (self._HIGHLIGHT[2] - self._BASE[2]) * t)
193
+ out.append(f"\033[38;2;{r};{g};{b}m{ch}")
194
+ out.append("\033[0m")
195
+ return "".join(out)
196
+
197
+ async def _animate(self):
198
+ text = "Thinking..."
199
+ n = len(text)
200
+ speed = 0.45 # characters per frame
201
+ pos = 0.0
202
+ try:
203
+ while self._running:
204
+ frame = self._render_frame(text, pos)
205
+ self._console.file.write(f"\r{frame}")
206
+ self._console.file.flush()
207
+ pos = (pos + speed) % (n + self._WIDTH)
208
+ await asyncio.sleep(1.0 / self._FPS)
209
+ except asyncio.CancelledError:
210
+ pass
211
+
212
+
213
+ class _StreamBuffer:
214
+ """Buffers streaming chunks and renders markdown line-by-line via rich Live."""
215
+
216
+ def __init__(self, console):
217
+ self._console = console
218
+ self._buffer = ""
219
+ self._live = None
220
+ self._lines_printed = 0
221
+
222
+ def _start_live(self):
223
+ if self._live is None:
224
+ from rich.live import Live
225
+ self._live = Live(
226
+ "",
227
+ console=self._console,
228
+ refresh_per_second=8,
229
+ vertical_overflow="visible",
230
+ )
231
+ self._live.start()
232
+
233
+ def add_chunk(self, text: str):
234
+ self._buffer += text
235
+ self._start_live()
236
+ self._update()
237
+
238
+ def _update(self):
239
+ from rich.markdown import Markdown
240
+ if self._live:
241
+ self._live.update(Markdown(self._buffer))
242
+
243
+ def finish(self):
244
+ """Finalize: stop live display (final frame is already rendered)."""
245
+ if self._live:
246
+ self._live.stop()
247
+ self._live = None
248
+ self._buffer = ""
249
+ self._lines_printed = 0
250
+
251
+ def discard(self):
252
+ """Discard without final render (e.g. for tool-only turns)."""
253
+ if self._live:
254
+ self._live.stop()
255
+ self._live = None
256
+ self._buffer = ""
257
+ self._lines_printed = 0
258
+
259
+
260
  async def event_listener(
261
  event_queue: asyncio.Queue,
262
  submission_queue: asyncio.Queue,
 
268
  """Background task that listens for events and displays them"""
269
  submission_id = [1000] # Use list to make it mutable in closure
270
  last_tool_name = [None] # Track last tool called
271
+ console = _create_rich_console()
272
+ spinner = _ThinkingShimmer(console)
273
+ stream_buf = _StreamBuffer(console)
274
 
275
  while True:
276
  try:
 
281
  print(format_success("\U0001f917 Agent ready"))
282
  ready_event.set()
283
  elif event.event_type == "assistant_message":
284
+ # Non-streaming: full message arrives at once
285
+ spinner.stop()
286
  content = event.data.get("content", "") if event.data else ""
287
  if content:
288
+ console.print()
289
+ _render_markdown(console, content)
290
  elif event.event_type == "assistant_chunk":
291
+ spinner.stop()
292
  content = event.data.get("content", "") if event.data else ""
293
  if content:
294
+ stream_buf.add_chunk(content)
295
  elif event.event_type == "assistant_stream_end":
296
+ stream_buf.finish()
297
  elif event.event_type == "tool_call":
298
+ spinner.stop()
299
+ stream_buf.discard()
300
  tool_name = event.data.get("tool", "") if event.data else ""
301
  arguments = event.data.get("arguments", {}) if event.data else {}
302
  if tool_name:
 
310
  # Don't truncate plan_tool output, truncate everything else
311
  should_truncate = last_tool_name[0] != "plan_tool"
312
  print(format_tool_output(output, success, truncate=should_truncate))
313
+ # After tool output, agent will think again
314
+ spinner.start()
315
  elif event.event_type == "turn_complete":
316
+ spinner.stop()
317
+ stream_buf.discard()
318
  print(format_turn_complete())
319
  # Display plan after turn complete
320
  plan_display = format_plan_display()
 
322
  print(plan_display)
323
  turn_complete_event.set()
324
  elif event.event_type == "interrupted":
325
+ spinner.stop()
326
+ stream_buf.discard()
327
  print("\n(interrupted)")
328
  turn_complete_event.set()
329
  elif event.event_type == "undo_complete":
 
340
  if state in ("approved", "rejected", "running"):
341
  print(f" {tool}: {state}")
342
  elif event.event_type == "error":
343
+ spinner.stop()
344
+ stream_buf.discard()
345
  error = (
346
  event.data.get("error", "Unknown error")
347
  if event.data
 
350
  print(format_error(error))
351
  turn_complete_event.set()
352
  elif event.event_type == "shutdown":
353
+ spinner.stop()
354
+ stream_buf.discard()
355
  break
356
  elif event.event_type == "processing":
357
+ spinner.start()
358
  elif event.event_type == "compacted":
359
  old_tokens = event.data.get("old_tokens", 0) if event.data else 0
360
  new_tokens = event.data.get("new_tokens", 0) if event.data else 0
 
823
  tool_router=tool_router,
824
  session_holder=session_holder,
825
  hf_token=hf_token,
826
+ local_mode=True,
827
+ stream=True,
828
  )
829
  )
830
 
 
917
  )
918
  await submission_queue.put(shutdown_submission)
919
 
920
+ # Wait for agent to finish (the listener must keep draining events
921
+ # or the agent will block on event_queue.put)
922
  try:
923
+ await asyncio.wait_for(agent_task, timeout=10.0)
924
  except asyncio.TimeoutError:
925
  agent_task.cancel()
926
+ # Agent didn't shut down cleanly β€” close MCP explicitly
927
+ await tool_router.__aexit__(None, None, None)
928
+
929
+ # Now safe to cancel the listener (agent is done emitting events)
930
  listener_task.cancel()
931
 
932
  print("Goodbye!\n")
933
 
934
 
935
+ async def headless_main(prompt: str, model: str | None = None) -> None:
936
+ """Run a single prompt headlessly and exit."""
937
+ import logging
938
+
939
+ logging.basicConfig(level=logging.WARNING)
940
+
941
+ hf_token = _get_hf_token()
942
+ if not hf_token:
943
+ print("ERROR: No HF token found. Set HF_TOKEN or run `huggingface-cli login`.", file=sys.stderr)
944
+ sys.exit(1)
945
+
946
+ print(f"HF token loaded", file=sys.stderr)
947
+
948
+ config_path = Path(__file__).parent.parent / "configs" / "main_agent_config.json"
949
+ config = load_config(config_path)
950
+ config.yolo_mode = True # Auto-approve everything in headless mode
951
+
952
+ if model:
953
+ if model not in VALID_MODEL_IDS:
954
+ print(f"ERROR: Unknown model '{model}'. Valid: {', '.join(VALID_MODEL_IDS)}", file=sys.stderr)
955
+ sys.exit(1)
956
+ config.model_name = model
957
+
958
+ print(f"Model: {config.model_name}", file=sys.stderr)
959
+ print(f"Prompt: {prompt}", file=sys.stderr)
960
+ print("---", file=sys.stderr)
961
+
962
+ submission_queue: asyncio.Queue = asyncio.Queue()
963
+ event_queue: asyncio.Queue = asyncio.Queue()
964
+
965
+ tool_router = ToolRouter(config.mcpServers, hf_token=hf_token, local_mode=True)
966
+ session_holder: list = [None]
967
+
968
+ agent_task = asyncio.create_task(
969
+ submission_loop(
970
+ submission_queue,
971
+ event_queue,
972
+ config=config,
973
+ tool_router=tool_router,
974
+ session_holder=session_holder,
975
+ hf_token=hf_token,
976
+ local_mode=True,
977
+ stream=True,
978
+ )
979
+ )
980
+
981
+ # Wait for ready
982
+ while True:
983
+ event = await event_queue.get()
984
+ if event.event_type == "ready":
985
+ break
986
+
987
+ # Submit the prompt
988
+ submission = Submission(
989
+ id="sub_1",
990
+ operation=Operation(op_type=OpType.USER_INPUT, data={"text": prompt}),
991
+ )
992
+ await submission_queue.put(submission)
993
+
994
+ # Process events until turn completes
995
+ console = _create_rich_console()
996
+ err_console = _create_rich_console()
997
+ err_console.file = sys.stderr
998
+ spinner = _ThinkingShimmer(console)
999
+ stream_buf = _StreamBuffer(console)
1000
+ spinner.start()
1001
+
1002
+ while True:
1003
+ event = await event_queue.get()
1004
+
1005
+ if event.event_type == "assistant_chunk":
1006
+ spinner.stop()
1007
+ content = event.data.get("content", "") if event.data else ""
1008
+ if content:
1009
+ stream_buf.add_chunk(content)
1010
+ elif event.event_type == "assistant_stream_end":
1011
+ stream_buf.finish()
1012
+ elif event.event_type == "assistant_message":
1013
+ spinner.stop()
1014
+ content = event.data.get("content", "") if event.data else ""
1015
+ if content:
1016
+ _render_markdown(console, content)
1017
+ elif event.event_type == "tool_call":
1018
+ spinner.stop()
1019
+ stream_buf.discard()
1020
+ tool_name = event.data.get("tool", "") if event.data else ""
1021
+ arguments = event.data.get("arguments", {}) if event.data else {}
1022
+ if tool_name:
1023
+ args_str = json.dumps(arguments)[:100] + "..."
1024
+ print(format_tool_call(tool_name, args_str), file=sys.stderr)
1025
+ elif event.event_type == "tool_output":
1026
+ output = event.data.get("output", "") if event.data else ""
1027
+ success = event.data.get("success", False) if event.data else False
1028
+ if output:
1029
+ print(format_tool_output(output, success, truncate=True), file=sys.stderr)
1030
+ spinner.start()
1031
+ elif event.event_type == "tool_log":
1032
+ tool = event.data.get("tool", "") if event.data else ""
1033
+ log = event.data.get("log", "") if event.data else ""
1034
+ if log:
1035
+ print(f" [{tool}] {log}", file=sys.stderr)
1036
+ elif event.event_type == "compacted":
1037
+ old_tokens = event.data.get("old_tokens", 0) if event.data else 0
1038
+ new_tokens = event.data.get("new_tokens", 0) if event.data else 0
1039
+ print(f"Compacted: {old_tokens} -> {new_tokens} tokens", file=sys.stderr)
1040
+ elif event.event_type == "error":
1041
+ spinner.stop()
1042
+ stream_buf.discard()
1043
+ error = event.data.get("error", "Unknown error") if event.data else "Unknown error"
1044
+ print(f"ERROR: {error}", file=sys.stderr)
1045
+ break
1046
+ elif event.event_type in ("turn_complete", "interrupted"):
1047
+ spinner.stop()
1048
+ stream_buf.discard()
1049
+ break
1050
+
1051
+ # Shutdown
1052
+ shutdown_submission = Submission(
1053
+ id="sub_shutdown", operation=Operation(op_type=OpType.SHUTDOWN)
1054
+ )
1055
+ await submission_queue.put(shutdown_submission)
1056
+
1057
+ try:
1058
+ await asyncio.wait_for(agent_task, timeout=10.0)
1059
+ except asyncio.TimeoutError:
1060
+ agent_task.cancel()
1061
+ await tool_router.__aexit__(None, None, None)
1062
+
1063
+
1064
  if __name__ == "__main__":
1065
+ import logging as _logging
1066
+ import warnings
1067
+ # Suppress aiohttp "Unclosed client session" noise during event loop teardown
1068
+ _logging.getLogger("asyncio").setLevel(_logging.CRITICAL)
1069
+ # Suppress litellm pydantic deprecation warnings
1070
+ warnings.filterwarnings("ignore", category=DeprecationWarning, module="litellm")
1071
+
1072
+ parser = argparse.ArgumentParser(description="Hugging Face Agent CLI")
1073
+ parser.add_argument("prompt", nargs="?", default=None, help="Run headlessly with this prompt")
1074
+ parser.add_argument("--model", "-m", default=None, help=f"Model to use (default: from config)")
1075
+ args = parser.parse_args()
1076
+
1077
  try:
1078
+ if args.prompt:
1079
+ asyncio.run(headless_main(args.prompt, model=args.model))
1080
+ else:
1081
+ asyncio.run(main())
1082
  except KeyboardInterrupt:
1083
  print("\n\nGoodbye!")
agent/tools/local_tools.py CHANGED
@@ -227,7 +227,63 @@ async def _edit_handler(args: dict[str, Any], **_kw) -> tuple[str, bool]:
227
  return msg, True
228
 
229
 
230
- # ── Public API ──────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  _HANDLERS = {
233
  "bash": _bash_handler,
@@ -242,7 +298,7 @@ def get_local_tools():
242
  from agent.core.tools import ToolSpec
243
 
244
  tools = []
245
- for name, spec in Sandbox.TOOLS.items():
246
  handler = _HANDLERS.get(name)
247
  if handler is None:
248
  continue
 
227
  return msg, True
228
 
229
 
230
+ # ── Local tool specs (override sandbox /app references) ────────────────
231
+
232
+ _LOCAL_TOOL_SPECS = {
233
+ "bash": {
234
+ "description": (
235
+ "Run a shell command on the local machine and return stdout/stderr.\n"
236
+ "\n"
237
+ "Commands run in a shell at the working directory (default: current directory). "
238
+ "Each invocation is independent.\n"
239
+ "\n"
240
+ "AVOID using bash for operations covered by specialized tools:\n"
241
+ "- File reading: use read (not cat/head/tail)\n"
242
+ "- File editing: use edit (not sed/awk)\n"
243
+ "- File writing: use write (not echo/cat <<EOF)\n"
244
+ "\n"
245
+ "Chain dependent commands with &&. Independent commands should be "
246
+ "separate bash calls (they can run in parallel).\n"
247
+ "\n"
248
+ "Timeout default 120s, max 600s."
249
+ ),
250
+ "parameters": {
251
+ "type": "object",
252
+ "required": ["command"],
253
+ "additionalProperties": False,
254
+ "properties": {
255
+ "command": {
256
+ "type": "string",
257
+ "description": "The shell command to execute.",
258
+ },
259
+ "description": {
260
+ "type": "string",
261
+ "description": "Short description (5-10 words, active voice).",
262
+ },
263
+ "work_dir": {
264
+ "type": "string",
265
+ "description": "Working directory (default: current directory).",
266
+ },
267
+ "timeout": {
268
+ "type": "integer",
269
+ "description": "Timeout in seconds (default: 120, max: 600).",
270
+ },
271
+ },
272
+ },
273
+ },
274
+ "read": {
275
+ "description": Sandbox.TOOLS["read"]["description"],
276
+ "parameters": Sandbox.TOOLS["read"]["parameters"],
277
+ },
278
+ "write": {
279
+ "description": Sandbox.TOOLS["write"]["description"],
280
+ "parameters": Sandbox.TOOLS["write"]["parameters"],
281
+ },
282
+ "edit": {
283
+ "description": Sandbox.TOOLS["edit"]["description"],
284
+ "parameters": Sandbox.TOOLS["edit"]["parameters"],
285
+ },
286
+ }
287
 
288
  _HANDLERS = {
289
  "bash": _bash_handler,
 
298
  from agent.core.tools import ToolSpec
299
 
300
  tools = []
301
+ for name, spec in _LOCAL_TOOL_SPECS.items():
302
  handler = _HANDLERS.get(name)
303
  if handler is None:
304
  continue
pyproject.toml CHANGED
@@ -20,6 +20,7 @@ agent = [
20
  "fastmcp>=2.4.0",
21
  "prompt-toolkit>=3.0.0",
22
  "thefuzz>=0.22.1",
 
23
  "nbconvert>=7.16.6",
24
  "nbformat>=5.10.4",
25
  "datasets>=4.3.0", # For session logging to HF datasets
 
20
  "fastmcp>=2.4.0",
21
  "prompt-toolkit>=3.0.0",
22
  "thefuzz>=0.22.1",
23
+ "rich>=13.0.0",
24
  "nbconvert>=7.16.6",
25
  "nbformat>=5.10.4",
26
  "datasets>=4.3.0", # For session logging to HF datasets