akseljoonas HF Staff commited on
Commit
2bc3b1a
·
1 Parent(s): c9b2500

feat: implement cooperative agent cancellation with cancel button

Browse files
agent/core/agent_loop.py CHANGED
@@ -164,6 +164,36 @@ async def _compact_and_notify(session: Session) -> None:
164
  )
165
 
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  class Handlers:
168
  """Handler functions for each operation type"""
169
 
@@ -218,6 +248,9 @@ class Handlers:
218
 
219
  Laminar.set_trace_session_id(session_id=session.session_id)
220
 
 
 
 
221
  # If there's a pending approval and the user sent a new message,
222
  # abandon the pending tools so the LLM context stays valid.
223
  if text and session.pending_approval:
@@ -238,6 +271,10 @@ class Handlers:
238
  final_response = None
239
 
240
  while iteration < max_iterations:
 
 
 
 
241
  # Compact before calling the LLM if context is near the limit
242
  await _compact_and_notify(session)
243
 
@@ -344,6 +381,11 @@ class Handlers:
344
  )
345
  session.context_manager.add_message(assistant_msg, token_count)
346
 
 
 
 
 
 
347
  # Separate tools into those requiring approval and those that don't
348
  approval_required_tools = []
349
  non_approval_tools = []
@@ -436,6 +478,10 @@ class Handlers:
436
  )
437
  )
438
 
 
 
 
 
439
  # If there are tools requiring approval, ask for batch approval
440
  if approval_required_tools:
441
  # Prepare batch approval data
@@ -493,12 +539,15 @@ class Handlers:
493
  )
494
  break
495
 
496
- await session.send_event(
497
- Event(
498
- event_type="turn_complete",
499
- data={"history_size": len(session.context_manager.items)},
 
 
 
 
500
  )
501
- )
502
 
503
  # Increment turn counter and check for auto-save
504
  session.increment_turn()
@@ -506,12 +555,6 @@ class Handlers:
506
 
507
  return final_response
508
 
509
- @staticmethod
510
- async def interrupt(session: Session) -> None:
511
- """Handle interrupt (like interrupt in codex.rs:1266)"""
512
- session.interrupt()
513
- await session.send_event(Event(event_type="interrupted"))
514
-
515
  @staticmethod
516
  async def undo(session: Session) -> None:
517
  """Remove the last complete turn (user msg + all assistant/tool msgs that follow).
@@ -771,10 +814,6 @@ async def process_submission(session: Session, submission) -> bool:
771
  await Handlers.run_agent(session, text)
772
  return True
773
 
774
- if op.op_type == OpType.INTERRUPT:
775
- await Handlers.interrupt(session)
776
- return True
777
-
778
  if op.op_type == OpType.COMPACT:
779
  await _compact_and_notify(session)
780
  return True
 
164
  )
165
 
166
 
167
+ def _patch_dangling_tool_calls(session: Session) -> None:
168
+ """Add stub tool results for any tool_calls that lack a matching result.
169
+
170
+ After cancellation the last assistant message may contain tool_calls
171
+ whose results were never recorded. LLM APIs require every tool_call
172
+ to have a corresponding tool-result message, so we inject placeholders.
173
+ """
174
+ items = session.context_manager.items
175
+ if not items:
176
+ return
177
+ last = items[-1]
178
+ if getattr(last, "role", None) != "assistant" or not getattr(last, "tool_calls", None):
179
+ return
180
+ answered_ids = {
181
+ getattr(m, "tool_call_id", None)
182
+ for m in items
183
+ if getattr(m, "role", None) == "tool"
184
+ }
185
+ for tc in last.tool_calls:
186
+ if tc.id not in answered_ids:
187
+ items.append(
188
+ Message(
189
+ role="tool",
190
+ content="Cancelled by user.",
191
+ tool_call_id=tc.id,
192
+ name=tc.function.name,
193
+ )
194
+ )
195
+
196
+
197
  class Handlers:
198
  """Handler functions for each operation type"""
199
 
 
248
 
249
  Laminar.set_trace_session_id(session_id=session.session_id)
250
 
251
+ # Clear any stale cancellation flag from a previous run
252
+ session.reset_cancel()
253
+
254
  # If there's a pending approval and the user sent a new message,
255
  # abandon the pending tools so the LLM context stays valid.
256
  if text and session.pending_approval:
 
271
  final_response = None
272
 
273
  while iteration < max_iterations:
274
+ # ── Cancellation check: before LLM call ──
275
+ if session.is_cancelled:
276
+ break
277
+
278
  # Compact before calling the LLM if context is near the limit
279
  await _compact_and_notify(session)
280
 
 
381
  )
382
  session.context_manager.add_message(assistant_msg, token_count)
383
 
384
+ # ── Cancellation check: before tool execution ──
385
+ if session.is_cancelled:
386
+ _patch_dangling_tool_calls(session)
387
+ break
388
+
389
  # Separate tools into those requiring approval and those that don't
390
  approval_required_tools = []
391
  non_approval_tools = []
 
478
  )
479
  )
480
 
481
+ # ── Cancellation check: after tool execution ──
482
+ if session.is_cancelled:
483
+ break
484
+
485
  # If there are tools requiring approval, ask for batch approval
486
  if approval_required_tools:
487
  # Prepare batch approval data
 
539
  )
540
  break
541
 
542
+ if session.is_cancelled:
543
+ await session.send_event(Event(event_type="interrupted"))
544
+ else:
545
+ await session.send_event(
546
+ Event(
547
+ event_type="turn_complete",
548
+ data={"history_size": len(session.context_manager.items)},
549
+ )
550
  )
 
551
 
552
  # Increment turn counter and check for auto-save
553
  session.increment_turn()
 
555
 
556
  return final_response
557
 
 
 
 
 
 
 
558
  @staticmethod
559
  async def undo(session: Session) -> None:
560
  """Remove the last complete turn (user msg + all assistant/tool msgs that follow).
 
