Spaces:
Running
Running
Commit ·
2bc3b1a
1
Parent(s): c9b2500
feat: implement cooperative agent cancellation with cancel button
Browse files
agent/core/agent_loop.py
CHANGED
|
@@ -164,6 +164,36 @@ async def _compact_and_notify(session: Session) -> None:
|
|
| 164 |
)
|
| 165 |
|
| 166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
class Handlers:
|
| 168 |
"""Handler functions for each operation type"""
|
| 169 |
|
|
@@ -218,6 +248,9 @@ class Handlers:
|
|
| 218 |
|
| 219 |
Laminar.set_trace_session_id(session_id=session.session_id)
|
| 220 |
|
|
|
|
|
|
|
|
|
|
| 221 |
# If there's a pending approval and the user sent a new message,
|
| 222 |
# abandon the pending tools so the LLM context stays valid.
|
| 223 |
if text and session.pending_approval:
|
|
@@ -238,6 +271,10 @@ class Handlers:
|
|
| 238 |
final_response = None
|
| 239 |
|
| 240 |
while iteration < max_iterations:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
# Compact before calling the LLM if context is near the limit
|
| 242 |
await _compact_and_notify(session)
|
| 243 |
|
|
@@ -344,6 +381,11 @@ class Handlers:
|
|
| 344 |
)
|
| 345 |
session.context_manager.add_message(assistant_msg, token_count)
|
| 346 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
# Separate tools into those requiring approval and those that don't
|
| 348 |
approval_required_tools = []
|
| 349 |
non_approval_tools = []
|
|
@@ -436,6 +478,10 @@ class Handlers:
|
|
| 436 |
)
|
| 437 |
)
|
| 438 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
# If there are tools requiring approval, ask for batch approval
|
| 440 |
if approval_required_tools:
|
| 441 |
# Prepare batch approval data
|
|
@@ -493,12 +539,15 @@ class Handlers:
|
|
| 493 |
)
|
| 494 |
break
|
| 495 |
|
| 496 |
-
|
| 497 |
-
Event(
|
| 498 |
-
|
| 499 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 500 |
)
|
| 501 |
-
)
|
| 502 |
|
| 503 |
# Increment turn counter and check for auto-save
|
| 504 |
session.increment_turn()
|
|
@@ -506,12 +555,6 @@ class Handlers:
|
|
| 506 |
|
| 507 |
return final_response
|
| 508 |
|
| 509 |
-
@staticmethod
|
| 510 |
-
async def interrupt(session: Session) -> None:
|
| 511 |
-
"""Handle interrupt (like interrupt in codex.rs:1266)"""
|
| 512 |
-
session.interrupt()
|
| 513 |
-
await session.send_event(Event(event_type="interrupted"))
|
| 514 |
-
|
| 515 |
@staticmethod
|
| 516 |
async def undo(session: Session) -> None:
|
| 517 |
"""Remove the last complete turn (user msg + all assistant/tool msgs that follow).
|
|
@@ -771,10 +814,6 @@ async def process_submission(session: Session, submission) -> bool:
|
|
| 771 |
await Handlers.run_agent(session, text)
|
| 772 |
return True
|
| 773 |
|
| 774 |
-
if op.op_type == OpType.INTERRUPT:
|
| 775 |
-
await Handlers.interrupt(session)
|
| 776 |
-
return True
|
| 777 |
-
|
| 778 |
if op.op_type == OpType.COMPACT:
|
| 779 |
await _compact_and_notify(session)
|
| 780 |
return True
|
|
|
|
| 164 |
)
|
| 165 |
|
| 166 |
|
| 167 |
+
def _patch_dangling_tool_calls(session: Session) -> None:
|
| 168 |
+
"""Add stub tool results for any tool_calls that lack a matching result.
|
| 169 |
+
|
| 170 |
+
After cancellation the last assistant message may contain tool_calls
|
| 171 |
+
whose results were never recorded. LLM APIs require every tool_call
|
| 172 |
+
to have a corresponding tool-result message, so we inject placeholders.
|
| 173 |
+
"""
|
| 174 |
+
items = session.context_manager.items
|
| 175 |
+
if not items:
|
| 176 |
+
return
|
| 177 |
+
last = items[-1]
|
| 178 |
+
if getattr(last, "role", None) != "assistant" or not getattr(last, "tool_calls", None):
|
| 179 |
+
return
|
| 180 |
+
answered_ids = {
|
| 181 |
+
getattr(m, "tool_call_id", None)
|
| 182 |
+
for m in items
|
| 183 |
+
if getattr(m, "role", None) == "tool"
|
| 184 |
+
}
|
| 185 |
+
for tc in last.tool_calls:
|
| 186 |
+
if tc.id not in answered_ids:
|
| 187 |
+
items.append(
|
| 188 |
+
Message(
|
| 189 |
+
role="tool",
|
| 190 |
+
content="Cancelled by user.",
|
| 191 |
+
tool_call_id=tc.id,
|
| 192 |
+
name=tc.function.name,
|
| 193 |
+
)
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
class Handlers:
|
| 198 |
"""Handler functions for each operation type"""
|
| 199 |
|
|
|
|
| 248 |
|
| 249 |
Laminar.set_trace_session_id(session_id=session.session_id)
|
| 250 |
|
| 251 |
+
# Clear any stale cancellation flag from a previous run
|
| 252 |
+
session.reset_cancel()
|
| 253 |
+
|
| 254 |
# If there's a pending approval and the user sent a new message,
|
| 255 |
# abandon the pending tools so the LLM context stays valid.
|
| 256 |
if text and session.pending_approval:
|
|
|
|
| 271 |
final_response = None
|
| 272 |
|
| 273 |
while iteration < max_iterations:
|
| 274 |
+
# ── Cancellation check: before LLM call ──
|
| 275 |
+
if session.is_cancelled:
|
| 276 |
+
break
|
| 277 |
+
|
| 278 |
# Compact before calling the LLM if context is near the limit
|
| 279 |
await _compact_and_notify(session)
|
| 280 |
|
|
|
|
| 381 |
)
|
| 382 |
session.context_manager.add_message(assistant_msg, token_count)
|
| 383 |
|
| 384 |
+
# ── Cancellation check: before tool execution ──
|
| 385 |
+
if session.is_cancelled:
|
| 386 |
+
_patch_dangling_tool_calls(session)
|
| 387 |
+
break
|
| 388 |
+
|
| 389 |
# Separate tools into those requiring approval and those that don't
|
| 390 |
approval_required_tools = []
|
| 391 |
non_approval_tools = []
|
|
|
|
| 478 |
)
|
| 479 |
)
|
| 480 |
|
| 481 |
+
# ── Cancellation check: after tool execution ──
|
| 482 |
+
if session.is_cancelled:
|
| 483 |
+
break
|
| 484 |
+
|
| 485 |
# If there are tools requiring approval, ask for batch approval
|
| 486 |
if approval_required_tools:
|
| 487 |
# Prepare batch approval data
|
|
|
|
| 539 |
)
|
| 540 |
break
|
| 541 |
|
| 542 |
+
if session.is_cancelled:
|
| 543 |
+
await session.send_event(Event(event_type="interrupted"))
|
| 544 |
+
else:
|
| 545 |
+
await session.send_event(
|
| 546 |
+
Event(
|
| 547 |
+
event_type="turn_complete",
|
| 548 |
+
data={"history_size": len(session.context_manager.items)},
|
| 549 |
+
)
|
| 550 |
)
|
|
|
|
| 551 |
|
| 552 |
# Increment turn counter and check for auto-save
|
| 553 |
session.increment_turn()
|
|
|
|
| 555 |
|
| 556 |
return final_response
|
| 557 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 558 |
@staticmethod
|
| 559 |
async def undo(session: Session) -> None:
|
| 560 |
"""Remove the last complete turn (user msg + all assistant/tool msgs that follow).
|
|
|
|
| 814 |
await Handlers.run_agent(session, text)
|
| 815 |
return True
|
| 816 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 817 |
if op.op_type == OpType.COMPACT:
|
| 818 |
await _compact_and_notify(session)
|
| 819 |
return True
|
agent/core/session.py
CHANGED
|
@@ -95,7 +95,7 @@ class Session:
|
|
| 95 |
model_name="anthropic/claude-sonnet-4-5-20250929",
|
| 96 |
)
|
| 97 |
self.is_running = True
|
| 98 |
-
self.
|
| 99 |
self.pending_approval: Optional[dict[str, Any]] = None
|
| 100 |
# User's HF OAuth token — set by session_manager after construction
|
| 101 |
self.hf_token: Optional[str] = None
|
|
@@ -120,10 +120,17 @@ class Session:
|
|
| 120 |
}
|
| 121 |
)
|
| 122 |
|
| 123 |
-
def
|
| 124 |
-
"""
|
| 125 |
-
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
def increment_turn(self) -> None:
|
| 129 |
"""Increment turn counter (called after each user interaction)"""
|
|
|
|
| 95 |
model_name="anthropic/claude-sonnet-4-5-20250929",
|
| 96 |
)
|
| 97 |
self.is_running = True
|
| 98 |
+
self._cancelled = asyncio.Event()
|
| 99 |
self.pending_approval: Optional[dict[str, Any]] = None
|
| 100 |
# User's HF OAuth token — set by session_manager after construction
|
| 101 |
self.hf_token: Optional[str] = None
|
|
|
|
| 120 |
}
|
| 121 |
)
|
| 122 |
|
| 123 |
+
def cancel(self) -> None:
|
| 124 |
+
"""Signal cancellation to the running agent loop."""
|
| 125 |
+
self._cancelled.set()
|
| 126 |
+
|
| 127 |
+
def reset_cancel(self) -> None:
|
| 128 |
+
"""Clear the cancellation flag before a new run."""
|
| 129 |
+
self._cancelled.clear()
|
| 130 |
+
|
| 131 |
+
@property
|
| 132 |
+
def is_cancelled(self) -> bool:
|
| 133 |
+
return self._cancelled.is_set()
|
| 134 |
|
| 135 |
def increment_turn(self) -> None:
|
| 136 |
"""Increment turn counter (called after each user interaction)"""
|
backend/session_manager.py
CHANGED
|
@@ -265,9 +265,12 @@ class SessionManager:
|
|
| 265 |
return await self.submit(session_id, operation)
|
| 266 |
|
| 267 |
async def interrupt(self, session_id: str) -> bool:
|
| 268 |
-
"""Interrupt a session."""
|
| 269 |
-
|
| 270 |
-
|
|
|
|
|
|
|
|
|
|
| 271 |
|
| 272 |
async def undo(self, session_id: str) -> bool:
|
| 273 |
"""Undo last turn in a session."""
|
|
|
|
| 265 |
return await self.submit(session_id, operation)
|
| 266 |
|
| 267 |
async def interrupt(self, session_id: str) -> bool:
|
| 268 |
+
"""Interrupt a session by signalling cancellation directly (bypasses queue)."""
|
| 269 |
+
agent_session = self.sessions.get(session_id)
|
| 270 |
+
if not agent_session or not agent_session.is_active:
|
| 271 |
+
return False
|
| 272 |
+
agent_session.session.cancel()
|
| 273 |
+
return True
|
| 274 |
|
| 275 |
async def undo(self, session_id: str) -> bool:
|
| 276 |
"""Undo last turn in a session."""
|
frontend/src/components/Chat/ChatInput.tsx
CHANGED
|
@@ -2,6 +2,7 @@ import { useState, useCallback, useEffect, useRef, KeyboardEvent } from 'react';
|
|
| 2 |
import { Box, TextField, IconButton, CircularProgress, Typography, Menu, MenuItem, ListItemIcon, ListItemText, Chip } from '@mui/material';
|
| 3 |
import ArrowUpwardIcon from '@mui/icons-material/ArrowUpward';
|
| 4 |
import ArrowDropDownIcon from '@mui/icons-material/ArrowDropDown';
|
|
|
|
| 5 |
import { apiFetch } from '@/utils/api';
|
| 6 |
|
| 7 |
// Model configuration
|
|
@@ -58,12 +59,15 @@ const findModelByPath = (path: string): ModelOption | undefined => {
|
|
| 58 |
|
| 59 |
interface ChatInputProps {
|
| 60 |
onSend: (text: string) => void;
|
|
|
|
|
|
|
| 61 |
disabled?: boolean;
|
| 62 |
placeholder?: string;
|
| 63 |
}
|
| 64 |
|
| 65 |
-
export default function ChatInput({ onSend, disabled = false, placeholder = 'Ask anything...' }: ChatInputProps) {
|
| 66 |
const [input, setInput] = useState('');
|
|
|
|
| 67 |
const inputRef = useRef<HTMLTextAreaElement>(null);
|
| 68 |
const [selectedModelId, setSelectedModelId] = useState<string>(() => {
|
| 69 |
try {
|
|
@@ -92,12 +96,12 @@ export default function ChatInput({ onSend, disabled = false, placeholder = 'Ask
|
|
| 92 |
|
| 93 |
const selectedModel = MODEL_OPTIONS.find(m => m.id === selectedModelId) || MODEL_OPTIONS[0];
|
| 94 |
|
| 95 |
-
// Auto-focus the textarea when the session becomes ready
|
| 96 |
useEffect(() => {
|
| 97 |
-
if (!disabled && inputRef.current) {
|
| 98 |
inputRef.current.focus();
|
| 99 |
}
|
| 100 |
-
}, [disabled]);
|
| 101 |
|
| 102 |
const handleSend = useCallback(() => {
|
| 103 |
if (input.trim() && !disabled) {
|
|
@@ -173,7 +177,7 @@ export default function ChatInput({ onSend, disabled = false, placeholder = 'Ask
|
|
| 173 |
onChange={(e) => setInput(e.target.value)}
|
| 174 |
onKeyDown={handleKeyDown}
|
| 175 |
placeholder={placeholder}
|
| 176 |
-
disabled={disabled}
|
| 177 |
variant="standard"
|
| 178 |
inputRef={inputRef}
|
| 179 |
InputProps={{
|
|
@@ -200,26 +204,46 @@ export default function ChatInput({ onSend, disabled = false, placeholder = 'Ask
|
|
| 200 |
}
|
| 201 |
}}
|
| 202 |
/>
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
</Box>
|
| 224 |
|
| 225 |
{/* Powered By Badge */}
|
|
|
|
| 2 |
import { Box, TextField, IconButton, CircularProgress, Typography, Menu, MenuItem, ListItemIcon, ListItemText, Chip } from '@mui/material';
|
| 3 |
import ArrowUpwardIcon from '@mui/icons-material/ArrowUpward';
|
| 4 |
import ArrowDropDownIcon from '@mui/icons-material/ArrowDropDown';
|
| 5 |
+
import CloseIcon from '@mui/icons-material/Close';
|
| 6 |
import { apiFetch } from '@/utils/api';
|
| 7 |
|
| 8 |
// Model configuration
|
|
|
|
| 59 |
|
| 60 |
interface ChatInputProps {
|
| 61 |
onSend: (text: string) => void;
|
| 62 |
+
onStop?: () => void;
|
| 63 |
+
isProcessing?: boolean;
|
| 64 |
disabled?: boolean;
|
| 65 |
placeholder?: string;
|
| 66 |
}
|
| 67 |
|
| 68 |
+
export default function ChatInput({ onSend, onStop, isProcessing = false, disabled = false, placeholder = 'Ask anything...' }: ChatInputProps) {
|
| 69 |
const [input, setInput] = useState('');
|
| 70 |
+
const [stopHovered, setStopHovered] = useState(false);
|
| 71 |
const inputRef = useRef<HTMLTextAreaElement>(null);
|
| 72 |
const [selectedModelId, setSelectedModelId] = useState<string>(() => {
|
| 73 |
try {
|
|
|
|
| 96 |
|
| 97 |
const selectedModel = MODEL_OPTIONS.find(m => m.id === selectedModelId) || MODEL_OPTIONS[0];
|
| 98 |
|
| 99 |
+
// Auto-focus the textarea when the session becomes ready
|
| 100 |
useEffect(() => {
|
| 101 |
+
if (!disabled && !isProcessing && inputRef.current) {
|
| 102 |
inputRef.current.focus();
|
| 103 |
}
|
| 104 |
+
}, [disabled, isProcessing]);
|
| 105 |
|
| 106 |
const handleSend = useCallback(() => {
|
| 107 |
if (input.trim() && !disabled) {
|
|
|
|
| 177 |
onChange={(e) => setInput(e.target.value)}
|
| 178 |
onKeyDown={handleKeyDown}
|
| 179 |
placeholder={placeholder}
|
| 180 |
+
disabled={disabled || isProcessing}
|
| 181 |
variant="standard"
|
| 182 |
inputRef={inputRef}
|
| 183 |
InputProps={{
|
|
|
|
| 204 |
}
|
| 205 |
}}
|
| 206 |
/>
|
| 207 |
+
{isProcessing ? (
|
| 208 |
+
<IconButton
|
| 209 |
+
onClick={onStop}
|
| 210 |
+
onMouseEnter={() => setStopHovered(true)}
|
| 211 |
+
onMouseLeave={() => setStopHovered(false)}
|
| 212 |
+
sx={{
|
| 213 |
+
mt: 1,
|
| 214 |
+
p: 1,
|
| 215 |
+
borderRadius: '10px',
|
| 216 |
+
color: stopHovered ? 'var(--accent-yellow)' : 'var(--muted-text)',
|
| 217 |
+
transition: 'all 0.2s',
|
| 218 |
+
'&:hover': {
|
| 219 |
+
bgcolor: 'var(--hover-bg)',
|
| 220 |
+
},
|
| 221 |
+
}}
|
| 222 |
+
>
|
| 223 |
+
{stopHovered ? <CloseIcon fontSize="small" /> : <CircularProgress size={20} color="inherit" />}
|
| 224 |
+
</IconButton>
|
| 225 |
+
) : (
|
| 226 |
+
<IconButton
|
| 227 |
+
onClick={handleSend}
|
| 228 |
+
disabled={disabled || !input.trim()}
|
| 229 |
+
sx={{
|
| 230 |
+
mt: 1,
|
| 231 |
+
p: 1,
|
| 232 |
+
borderRadius: '10px',
|
| 233 |
+
color: 'var(--muted-text)',
|
| 234 |
+
transition: 'all 0.2s',
|
| 235 |
+
'&:hover': {
|
| 236 |
+
color: 'var(--accent-yellow)',
|
| 237 |
+
bgcolor: 'var(--hover-bg)',
|
| 238 |
+
},
|
| 239 |
+
'&.Mui-disabled': {
|
| 240 |
+
opacity: 0.3,
|
| 241 |
+
},
|
| 242 |
+
}}
|
| 243 |
+
>
|
| 244 |
+
<ArrowUpwardIcon fontSize="small" />
|
| 245 |
+
</IconButton>
|
| 246 |
+
)}
|
| 247 |
</Box>
|
| 248 |
|
| 249 |
{/* Powered By Badge */}
|
frontend/src/components/Layout/AppLayout.tsx
CHANGED
|
@@ -110,7 +110,7 @@ export default function AppLayout() {
|
|
| 110 |
|
| 111 |
const hasAnySessions = sessions.length > 0;
|
| 112 |
|
| 113 |
-
const { messages, sendMessage, undoLastTurn, approveTools } = useAgentChat({
|
| 114 |
sessionId: activeSessionId,
|
| 115 |
onReady: () => logger.log('Agent ready'),
|
| 116 |
onError: (error) => logger.error('Agent error:', error),
|
|
@@ -345,7 +345,9 @@ export default function AppLayout() {
|
|
| 345 |
<MessageList messages={messages} isProcessing={isProcessing} approveTools={approveTools} onUndoLastTurn={undoLastTurn} />
|
| 346 |
<ChatInput
|
| 347 |
onSend={handleSendMessage}
|
| 348 |
-
|
|
|
|
|
|
|
| 349 |
placeholder={activityStatus.type === 'waiting-approval' ? 'Approve or reject pending tools first...' : undefined}
|
| 350 |
/>
|
| 351 |
</>
|
|
@@ -439,7 +441,7 @@ export default function AppLayout() {
|
|
| 439 |
onClose={() => setShowExpiredToast(false)}
|
| 440 |
sx={{ fontFamily: 'monospace', fontSize: '0.8rem' }}
|
| 441 |
>
|
| 442 |
-
|
| 443 |
</Alert>
|
| 444 |
</Snackbar>
|
| 445 |
<Snackbar
|
|
|
|
| 110 |
|
| 111 |
const hasAnySessions = sessions.length > 0;
|
| 112 |
|
| 113 |
+
const { messages, sendMessage, stop, undoLastTurn, approveTools } = useAgentChat({
|
| 114 |
sessionId: activeSessionId,
|
| 115 |
onReady: () => logger.log('Agent ready'),
|
| 116 |
onError: (error) => logger.error('Agent error:', error),
|
|
|
|
| 345 |
<MessageList messages={messages} isProcessing={isProcessing} approveTools={approveTools} onUndoLastTurn={undoLastTurn} />
|
| 346 |
<ChatInput
|
| 347 |
onSend={handleSendMessage}
|
| 348 |
+
onStop={stop}
|
| 349 |
+
isProcessing={isProcessing}
|
| 350 |
+
disabled={!isConnected || activityStatus.type === 'waiting-approval'}
|
| 351 |
placeholder={activityStatus.type === 'waiting-approval' ? 'Approve or reject pending tools first...' : undefined}
|
| 352 |
/>
|
| 353 |
</>
|
|
|
|
| 441 |
onClose={() => setShowExpiredToast(false)}
|
| 442 |
sx={{ fontFamily: 'monospace', fontSize: '0.8rem' }}
|
| 443 |
>
|
| 444 |
+
Task expired — create a new task to continue.
|
| 445 |
</Alert>
|
| 446 |
</Snackbar>
|
| 447 |
<Snackbar
|
frontend/src/hooks/useAgentChat.ts
CHANGED
|
@@ -270,6 +270,7 @@ export function useAgentChat({ sessionId, onReady, onError, onSessionDead }: Use
|
|
| 270 |
return {
|
| 271 |
messages: chat.messages,
|
| 272 |
sendMessage: chat.sendMessage,
|
|
|
|
| 273 |
status: chat.status,
|
| 274 |
undoLastTurn,
|
| 275 |
approveTools,
|
|
|
|
| 270 |
return {
|
| 271 |
messages: chat.messages,
|
| 272 |
sendMessage: chat.sendMessage,
|
| 273 |
+
stop: chat.stop,
|
| 274 |
status: chat.status,
|
| 275 |
undoLastTurn,
|
| 276 |
approveTools,
|