814
  await Handlers.run_agent(session, text)
815
  return True
816
 
 
 
 
 
817
  if op.op_type == OpType.COMPACT:
818
  await _compact_and_notify(session)
819
  return True
agent/core/session.py CHANGED
@@ -95,7 +95,7 @@ class Session:
95
  model_name="anthropic/claude-sonnet-4-5-20250929",
96
  )
97
  self.is_running = True
98
- self.current_task: asyncio.Task | None = None
99
  self.pending_approval: Optional[dict[str, Any]] = None
100
  # User's HF OAuth token — set by session_manager after construction
101
  self.hf_token: Optional[str] = None
@@ -120,10 +120,17 @@ class Session:
120
  }
121
  )
122
 
123
- def interrupt(self) -> None:
124
- """Interrupt current running task"""
125
- if self.current_task and not self.current_task.done():
126
- self.current_task.cancel()
 
 
 
 
 
 
 
127
 
128
  def increment_turn(self) -> None:
129
  """Increment turn counter (called after each user interaction)"""
 
95
  model_name="anthropic/claude-sonnet-4-5-20250929",
96
  )
97
  self.is_running = True
98
+ self._cancelled = asyncio.Event()
99
  self.pending_approval: Optional[dict[str, Any]] = None
100
  # User's HF OAuth token — set by session_manager after construction
101
  self.hf_token: Optional[str] = None
 
120
  }
121
  )
122
 
123
+ def cancel(self) -> None:
124
+ """Signal cancellation to the running agent loop."""
125
+ self._cancelled.set()
126
+
127
+ def reset_cancel(self) -> None:
128
+ """Clear the cancellation flag before a new run."""
129
+ self._cancelled.clear()
130
+
131
+ @property
132
+ def is_cancelled(self) -> bool:
133
+ return self._cancelled.is_set()
134
 
135
  def increment_turn(self) -> None:
136
  """Increment turn counter (called after each user interaction)"""
backend/session_manager.py CHANGED
@@ -265,9 +265,12 @@ class SessionManager:
265
  return await self.submit(session_id, operation)
266
 
267
  async def interrupt(self, session_id: str) -> bool:
268
- """Interrupt a session."""
269
- operation = Operation(op_type=OpType.INTERRUPT)
270
- return await self.submit(session_id, operation)
 
 
 
271
 
272
  async def undo(self, session_id: str) -> bool:
273
  """Undo last turn in a session."""
 
265
  return await self.submit(session_id, operation)
266
 
267
  async def interrupt(self, session_id: str) -> bool:
268
+ """Interrupt a session by signalling cancellation directly (bypasses queue)."""
269
+ agent_session = self.sessions.get(session_id)
270
+ if not agent_session or not agent_session.is_active:
271
+ return False
272
+ agent_session.session.cancel()
273
+ return True
274
 
275
  async def undo(self, session_id: str) -> bool:
276
  """Undo last turn in a session."""
frontend/src/components/Chat/ChatInput.tsx CHANGED
@@ -2,6 +2,7 @@ import { useState, useCallback, useEffect, useRef, KeyboardEvent } from 'react';
2
  import { Box, TextField, IconButton, CircularProgress, Typography, Menu, MenuItem, ListItemIcon, ListItemText, Chip } from '@mui/material';
3
  import ArrowUpwardIcon from '@mui/icons-material/ArrowUpward';
4
  import ArrowDropDownIcon from '@mui/icons-material/ArrowDropDown';
 
5
  import { apiFetch } from '@/utils/api';
6
 
7
  // Model configuration
@@ -58,12 +59,15 @@ const findModelByPath = (path: string): ModelOption | undefined => {
58
 
59
  interface ChatInputProps {
60
  onSend: (text: string) => void;
 
 
61
  disabled?: boolean;
62
  placeholder?: string;
63
  }
64
 
65
- export default function ChatInput({ onSend, disabled = false, placeholder = 'Ask anything...' }: ChatInputProps) {
66
  const [input, setInput] = useState('');
 
67
  const inputRef = useRef<HTMLTextAreaElement>(null);
68
  const [selectedModelId, setSelectedModelId] = useState<string>(() => {
69
  try {
@@ -92,12 +96,12 @@ export default function ChatInput({ onSend, disabled = false, placeholder = 'Ask
92
 
93
  const selectedModel = MODEL_OPTIONS.find(m => m.id === selectedModelId) || MODEL_OPTIONS[0];
94
 
95
- // Auto-focus the textarea when the session becomes ready (disabled -> false)
96
  useEffect(() => {
97
- if (!disabled && inputRef.current) {
98
  inputRef.current.focus();
99
  }
100
- }, [disabled]);
101
 
102
  const handleSend = useCallback(() => {
103
  if (input.trim() && !disabled) {
@@ -173,7 +177,7 @@ export default function ChatInput({ onSend, disabled = false, placeholder = 'Ask
173
  onChange={(e) => setInput(e.target.value)}
174
  onKeyDown={handleKeyDown}
175
  placeholder={placeholder}
176
- disabled={disabled}
177
  variant="standard"
178
  inputRef={inputRef}
179
  InputProps={{
@@ -200,26 +204,46 @@ export default function ChatInput({ onSend, disabled = false, placeholder = 'Ask
200
  }
201
  }}
202
  />
203
- <IconButton
204
- onClick={handleSend}
205
- disabled={disabled || !input.trim()}
206
- sx={{
207
- mt: 1,
208
- p: 1,
209
- borderRadius: '10px',
210
- color: 'var(--muted-text)',
211
- transition: 'all 0.2s',
212
- '&:hover': {
213
- color: 'var(--accent-yellow)',
214
- bgcolor: 'var(--hover-bg)',
215
- },
216
- '&.Mui-disabled': {
217
- opacity: 0.3,
218
- },
219
- }}
220
- >
221
- {disabled ? <CircularProgress size={20} color="inherit" /> : <ArrowUpwardIcon fontSize="small" />}
222
- </IconButton>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  </Box>
224
 
225
  {/* Powered By Badge */}
 
2
  import { Box, TextField, IconButton, CircularProgress, Typography, Menu, MenuItem, ListItemIcon, ListItemText, Chip } from '@mui/material';
3
  import ArrowUpwardIcon from '@mui/icons-material/ArrowUpward';
4
  import ArrowDropDownIcon from '@mui/icons-material/ArrowDropDown';
5
+ import CloseIcon from '@mui/icons-material/Close';
6
  import { apiFetch } from '@/utils/api';
7
 
8
  // Model configuration
 
59
 
60
  interface ChatInputProps {
61
  onSend: (text: string) => void;
62
+ onStop?: () => void;
63
+ isProcessing?: boolean;
64
  disabled?: boolean;
65
  placeholder?: string;
66
  }
67
 
68
+ export default function ChatInput({ onSend, onStop, isProcessing = false, disabled = false, placeholder = 'Ask anything...' }: ChatInputProps) {
69
  const [input, setInput] = useState('');
70
+ const [stopHovered, setStopHovered] = useState(false);
71
  const inputRef = useRef<HTMLTextAreaElement>(null);
72
  const [selectedModelId, setSelectedModelId] = useState<string>(() => {
73
  try {
 
96
 
97
  const selectedModel = MODEL_OPTIONS.find(m => m.id === selectedModelId) || MODEL_OPTIONS[0];
98
 
99
+ // Auto-focus the textarea when the session becomes ready
100
  useEffect(() => {
101
+ if (!disabled && !isProcessing && inputRef.current) {
102
  inputRef.current.focus();
103
  }
104
+ }, [disabled, isProcessing]);
105
 
106
  const handleSend = useCallback(() => {
107
  if (input.trim() && !disabled) {
 
177
  onChange={(e) => setInput(e.target.value)}
178
  onKeyDown={handleKeyDown}
179
  placeholder={placeholder}
180
+ disabled={disabled || isProcessing}
181
  variant="standard"
182
  inputRef={inputRef}
183
  InputProps={{
 
204
  }
205
  }}
206
  />
207
+ {isProcessing ? (
208
+ <IconButton
209
+ onClick={onStop}
210
+ onMouseEnter={() => setStopHovered(true)}
211
+ onMouseLeave={() => setStopHovered(false)}
212
+ sx={{
213
+ mt: 1,
214
+ p: 1,
215
+ borderRadius: '10px',
216
+ color: stopHovered ? 'var(--accent-yellow)' : 'var(--muted-text)',
217
+ transition: 'all 0.2s',
218
+ '&:hover': {
219
+ bgcolor: 'var(--hover-bg)',
220
+ },
221
+ }}
222
+ >
223
+ {stopHovered ? <CloseIcon fontSize="small" /> : <CircularProgress size={20} color="inherit" />}
224
+ </IconButton>
225
+ ) : (
226
+ <IconButton
227
+ onClick={handleSend}
228
+ disabled={disabled || !input.trim()}
229
+ sx={{
230
+ mt: 1,
231
+ p: 1,
232
+ borderRadius: '10px',
233
+ color: 'var(--muted-text)',
234
+ transition: 'all 0.2s',
235
+ '&:hover': {
236
+ color: 'var(--accent-yellow)',
237
+ bgcolor: 'var(--hover-bg)',
238
+ },
239
+ '&.Mui-disabled': {
240
+ opacity: 0.3,
241
+ },
242
+ }}
243
+ >
244
+ <ArrowUpwardIcon fontSize="small" />
245
+ </IconButton>
246
+ )}
247
  </Box>
248
 
249
  {/* Powered By Badge */}
frontend/src/components/Layout/AppLayout.tsx CHANGED
@@ -110,7 +110,7 @@ export default function AppLayout() {
110
 
111
  const hasAnySessions = sessions.length > 0;
112
 
113
- const { messages, sendMessage, undoLastTurn, approveTools } = useAgentChat({
114
  sessionId: activeSessionId,
115
  onReady: () => logger.log('Agent ready'),
116
  onError: (error) => logger.error('Agent error:', error),
@@ -345,7 +345,9 @@ export default function AppLayout() {
345
  <MessageList messages={messages} isProcessing={isProcessing} approveTools={approveTools} onUndoLastTurn={undoLastTurn} />
346
  <ChatInput
347
  onSend={handleSendMessage}
348
- disabled={isProcessing || !isConnected || activityStatus.type === 'waiting-approval'}
 
 
349
  placeholder={activityStatus.type === 'waiting-approval' ? 'Approve or reject pending tools first...' : undefined}
350
  />
351
  </>
@@ -439,7 +441,7 @@ export default function AppLayout() {
439
  onClose={() => setShowExpiredToast(false)}
440
  sx={{ fontFamily: 'monospace', fontSize: '0.8rem' }}
441
  >
442
- Session expired — create a new session to continue.
443
  </Alert>
444
  </Snackbar>
445
  <Snackbar
 
110
 
111
  const hasAnySessions = sessions.length > 0;
112
 
113
+ const { messages, sendMessage, stop, undoLastTurn, approveTools } = useAgentChat({
114
  sessionId: activeSessionId,
115
  onReady: () => logger.log('Agent ready'),
116
  onError: (error) => logger.error('Agent error:', error),
 
345
  <MessageList messages={messages} isProcessing={isProcessing} approveTools={approveTools} onUndoLastTurn={undoLastTurn} />
346
  <ChatInput
347
  onSend={handleSendMessage}
348
+ onStop={stop}
349
+ isProcessing={isProcessing}
350
+ disabled={!isConnected || activityStatus.type === 'waiting-approval'}
351
  placeholder={activityStatus.type === 'waiting-approval' ? 'Approve or reject pending tools first...' : undefined}
352
  />
353
  </>
 
441
  onClose={() => setShowExpiredToast(false)}
442
  sx={{ fontFamily: 'monospace', fontSize: '0.8rem' }}
443
  >
444
+ Task expired — create a new task to continue.
445
  </Alert>
446
  </Snackbar>
447
  <Snackbar
frontend/src/hooks/useAgentChat.ts CHANGED
@@ -270,6 +270,7 @@ export function useAgentChat({ sessionId, onReady, onError, onSessionDead }: Use
270
  return {
271
  messages: chat.messages,
272
  sendMessage: chat.sendMessage,
 
273
  status: chat.status,
274
  undoLastTurn,
275
  approveTools,
 
270
  return {
271
  messages: chat.messages,
272
  sendMessage: chat.sendMessage,
273
+ stop: chat.stop,
274
  status: chat.status,
275
  undoLastTurn,
276
  approveTools,