diff --git a/README.md b/README.md index 729674aabe970a7c0cf11778b679e84cbf4a371f..fed2a689ba76144abbf89ce49d0b74a973325062 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,11 @@ hf_oauth: true hf_oauth_scopes: - read-repos - write-repos + - contribute-repos + - manage-repos - inference-api + - jobs + - write-discussions --- # HF Agent diff --git a/agent/context_manager/manager.py b/agent/context_manager/manager.py index 8d7296ae0f8952dd976b522d93a34156db41ed72..230589b0cc0dcea63ab723d036131366f925a825 100644 --- a/agent/context_manager/manager.py +++ b/agent/context_manager/manager.py @@ -2,6 +2,7 @@ Context management for conversation history """ +import logging import os import zoneinfo from datetime import datetime @@ -13,6 +14,72 @@ from huggingface_hub import HfApi from jinja2 import Template from litellm import Message, acompletion +logger = logging.getLogger(__name__) + +# Module-level cache for HF username — avoids repeating the slow whoami() call +_hf_username_cache: str | None = None + +_HF_WHOAMI_URL = "https://huggingface.co/api/whoami-v2" +_HF_WHOAMI_TIMEOUT = 5 # seconds + + +def _get_hf_username() -> str: + """Return the HF username, cached after the first call. + + Uses subprocess + curl to avoid Python HTTP client IPv6 issues that + cause 40+ second hangs (httpx/urllib try IPv6 first which times out + at OS level before falling back to IPv4 — the "Happy Eyeballs" problem). + """ + import json + import subprocess + import time as _t + + global _hf_username_cache + if _hf_username_cache is not None: + return _hf_username_cache + + hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") + if not hf_token: + logger.warning("No HF_TOKEN set, using 'unknown' as username") + _hf_username_cache = "unknown" + return _hf_username_cache + + t0 = _t.monotonic() + try: + result = subprocess.run( + [ + "curl", + "-s", + "-4", # force IPv4 + "-m", + str(_HF_WHOAMI_TIMEOUT), # max time + "-H", + f"Authorization: Bearer {hf_token}", + _HF_WHOAMI_URL, + ], + capture_output=True, + text=True, + timeout=_HF_WHOAMI_TIMEOUT + 2, + ) + t1 = _t.monotonic() + if result.returncode == 0 and result.stdout: + data = json.loads(result.stdout) + _hf_username_cache = data.get("name", "unknown") + logger.info( + f"HF username resolved to '{_hf_username_cache}' in {t1 - t0:.2f}s" + ) + else: + logger.warning( + f"curl whoami failed (rc={result.returncode}) in {t1 - t0:.2f}s" + ) + _hf_username_cache = "unknown" + except Exception as e: + t1 = _t.monotonic() + logger.warning(f"HF whoami failed in {t1 - t0:.2f}s: {e}") + _hf_username_cache = "unknown" + + return _hf_username_cache + class ContextManager: """Manages conversation context and message history for the agent""" @@ -23,11 +90,11 @@ class ContextManager: compact_size: float = 0.1, untouched_messages: int = 5, tool_specs: list[dict[str, Any]] | None = None, - prompt_file_suffix: str = "system_prompt_v3.yaml", + prompt_file_suffix: str = "system_prompt_v2.yaml", ): self.system_prompt = self._load_system_prompt( tool_specs or [], - prompt_file_suffix="system_prompt_v3.yaml", + prompt_file_suffix="system_prompt_v2.yaml", ) self.max_context = max_context self.compact_size = int(max_context * compact_size) @@ -54,9 +121,8 @@ class ContextManager: current_time = now.strftime("%H:%M:%S.%f")[:-3] current_timezone = f"{now.strftime('%Z')} (UTC{now.strftime('%z')[:3]}:{now.strftime('%z')[3:]})" - # Get HF user info with explicit token from env - hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") - hf_user_info = HfApi(token=hf_token).whoami().get("name", "unknown") + # Get HF user info (cached after the first call) + hf_user_info = _get_hf_username() template = Template(template_str) return template.render( @@ -78,9 +144,7 @@ class ContextManager: """Get all messages for sending to LLM""" return self.items - async def compact( - self, model_name: str, tool_specs: list[dict] | None = None - ) -> None: + async def compact(self, model_name: str) -> None: """Remove old messages to keep history under target size""" if (self.context_length <= self.max_context) or not self.items: return @@ -110,11 +174,14 @@ class ContextManager: ) ) + hf_key = os.environ.get("INFERENCE_TOKEN") response = await acompletion( model=model_name, messages=messages_to_summarize, max_completion_tokens=self.compact_size, - tools=tool_specs, + api_key=hf_key + if hf_key and model_name.startswith("huggingface/") + else None, ) summarized_message = Message( role="assistant", content=response.choices[0].message.content diff --git a/agent/core/agent_loop.py b/agent/core/agent_loop.py index b1f85cd156ce8f79f1e51ffae72f4ee818b28480..d7bb255588111e2eef95d62083ad0a0c8f961de5 100644 --- a/agent/core/agent_loop.py +++ b/agent/core/agent_loop.py @@ -4,9 +4,10 @@ Main agent implementation with integrated tool system and MCP support import asyncio import json +import logging +import os -from litellm import ChatCompletionMessageToolCall, Message, ModelResponse, acompletion -from litellm.exceptions import ContextWindowExceededError +from litellm import ChatCompletionMessageToolCall, Message, acompletion from lmnr import observe from agent.config import Config @@ -14,7 +15,42 @@ from agent.core.session import Event, OpType, Session from agent.core.tools import ToolRouter from agent.tools.jobs_tool import CPU_FLAVORS +logger = logging.getLogger(__name__) + ToolCall = ChatCompletionMessageToolCall +# Explicit inference token — needed because litellm checks HF_TOKEN before +# HUGGINGFACE_API_KEY, and HF_TOKEN (used for Hub ops) may lack inference permissions. +_INFERENCE_API_KEY = os.environ.get("INFERENCE_TOKEN") + + +def _resolve_hf_router_params(model_name: str) -> dict: + """ + Build LiteLLM kwargs for HuggingFace Router models. + + api-inference.huggingface.co is deprecated; the new router lives at + router.huggingface.co//v3/openai. LiteLLM's built-in + ``huggingface/`` provider still targets the old endpoint, so we + rewrite model names to ``openai/`` and supply the correct api_base. + + Input format: huggingface/// + Example: huggingface/novita/moonshotai/kimi-k2.5 + """ + if not model_name.startswith("huggingface/"): + return {"model": model_name} + + parts = model_name.split("/", 2) # ['huggingface', 'novita', 'moonshotai/kimi-k2.5'] + if len(parts) < 3: + return {"model": model_name} + + router_provider = parts[1] + actual_model = parts[2] + api_key = _INFERENCE_API_KEY or os.environ.get("HF_TOKEN") + + return { + "model": f"openai/{actual_model}", + "api_base": f"https://router.huggingface.co/{router_provider}/v3/openai", + "api_key": api_key, + } def _validate_tool_args(tool_args: dict) -> tuple[bool, str | None]: @@ -52,9 +88,6 @@ def _needs_approval( if not args_valid: return False - if tool_name == "sandbox_create": - return True - if tool_name == "hf_jobs": operation = tool_args.get("operation", "") if operation not in ["run", "uv", "scheduled run", "scheduled uv"]: @@ -109,31 +142,49 @@ def _needs_approval( return False -async def _compact_and_notify(session: Session) -> None: - """Run compaction and send event if context was reduced.""" - old_length = session.context_manager.context_length - tool_specs = session.tool_router.get_tool_specs_for_llm() - await session.context_manager.compact( - model_name=session.config.model_name, - tool_specs=tool_specs, - ) - new_length = session.context_manager.context_length - if new_length != old_length: - await session.send_event( - Event( - event_type="compacted", - data={"old_tokens": old_length, "new_tokens": new_length}, +class Handlers: + """Handler functions for each operation type""" + + @staticmethod + async def _abandon_pending_approval(session: Session) -> None: + """Cancel pending approval tools when the user continues the conversation. + + Injects rejection tool-result messages into the LLM context (so the + history stays valid) and notifies the frontend that those tools were + abandoned. + """ + tool_calls = session.pending_approval.get("tool_calls", []) + for tc in tool_calls: + tool_name = tc.function.name + abandon_msg = "Task abandoned — user continued the conversation without approving." + + # Keep LLM context valid: every tool_call needs a tool result + tool_msg = Message( + role="tool", + content=abandon_msg, + tool_call_id=tc.id, + name=tool_name, ) - ) + session.context_manager.add_message(tool_msg) + await session.send_event( + Event( + event_type="tool_state_change", + data={ + "tool_call_id": tc.id, + "tool": tool_name, + "state": "abandoned", + }, + ) + ) -class Handlers: - """Handler functions for each operation type""" + session.pending_approval = None + logger.info("Abandoned %d pending approval tool(s)", len(tool_calls)) @staticmethod @observe(name="run_agent") async def run_agent( - session: Session, text: str, max_iterations: int = 300 + session: Session, text: str, max_iterations: int = 10 ) -> str | None: """ Handle user input (like user_input_or_turn in codex.rs:1291) @@ -145,6 +196,11 @@ class Handlers: Laminar.set_trace_session_id(session_id=session.session_id) + # If there's a pending approval and the user sent a new message, + # abandon the pending tools so the LLM context stays valid. + if text and session.pending_approval: + await Handlers._abandon_pending_approval(session) + # Add user message to history only if there's actual content if text: user_msg = Message(role="user", content=text) @@ -160,42 +216,102 @@ class Handlers: final_response = None while iteration < max_iterations: - # Compact before calling the LLM if context is near the limit - await _compact_and_notify(session) - messages = session.context_manager.get_messages() tools = session.tool_router.get_tool_specs_for_llm() - try: - response: ModelResponse = await acompletion( - model=session.config.model_name, + # ── Stream the LLM response ────────────────────────── + llm_params = _resolve_hf_router_params(session.config.model_name) + response = await acompletion( messages=messages, tools=tools, tool_choice="auto", + stream=True, + stream_options={"include_usage": True}, + **llm_params, ) - # Extract text response, token usage, and tool calls - message = response.choices[0].message - content = message.content - token_count = response.usage.total_tokens - tool_calls: list[ToolCall] = message.get("tool_calls", []) + full_content = "" + tool_calls_acc: dict[int, dict] = {} + token_count = 0 + + async for chunk in response: + choice = chunk.choices[0] if chunk.choices else None + if not choice: + # Last chunk may carry only usage info + if hasattr(chunk, "usage") and chunk.usage: + token_count = chunk.usage.total_tokens + continue + + delta = choice.delta + + # Stream text deltas to the frontend + if delta.content: + full_content += delta.content + await session.send_event( + Event( + event_type="assistant_chunk", + data={"content": delta.content}, + ) + ) + + # Accumulate tool-call deltas (name + args arrive in pieces) + if delta.tool_calls: + for tc_delta in delta.tool_calls: + idx = tc_delta.index + if idx not in tool_calls_acc: + tool_calls_acc[idx] = { + "id": "", + "type": "function", + "function": {"name": "", "arguments": ""}, + } + if tc_delta.id: + tool_calls_acc[idx]["id"] = tc_delta.id + if tc_delta.function: + if tc_delta.function.name: + tool_calls_acc[idx]["function"]["name"] += ( + tc_delta.function.name + ) + if tc_delta.function.arguments: + tool_calls_acc[idx]["function"]["arguments"] += ( + tc_delta.function.arguments + ) + + # Capture usage from the final chunk + if hasattr(chunk, "usage") and chunk.usage: + token_count = chunk.usage.total_tokens + + # ── Stream finished — reconstruct full message ─────── + content = full_content or None + + # Build tool_calls list from accumulated deltas + tool_calls: list[ToolCall] = [] + for idx in sorted(tool_calls_acc.keys()): + tc_data = tool_calls_acc[idx] + tool_calls.append( + ToolCall( + id=tc_data["id"], + type="function", + function={ + "name": tc_data["function"]["name"], + "arguments": tc_data["function"]["arguments"], + }, + ) + ) + + # Signal end of streaming to the frontend + await session.send_event( + Event(event_type="assistant_stream_end", data={}) + ) # If no tool calls, add assistant message and we're done if not tool_calls: if content: assistant_msg = Message(role="assistant", content=content) session.context_manager.add_message(assistant_msg, token_count) - await session.send_event( - Event( - event_type="assistant_message", - data={"content": content}, - ) - ) final_response = content break # Add assistant message with tool calls to history - # LiteLLM will format this correctly for the provider assistant_msg = Message( role="assistant", content=content, @@ -203,66 +319,97 @@ class Handlers: ) session.context_manager.add_message(assistant_msg, token_count) - if content: - await session.send_event( - Event(event_type="assistant_message", data={"content": content}) - ) - # Separate tools into those requiring approval and those that don't approval_required_tools = [] non_approval_tools = [] for tc in tool_calls: tool_name = tc.function.name - tool_args = json.loads(tc.function.arguments) + try: + tool_args = json.loads(tc.function.arguments) + except (json.JSONDecodeError, TypeError) as e: + logger.warning(f"Malformed tool arguments for {tool_name}: {e}") + tool_args = {} if _needs_approval(tool_name, tool_args, session.config): approval_required_tools.append(tc) else: non_approval_tools.append(tc) + # Execute non-approval tools (in parallel when possible) + if non_approval_tools: + # 1. Parse args and validate upfront + parsed_tools: list[ + tuple[ChatCompletionMessageToolCall, str, dict, bool, str] + ] = [] + for tc in non_approval_tools: + tool_name = tc.function.name + try: + tool_args = json.loads(tc.function.arguments) + except (json.JSONDecodeError, TypeError): + tool_args = {} + + args_valid, error_msg = _validate_tool_args(tool_args) + parsed_tools.append( + (tc, tool_name, tool_args, args_valid, error_msg) + ) - # Execute non-approval tools first - for tc in non_approval_tools: - tool_name = tc.function.name - tool_args = json.loads(tc.function.arguments) - - # Validate tool arguments before calling - args_valid, error_msg = _validate_tool_args(tool_args) - if not args_valid: - # Return error to agent instead of calling tool - output = error_msg - success = False - else: - await session.send_event( - Event( - event_type="tool_call", - data={"tool": tool_name, "arguments": tool_args}, + # 2. Send all tool_call events upfront (so frontend shows them all) + for tc, tool_name, tool_args, args_valid, _ in parsed_tools: + if args_valid: + await session.send_event( + Event( + event_type="tool_call", + data={ + "tool": tool_name, + "arguments": tool_args, + "tool_call_id": tc.id, + }, + ) ) - ) - output, success = await session.tool_router.call_tool( - tool_name, tool_args, session=session + # 3. Execute all valid tools in parallel + async def _exec_tool( + tc: ChatCompletionMessageToolCall, + name: str, + args: dict, + valid: bool, + err: str, + ) -> tuple[ChatCompletionMessageToolCall, str, dict, str, bool]: + if not valid: + return (tc, name, args, err, False) + out, ok = await session.tool_router.call_tool( + name, args, session=session ) + return (tc, name, args, out, ok) - # Add tool result to history - tool_msg = Message( - role="tool", - content=output, - tool_call_id=tc.id, - name=tool_name, + results = await asyncio.gather( + *[ + _exec_tool(tc, name, args, valid, err) + for tc, name, args, valid, err in parsed_tools + ] ) - session.context_manager.add_message(tool_msg) - await session.send_event( - Event( - event_type="tool_output", - data={ - "tool": tool_name, - "output": output, - "success": success, - }, + # 4. Record results and send outputs (order preserved) + for tc, tool_name, tool_args, output, success in results: + tool_msg = Message( + role="tool", + content=output, + tool_call_id=tc.id, + name=tool_name, + ) + session.context_manager.add_message(tool_msg) + + await session.send_event( + Event( + event_type="tool_output", + data={ + "tool": tool_name, + "tool_call_id": tc.id, + "output": output, + "success": success, + }, + ) ) - ) # If there are tools requiring approval, ask for batch approval if approval_required_tools: @@ -270,7 +417,10 @@ class Handlers: tools_data = [] for tc in approval_required_tools: tool_name = tc.function.name - tool_args = json.loads(tc.function.arguments) + try: + tool_args = json.loads(tc.function.arguments) + except (json.JSONDecodeError, TypeError): + tool_args = {} tools_data.append( { "tool": tool_name, @@ -299,14 +449,6 @@ class Handlers: iteration += 1 - except ContextWindowExceededError: - # Force compact and retry this iteration - session.context_manager.context_length = ( - session.context_manager.max_context + 1 - ) - await _compact_and_notify(session) - continue - except Exception as e: import traceback @@ -318,6 +460,18 @@ class Handlers: ) break + old_length = session.context_manager.context_length + await session.context_manager.compact(model_name=session.config.model_name) + new_length = session.context_manager.context_length + + if new_length != old_length: + await session.send_event( + Event( + event_type="compacted", + data={"old_tokens": old_length, "new_tokens": new_length}, + ) + ) + await session.send_event( Event( event_type="turn_complete", @@ -337,13 +491,43 @@ class Handlers: session.interrupt() await session.send_event(Event(event_type="interrupted")) + @staticmethod + async def compact(session: Session) -> None: + """Handle compact (like compact in codex.rs:1317)""" + old_length = session.context_manager.context_length + await session.context_manager.compact(model_name=session.config.model_name) + new_length = session.context_manager.context_length + + await session.send_event( + Event( + event_type="compacted", + data={"removed": old_length, "remaining": new_length}, + ) + ) + @staticmethod async def undo(session: Session) -> None: - """Handle undo (like undo in codex.rs:1314)""" - # Remove last user turn and all following items - # Simplified: just remove last 2 items - for _ in range(min(2, len(session.context_manager.items))): - session.context_manager.items.pop() + """Remove the last complete turn (user msg + all assistant/tool msgs that follow). + + Anthropic requires every tool_use to have a matching tool_result, + so we can't just pop 2 items — we must pop everything back to + (and including) the last user message to keep the history valid. + """ + items = session.context_manager.items + if not items: + await session.send_event(Event(event_type="undo_complete")) + return + + # Pop from the end until we've removed the last user message + removed_user = False + while items: + msg = items.pop() + if getattr(msg, "role", None) == "user": + removed_user = True + break + + if not removed_user: + logger.warning("Undo: no user message found to remove") await session.send_event(Event(event_type="undo_complete")) @@ -371,6 +555,9 @@ class Handlers: # Create a map of tool_call_id -> approval decision approval_map = {a["tool_call_id"]: a for a in approvals} + for a in approvals: + if a.get("edited_script"): + logger.info(f"Received edited script for tool_call {a['tool_call_id']} ({len(a['edited_script'])} chars)") # Separate approved and rejected tool calls approved_tasks = [] @@ -378,36 +565,99 @@ class Handlers: for tc in tool_calls: tool_name = tc.function.name - tool_args = json.loads(tc.function.arguments) + try: + tool_args = json.loads(tc.function.arguments) + except (json.JSONDecodeError, TypeError) as e: + # Malformed arguments — treat as failed, notify agent + logger.warning(f"Malformed tool arguments for {tool_name}: {e}") + tool_msg = Message( + role="tool", + content=f"Malformed arguments: {e}", + tool_call_id=tc.id, + name=tool_name, + ) + session.context_manager.add_message(tool_msg) + await session.send_event( + Event( + event_type="tool_output", + data={ + "tool": tool_name, + "tool_call_id": tc.id, + "output": f"Malformed arguments: {e}", + "success": False, + }, + ) + ) + continue + approval_decision = approval_map.get(tc.id, {"approved": False}) if approval_decision.get("approved", False): - approved_tasks.append((tc, tool_name, tool_args)) + edited_script = approval_decision.get("edited_script") + was_edited = False + if edited_script and "script" in tool_args: + tool_args["script"] = edited_script + was_edited = True + logger.info(f"Using user-edited script for {tool_name} ({tc.id})") + approved_tasks.append((tc, tool_name, tool_args, was_edited)) else: rejected_tasks.append((tc, tool_name, approval_decision)) + # Notify frontend of approval decisions immediately (before execution) + for tc, tool_name, tool_args, _was_edited in approved_tasks: + await session.send_event( + Event( + event_type="tool_state_change", + data={ + "tool_call_id": tc.id, + "tool": tool_name, + "state": "approved", + }, + ) + ) + for tc, tool_name, approval_decision in rejected_tasks: + await session.send_event( + Event( + event_type="tool_state_change", + data={ + "tool_call_id": tc.id, + "tool": tool_name, + "state": "rejected", + }, + ) + ) + # Execute all approved tools concurrently - async def execute_tool(tc, tool_name, tool_args): - """Execute a single tool and return its result""" + async def execute_tool(tc, tool_name, tool_args, was_edited): + """Execute a single tool and return its result. + + The TraceLog already exists on the frontend (created by + approval_required), so we send tool_state_change instead of + tool_call to avoid creating a duplicate. + """ await session.send_event( Event( - event_type="tool_call", - data={"tool": tool_name, "arguments": tool_args}, + event_type="tool_state_change", + data={ + "tool_call_id": tc.id, + "tool": tool_name, + "state": "running", + }, ) ) output, success = await session.tool_router.call_tool( - tool_name, tool_args, session=session + tool_name, tool_args, session=session, tool_call_id=tc.id ) - return (tc, tool_name, output, success) + return (tc, tool_name, output, success, was_edited) # Execute all approved tools concurrently and wait for ALL to complete if approved_tasks: results = await asyncio.gather( *[ - execute_tool(tc, tool_name, tool_args) - for tc, tool_name, tool_args in approved_tasks + execute_tool(tc, tool_name, tool_args, was_edited) + for tc, tool_name, tool_args, was_edited in approved_tasks ], return_exceptions=True, ) @@ -416,10 +666,13 @@ class Handlers: for result in results: if isinstance(result, Exception): # Handle execution error - print(f"Tool execution error: {result}") + logger.error(f"Tool execution error: {result}") continue - tc, tool_name, output, success = result + tc, tool_name, output, success, was_edited = result + + if was_edited: + output = f"[Note: The user edited the script before execution. The output below reflects the user-modified version, not your original script.]\n\n{output}" # Add tool result to context tool_msg = Message( @@ -435,6 +688,7 @@ class Handlers: event_type="tool_output", data={ "tool": tool_name, + "tool_call_id": tc.id, "output": output, "success": success, }, @@ -446,7 +700,14 @@ class Handlers: rejection_msg = "Job execution cancelled by user" user_feedback = approval_decision.get("feedback") if user_feedback: - rejection_msg += f". User feedback: {user_feedback}" + # Ensure feedback is a string and sanitize any problematic characters + feedback_str = str(user_feedback).strip() + # Remove any control characters that might break JSON parsing + feedback_str = "".join(char for char in feedback_str if ord(char) >= 32 or char in "\n\t") + rejection_msg += f". User feedback: {feedback_str}" + + # Ensure rejection_msg is a clean string + rejection_msg = str(rejection_msg).strip() tool_msg = Message( role="tool", @@ -461,6 +722,7 @@ class Handlers: event_type="tool_output", data={ "tool": tool_name, + "tool_call_id": tc.id, "output": rejection_msg, "success": False, }, @@ -478,11 +740,9 @@ class Handlers: """Handle shutdown (like shutdown in codex.rs:1329)""" # Save session trajectory if enabled (fire-and-forget, returns immediately) if session.config.save_sessions: - print("💾 Saving session...") + logger.info("Saving session...") repo_id = session.config.session_dataset_repo _ = session.save_and_upload_detached(repo_id) - # if local_path: - # print("✅ Session saved locally, upload in progress") session.is_running = False await session.send_event(Event(event_type="shutdown")) @@ -497,7 +757,7 @@ async def process_submission(session: Session, submission) -> bool: bool: True to continue, False to shutdown """ op = submission.operation - # print(f"📨 Received: {op.op_type.value}") + logger.debug("Received operation: %s", op.op_type.value) if op.op_type == OpType.USER_INPUT: text = op.data.get("text", "") if op.data else "" @@ -509,8 +769,7 @@ async def process_submission(session: Session, submission) -> bool: return True if op.op_type == OpType.COMPACT: - # compact from the frontend - await _compact_and_notify(session) + await Handlers.compact(session) return True if op.op_type == OpType.UNDO: @@ -525,7 +784,7 @@ async def process_submission(session: Session, submission) -> bool: if op.op_type == OpType.SHUTDOWN: return not await Handlers.shutdown(session) - print(f"⚠️ Unknown operation: {op.op_type}") + logger.warning(f"Unknown operation: {op.op_type}") return True @@ -543,7 +802,7 @@ async def submission_loop( # Create session with tool router session = Session(event_queue, config=config, tool_router=tool_router) - print("Agent loop started") + logger.info("Agent loop started") # Retry any failed uploads from previous sessions (fire-and-forget) if config and config.save_sessions: @@ -567,25 +826,25 @@ async def submission_loop( if not should_continue: break except asyncio.CancelledError: - print("\n⚠️ Agent loop cancelled") + logger.warning("Agent loop cancelled") break except Exception as e: - print(f"❌ Error in agent loop: {e}") + logger.error(f"Error in agent loop: {e}") await session.send_event( Event(event_type="error", data={"error": str(e)}) ) - print("🛑 Agent loop exited") + logger.info("Agent loop exited") finally: # Emergency save if session saving is enabled and shutdown wasn't called properly if session.config.save_sessions and session.is_running: - print("\n💾 Emergency save: preserving session before exit...") + logger.info("Emergency save: preserving session before exit...") try: local_path = session.save_and_upload_detached( session.config.session_dataset_repo ) if local_path: - print("✅ Emergency save successful, upload in progress") + logger.info("Emergency save successful, upload in progress") except Exception as e: - print(f"❌ Emergency save failed: {e}") + logger.error(f"Emergency save failed: {e}") diff --git a/agent/core/session.py b/agent/core/session.py index 439260d5a9fc3992cc5ccb8aee3a26e56b41bc8c..1f51153a895835358be618ba39061f37571f27a9 100644 --- a/agent/core/session.py +++ b/agent/core/session.py @@ -1,5 +1,6 @@ import asyncio import json +import logging import subprocess import sys import uuid @@ -9,11 +10,48 @@ from enum import Enum from pathlib import Path from typing import Any, Optional -from litellm import get_max_tokens - from agent.config import Config from agent.context_manager.manager import ContextManager +logger = logging.getLogger(__name__) + +# Local max-token lookup — avoids litellm.get_max_tokens() which can hang +# on network calls for certain providers (known litellm issue). +_MAX_TOKENS_MAP: dict[str, int] = { + # Anthropic + "anthropic/claude-opus-4-5-20251101": 200_000, + "anthropic/claude-sonnet-4-5-20250929": 200_000, + "anthropic/claude-sonnet-4-20250514": 200_000, + "anthropic/claude-haiku-3-5-20241022": 200_000, + "anthropic/claude-3-5-sonnet-20241022": 200_000, + "anthropic/claude-3-opus-20240229": 200_000, + "huggingface/novita/minimax/minimax-m2.1": 196_608, + "huggingface/novita/moonshotai/kimi-k2.5": 262_144, + "huggingface/novita/zai-org/glm-5": 200_000, +} +_DEFAULT_MAX_TOKENS = 200_000 + + +def _get_max_tokens_safe(model_name: str) -> int: + """Return the max context window for a model without network calls.""" + tokens = _MAX_TOKENS_MAP.get(model_name) + if tokens: + return tokens + # Fallback: try litellm but with a short timeout via threading + try: + from litellm import get_max_tokens + + result = get_max_tokens(model_name) + if result and isinstance(result, int): + return result + logger.warning( + f"get_max_tokens returned {result} for {model_name}, using default" + ) + return _DEFAULT_MAX_TOKENS + except Exception as e: + logger.warning(f"get_max_tokens failed for {model_name}, using default: {e}") + return _DEFAULT_MAX_TOKENS + class OpType(Enum): USER_INPUT = "user_input" @@ -46,7 +84,7 @@ class Session: self.tool_router = tool_router tool_specs = tool_router.get_tool_specs_for_llm() if tool_router else [] self.context_manager = context_manager or ContextManager( - max_context=get_max_tokens(config.model_name), + max_context=_get_max_tokens_safe(config.model_name), compact_size=0.1, untouched_messages=5, tool_specs=tool_specs, @@ -59,7 +97,8 @@ class Session: self.is_running = True self.current_task: asyncio.Task | None = None self.pending_approval: Optional[dict[str, Any]] = None - self.sandbox = None + # User's HF OAuth token — set by session_manager after construction + self.hf_token: Optional[str] = None # Session trajectory logging self.logged_events: list[dict] = [] @@ -100,7 +139,7 @@ class Session: turns_since_last_save = self.turn_count - self.last_auto_save_turn if turns_since_last_save >= interval: - print(f"\n💾 Auto-saving session (turn {self.turn_count})...") + logger.info(f"Auto-saving session (turn {self.turn_count})...") # Fire-and-forget save - returns immediately self.save_and_upload_detached(self.config.session_dataset_repo) self.last_auto_save_turn = self.turn_count @@ -152,7 +191,7 @@ class Session: return str(filepath) except Exception as e: - print(f"Failed to save session locally: {e}") + logger.error(f"Failed to save session locally: {e}") return None def update_local_save_status( @@ -172,7 +211,7 @@ class Session: return True except Exception as e: - print(f"Failed to update local save status: {e}") + logger.error(f"Failed to update local save status: {e}") return False def save_and_upload_detached(self, repo_id: str) -> Optional[str]: @@ -203,7 +242,7 @@ class Session: start_new_session=True, # Detach from parent ) except Exception as e: - print(f"⚠️ Failed to spawn upload subprocess: {e}") + logger.warning(f"Failed to spawn upload subprocess: {e}") return local_path @@ -233,4 +272,4 @@ class Session: start_new_session=True, # Detach from parent ) except Exception as e: - print(f"⚠️ Failed to spawn retry subprocess: {e}") + logger.warning(f"Failed to spawn retry subprocess: {e}") diff --git a/agent/core/session_uploader.py b/agent/core/session_uploader.py index 2cbef9e306530e40b4cad310233a9952f737120e..ef2f9496d87f832489010f9a9529c538d939bedb 100644 --- a/agent/core/session_uploader.py +++ b/agent/core/session_uploader.py @@ -15,10 +15,8 @@ from dotenv import load_dotenv load_dotenv() -# Fallback token for session uploads (write-only access to akseljoonas/hf-agent-sessions) -_SESSION_TOKEN = "".join([ - "hf_", "Nzya", "Eeb", "ESz", "DtA", "BoW", "Czj", "SEC", "ZZv", "kVL", "Ac", "Vf", "Sz" -]) +# Token for session uploads — loaded from env var (never hardcode tokens in source) +_SESSION_TOKEN = os.environ.get("HF_SESSION_UPLOAD_TOKEN", "") def upload_session_as_file( diff --git a/agent/core/tools.py b/agent/core/tools.py index 586afab5eb074b4e1c391f140edc30d90c52c0a2..af442f066ca6880b5a537be2f0b2b5044cafc77b 100644 --- a/agent/core/tools.py +++ b/agent/core/tools.py @@ -3,10 +3,13 @@ Tool system for the agent Provides ToolSpec and ToolRouter for managing both built-in and MCP tools """ +import logging import warnings from dataclasses import dataclass from typing import Any, Awaitable, Callable, Optional +logger = logging.getLogger(__name__) + from fastmcp import Client from fastmcp.exceptions import ToolError from lmnr import observe @@ -45,7 +48,6 @@ from agent.tools.hf_repo_git_tool import ( ) from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, hf_jobs_handler from agent.tools.plan_tool import PLAN_TOOL_SPEC, plan_tool_handler -from agent.tools.sandbox_tool import get_sandbox_tools # NOTE: Private HF repo tool disabled - replaced by hf_repo_files and hf_repo_git # from agent.tools.private_hf_repo_tools import ( @@ -132,6 +134,7 @@ class ToolRouter: for tool in create_builtin_tools(): self.register_tool(tool) + self.mcp_client: Client | None = None if mcp_servers: mcp_servers_payload = {} for name, server in mcp_servers.items(): @@ -159,7 +162,7 @@ class ToolRouter: handler=None, ) ) - print( + logger.info( f"Loaded {len(registered_names)} MCP tools: {', '.join(registered_names)} ({skipped_count} disabled)" ) @@ -180,7 +183,7 @@ class ToolRouter: handler=search_openapi_handler, ) ) - print(f"Loaded OpenAPI search tool: {openapi_spec['name']}") + logger.info(f"Loaded OpenAPI search tool: {openapi_spec['name']}") def get_tool_specs_for_llm(self) -> list[dict[str, Any]]: """Get tool specifications in OpenAI format""" @@ -209,7 +212,7 @@ class ToolRouter: await self.register_openapi_tool() total_tools = len(self.tools) - print(f"\nAgent ready with {total_tools} tools total\n") + logger.info(f"Agent ready with {total_tools} tools total") return self @@ -220,7 +223,7 @@ class ToolRouter: @observe(name="call_tool") async def call_tool( - self, tool_name: str, arguments: dict[str, Any], session: Any = None + self, tool_name: str, arguments: dict[str, Any], session: Any = None, tool_call_id: str | None = None ) -> tuple[str, bool]: """ Call a tool and return (output_string, success_bool). @@ -236,6 +239,9 @@ class ToolRouter: # Check if handler accepts session argument sig = inspect.signature(tool.handler) if "session" in sig.parameters: + # Check if handler also accepts tool_call_id parameter + if "tool_call_id" in sig.parameters: + return await tool.handler(arguments, session=session, tool_call_id=tool_call_id) return await tool.handler(arguments, session=session) return await tool.handler(arguments) @@ -328,10 +334,7 @@ def create_builtin_tools() -> list[ToolSpec]: ), ] - # Sandbox tools - tools = get_sandbox_tools() + tools - tool_names = ", ".join([t.name for t in tools]) - print(f"Loaded {len(tools)} built-in tools: {tool_names}") + logger.info(f"Loaded {len(tools)} built-in tools: {tool_names}") return tools diff --git a/agent/prompts/system_prompt.yaml b/agent/prompts/system_prompt.yaml index fc9607dd2e1d38cd5e0bde7ccd938cdf6a131645..00f28be1718457e1d1004c7e3745f903f14d5ab1 100644 --- a/agent/prompts/system_prompt.yaml +++ b/agent/prompts/system_prompt.yaml @@ -1,5 +1,5 @@ system_prompt: | - You are Hugging Face Agent, a skilled AI assistant for machine learning engineering. Hugging Face is a company that provides two main services : libraries to write deep learning tasks, and ressources (models, datasets, compute) to execute them. You will aid users to do theses tasks, interacting with the Hugging Face stack via {{ num_tools }}. + You are Hugging Face Agent, a skilled AI assistant for machine learning engineering. Hugging Face is a company that provides two main services : libraries to write deep learning tasks, and resources (models, datasets, compute) to execute them. You will aid users to do these tasks, interacting with the Hugging Face stack via {{ num_tools }}. # General behavior @@ -9,7 +9,7 @@ system_prompt: | **CRITICAL : Research first, Then Implement** - For ANY implementation task (training, fine-tuning, inference, data processing, etc.), you should proceed in thoses three mandatory steps: + For ANY implementation task (training, fine-tuning, inference, data processing, etc.), you should proceed in these three mandatory steps: 1. **FIRST**: Search HF documentation to find the correct approach. - Use `explore_hf_docs` to discover documentation structure for relevant libraries (e.g., "trl", "transformers", "diffusers"). diff --git a/agent/prompts/system_prompt_v2.yaml b/agent/prompts/system_prompt_v2.yaml index 9f80bbefb98424d6ac281c628442322aea3c9fc7..d404b2788fe887a1a6f0f326961b284efbc9ca09 100644 --- a/agent/prompts/system_prompt_v2.yaml +++ b/agent/prompts/system_prompt_v2.yaml @@ -186,59 +186,61 @@ system_prompt: | 3. ✅ Determine optimal processing approach based on requirements 4. ✅ Plan output format and destination - ## PHASE 3: IMPLEMENT (Develop in Sandbox, Launch via Jobs) - - ⚠️ **CRITICAL WORKFLOW: Sandbox First, Jobs Second** - - For ANY implementation task (training, data processing, inference), follow this pattern: - - **Step 1: Create a sandbox** — `sandbox_create` with appropriate hardware (cpu-basic for scripting, t4-small for GPU testing) - **Step 2: Develop & iterate** — Write scripts, install dependencies, test with small runs, fix errors interactively - **Step 3: Launch via hf_jobs** — Once the script works, pass the sandbox file path directly: `hf_jobs(operation="run", script="/app/train.py", ...)` - - This is the CORRECT pattern: - ``` - sandbox_create(hardware="t4-small") # interactive dev environment - bash("pip install trl transformers") # install deps - write("/app/train.py", "...") # write training script - bash("cd /app && python train.py --max_steps 10") # test run - edit("/app/train.py", ...) # fix issues - bash("cd /app && python train.py --max_steps 10") # verify fix - hf_jobs(operation="run", script="/app/train.py", hardware_flavor="a10g-large", timeout="4h") # launch at scale - ``` - - Do NOT write long inline scripts directly in hf_jobs if necessary — develop in sandbox first. - - ### Training Script Requirements - - **Script MUST Include:** - - Imports from researched documentation (current APIs) - - Trackio initialization with project/run_name/config - - Model and tokenizer loading - - Dataset loading with verified columns and conversational format - - Training config with ALL critical settings: + ## PHASE 3: IMPLEMENT (Execute with Researched Approaches) + + ### For Training Tasks + + ⚠️ **TRAINING REQUIREMENTS CHECKLIST:** + + **Before Submission:** + - [ ] Researched current TRL documentation + - [ ] Found and verified base model + - [ ] Found dataset and VALIDATED columns and conversational format matches method + - [ ] Selected optimal model + dataset + hardware configuration + - [ ] Created plan with plan_tool + - [ ] Researched Trackio monitoring setup + + **Training Script MUST Include:** + - [ ] Imports from researched documentation (current APIs) + - [ ] Trackio initialization with project/run_name/config + - [ ] Model and tokenizer loading + - [ ] Dataset loading with verified columns and conversational format + - [ ] Training config with ALL critical settings: - `push_to_hub=True` ⚠️ MANDATORY - `hub_model_id="username/model-name"` ⚠️ MANDATORY - `report_to=["trackio"]` (for monitoring) - `output_dir="./output"` - `num_train_epochs`, `per_device_train_batch_size`, `learning_rate` - `logging_steps`, `save_steps` - - `trainer.train()` call - - `trainer.push_to_hub()` at end ⚠️ MANDATORY - - **hf_jobs Launch Configuration:** - - `script`: Path to sandbox file (e.g. "/app/train.py") or inline code - - `dependencies`: ['transformers', 'trl', 'torch', 'datasets', 'trackio'] - - `hardware_flavor`: Based on model size: - - 1-3B models: `t4-small` or `a10g-small` - - 7-13B models: `a10g-large` - - 30B+ models: `a100-large` - - 70B+ models: `h100` or `h100x8` - - `timeout`: ⚠️ CRITICAL — Small (2-4h), Medium (4-8h), Large (8-24h). NEVER default 30m for training. + - `max_length` if needed (default 1024 usually fine) + - [ ] Trainer initialization with model, args, dataset, tokenizer + - [ ] `trainer.train()` call + - [ ] `trainer.push_to_hub()` at end ⚠️ MANDATORY + - [ ] `tracker.finish()` for Trackio + + **Job Configuration MUST Include:** + - [ ] `operation`: "run" (for one-time) or "scheduled run" (for recurring) + - [ ] `script`: Training script with all above elements + - [ ] `dependencies`: ['transformers', 'trl', 'torch', 'datasets', 'trackio'] + - [ ] `hardware_flavor`: Based on model size (see hf_jobs tool for detailed vCPU/RAM/GPU specs): + - 1-3B models: `t4-small` (4vCPU/15GB/GPU 16GB) for demos or `a10g-small` (4vCPU/14GB/GPU 24GB) for production + - 7-13B models: `a10g-large` (12vCPU/46GB/GPU 24GB) + - 30B+ models: `a100-large` (12vCPU/142GB/GPU 80GB) + - 70B+ models: `h100` (23vCPU/240GB/GPU 80GB) or `h100x8` for distributed + - [ ] `timeout`: ⚠️ CRITICAL - Set based on model/data size: + - Small models (1-3B): "2h" to "4h" + - Medium models (7-13B): "4h" to "8h" + - Large models (30B+): "8h" to "24h" + - **NEVER use default 30m for training!** ### For Data Processing Tasks - **Same pattern:** develop script in sandbox, test on subset, launch via hf_jobs. + **Script Requirements:** + - Load dataset with `load_dataset` + - Process according to user requirements + - Push results with `push_to_hub()` or upload to `hf_private_repos` + + **Job Configuration:** - Use `cpu-upgrade` or `cpu-performance` for most data tasks - Set timeout based on dataset size (1-4 hours typical) @@ -339,21 +341,6 @@ system_prompt: | - ⚠️ Include HF_TOKEN for Hub operations - ⚠️ Storage is EPHEMERAL - must push_to_hub - ## Sandbox (Interactive Development Environment) - - **sandbox_create:** - - ⚠️ **Create a sandbox FIRST for any implementation task** — develop and test before launching jobs - - Persistent remote Linux environment on HF Spaces - - First call sandbox_create with hardware choice, then use bash/read/write/edit freely - - Hardware: cpu-basic (free tier), cpu-upgrade (8vCPU/32GB), t4-small (16GB GPU), a10g-small (24GB GPU), a10g-large (24GB GPU + 46GB RAM), a100-large (80GB GPU) - - `pip install` works out of the box — no special flags needed - - Workflow: sandbox_create → write script → test → fix → hf_jobs(script="/app/script.py") to launch at scale - - **bash / read / write / edit:** - - Available after sandbox_create — no additional approvals needed - - Same semantics as local file/shell operations, but run on the remote sandbox - - bash: run shell commands; read/write/edit: file operations - **hf_private_repos:** - Store job outputs persistently in datasets with push_to_hub (jobs lose files after completion) - Upload logs, scripts, results that can't push_to_hub diff --git a/agent/prompts/system_prompt_v3.yaml b/agent/prompts/system_prompt_v3.yaml deleted file mode 100644 index 4b31e9a05abb5748e104f1e66381a46b67920f57..0000000000000000000000000000000000000000 --- a/agent/prompts/system_prompt_v3.yaml +++ /dev/null @@ -1,118 +0,0 @@ -system_prompt: | - You are Hugging Face Agent, an ML engineering assistant with {{ num_tools }} tools for training, fine-tuning, data processing, inference, and evaluation on the Hugging Face ecosystem. - - _Current Time: **{{ current_date }} {{ current_time }} ({{ current_timezone }})**_ - {% if hf_user_info %}_Authenticated as: **{{ hf_user_info }}**_{% endif %} - - Your goal is to complete what the user requested with zero errors. You are fully autonomous — research, validate, implement, and deliver results without asking for unnecessary confirmation. - - # Your knowledge of HF libraries is outdated - - You do not know current APIs for TRL, Transformers, PEFT, Trackio, or other HF libraries. Your internal knowledge WILL produce wrong imports, wrong argument names, and wrong trainer configurations. - - Before writing any ML implementation code (training, fine-tuning, inference, data processing), ground yourself in current working code: - - github_find_examples → github_read_file → explore_hf_docs + fetch_hf_docs - - Skip research only for trivial non-code operations. - - # Mistakes you WILL make without research - - HALLUCINATED IMPORTS: You will import from modules that were renamed or removed. Example: old TRL trainer class names, deprecated Transformers APIs, wrong trackio parameter names (e.g. `run_name` instead of `name`). Fix: read a current example script first. - - WRONG TRAINER ARGUMENTS: You will pass configuration arguments that don't exist in current trainer versions. Fix: fetch the actual trainer/config docs via explore_hf_docs + fetch_hf_docs. - - WRONG DATASET FORMAT: You will assume column names without checking. Training fails with KeyError. Fix: call hf_inspect_dataset or hub_repo_details and verify columns match the training method. - - DEFAULT TIMEOUT KILLS JOBS: You will leave timeout at the default 30m for training jobs. Training takes hours. The job gets killed and all progress is lost. Fix: set timeout based on model size (minimum 2h for any training). - - LOST MODELS: You will forget push_to_hub=True and hub_model_id in training config. Job storage is ephemeral — the filesystem is deleted when the job ends. Without push_to_hub, the trained model is permanently lost. - - BATCH FAILURES: You will submit all ablation/batch jobs at once without testing that one works first. All will fail for the same bug. Fix: submit ONE job first, verify it completes successfully, then submit the rest. - - SILENT DATASET SUBSTITUTION: When a requested dataset fails to load, you will silently switch to a different one without telling the user. Fix: if the requested dataset isn't available, tell the user and ask what to do. - - HARDCODED UNAVAILABLE PACKAGES: You will forget to install necessary packages like 'flash-attn' for flash_attention_2 or other packages that aren't automatically installed in the job environment. Fix: install necessary packages before running the job. - - SCOPE-CHANGING FIXES: Avoid at all costs! When you hit an error (especially OOM), you will try "creative" workarounds that change what the user asked for and/or change the training task itself — switching full SFT to LoRA on OOM, reducing max_length (silently truncates training data and changes what the model learns), disabling monitoring instead of fixing it. Do not do this. Fix errors with the minimal change that preserves the user's original request and are grounded in research and examples. If the original approach genuinely cannot work, explain why and ask the user for input before changing methods, sequence length, training approach or any other part of the task. - - # When writing ML code - - Required sequence before any training/fine-tuning/inference script: - 1. Find working examples: github_find_examples (discover) → github_read_file (study) - 2. Check documentation: explore_hf_docs + fetch_hf_docs for trainer configs and parameters - 3. Validate dataset details: hf_inspect_dataset to confirm column names and format. - 4. Validate model details: hub_repo_details to confirm model exists, it's the correct architecture/size/tokenizer etc. - - Dataset format requirements by training method: - SFT: "messages", "text", or "prompt"/"completion" - DPO: "prompt", "chosen", "rejected" - GRPO: "prompt" - - # When submitting a training job - - Before calling hf_jobs, output a pre-flight check: - - Reference implementation: [which example you based this on] - - Dataset format verified: [columns confirmed via hf_inspect_dataset/hub_repo_details] - - push_to_hub=True and hub_model_id set - - timeout: [value] (based on: [model size] on [hardware]) - - Trackio monitoring included and working - - If you cannot fill in all items, stop and complete the missing steps first. - - For batch/ablation jobs: submit ONE job first. Check logs to confirm it starts training successfully. Only then submit the remaining jobs. Never submit all at once. - - Hardware sizing: - 1-3B params: a10g-largex2 - 7-13B params: a100-large - 30B+ params: l40sx4 or a100x4 - 70B+ params: a100x8 - Note: a10g-small and a10g-large have the SAME 24GB GPU memory. The difference is CPU/RAM only. - - # Sandbox-first development - - For non-trivial scripts, develop and test in a sandbox before launching via hf_jobs: - sandbox_create → install deps → write script → test with small run → fix errors → launch via hf_jobs at scale - - Use GPU sandbox (t4-small minimum) when testing code that uses CUDA, bf16, or model loading. CPU sandboxes cannot test GPU code paths. - - - # When a task has 3+ steps - - Use plan_tool to track progress. One task in_progress at a time. Mark completed immediately after finishing. Update frequently to show the user what you're doing. - - # Error recovery - - When something fails: - - Diagnose the actual error. Read the full error message and logs. - - Do not retry the exact same thing. Identify what needs to change. - - If an API/import error: check documentation for the correct API. - - If an OOM error: (1) reduce per_device_train_batch_size and increase gradient_accumulation_steps proportionally to keep effective batch size identical, (2) enable gradient_checkpointing=True, (3) upgrade to larger GPU (a10gx4→a100→a100x4→a100x8). Do NOT switch training methods (e.g. SFT→LoRA) or reduce max_length — those change what the user gets. If OOM happens in sandbox, create a new sandbox with larger GPU hardware. - - Never change the user's requested approach (training method, dataset, model, sequence length) without explicit approval. - - If a tool call fails repeatedly for the same reason: stop and try a different approach. - - Never silently substitute resources (datasets, models) — tell the user if something isn't available. - - # Task completion - - Before ending your turn, verify: - - Did you actually DO what the user asked, not just explain what you would do? - - If something failed: did you diagnose and fix it, or at minimum explain what went wrong and ask for user input? - - For training jobs: did you include a working Trackio dashboard URL? - - Do not stop after describing what you plan to do. Continue calling tools until the task is verifiably done. - Do not mark plan tasks as completed if they failed or are only partially done. - - # Communication - - - Be concise and direct. No filler, no restating what the user said. - - One-word answers when appropriate for simple questions. - - Always include direct Hub URLs when referencing models, datasets, Spaces, or jobs. - - For errors: state what went wrong, why, and what you're doing to fix it. - - Do not over-explain or present elaborate option menus for simple tasks. When the user's intent is clear, act on it. Present options only when there's genuine ambiguity. - - # Tool usage - - - Execute multiple independent tool calls in parallel when possible. - - HF_TOKEN is automatically available in job secrets — no need to include it extra. - - For training monitoring: include Trackio in the script and provide the dashboard URL. - - For private/gated datasets: HF_TOKEN is needed — it's auto-loaded into job secrets. diff --git a/agent/tools/dataset_tools.py b/agent/tools/dataset_tools.py index 5fed516f175cd147db7ca9d18460a3a9200ddd8c..39f5d5d85b4478a1dd1e8934397f3b86aad71431 100644 --- a/agent/tools/dataset_tools.py +++ b/agent/tools/dataset_tools.py @@ -388,15 +388,22 @@ def _format_parquet_files(data: dict, max_rows: int = 10) -> str | None: HF_INSPECT_DATASET_TOOL_SPEC = { "name": "hf_inspect_dataset", "description": ( - "Inspect a HF dataset in one call: status, configs/splits, schema, sample rows, parquet info.\n\n" - "REQUIRED before any training job to verify dataset format matches training method:\n" - " SFT: needs 'messages', 'text', or 'prompt'/'completion'\n" - " DPO: needs 'prompt', 'chosen', 'rejected'\n" - " GRPO: needs 'prompt'\n" - "All datasets used for training have to be in conversational ChatML format to be compatible with HF libraries.'\n" - "Training will fail with KeyError if columns don't match.\n\n" - "Also use to get example datapoints, understand column names, data types, and available splits before writing any data loading code. " - "Supports private/gated datasets when HF_TOKEN is set." + "Inspect a Hugging Face dataset comprehensively in one call.\n\n" + "## What you get\n" + "- Status check (validates dataset works without errors)\n" + "- All configs and splits (row counts/shares may be '?' when metadata is missing)\n" + "- Column names and types (schema)\n" + "- Sample rows to understand data format\n" + "- Parquet file structure and sizes\n\n" + "## CRITICAL\n" + "**Always inspect datasets before writing training code** to understand:\n" + "- Column names for your dataloader\n" + "- Data types and format\n" + "- Available splits (train/test/validation)\n\n" + "Supports private/gated datasets when HF_TOKEN is set.\n\n" + "## Examples\n" + '{"dataset": "stanfordnlp/imdb"}\n' + '{"dataset": "nyu-mll/glue", "config": "mrpc", "sample_rows": 5}\n' ), "parameters": { "type": "object", diff --git a/agent/tools/docs_tools.py b/agent/tools/docs_tools.py index eb4c84591c029b305822d14c123da456319dbd48..49a330bedfccb47bcfbf2caf4d51aafa2af1babc 100644 --- a/agent/tools/docs_tools.py +++ b/agent/tools/docs_tools.py @@ -845,12 +845,17 @@ DOC_ENDPOINTS = [ EXPLORE_HF_DOCS_TOOL_SPEC = { "name": "explore_hf_docs", "description": ( - "Browse HF documentation structure — discover all available documentation with 200-char previews.\n\n" - "Use this to find relevant documentation and/or examples with detailed parameter docs and API reference. " - "To be used together with github_find_examples and github_read_file to find working examples and documentation.\n\n" - "Pattern: explore_hf_docs (find relevant pages) → fetch_hf_docs (get full content).\n\n" - "For training tasks: fetch the trainer config docs (SFTConfig, DPOConfig, GRPOConfig) to verify parameter names. " - "Returns top 20 results by default; set max_results (max 50) to adjust." + "Explore Hugging Face documentation structure and discover available pages with 200-character previews. " + "⚠️ MANDATORY: ALWAYS use this BEFORE implementing any ML task (training, fine-tuning, data processing, inference). " + "Your training data may be outdated - current documentation is the source of truth. " + "**Use when:** (1) Starting any implementation task, (2) User asks 'how to' questions, " + "(3) Before writing training/processing code, (4) Researching library capabilities, " + "(5) Verifying API syntax and parameters. " + "**Pattern:** explore (discover structure) → fetch_hf_docs (get details) → implement with researched approach. " + "Returns: Sidebar navigation with titles, URLs, and glimpses of all pages in the selected documentation. " + "**Then:** Use fetch_hf_docs with specific URLs from results to get full content. " + "**Critical for reliability:** Never implement based on internal knowledge without checking current docs first - APIs change frequently." + " By default returns the top 20 results; set max_results (max 50) to adjust." ), "parameters": { "type": "object", @@ -923,10 +928,16 @@ EXPLORE_HF_DOCS_TOOL_SPEC = { HF_DOCS_FETCH_TOOL_SPEC = { "name": "fetch_hf_docs", "description": ( - "Fetch full markdown content of an HF documentation page. Use after explore_hf_docs.\n\n" - "Critical for finding documentation e.g. current trainer configuration parameters (SFTConfig, DPOConfig, etc.) " - "Use for researching solutions and before writing training scripts. Your internal knowledge is outdated.\n\n" - "Provide the full URL from explore_hf_docs results. The .md extension is added automatically." + "Fetch full markdown content of a specific HF documentation page. " + "⚠️ CRITICAL: Use this after explore_hf_docs to get detailed implementation guidance. " + "**Use when:** (1) Found relevant page in explore_hf_docs results, (2) Need complete API documentation, " + "(3) Need training method details (SFT/DPO/GRPO), (4) Need configuration examples, " + "(5) Need parameter descriptions and usage patterns. " + "**Pattern:** explore_hf_docs (find relevant page) → fetch_hf_docs (get full content) → implement using documented approach. " + "Provide full URL from explore_hf_docs results (e.g., 'https://huggingface.co/docs/trl/sft_trainer'). " + "Returns: Complete markdown documentation with examples, parameters, and usage patterns. " + "**For training tasks:** ALWAYS fetch trainer docs (SFTConfig, DPOConfig, etc.) before creating training scripts. " + "**Critical for reliability:** This ensures you use current APIs and best practices." ), "parameters": { "type": "object", diff --git a/agent/tools/github_find_examples.py b/agent/tools/github_find_examples.py index f5f2ddaad0a1959ec3418cc45ed88432a40e13c2..c0d795d93363a93f8f4f3e316f71f988017b98c4 100644 --- a/agent/tools/github_find_examples.py +++ b/agent/tools/github_find_examples.py @@ -405,16 +405,55 @@ def find_examples( GITHUB_FIND_EXAMPLES_TOOL_SPEC = { "name": "github_find_examples", "description": ( - "Find working example scripts in GitHub repositories (from a list of predetermined directories e.g. examples/, scripts/, tutorials/, etc.). " - "Uses fuzzy keyword matching.\n\n" - "MANDATORY before writing any ML training, fine-tuning, or inference code. " - "Your internal knowledge of library APIs is outdated — working examples show current API patterns.\n\n" - "Sequence: github_find_examples → github_read_file (study the example) → implement based on what you found.\n\n" - "Skip this only for: simple data queries, status checks, non-code tasks.\n\n" - "Examples:\n" - " {keyword: 'sft', repo: 'trl'} → finds examples/scripts/sft.py\n" - " {keyword: 'grpo', repo: 'trl'} → finds GRPO training examples\n" - " {repo: 'trl', max_results: 20} → lists all available training method examples" + "Discover working code examples, tutorials, scripts, and demos in GitHub repositories. " + "⚠️ CRITICAL: ALWAYS use this BEFORE implementing ML tasks - find working reference code first. " + "Your training data may be outdated; real repository examples show current best practices. " + "**Use when:** (1) Starting any ML implementation (training, inference, evaluation), " + "(2) User asks 'how to' questions about libraries, (3) Need reference implementations, " + "(4) Exploring library capabilities, (5) Before writing training/processing scripts. " + "**Pattern:** github_find_examples (discover) → github_read_file (study code) → implement with researched approach. " + "Returns: List of example files (scripts/notebooks/tutorials) with paths and URLs, sorted by relevance. " + "**Then:** Use github_read_file to read the actual implementation code. " + "**Critical for reliability:** Real examples prevent outdated API usage and show proven patterns. " + "## How it works\n\n" + "1. Fetches all example files (examples/, scripts/, tutorials/, demos/, notebooks/, etc.) from repository\n" + "2. If keyword provided, scores files against keyword using fuzzy matching\n" + "3. Returns best matches sorted by relevance and pattern priority\n" + "4. Provides copyable parameters for github_read_file tool\n\n" + "## Examples\n\n" + "\n" + "// ML Workflow Step: Find GRPO training examples before implementation\n" + "// Task: Starting GRPO fine-tuning project, need reference implementation\n" + "{\n" + " keyword: 'grpo',\n" + " repo: 'trl',\n" + " org: 'huggingface'\n" + "}\n" + "// Returns: examples/scripts/grpo_agent.py, examples/scripts/grpo_vlm.py\n" + "// Next step: github_read_file to study working implementation\n" + "\n\n" + "\n" + "// ML Workflow Step: Discover all available training methods\n" + "// Task: Exploring TRL training options before choosing approach\n" + "{\n" + " repo: 'trl',\n" + " org: 'huggingface',\n" + " max_results: 20\n" + "}\n" + "// Lists: SFT, DPO, GRPO, PPO, reward modeling examples\n" + "// Helps user choose appropriate method\n" + "\n\n" + "\n" + "// ML Workflow Step: Find LoRA fine-tuning examples\n" + "// Task: Learning parameter-efficient fine-tuning patterns\n" + "{\n" + " keyword: 'lora',\n" + " repo: 'peft',\n" + " org: 'huggingface'\n" + "}\n" + "// Discovers LoRA configuration and training examples\n" + "// Shows current PEFT API usage patterns\n" + "" ), "parameters": { "type": "object", diff --git a/agent/tools/github_read_file.py b/agent/tools/github_read_file.py index 485fe277972f8ebf6c52ff62cc488ed2b4e97d9b..02bccef05d53120670f95dd7556e40811fad9db0 100644 --- a/agent/tools/github_read_file.py +++ b/agent/tools/github_read_file.py @@ -250,13 +250,59 @@ def read_file( GITHUB_READ_FILE_TOOL_SPEC = { "name": "github_read_file", "description": ( - "Read file contents from GitHub repositories. Returns first 300 lines by default. " - "Auto-converts Jupyter notebooks to markdown.\n\n" - "Use AFTER github_find_examples to study the working implementation. " - "The purpose is to learn current API patterns — imports, trainer configs, dataset handling — " - "so your implementation uses correct, up-to-date code.\n\n" + "Read file contents from GitHub repositories with line range support (default 300 lines). " + "⚠️ CRITICAL: Use AFTER github_find_examples to study working implementation code. " + "**Use when:** (1) Found example file via github_find_examples and need full code, " + "(2) Need to read trainer class implementation, (3) Study configuration patterns, " + "(4) Read specific code sections with line ranges, (5) Review code from specific branches/commits. " + "**Pattern:** github_find_examples (discover files) → github_read_file (read code) → implement using researched patterns. " + "Returns: File contents with line numbers, formatted for LLM reading. Auto-converts Jupyter notebooks to markdown. " + "**Then:** Implement using patterns and APIs from the example code. " + "**Critical for reliability:** Reading working examples prevents API errors and shows current best practices. " "Use line_start/line_end for large files (>300 lines) to read specific sections.\n\n" - "When NOT to use: when you don't know the file path (use github_find_examples first)." + "## When to use this tool\n\n" + "- When reading example code, trainer implementations, or configuration files\n" + "- After github_find_examples returns file paths you want to study\n" + "- When investigating specific code sections with line ranges\n" + "- When reading from specific branches, tags, or commits (use ref parameter)\n\n" + "## When NOT to use this tool\n\n" + "- When you don't know exact file path (use github_find_examples or github_search_code first)\n" + "- When searching for code patterns across repos (use github_search_code instead)\n\n" + "## Examples\n\n" + "\n" + "// ML Workflow Step: Read GRPO trainer class after finding via github_find_examples\n" + "// Use case: Understand GRPOTrainer API, parameters, and methods\n" + "{\n" + " repo: 'huggingface/trl',\n" + " path: 'trl/trainer/grpo_trainer.py',\n" + " line_start: 1,\n" + " line_end: 200\n" + "}\n" + "// Read class definition and constructor to understand current API\n" + "// Shows: __init__ parameters, configuration, required arguments\n" + "\n\n" + "\n" + "// ML Workflow Step: Study complete training script from examples\n" + "// Use case: Learn end-to-end VLM fine-tuning workflow\n" + "{\n" + " repo: 'huggingface/trl',\n" + " path: 'examples/scripts/grpo_vlm.py'\n" + "}\n" + "// Returns first 300 lines - shows full training setup\n" + "// Use line_start/line_end if need to read more\n" + "\n\n" + "\n" + "// ML Workflow Step: Check TrainingArguments configuration patterns\n" + "// Use case: Learn how to structure training configs correctly\n" + "{\n" + " repo: 'huggingface/transformers',\n" + " path: 'examples/pytorch/language-modeling/run_clm.py',\n" + " line_start: 50,\n" + " line_end: 150\n" + "}\n" + "// Read argument parsing and config setup section\n" + "// Shows: current parameter names, default values, best practices\n" + "" ), "parameters": { "type": "object", diff --git a/agent/tools/jobs_tool.py b/agent/tools/jobs_tool.py index c9819b5c75f65e9c9154b40f66637c8935a8fe03..05a19d8501bc10e87ad5d4dbee4619266feec39b 100644 --- a/agent/tools/jobs_tool.py +++ b/agent/tools/jobs_tool.py @@ -9,7 +9,9 @@ import base64 import http.client import os import re -from typing import Any, Awaitable, Callable, Dict, Literal, Optional +from typing import Any, Dict, Literal, Optional, Callable, Awaitable + +import logging import httpx from huggingface_hub import HfApi @@ -17,6 +19,8 @@ from huggingface_hub.utils import HfHubHTTPError from agent.core.session import Event from agent.tools.types import ToolResult + +logger = logging.getLogger(__name__) from agent.tools.utilities import ( format_job_details, format_jobs_table, @@ -25,33 +29,38 @@ from agent.tools.utilities import ( ) # Hardware flavors -CPU_FLAVORS = ["cpu-basic", "cpu-upgrade"] +CPU_FLAVORS = ["cpu-basic", "cpu-upgrade", "cpu-performance", "cpu-xl"] GPU_FLAVORS = [ + "sprx8", + "zero-a10g", "t4-small", "t4-medium", - "a10g-small", - "a10g-large", - "a10g-largex2", - "a10g-largex4", - "a100-large", - "a100x4", - "a100x8", "l4x1", "l4x4", "l40sx1", "l40sx4", "l40sx8", + "a10g-small", + "a10g-large", + "a10g-largex2", + "a10g-largex4", + "a100-large", + "h100", + "h100x8", ] # Detailed specs for display (vCPU/RAM/GPU VRAM) -CPU_FLAVORS_DESC = "cpu-basic(2vCPU/16GB), cpu-upgrade(8vCPU/32GB)" +CPU_FLAVORS_DESC = ( + "cpu-basic(2vCPU/16GB), cpu-upgrade(8vCPU/32GB), cpu-performance, cpu-xl" +) GPU_FLAVORS_DESC = ( "t4-small(4vCPU/15GB/GPU 16GB), t4-medium(8vCPU/30GB/GPU 16GB), " - "a10g-small(4vCPU/15GB/GPU 24GB), a10g-large(12vCPU/46GB/GPU 24GB), " - "a10g-largex2(24vCPU/92GB/GPU 48GB), a10g-largex4(48vCPU/184GB/GPU 96GB), " - "a100-large(12vCPU/142GB/GPU 80GB), a100x4(48vCPU/568GB/GPU 320GB), a100x8(96vCPU/1136GB/GPU 640GB), " "l4x1(8vCPU/30GB/GPU 24GB), l4x4(48vCPU/186GB/GPU 96GB), " - "l40sx1(8vCPU/62GB/GPU 48GB), l40sx4(48vCPU/382GB/GPU 192GB), l40sx8(192vCPU/1534GB/GPU 384GB)" + "l40sx1(8vCPU/62GB/GPU 48GB), l40sx4(48vCPU/382GB/GPU 192GB), l40sx8(192vCPU/1534GB/GPU 384GB), " + "a10g-small(4vCPU/14GB/GPU 24GB), a10g-large(12vCPU/46GB/GPU 24GB), " + "a10g-largex2(24vCPU/92GB/GPU 48GB), a10g-largex4(48vCPU/184GB/GPU 96GB), " + "a100-large(12vCPU/142GB/GPU 80GB), h100(23vCPU/240GB/GPU 80GB), h100x8(184vCPU/1920GB/GPU 640GB), " + "zero-a10g(dynamic alloc)" ) SPECIALIZED_FLAVORS = ["inf2x6"] ALL_FLAVORS = CPU_FLAVORS + GPU_FLAVORS + SPECIALIZED_FLAVORS @@ -113,23 +122,11 @@ def _filter_uv_install_output(logs: list[str]) -> list[str]: return logs -_DEFAULT_ENV = { - "HF_HUB_DISABLE_PROGRESS_BARS": "1", - "TQDM_DISABLE": "1", - "TRANSFORMERS_VERBOSITY": "warning", - "HF_HUB_ENABLE_HF_TRANSFER": "1", -} - - -def _add_default_env(params: Dict[str, Any] | None) -> Dict[str, Any]: - """Inject default env vars for clean, agent-friendly output.""" - result = dict(_DEFAULT_ENV) - result.update(params or {}) # user-provided values override defaults - return result - - -def _add_environment_variables(params: Dict[str, Any] | None) -> Dict[str, Any]: - token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") or "" +def _add_environment_variables( + params: Dict[str, Any] | None, user_token: str | None = None +) -> Dict[str, Any]: + # Prefer the authenticated user's OAuth token, fall back to global env var + token = user_token or os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") or "" # Start with user-provided env vars, then force-set token last result = dict(params or {}) @@ -285,10 +282,15 @@ class HfJobsTool: hf_token: Optional[str] = None, namespace: Optional[str] = None, log_callback: Optional[Callable[[str], Awaitable[None]]] = None, + session: Any = None, + tool_call_id: Optional[str] = None, ): + self.hf_token = hf_token self.api = HfApi(token=hf_token) self.namespace = namespace self.log_callback = log_callback + self.session = session + self.tool_call_id = tool_call_id async def execute(self, params: Dict[str, Any]) -> ToolResult: """Execute the specified operation""" @@ -384,9 +386,7 @@ class HfJobsTool: def log_producer(): try: # fetch_job_logs is a blocking sync generator - logs_gen = self.api.fetch_job_logs( - job_id=job_id, namespace=namespace - ) + logs_gen = self.api.fetch_job_logs(job_id=job_id, namespace=namespace) for line in logs_gen: # Push line to queue thread-safely loop.call_soon_threadsafe(queue.put_nowait, line) @@ -413,7 +413,7 @@ class HfJobsTool: # Process log line log_line = item - print("\t" + log_line) + logger.debug(log_line) if self.log_callback: await self.log_callback(log_line) all_logs.append(log_line) @@ -441,19 +441,19 @@ class HfJobsTool: if current_status in terminal_states: # Job finished, no need to retry - print(f"\tJob reached terminal state: {current_status}") + logger.info(f"Job reached terminal state: {current_status}") break # Job still running, retry connection - print( - f"\tConnection interrupted ({str(e)[:50]}...), reconnecting in {retry_delay}s..." + logger.warning( + f"Connection interrupted ({str(e)[:50]}...), reconnecting in {retry_delay}s..." ) await asyncio.sleep(retry_delay) continue except (ConnectionError, TimeoutError, OSError): # Can't even check job status, wait and retry - print(f"\tConnection error, retrying in {retry_delay}s...") + logger.warning(f"Connection error, retrying in {retry_delay}s...") await asyncio.sleep(retry_delay) continue @@ -509,16 +509,30 @@ class HfJobsTool: self.api.run_job, image=image, command=command, - env=_add_default_env(args.get("env")), - secrets=_add_environment_variables(args.get("secrets")), + env=args.get("env"), + secrets=_add_environment_variables(args.get("secrets"), self.hf_token), flavor=args.get("hardware_flavor", "cpu-basic"), timeout=args.get("timeout", "30m"), namespace=self.namespace, ) + # Send job URL immediately after job creation (before waiting for completion) + if self.session and self.tool_call_id: + await self.session.send_event( + Event( + event_type="tool_state_change", + data={ + "tool_call_id": self.tool_call_id, + "tool": "hf_jobs", + "state": "running", + "jobUrl": job.url, + }, + ) + ) + # Wait for completion and stream logs - print(f"{job_type} job started: {job.url}") - print("Streaming logs...\n---\n") + logger.info(f"{job_type} job started: {job.url}") + logger.info("Streaming logs...") final_status, all_logs = await self._wait_for_job_completion( job_id=job.id, @@ -727,8 +741,8 @@ To verify, call this tool with `{{"operation": "inspect", "job_id": "{job_id}"}} image=image, command=command, schedule=schedule, - env=_add_default_env(args.get("env")), - secrets=_add_environment_variables(args.get("secrets")), + env=args.get("env"), + secrets=_add_environment_variables(args.get("secrets"), self.hf_token), flavor=args.get("hardware_flavor", "cpu-basic"), timeout=args.get("timeout", "30m"), namespace=self.namespace, @@ -887,31 +901,56 @@ To inspect, call this tool with `{{"operation": "scheduled inspect", "scheduled_ HF_JOBS_TOOL_SPEC = { "name": "hf_jobs", "description": ( - "Execute Python scripts or Docker containers on HF cloud infrastructure.\n\n" - "Two modes (mutually exclusive): Python mode (script + dependencies) or Docker mode (command + image). " - "Provide exactly ONE of 'script' or 'command'.\n\n" - "BEFORE submitting training/fine-tuning jobs:\n" - "- You MUST have called github_find_examples + github_read_file to find a working reference implementation. " - "Scripts based on your internal knowledge WILL use outdated APIs and fail.\n" - "- You MUST have validated dataset format via hf_inspect_dataset or hub_repo_details.\n" - "- Training config MUST include push_to_hub=True and hub_model_id. " - "Job storage is EPHEMERAL — all files are deleted when the job ends. Without push_to_hub, trained models are lost permanently.\n" - "- Include trackio monitoring and provide the dashboard URL to the user.\n\n" - "BATCH/ABLATION JOBS: Submit ONE job first. Check logs to confirm it starts training successfully. " - "Only then submit the remaining jobs. Never submit all at once — if there's a bug, all jobs fail.\n\n" - "Operations: run, ps, logs, inspect, cancel, scheduled run/ps/inspect/delete/suspend/resume.\n\n" - f"Hardware: CPU: {CPU_FLAVORS_DESC}. GPU: {GPU_FLAVORS_DESC}.\n" - "Common picks: t4-small ($0.60/hr, 1-3B), a10g-large ($2/hr, 7-13B), a100-large ($4/hr, 30B+), h100 ($6/hr, 70B+). " - "Note: a10g-small and a10g-large have the SAME 24GB GPU — the difference is CPU/RAM only.\n\n" - "OOM RECOVERY: When a training job fails with CUDA OOM:\n" - "1. Reduce per_device_train_batch_size and increase gradient_accumulation_steps proportionally (keep effective batch size identical)\n" - "2. Enable gradient_checkpointing=True\n" - "3. Upgrade to larger GPU (a10g→a100→h100)\n" - "Do NOT switch training methods (e.g. full SFT to LoRA) or reduce max_length — those change what the user gets and require explicit approval.\n\n" - "Examples:\n" - "Training: {'operation': 'run', 'script': '/app/train.py', 'dependencies': ['transformers', 'trl', 'torch', 'datasets', 'trackio'], 'hardware_flavor': 'a100-large', 'timeout': '8h'}\n" - "Monitor: {'operation': 'ps'}, {'operation': 'logs', 'job_id': 'xxx'}, {'operation': 'cancel', 'job_id': 'xxx'}" - "Docker: {'operation': 'run', 'command': ['duckdb', '-c', 'select 1 + 2'], 'image': 'duckdb/duckdb', 'hardware_flavor': 'cpu-basic', 'timeout': '1h'}\n" + "Execute Python scripts or Docker containers on HF cloud infrastructure (CPUs/GPUs) in one of two modes. " + "\n\n" + "**Two Modes (mutually exclusive):**\n" + "1. Python mode: using 'script' arg (REQUIRED) + 'dependencies'\n" + "2. Docker mode: using 'command' arg (REQUIRED) + 'image'\n\n" + "🚨 **REQUIRED:** You MUST provide exactly ONE of: 'script' (Python code as string) OR 'command' (Docker command as array). " + "They are mutually exclusive - provide one or the other, never both, never neither. " + "Do NOT call with just {'operation': 'run'} - always include your code. Example: {'operation': 'run', 'script': 'import torch; print(torch.cuda.is_available())', 'dependencies': ['torch']} or {'operation': 'run', 'command': ['duckdb', '-c', 'select 1 + 2']', 'image': 'duckdb/duckdb'}\n\n" + "⚠️ CRITICAL for reliability: (1) Jobs run ASYNC - provide monitoring URL immediately, don't poll; " + "(2) Set timeout >30min (default too short - training needs 2-8h); " + "(3) HF_TOKEN auto-loaded to secrets for Hub ops (push_to_hub, private repos); " + "(4) Job storage EPHEMERAL - MUST push_to_hub() or ALL work is LOST. " + "**Use when:** User wants cloud compute, training models, data processing, batch inference, GPU workloads, scheduled tasks. " + "ALWAYS use this tool (✓), never bash 'hf jobs' commands (✗). Pass script content inline (✓), don't save to files unless requested (✗). " + "\n\n" + "**Operations:** run, ps, logs, inspect, cancel, scheduled run, scheduled ps, scheduled inspect, scheduled delete, scheduled suspend, scheduled resume. " + "**Available Hardware (vCPU/RAM/GPU):**\n" + f"• CPU: {CPU_FLAVORS_DESC}\n" + f"• GPU: {GPU_FLAVORS_DESC}\n" + " ◦ Common: t4-small ($0.60/hr, demos/1-3B models), a10g-small ($1/hr), a10g-large ($2/hr, production 7-13B), a100-large ($4/hr, 30B+), h100 ($6/hr, 70B+)\n\n" + "**After Submission Ground Rules:**\n" + "✓ Return immediately with job ID and monitoring URL\n" + "✓ Provide expected completion time and cost estimate\n" + "✓ For training: Include Trackio dashboard URL\n" + "✓ Note user can check status later\n" + "✗ DON'T poll logs automatically\n" + "✗ DON'T wait for completion\n" + "✗ DON'T check status unless user asks\n\n" + "**For Training Tasks:**\n" + "• ALWAYS research TRL docs first: explore_hf_docs('trl') → fetch_hf_docs()\n" + "• ALWAYS validate dataset format with hub_repo_details (SFT needs messages/text, DPO needs chosen/rejected)\n" + "• ALWAYS include Trackio monitoring in script (explore_hf_docs('trackio'))\n" + "• ALWAYS enable push_to_hub=True in training config\n" + "• Set timeout 2-8h for training (NOT default 30m)\n" + "• Confirm model/dataset choices with user before submitting\n\n" + "**Examples:**\n\n" + "**Training - Fine-tune LLM:**\n" + "{'operation': 'run', 'script': '# Training script with TRL\\nfrom trl import SFTConfig, SFTTrainer\\nfrom transformers import AutoModelForCausalLM\\nmodel = AutoModelForCausalLM.from_pretrained(\"Qwen/Qwen3-4B\")\\n# ... researched implementation from docs ...\\ntrainer.train()\\ntrainer.push_to_hub(\"user-name/my-model\")', 'dependencies': ['transformers', 'trl', 'torch', 'datasets', 'trackio'], 'hardware_flavor': 'a10g-large', 'timeout': '4h'}\n\n" + "**Data Processing:**\n" + "{'operation': 'run', 'script': 'from datasets import load_dataset\\nds = load_dataset(\"data\")\\n# process...\\nds.push_to_hub(\"user/processed\")', 'dependencies': ['datasets', 'pandas'], 'hardware_flavor': 'cpu-upgrade', 'timeout': '2h'}\n\n" + "**Scheduled Daily Job:**\n" + "{'operation': 'scheduled run', 'schedule': '@daily', 'script': 'from datasets import Dataset\\nimport pandas as pd\\n# scrape/generate data\\ndf = pd.DataFrame(data)\\nds = Dataset.from_pandas(df)\\nds.push_to_hub(\"user-name/daily-dataset\")', 'dependencies': ['datasets', 'pandas'], 'hardware_flavor': 'cpu-basic'}\n\n" + "**Docker Mode:**\n" + "{'operation': 'run', 'image': 'pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime', 'command': ['python', 'train.py', '--epochs', '10'], 'hardware_flavor': 'a100-large'}\n\n" + "**Monitor Operations:**\n" + "{'operation': 'ps'} - List all jobs\n" + "{'operation': 'logs', 'job_id': 'xxx'} - Stream logs (only when user requests)\n" + "{'operation': 'inspect', 'job_id': 'xxx'} - Get job details\n" + "{'operation': 'cancel', 'job_id': 'xxx'} - Stop job\n\n" + "⚠️ CRITICAL: Files created during execution are DELETED when job finishes. MUST push_to_hub() all outputs (models, datasets, artifacts) in script. For logs/scripts, use hf_private_repos after completion." ), "parameters": { "type": "object", @@ -931,65 +970,58 @@ HF_JOBS_TOOL_SPEC = { "scheduled suspend", "scheduled resume", ], - "description": "Operation to execute.", + "description": ( + "Operation to execute. Valid values: [run, ps, logs, inspect, cancel, " + "scheduled run, scheduled ps, scheduled inspect, scheduled delete, " + "scheduled suspend, scheduled resume]" + ), }, + # Python/UV specific parameters "script": { "type": "string", - "description": ( - "Python code or sandbox file path (e.g. '/app/train.py') or URL. " - "Triggers Python mode. For ML training: base this on a working example found via github_find_examples, not on internal knowledge. " - "Mutually exclusive with 'command'." - ), + "description": "Python code to execute. Triggers Python mode (auto pip install). Use with 'run'/'scheduled run'. Mutually exclusive with 'command'.", }, "dependencies": { "type": "array", "items": {"type": "string"}, - "description": ( - "Pip packages to install. Include ALL required packages. " - "Common training set: ['transformers', 'trl', 'torch', 'datasets', 'trackio', 'accelerate']. " - "Only used with 'script'." - ), + "description": "Pip packages to install. Example: ['trl', 'torch', 'datasets', 'transformers']. Only used with 'script'.", }, + # Docker specific parameters "image": { "type": "string", - "description": "Docker image. Optional — auto-selected if not provided. Use with 'command'.", + "description": "Docker image. Example: 'pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime'. Use with 'run'/'scheduled run'. Optional (auto-selected if not provided).", }, "command": { "type": "array", "items": {"type": "string"}, - "description": "Command to execute as list. Triggers Docker mode. Mutually exclusive with 'script'.", + "description": "Command to execute as list. Example: ['python', 'train.py', '--epochs', '10']. Triggers Docker mode. Use with 'run'/'scheduled run'. Mutually exclusive with 'script'.", }, + # Hardware and environment "hardware_flavor": { "type": "string", - "description": ( - "Hardware type. Sizing guide: 1-3B params → t4-small/a10g-small, " - "7-13B → a10g-large, 30B+ → a100-large, 70B+ → h100/h100x8. " - f"All options: CPU: {CPU_FLAVORS}. GPU: {GPU_FLAVORS}." - ), + "description": f"Hardware type. Available CPU flavors: {CPU_FLAVORS}. Available GPU flavors: {GPU_FLAVORS}. Use with 'run'/'scheduled run'.", }, "timeout": { "type": "string", - "description": ( - "Maximum job runtime. MUST be >2h for any training job — default 30m kills training mid-run. " - "Guidelines: 1-3B models: 3-4h, 7-13B: 6-8h, 30B+: 12-24h. " - "Use 30m-1h only for quick data processing or inference tasks. Default: '30m'." - ), + "description": "Max runtime. Examples: '30m', '2h', '4h'. Default: '30m'. Important for long training jobs. Use with 'run'/'scheduled run'.", }, "env": { "type": "object", - "description": "Environment variables {'KEY': 'VALUE'}. HF_TOKEN is auto-included.", + "description": "Environment variables. Format: {'KEY': 'VALUE'}. HF_TOKEN is automatically included from your auth. Use with 'run'/'scheduled run'.", }, + # Job management parameters "job_id": { "type": "string", - "description": "Job ID. Required for: logs, inspect, cancel.", + "description": "Job ID to operate on. Required for: 'logs', 'inspect', 'cancel'.", }, + # Scheduled job parameters "scheduled_job_id": { "type": "string", - "description": "Scheduled job ID. Required for: scheduled inspect/delete/suspend/resume.", + "description": "Scheduled job ID. Required for: 'scheduled inspect', 'scheduled delete', 'scheduled suspend', 'scheduled resume'.", }, "schedule": { "type": "string", - "description": "Cron schedule or preset (@hourly, @daily, @weekly, @monthly). Required for: scheduled run.", + "description": "Schedule for recurring job. Presets: '@hourly', '@daily', '@weekly', '@monthly'. Cron: '0 9 * * 1' (Mon 9am). Required for: 'scheduled run'.", }, }, "required": ["operation"], @@ -998,7 +1030,7 @@ HF_JOBS_TOOL_SPEC = { async def hf_jobs_handler( - arguments: Dict[str, Any], session: Any = None + arguments: Dict[str, Any], session: Any = None, tool_call_id: str | None = None ) -> tuple[str, bool]: """Handler for agent tool router""" try: @@ -1009,36 +1041,20 @@ async def hf_jobs_handler( Event(event_type="tool_log", data={"tool": "hf_jobs", "log": log}) ) - # If script is a sandbox file path, read it from the sandbox - script = arguments.get("script", "") - sandbox = getattr(session, "sandbox", None) if session else None - is_path = ( - sandbox - and isinstance(script, str) - and script.strip() == script - and not any(c in script for c in "\r\n\0") - and ( - script.startswith("/") - or script.startswith("./") - or script.startswith("../") - ) + # Prefer the authenticated user's OAuth token, fall back to global env + hf_token = ( + (getattr(session, "hf_token", None) if session else None) + or os.environ.get("HF_TOKEN") + or os.environ.get("HUGGINGFACE_HUB_TOKEN") ) - if is_path: - import shlex - - result = await asyncio.to_thread(sandbox.bash, f"cat {shlex.quote(script)}") - if not result.success: - return f"Failed to read {script} from sandbox: {result.error}", False - arguments = {**arguments, "script": result.output} - - # Get token and namespace from HF token - hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") - namespace = HfApi(token=hf_token).whoami().get("name") if hf_token else None + namespace = os.environ.get("HF_NAMESPACE") or (HfApi(token=hf_token).whoami().get("name") if hf_token else None) tool = HfJobsTool( namespace=namespace, hf_token=hf_token, log_callback=log_callback if session else None, + session=session, + tool_call_id=tool_call_id, ) result = await tool.execute(arguments) return result["formatted"], not result.get("isError", False) diff --git a/agent/tools/plan_tool.py b/agent/tools/plan_tool.py index a923d53c27068fe81d5fe5dd1e774255c4339601..25ba5f87201ff45d874b94abc8975857f10b40d1 100644 --- a/agent/tools/plan_tool.py +++ b/agent/tools/plan_tool.py @@ -85,11 +85,18 @@ def get_current_plan() -> List[Dict[str, str]]: PLAN_TOOL_SPEC = { "name": "plan_tool", "description": ( - "Track progress on multi-step tasks with a todo list (pending/in_progress/completed).\n\n" - "Use for tasks with 3+ steps. Each call replaces the entire plan (send full list).\n\n" - "Rules: exactly ONE task in_progress at a time. Mark completed immediately after finishing. " - "Only mark completed when the task fully succeeded — keep in_progress if there are errors. " - "Update frequently so the user sees progress." + "Manage task planning and progress tracking with todo list (pending/in_progress/completed statuses). " + "⚠️ CRITICAL: ALWAYS use for multi-step tasks (3+ steps) and MUST update frequently to show progress. " + "**Use when:** (1) User provides multiple tasks, (2) Complex workflows (training, evaluation, data processing), " + "(3) Tasks requiring multiple tool calls, (4) Need to communicate progress clearly to user, " + "(5) Breaking down ambiguous requests into concrete steps. " + "**Pattern:** Create plan at start → Mark in_progress when starting task → Mark completed immediately after finishing → User sees clear progress. " + "Each call replaces entire plan (full list required). " + "**Critical for reliability:** Exactly ONE task in_progress at a time (not zero, not multiple). " + "Mark tasks completed IMMEDIATELY after finishing - don't batch completions. " + "**For long-running tasks:** Update plan after each major step to keep user informed. " + "**Only mark completed when:** Task fully accomplished, no errors, all requirements met. " + "Keep tasks pending if blocked/errors occur - create new task to resolve blockers." ), "parameters": { "type": "object", diff --git a/agent/tools/sandbox_client.py b/agent/tools/sandbox_client.py deleted file mode 100644 index 2eb74c9ce61420d0ce69757835cf3c9e4d92d120..0000000000000000000000000000000000000000 --- a/agent/tools/sandbox_client.py +++ /dev/null @@ -1,714 +0,0 @@ -#!/usr/bin/env python3 -# /// script -# requires-python = ">=3.10" -# dependencies = ["huggingface_hub>=0.20.0", "httpx>=0.27.0"] -# /// -""" -Sandbox Tools — Agent-native primitives for HF Space dev-mode sandboxes. - -Architecture: - - Creates a sandbox by duplicating a template Space (runs sandbox_server.py) - - Waits for it to come online - - Communicates via HTTPS to the Space's API - - Optionally deletes the Space when done - -Lifecycle: - sb = Sandbox.create(owner="burtenshaw") # duplicate, wait, connect - sb = Sandbox.create(owner="burtenshaw", # with options - hardware="t4-small", - private=True, - sleep_time=3600) - sb = Sandbox.connect("burtenshaw/my-sandbox-abc") # attach to existing - - sb.bash("uv run train.py") - sb.read("/app/train.py") - sb.edit("/app/train.py", old_str="lr=1e-3", new_str="lr=1e-4") - - sb.delete() # tear down when done - - # Or use as a context manager for automatic cleanup - with Sandbox.create(owner="burtenshaw") as sb: - sb.bash("python train.py") - # Space deleted on exit - -Tools: bash, read, write, edit, upload -""" - -from __future__ import annotations - -import io -import os -import sys -import time -import uuid -from dataclasses import dataclass, field -from typing import Any - -import httpx -from huggingface_hub import CommitOperationAdd, HfApi - -TEMPLATE_SPACE = "burtenshaw/sandbox" -HARDWARE_OPTIONS = [ - "cpu-basic", - "cpu-upgrade", - "t4-small", - "t4-medium", - "a10g-small", - "a10g-large", - "a100-large", -] -OUTPUT_LIMIT = 30000 -LINE_LIMIT = 2000 -DEFAULT_READ_LIMIT = 2000 -DEFAULT_TIMEOUT = 120 -MAX_TIMEOUT = 600 -WAIT_TIMEOUT = 300 -WAIT_INTERVAL = 5 -API_WAIT_TIMEOUT = 180 - -_DOCKERFILE = """\ -FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim - -RUN apt-get update && \\ - apt-get install -y \\ - bash git git-lfs wget curl procps \\ - htop vim nano jq tmux \\ - build-essential && \\ - rm -rf /var/lib/apt/lists/* - -RUN uv pip install --system fastapi uvicorn python-multipart - -RUN useradd -m -u 1000 user -USER user - -ENV HOME=/home/user \\ - PATH=/home/user/.local/bin:$PATH \\ - PIP_USER=1 \\ - HF_HUB_DISABLE_PROGRESS_BARS=1 \\ - TQDM_DISABLE=1 \\ - TRANSFORMERS_VERBOSITY=warning \\ - HF_HUB_ENABLE_HF_TRANSFER=1 - -WORKDIR /app -COPY --chown=user . /app - -EXPOSE 7860 - -CMD ["python", "sandbox_server.py"] -""" - -_SANDBOX_SERVER = '''\ -"""Minimal FastAPI server for sandbox operations.""" -import os, subprocess, pathlib -from fastapi import FastAPI -from pydantic import BaseModel -from typing import Optional -import uvicorn - -app = FastAPI() - -class BashReq(BaseModel): - command: str - work_dir: str = "/app" - timeout: int = 120 - -class ReadReq(BaseModel): - path: str - offset: Optional[int] = None - limit: Optional[int] = 2000 - -class WriteReq(BaseModel): - path: str - content: str - -class EditReq(BaseModel): - path: str - old_str: str - new_str: str - replace_all: bool = False - -class ExistsReq(BaseModel): - path: str - -@app.get("/api/health") -def health(): - return {"status": "ok"} - -@app.post("/api/bash") -def bash(req: BashReq): - try: - r = subprocess.run( - req.command, shell=True, capture_output=True, text=True, - cwd=req.work_dir, timeout=req.timeout, - ) - output = r.stdout + r.stderr - if len(output) > 30000: - output = output[:30000] + "\\n... (truncated)" - return {"success": r.returncode == 0, "output": output, "error": "" if r.returncode == 0 else f"Exit code {r.returncode}"} - except subprocess.TimeoutExpired: - return {"success": False, "output": "", "error": f"Timeout after {req.timeout}s"} - except Exception as e: - return {"success": False, "output": "", "error": str(e)} - -@app.post("/api/read") -def read(req: ReadReq): - try: - p = pathlib.Path(req.path) - if not p.exists(): - return {"success": False, "output": "", "error": f"File not found: {req.path}"} - if p.is_dir(): - return {"success": False, "output": "", "error": f"Is a directory: {req.path}"} - lines = p.read_text().splitlines() - start = (req.offset or 1) - 1 - end = start + (req.limit or len(lines)) - selected = lines[start:end] - numbered = "\\n".join(f"{start + i + 1}\\t{line}" for i, line in enumerate(selected)) - return {"success": True, "output": numbered, "error": ""} - except Exception as e: - return {"success": False, "output": "", "error": str(e)} - -@app.post("/api/write") -def write(req: WriteReq): - try: - p = pathlib.Path(req.path) - p.parent.mkdir(parents=True, exist_ok=True) - p.write_text(req.content) - return {"success": True, "output": f"Wrote {len(req.content)} bytes to {req.path}", "error": ""} - except Exception as e: - return {"success": False, "output": "", "error": str(e)} - -@app.post("/api/edit") -def edit(req: EditReq): - try: - p = pathlib.Path(req.path) - if not p.exists(): - return {"success": False, "output": "", "error": f"File not found: {req.path}"} - content = p.read_text() - if req.old_str not in content: - return {"success": False, "output": "", "error": f"old_str not found in {req.path}"} - if not req.replace_all and content.count(req.old_str) > 1: - return {"success": False, "output": "", "error": f"old_str appears {content.count(req.old_str)} times. Use replace_all=true or provide more context."} - if req.replace_all: - new_content = content.replace(req.old_str, req.new_str) - else: - new_content = content.replace(req.old_str, req.new_str, 1) - p.write_text(new_content) - return {"success": True, "output": f"Edited {req.path}", "error": ""} - except Exception as e: - return {"success": False, "output": "", "error": str(e)} - -@app.post("/api/exists") -def exists(req: ExistsReq): - return {"success": True, "output": str(pathlib.Path(req.path).exists()).lower(), "error": ""} - -if __name__ == "__main__": - uvicorn.run(app, host="0.0.0.0", port=7860) -''' - - -@dataclass -class ToolResult: - success: bool - output: str = "" - error: str = "" - - def __str__(self): - if self.success: - return self.output or "(no output)" - return f"ERROR: {self.error}" - - def to_dict(self) -> dict: - return {"success": self.success, "output": self.output, "error": self.error} - - -@dataclass -class Sandbox: - """ - A handle to an HF Space sandbox. - - Use Sandbox.create() to spin up a new one, or Sandbox.connect() to - attach to an existing running Space. - """ - - space_id: str - token: str | None = None - work_dir: str = "/app" - timeout: int = DEFAULT_TIMEOUT - _owns_space: bool = field(default=False, repr=False) - _base_url: str = field(init=False, repr=False) - _client: httpx.Client = field(init=False, repr=False) - _hf_api: HfApi = field(init=False, repr=False) - _files_read: set = field(init=False, repr=False, default_factory=set) - - def __post_init__(self): - self.token = self.token or os.environ.get("HF_TOKEN") - slug = self.space_id.replace("/", "-") - # Trailing slash is critical: httpx resolves relative paths against base_url. - # Without it, client.get("health") resolves to /health instead of /api/health. - self._base_url = f"https://{slug}.hf.space/api/" - self._client = httpx.Client( - base_url=self._base_url, - headers={"Authorization": f"Bearer {self.token}"} if self.token else {}, - timeout=httpx.Timeout(MAX_TIMEOUT, connect=30), - follow_redirects=True, - ) - self._hf_api = HfApi(token=self.token) - - # ── Lifecycle ───────────────────────────────────────────────── - - @classmethod - def create( - cls, - owner: str, - *, - name: str | None = None, - template: str = TEMPLATE_SPACE, - hardware: str = "cpu-basic", - private: bool = False, - sleep_time: int | None = None, - token: str | None = None, - wait_timeout: int = WAIT_TIMEOUT, - ) -> Sandbox: - """ - Create a new sandbox by duplicating the template Space. - - Generates a unique space name, duplicates the template, waits for it - to come online, then returns a connected Sandbox. - - Args: - owner: HF username or org (e.g. "burtenshaw"). - name: Base name for the space. Defaults to "sandbox". - A unique suffix is always appended. - template: Source Space to duplicate (default: burtenshaw/sandbox). - hardware: Hardware tier (cpu-basic, t4-small, etc.). - private: Whether the Space should be private. - sleep_time: Auto-sleep after N seconds of inactivity. - token: HF API token. Falls back to HF_TOKEN env var. - wait_timeout: Max seconds to wait for Space to start (default: 300). - - Returns: - A Sandbox instance connected to the running Space. - """ - token = token or os.environ.get("HF_TOKEN") - api = HfApi(token=token) - - base = name or "sandbox" - suffix = uuid.uuid4().hex[:8] - space_id = f"{owner}/{base}-{suffix}" - - print(f"Creating sandbox: {space_id} (from {template})...") - - kwargs = { - "from_id": template, - "to_id": space_id, - "private": private, - "hardware": hardware, - } - if sleep_time is not None: - kwargs["sleep_time"] = sleep_time - - api.duplicate_space(**kwargs) - print(f"Space created: https://huggingface.co/spaces/{space_id}") - - # Upload sandbox server and Dockerfile (triggers rebuild) - cls._setup_server(space_id, api) - - # Wait for it to come online (rebuild + start) - print(f"Waiting for Space to start (timeout: {wait_timeout}s)...") - deadline = time.time() + wait_timeout - while time.time() < deadline: - runtime = api.get_space_runtime(space_id) - if runtime.stage == "RUNNING": - print(f"Space is running (hardware: {runtime.hardware})") - break - if runtime.stage in ("RUNTIME_ERROR", "BUILD_ERROR"): - raise RuntimeError( - f"Space failed to start: {runtime.stage}. " - f"Check https://huggingface.co/spaces/{space_id}" - ) - print(f" {runtime.stage}...") - time.sleep(WAIT_INTERVAL) - else: - raise TimeoutError( - f"Space did not start within {wait_timeout}s. " - f"Check https://huggingface.co/spaces/{space_id}" - ) - - # Wait for the API server to be responsive (non-fatal) - sb = cls(space_id=space_id, token=token, _owns_space=True) - try: - sb._wait_for_api(timeout=API_WAIT_TIMEOUT) - except TimeoutError as e: - print( - f"Warning: API health check timed out ({e}), but Space is RUNNING. Continuing." - ) - return sb - - @staticmethod - def _setup_server(space_id: str, api: HfApi) -> None: - """Upload embedded sandbox server + Dockerfile to the Space (single commit).""" - print(f"Uploading sandbox server to {space_id}...") - api.create_commit( - repo_id=space_id, - repo_type="space", - operations=[ - CommitOperationAdd( - path_in_repo="sandbox_server.py", - path_or_fileobj=io.BytesIO(_SANDBOX_SERVER.encode()), - ), - CommitOperationAdd( - path_in_repo="Dockerfile", - path_or_fileobj=io.BytesIO(_DOCKERFILE.encode()), - ), - ], - commit_message="Setup sandbox server", - ) - print("Server files uploaded, rebuild triggered.") - - @classmethod - def connect(cls, space_id: str, *, token: str | None = None) -> Sandbox: - """ - Connect to an existing running Space. - - Does a health check to verify the Space is reachable. - """ - sb = cls(space_id=space_id, token=token, _owns_space=False) - sb._wait_for_api(timeout=60) - return sb - - def _wait_for_api(self, timeout: int = API_WAIT_TIMEOUT): - """Poll the health endpoint until the server responds.""" - deadline = time.time() + timeout - last_err = None - last_status = None - while time.time() < deadline: - try: - resp = self._client.get("health", timeout=10) - last_status = resp.status_code - if resp.status_code == 200: - print(f"API is responsive at {self._base_url}") - return - except Exception as e: - last_err = e - time.sleep(3) - raise TimeoutError( - f"Sandbox API at {self._base_url} not responding after {timeout}s. " - f"Last status: {last_status}, last error: {last_err}" - ) - - def delete(self): - """Delete the Space. Only works if this Sandbox created it.""" - if not self._owns_space: - raise RuntimeError( - f"This Sandbox did not create {self.space_id}. " - f"Use self._hf_api.delete_repo() directly if you're sure." - ) - print(f"Deleting sandbox: {self.space_id}...") - self._hf_api.delete_repo(self.space_id, repo_type="space") - self._client.close() - print("Deleted.") - - def pause(self): - """Pause the Space (stops billing, preserves state).""" - self._hf_api.pause_space(self.space_id) - - def restart(self): - """Restart the Space.""" - self._hf_api.restart_space(self.space_id) - self._wait_for_api() - - @property - def url(self) -> str: - """Public URL of the Space.""" - return f"https://huggingface.co/spaces/{self.space_id}" - - @property - def status(self) -> str: - """Current Space stage (RUNNING, BUILDING, PAUSED, etc.).""" - return self._hf_api.get_space_runtime(self.space_id).stage - - def __enter__(self) -> Sandbox: - return self - - def __exit__(self, *exc): - if self._owns_space: - try: - self.delete() - except Exception as e: - print(f"Warning: failed to delete sandbox: {e}", file=sys.stderr) - self._client.close() - - # ── HTTP plumbing ───────────────────────────────────────────── - - def _call( - self, endpoint: str, payload: dict, timeout: float | None = None - ) -> ToolResult: - # Strip leading slash for correct httpx base_url resolution - endpoint = endpoint.lstrip("/") - try: - resp = self._client.post( - endpoint, - json=payload, - timeout=timeout or self.timeout, - ) - data = resp.json() - if resp.status_code == 200: - return ToolResult( - success=data.get("success", True), - output=data.get("output", ""), - error=data.get("error", ""), - ) - return ToolResult( - success=False, - error=data.get("error", f"HTTP {resp.status_code}"), - ) - except httpx.TimeoutException: - return ToolResult( - success=False, error=f"Timeout after {timeout or self.timeout}s" - ) - except httpx.ConnectError: - return ToolResult( - success=False, - error=f"Cannot connect to sandbox. Is {self.space_id} running? Status: {self.status}", - ) - except Exception as e: - return ToolResult(success=False, error=str(e)) - - # ── Tools ───────────────────────────────────────────────────── - - def bash( - self, - command: str, - *, - work_dir: str | None = None, - timeout: int | None = None, - description: str | None = None, - ) -> ToolResult: - return self._call( - "bash", - { - "command": command, - "work_dir": work_dir or self.work_dir, - "timeout": min(timeout or self.timeout, MAX_TIMEOUT), - }, - timeout=timeout, - ) - - def read( - self, path: str, *, offset: int | None = None, limit: int | None = None - ) -> ToolResult: - self._files_read.add(path) - return self._call( - "read", - { - "path": path, - "offset": offset, - "limit": limit or (DEFAULT_READ_LIMIT if offset is None else None), - }, - ) - - def write(self, path: str, content: str) -> ToolResult: - if path not in self._files_read: - check = self._call("exists", {"path": path}) - if check.success and check.output == "true": - return ToolResult( - success=False, - error=( - f"File {path} exists but has not been read this session. " - f"Read it first, or use sandbox_edit for targeted changes." - ), - ) - result = self._call("write", {"path": path, "content": content}) - if result.success: - self._files_read.add(path) - return result - - def edit( - self, path: str, old_str: str, new_str: str, *, replace_all: bool = False - ) -> ToolResult: - if old_str == new_str: - return ToolResult(success=False, error="old_str and new_str are identical.") - if path not in self._files_read: - return ToolResult( - success=False, - error=f"File {path} has not been read this session. Read it first.", - ) - return self._call( - "edit", - { - "path": path, - "old_str": old_str, - "new_str": new_str, - "replace_all": replace_all, - }, - ) - - # ── Tool schemas & dispatch ─────────────────────────────────── - - TOOLS = { - "bash": { - "description": ( - "Run a shell command in the remote sandbox and return stdout/stderr.\n" - "\n" - "Commands run in a shell at the working directory (default /app). " - "Each invocation is independent — use files in /app to persist state.\n" - "\n" - "AVOID using bash for operations covered by specialized tools:\n" - "- File reading: use read (not cat/head/tail)\n" - "- File editing: use edit (not sed/awk)\n" - "- File writing: use write (not echo/cat < /app/train.log 2>&1 &\n" - "Then check with read on the log file.\n" - "\n" - "Chain dependent commands with &&. Independent commands should be " - "separate bash calls (they can run in parallel).\n" - "\n" - "Timeout default 120s, max 600s." - ), - "parameters": { - "type": "object", - "required": ["command"], - "additionalProperties": False, - "properties": { - "command": { - "type": "string", - "description": "The shell command to execute.", - }, - "description": { - "type": "string", - "description": "Short description (5-10 words, active voice). E.g. 'Install dependencies', 'Run training script'.", - }, - "work_dir": { - "type": "string", - "description": "Working directory (default: /app).", - }, - "timeout": { - "type": "integer", - "description": "Timeout in seconds (default: 120, max: 600).", - }, - }, - }, - }, - "read": { - "description": ( - "Read file contents with line numbers (cat -n format).\n" - "\n" - "Returns the first 2000 lines by default. For large files, use offset/limit " - "to read a specific range. Line numbers always match the original file.\n" - "\n" - "Lines longer than 2000 chars are truncated.\n" - "Cannot read directories — use bash with 'ls' instead." - ), - "parameters": { - "type": "object", - "required": ["path"], - "additionalProperties": False, - "properties": { - "path": { - "type": "string", - "description": "Absolute path to the file to read.", - }, - "offset": { - "type": "integer", - "description": "Start from this line (1-based). Only if file is too large.", - }, - "limit": { - "type": "integer", - "description": "Number of lines to read. Only if file is too large.", - }, - }, - }, - }, - "write": { - "description": ( - "Create or overwrite a file. Creates parent directories as needed.\n" - "\n" - "For existing files, you MUST read the file first (system enforced). " - "Prefer edit for modifications." - ), - "parameters": { - "type": "object", - "required": ["path", "content"], - "additionalProperties": False, - "properties": { - "path": { - "type": "string", - "description": "Absolute path to the file to write.", - }, - "content": { - "type": "string", - "description": "Complete file content.", - }, - }, - }, - }, - "edit": { - "description": ( - "Targeted edit via exact string replacement.\n" - "\n" - "Rules:\n" - "- old_str must appear EXACTLY once (unless replace_all is true).\n" - "- Include enough context in old_str for uniqueness.\n" - "- old_str and new_str must differ.\n" - "- Preserve indentation exactly.\n" - "- To delete code, set new_str to empty string.\n" - "- File MUST have been read this session (system enforced).\n" - "- Do NOT include line number prefixes in old_str/new_str.\n" - "\n" - "Use replace_all=true for batch operations like variable renaming." - ), - "parameters": { - "type": "object", - "required": ["path", "old_str", "new_str"], - "additionalProperties": False, - "properties": { - "path": { - "type": "string", - "description": "Absolute path to the file.", - }, - "old_str": { - "type": "string", - "description": "Exact text to find (must differ from new_str).", - }, - "new_str": {"type": "string", "description": "Replacement text."}, - "replace_all": { - "type": "boolean", - "description": "Replace all occurrences (default: false).", - "default": False, - }, - }, - }, - }, - } - - @classmethod - def tool_definitions(cls) -> list[dict]: - return [{"name": name, **spec} for name, spec in cls.TOOLS.items()] - - def call_tool(self, name: str, arguments: dict[str, Any]) -> ToolResult: - dispatch = { - "bash": lambda a: self.bash( - a["command"], - work_dir=a.get("work_dir"), - timeout=a.get("timeout"), - description=a.get("description"), - ), - "read": lambda a: self.read( - a["path"], - offset=a.get("offset"), - limit=a.get("limit"), - ), - "write": lambda a: self.write(a["path"], a["content"]), - "edit": lambda a: self.edit( - a["path"], - a["old_str"], - a["new_str"], - replace_all=a.get("replace_all", False), - ), - } - fn = dispatch.get(name) - if not fn: - return ToolResult(success=False, error=f"Unknown tool: {name}") - return fn(arguments) diff --git a/agent/tools/sandbox_tool.py b/agent/tools/sandbox_tool.py deleted file mode 100644 index 61dea149a2605ecc4789d82f96ddeab9264058c9..0000000000000000000000000000000000000000 --- a/agent/tools/sandbox_tool.py +++ /dev/null @@ -1,201 +0,0 @@ -""" -Sandbox tools — expose the Sandbox client as agent tools. - -5 tools total: - sandbox_create — explicit sandbox creation (requires approval) - bash, read, write, edit — operations on the sandbox - -If any operation tool is called without an active sandbox, -a cpu-basic sandbox is auto-created (no approval needed). -""" - -from __future__ import annotations - -import asyncio -import os -from typing import Any - -from huggingface_hub import HfApi, SpaceHardware - -from agent.core.session import Event -from agent.tools.sandbox_client import Sandbox - -# ── Tool name mapping (short agent names → Sandbox client names) ────── - - -async def _ensure_sandbox( - session: Any, hardware: str = "cpu-basic", **create_kwargs -) -> tuple[Sandbox | None, str | None]: - """ - Ensure a sandbox exists on the session. Auto-creates with given hardware if needed. - - Returns: - (sandbox, error_message) — one will be None. - """ - if session and getattr(session, "sandbox", None): - return session.sandbox, None - - if not session: - return None, "No session available." - - token = os.environ.get("HF_TOKEN") - if not token: - return None, "HF_TOKEN environment variable not set. Cannot create sandbox." - - api = HfApi(token=token) - user_info = api.whoami() - owner = user_info.get("name", user_info.get("user", "")) - if not owner: - return None, "Could not determine HF username from token." - - await session.send_event( - Event( - event_type="tool_log", - data={ - "tool": "sandbox", - "log": f"Auto-creating sandbox for {owner} ({hardware})...", - }, - ) - ) - - kwargs = {"owner": owner, "hardware": hardware, "token": token, **create_kwargs} - sb = await asyncio.to_thread(Sandbox.create, **kwargs) - session.sandbox = sb - - await session.send_event( - Event( - event_type="tool_log", - data={"tool": "sandbox", "log": f"Sandbox ready: {sb.space_id} ({sb.url})"}, - ) - ) - - return sb, None - - -# ── sandbox_create tool ────────────────────────────────────────────── - -SANDBOX_CREATE_TOOL_SPEC = { - "name": "sandbox_create", - "description": ( - "Create a persistent remote Linux environment for developing and testing scripts.\n\n" - "Workflow: sandbox_create → write script → pip install → test with small run → fix errors → hf_jobs at scale.\n" - "The sandbox persists across tool calls within the session. pip install works out of the box.\n\n" - "Use this when: you need to develop, test, and iterate on scripts before launching via hf_jobs. " - "Especially for training scripts where you need to verify imports, test on a small subset, and fix errors interactively.\n\n" - "Skip this when: the task is a simple one-shot operation (status check, resource search, quick data query), " - "or the script is copied from a verified working example with minimal changes.\n\n" - "For ML code that uses CUDA, bf16, or model loading: use GPU hardware (t4-small minimum). " - "CPU sandboxes cannot run GPU code paths — your test will not catch GPU-related errors.\n\n" - "Hardware: " + ", ".join([e.value for e in SpaceHardware]) + ".\n" - ), - "parameters": { - "type": "object", - "required": [], - "additionalProperties": False, - "properties": { - "hardware": { - "type": "string", - "enum": [e.value for e in SpaceHardware], - "description": "Hardware tier for the sandbox (default: cpu-basic)", - }, - "private": { - "type": "boolean", - "description": "If true, create a private Space", - }, - }, - }, -} - - -async def sandbox_create_handler( - args: dict[str, Any], session: Any = None -) -> tuple[str, bool]: - """Handle sandbox_create tool calls.""" - # If sandbox already exists, return its info - if session and getattr(session, "sandbox", None): - sb = session.sandbox - return ( - f"Sandbox already active: {sb.space_id}\n" - f"URL: {sb.url}\n" - f"Use bash/read/write/edit to interact with it." - ), True - - hardware = args.get("hardware", "cpu-basic") - create_kwargs = {} - if "private" in args: - create_kwargs["private"] = args["private"] - - try: - sb, error = await _ensure_sandbox(session, hardware=hardware, **create_kwargs) - except Exception as e: - return f"Failed to create sandbox: {e}", False - - if error: - return error, False - - return ( - f"Sandbox created: {sb.space_id}\n" - f"URL: {sb.url}\n" - f"Hardware: {hardware}\n" - f"Use bash/read/write/edit to interact with it." - ), True - - -def _make_tool_handler(sandbox_tool_name: str): - """Factory: create a handler for a sandbox operation tool.""" - - async def handler(args: dict[str, Any], session: Any = None) -> tuple[str, bool]: - # Auto-create sandbox if not present - try: - sb, error = await _ensure_sandbox(session) - except Exception as e: - return f"Failed to auto-create sandbox: {e}", False - - if error: - return error, False - - try: - result = await asyncio.to_thread(sb.call_tool, sandbox_tool_name, args) - if result.success: - return result.output or "(no output)", True - else: - error_msg = result.error or "Unknown error" - output = result.output - if output: - return f"{output}\n\nERROR: {error_msg}", False - return f"ERROR: {error_msg}", False - except Exception as e: - return f"Sandbox operation failed: {e}", False - - return handler - - -def get_sandbox_tools(): - """Return all 5 sandbox ToolSpecs (sandbox_create + 4 operation tools).""" - from agent.core.tools import ToolSpec - - tools = [] - - # sandbox_create (explicit creation, requires approval) - tools.append( - ToolSpec( - name=SANDBOX_CREATE_TOOL_SPEC["name"], - description=SANDBOX_CREATE_TOOL_SPEC["description"], - parameters=SANDBOX_CREATE_TOOL_SPEC["parameters"], - handler=sandbox_create_handler, - ) - ) - - # Operation tools (auto-execute, no approval needed) - for name in Sandbox.TOOLS.keys(): - spec = Sandbox.TOOLS[name] - tools.append( - ToolSpec( - name=name, - description=spec["description"], - parameters=spec["parameters"], - handler=_make_tool_handler(name), - ) - ) - - return tools diff --git a/backend/dependencies.py b/backend/dependencies.py new file mode 100644 index 0000000000000000000000000000000000000000..03a1bb284507b6b78ad2a7534492934e416f6bed --- /dev/null +++ b/backend/dependencies.py @@ -0,0 +1,144 @@ +"""Authentication dependencies for FastAPI routes. + +Provides auth validation for both REST and WebSocket endpoints. +- In dev mode (OAUTH_CLIENT_ID not set): auth is bypassed, returns a default "dev" user. +- In production: validates Bearer tokens or cookies against HF OAuth. +""" + +import logging +import os +import time +from typing import Any + +import httpx +from fastapi import HTTPException, Request, WebSocket, status + +logger = logging.getLogger(__name__) + +OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co") +AUTH_ENABLED = bool(os.environ.get("OAUTH_CLIENT_ID", "")) + +# Simple in-memory token cache: token -> (user_info, expiry_time) +_token_cache: dict[str, tuple[dict[str, Any], float]] = {} +TOKEN_CACHE_TTL = 300 # 5 minutes + +DEV_USER: dict[str, Any] = { + "user_id": "dev", + "username": "dev", + "authenticated": True, +} + + +async def _validate_token(token: str) -> dict[str, Any] | None: + """Validate a token against HF OAuth userinfo endpoint. + + Results are cached for TOKEN_CACHE_TTL seconds to avoid excessive API calls. + """ + now = time.time() + + # Check cache + if token in _token_cache: + user_info, expiry = _token_cache[token] + if now < expiry: + return user_info + del _token_cache[token] + + # Validate against HF + async with httpx.AsyncClient(timeout=10.0) as client: + try: + response = await client.get( + f"{OPENID_PROVIDER_URL}/oauth/userinfo", + headers={"Authorization": f"Bearer {token}"}, + ) + if response.status_code != 200: + logger.debug("Token validation failed: status %d", response.status_code) + return None + user_info = response.json() + _token_cache[token] = (user_info, now + TOKEN_CACHE_TTL) + return user_info + except httpx.HTTPError as e: + logger.warning("Token validation error: %s", e) + return None + + +def _user_from_info(user_info: dict[str, Any]) -> dict[str, Any]: + """Build a normalized user dict from HF userinfo response.""" + return { + "user_id": user_info.get("sub", user_info.get("preferred_username", "unknown")), + "username": user_info.get("preferred_username", "unknown"), + "name": user_info.get("name"), + "picture": user_info.get("picture"), + "authenticated": True, + } + + +async def _extract_user_from_token(token: str) -> dict[str, Any] | None: + """Validate a token and return a user dict, or None.""" + user_info = await _validate_token(token) + if user_info: + return _user_from_info(user_info) + return None + + +async def get_current_user(request: Request) -> dict[str, Any]: + """FastAPI dependency: extract and validate the current user. + + Checks (in order): + 1. Authorization: Bearer header + 2. hf_access_token cookie + + In dev mode (AUTH_ENABLED=False), returns a default dev user. + """ + if not AUTH_ENABLED: + return DEV_USER + + # Try Authorization header + auth_header = request.headers.get("Authorization", "") + if auth_header.startswith("Bearer "): + token = auth_header[7:] + user = await _extract_user_from_token(token) + if user: + return user + + # Try cookie + token = request.cookies.get("hf_access_token") + if token: + user = await _extract_user_from_token(token) + if user: + return user + + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not authenticated. Please log in via /auth/login.", + headers={"WWW-Authenticate": "Bearer"}, + ) + + +async def get_ws_user(websocket: WebSocket) -> dict[str, Any] | None: + """Extract and validate user from WebSocket connection. + + WebSocket doesn't support custom headers from browser, so we check: + 1. ?token= query parameter + 2. hf_access_token cookie (sent automatically for same-origin) + + Returns user dict or None if not authenticated. + In dev mode, returns the default dev user. + """ + if not AUTH_ENABLED: + return DEV_USER + + # Try query param + token = websocket.query_params.get("token") + if token: + user = await _extract_user_from_token(token) + if user: + return user + + # Try cookie (works for same-origin WebSocket) + token = websocket.cookies.get("hf_access_token") + if token: + user = await _extract_user_from_token(token) + if user: + return user + + return None diff --git a/backend/main.py b/backend/main.py index 2ea33e05b92332f2bc4e32ba160b4029fdb69e31..fc75ab9e11696664776cc2370d68e589196af7ad 100644 --- a/backend/main.py +++ b/backend/main.py @@ -5,6 +5,14 @@ import os from contextlib import asynccontextmanager from pathlib import Path +from dotenv import load_dotenv + +load_dotenv() + +# Ensure HF_TOKEN is set — fall back to HF_ADMIN_TOKEN if available (HF Spaces) +if not os.environ.get("HF_TOKEN") and os.environ.get("HF_ADMIN_TOKEN"): + os.environ["HF_TOKEN"] = os.environ["HF_ADMIN_TOKEN"] + from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles diff --git a/backend/models.py b/backend/models.py index 0c529522f88804e8ccdc97224359a6ed62462ab2..f22ab3048b1d1a18e75517e1475521f59bce526d 100644 --- a/backend/models.py +++ b/backend/models.py @@ -37,6 +37,7 @@ class ToolApproval(BaseModel): tool_call_id: str approved: bool feedback: str | None = None + edited_script: str | None = None class ApprovalRequest(BaseModel): @@ -67,6 +68,7 @@ class SessionInfo(BaseModel): created_at: str is_active: bool message_count: int + user_id: str = "dev" class HealthResponse(BaseModel): @@ -74,3 +76,13 @@ class HealthResponse(BaseModel): status: str = "ok" active_sessions: int = 0 + max_sessions: int = 0 + + +class LLMHealthResponse(BaseModel): + """LLM provider health check response.""" + + status: str # "ok" | "error" + model: str + error: str | None = None + error_type: str | None = None # "auth" | "credits" | "rate_limit" | "network" | "unknown" diff --git a/backend/routes/agent.py b/backend/routes/agent.py index 8a3db071ef8707d8e79ac272facb203f44cd9857..fed198d672ff147e60ecbced0c60c474e964eced 100644 --- a/backend/routes/agent.py +++ b/backend/routes/agent.py @@ -1,58 +1,252 @@ -"""Agent API routes - WebSocket and REST endpoints.""" +"""Agent API routes - WebSocket and REST endpoints. -import logging +All routes (except /health) require authentication via the get_current_user +dependency. In dev mode (no OAUTH_CLIENT_ID), auth is bypassed automatically. +""" -from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect +import logging +import os +from typing import Any + +from dependencies import get_current_user, get_ws_user +from fastapi import ( + APIRouter, + Depends, + HTTPException, + Request, + WebSocket, + WebSocketDisconnect, +) +from litellm import acompletion +from agent.core.agent_loop import _resolve_hf_router_params from models import ( ApprovalRequest, HealthResponse, + LLMHealthResponse, SessionInfo, SessionResponse, SubmitRequest, ) -from session_manager import session_manager +from session_manager import MAX_SESSIONS, SessionCapacityError, session_manager from websocket import manager as ws_manager logger = logging.getLogger(__name__) router = APIRouter(prefix="/api", tags=["agent"]) +AVAILABLE_MODELS = [ + { + "id": "huggingface/novita/minimax/minimax-m2.1", + "label": "MiniMax M2.1", + "provider": "huggingface", + "recommended": True, + }, + { + "id": "anthropic/claude-opus-4-5-20251101", + "label": "Claude Opus 4.5", + "provider": "anthropic", + "recommended": True, + }, + { + "id": "huggingface/novita/moonshotai/kimi-k2.5", + "label": "Kimi K2.5", + "provider": "huggingface", + }, + { + "id": "huggingface/novita/zai-org/glm-5", + "label": "GLM 5", + "provider": "huggingface", + }, +] + + +def _check_session_access(session_id: str, user: dict[str, Any]) -> None: + """Verify the user has access to the given session. Raises 403 or 404.""" + info = session_manager.get_session_info(session_id) + if not info: + raise HTTPException(status_code=404, detail="Session not found") + if not session_manager.verify_session_access(session_id, user["user_id"]): + raise HTTPException(status_code=403, detail="Access denied to this session") + @router.get("/health", response_model=HealthResponse) async def health_check() -> HealthResponse: """Health check endpoint.""" return HealthResponse( - status="ok", active_sessions=session_manager.active_session_count + status="ok", + active_sessions=session_manager.active_session_count, + max_sessions=MAX_SESSIONS, ) +@router.get("/health/llm", response_model=LLMHealthResponse) +async def llm_health_check() -> LLMHealthResponse: + """Check if the LLM provider is reachable and the API key is valid. + + Makes a minimal 1-token completion call. Catches common errors: + - 401 → invalid API key + - 402/insufficient_quota → out of credits + - 429 → rate limited + - timeout / network → provider unreachable + """ + model = session_manager.config.model_name + try: + llm_params = _resolve_hf_router_params(model) + await acompletion( + messages=[{"role": "user", "content": "hi"}], + max_tokens=1, + timeout=10, + **llm_params, + ) + return LLMHealthResponse(status="ok", model=model) + except Exception as e: + err_str = str(e).lower() + error_type = "unknown" + + if ( + "401" in err_str + or "auth" in err_str + or "invalid" in err_str + or "api key" in err_str + ): + error_type = "auth" + elif ( + "402" in err_str + or "credit" in err_str + or "quota" in err_str + or "insufficient" in err_str + or "billing" in err_str + ): + error_type = "credits" + elif "429" in err_str or "rate" in err_str: + error_type = "rate_limit" + elif "timeout" in err_str or "connect" in err_str or "network" in err_str: + error_type = "network" + + logger.warning(f"LLM health check failed ({error_type}): {e}") + return LLMHealthResponse( + status="error", + model=model, + error=str(e)[:500], + error_type=error_type, + ) + + +@router.get("/config/model") +async def get_model() -> dict: + """Get current model and available models. No auth required.""" + return { + "current": session_manager.config.model_name, + "available": AVAILABLE_MODELS, + } + + +@router.post("/config/model") +async def set_model(body: dict, user: dict = Depends(get_current_user)) -> dict: + """Set the LLM model. Applies to new conversations.""" + model_id = body.get("model") + if not model_id: + raise HTTPException(status_code=400, detail="Missing 'model' field") + valid_ids = {m["id"] for m in AVAILABLE_MODELS} + if model_id not in valid_ids: + raise HTTPException(status_code=400, detail=f"Unknown model: {model_id}") + session_manager.config.model_name = model_id + logger.info(f"Model changed to {model_id} by {user.get('username', 'unknown')}") + return {"model": model_id} + + +@router.post("/title") +async def generate_title( + request: SubmitRequest, user: dict = Depends(get_current_user) +) -> dict: + """Generate a short title for a chat session based on the first user message.""" + model = session_manager.config.model_name + llm_params = _resolve_hf_router_params(model) + try: + response = await acompletion( + messages=[ + { + "role": "system", + "content": ( + "Generate a very short title (max 6 words) for a chat conversation " + "that starts with the following user message. " + "Reply with ONLY the title, no quotes, no punctuation at the end." + ), + }, + {"role": "user", "content": request.text[:500]}, + ], + max_tokens=20, + temperature=0.3, + timeout=8, + **llm_params, + ) + title = response.choices[0].message.content.strip().strip('"').strip("'") + # Safety: cap at 50 chars + if len(title) > 50: + title = title[:50].rstrip() + "…" + return {"title": title} + except Exception as e: + logger.warning(f"Title generation failed: {e}") + # Fallback: truncate the message + fallback = request.text.strip() + title = fallback[:40].rstrip() + "…" if len(fallback) > 40 else fallback + return {"title": title} + + @router.post("/session", response_model=SessionResponse) -async def create_session() -> SessionResponse: - """Create a new agent session.""" - session_id = await session_manager.create_session() +async def create_session( + request: Request, user: dict = Depends(get_current_user) +) -> SessionResponse: + """Create a new agent session bound to the authenticated user. + + The user's HF access token is extracted from the Authorization header + and stored in the session so that tools (e.g. hf_jobs) can act on + behalf of the user. + + Returns 503 if the server or user has reached the session limit. + """ + # Extract the user's HF token (Bearer header or HttpOnly cookie) + hf_token = None + auth_header = request.headers.get("Authorization", "") + if auth_header.startswith("Bearer "): + hf_token = auth_header[7:] + if not hf_token: + hf_token = request.cookies.get("hf_access_token") + + try: + session_id = await session_manager.create_session( + user_id=user["user_id"], hf_token=hf_token + ) + except SessionCapacityError as e: + raise HTTPException(status_code=503, detail=str(e)) + return SessionResponse(session_id=session_id, ready=True) @router.get("/session/{session_id}", response_model=SessionInfo) -async def get_session(session_id: str) -> SessionInfo: - """Get session information.""" +async def get_session( + session_id: str, user: dict = Depends(get_current_user) +) -> SessionInfo: + """Get session information. Only accessible by the session owner.""" + _check_session_access(session_id, user) info = session_manager.get_session_info(session_id) - if not info: - raise HTTPException(status_code=404, detail="Session not found") return SessionInfo(**info) @router.get("/sessions", response_model=list[SessionInfo]) -async def list_sessions() -> list[SessionInfo]: - """List all sessions.""" - sessions = session_manager.list_sessions() +async def list_sessions(user: dict = Depends(get_current_user)) -> list[SessionInfo]: + """List sessions belonging to the authenticated user.""" + sessions = session_manager.list_sessions(user_id=user["user_id"]) return [SessionInfo(**s) for s in sessions] @router.delete("/session/{session_id}") -async def delete_session(session_id: str) -> dict: - """Delete a session.""" +async def delete_session( + session_id: str, user: dict = Depends(get_current_user) +) -> dict: + """Delete a session. Only accessible by the session owner.""" + _check_session_access(session_id, user) success = await session_manager.delete_session(session_id) if not success: raise HTTPException(status_code=404, detail="Session not found") @@ -60,8 +254,11 @@ async def delete_session(session_id: str) -> dict: @router.post("/submit") -async def submit_input(request: SubmitRequest) -> dict: - """Submit user input to a session.""" +async def submit_input( + request: SubmitRequest, user: dict = Depends(get_current_user) +) -> dict: + """Submit user input to a session. Only accessible by the session owner.""" + _check_session_access(request.session_id, user) success = await session_manager.submit_user_input(request.session_id, request.text) if not success: raise HTTPException(status_code=404, detail="Session not found or inactive") @@ -69,13 +266,17 @@ async def submit_input(request: SubmitRequest) -> dict: @router.post("/approve") -async def submit_approval(request: ApprovalRequest) -> dict: - """Submit tool approvals to a session.""" +async def submit_approval( + request: ApprovalRequest, user: dict = Depends(get_current_user) +) -> dict: + """Submit tool approvals to a session. Only accessible by the session owner.""" + _check_session_access(request.session_id, user) approvals = [ { "tool_call_id": a.tool_call_id, "approved": a.approved, "feedback": a.feedback, + "edited_script": a.edited_script, } for a in request.approvals ] @@ -86,8 +287,11 @@ async def submit_approval(request: ApprovalRequest) -> dict: @router.post("/interrupt/{session_id}") -async def interrupt_session(session_id: str) -> dict: +async def interrupt_session( + session_id: str, user: dict = Depends(get_current_user) +) -> dict: """Interrupt the current operation in a session.""" + _check_session_access(session_id, user) success = await session_manager.interrupt(session_id) if not success: raise HTTPException(status_code=404, detail="Session not found or inactive") @@ -95,8 +299,9 @@ async def interrupt_session(session_id: str) -> dict: @router.post("/undo/{session_id}") -async def undo_session(session_id: str) -> dict: +async def undo_session(session_id: str, user: dict = Depends(get_current_user)) -> dict: """Undo the last turn in a session.""" + _check_session_access(session_id, user) success = await session_manager.undo(session_id) if not success: raise HTTPException(status_code=404, detail="Session not found or inactive") @@ -104,8 +309,11 @@ async def undo_session(session_id: str) -> dict: @router.post("/compact/{session_id}") -async def compact_session(session_id: str) -> dict: +async def compact_session( + session_id: str, user: dict = Depends(get_current_user) +) -> dict: """Compact the context in a session.""" + _check_session_access(session_id, user) success = await session_manager.compact(session_id) if not success: raise HTTPException(status_code=404, detail="Session not found or inactive") @@ -113,8 +321,11 @@ async def compact_session(session_id: str) -> dict: @router.post("/shutdown/{session_id}") -async def shutdown_session(session_id: str) -> dict: +async def shutdown_session( + session_id: str, user: dict = Depends(get_current_user) +) -> dict: """Shutdown a session.""" + _check_session_access(session_id, user) success = await session_manager.shutdown_session(session_id) if not success: raise HTTPException(status_code=404, detail="Session not found or inactive") @@ -123,17 +334,61 @@ async def shutdown_session(session_id: str) -> dict: @router.websocket("/ws/{session_id}") async def websocket_endpoint(websocket: WebSocket, session_id: str) -> None: - """WebSocket endpoint for real-time events.""" + """WebSocket endpoint for real-time events. + + Authentication is done via: + - ?token= query parameter (for browsers that can't send WS headers) + - Cookie (automatic for same-origin connections) + - Dev mode bypass (when OAUTH_CLIENT_ID is not set) + + NOTE: We must accept() before close() so the browser receives our custom + close codes (4001, 4003, 4004). If we close() before accept(), Starlette + sends HTTP 403 and the browser only sees code 1006 (abnormal closure). + """ logger.info(f"WebSocket connection request for session {session_id}") + + # Authenticate the WebSocket connection + user = await get_ws_user(websocket) + if not user: + logger.warning( + f"WebSocket rejected: authentication failed for session {session_id}" + ) + await websocket.accept() + await websocket.close(code=4001, reason="Authentication required") + return + # Verify session exists info = session_manager.get_session_info(session_id) if not info: - logger.warning(f"WebSocket connection rejected: Session {session_id} not found") + logger.warning(f"WebSocket rejected: session {session_id} not found") + await websocket.accept() await websocket.close(code=4004, reason="Session not found") return + # Verify user owns the session + if not session_manager.verify_session_access(session_id, user["user_id"]): + logger.warning( + f"WebSocket rejected: user {user['user_id']} denied access to session {session_id}" + ) + await websocket.accept() + await websocket.close(code=4003, reason="Access denied") + return + await ws_manager.connect(websocket, session_id) + # Send "ready" immediately on WebSocket connection so the frontend + # knows the session is alive. The original ready event from _run_session + # fires before the WS is connected and is always lost. + try: + await websocket.send_json( + { + "event_type": "ready", + "data": {"message": "Agent initialized"}, + } + ) + except Exception as e: + logger.error(f"Failed to send ready event for session {session_id}: {e}") + try: while True: # Keep connection alive, handle ping/pong diff --git a/backend/routes/auth.py b/backend/routes/auth.py index a39bacd0b4696d2c3989ff161be3c54ffec2b14f..224febf4b926890eb58943e3103a985fe0ed4626 100644 --- a/backend/routes/auth.py +++ b/backend/routes/auth.py @@ -1,11 +1,17 @@ -"""Authentication routes for HF OAuth.""" +"""Authentication routes for HF OAuth. + +Handles the OAuth 2.0 authorization code flow with HF as provider. +After successful auth, sets an HttpOnly cookie with the access token. +""" import os import secrets +import time from urllib.parse import urlencode import httpx -from fastapi import APIRouter, HTTPException, Request +from dependencies import AUTH_ENABLED, get_current_user +from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import RedirectResponse router = APIRouter(prefix="/auth", tags=["auth"]) @@ -15,10 +21,19 @@ OAUTH_CLIENT_ID = os.environ.get("OAUTH_CLIENT_ID", "") OAUTH_CLIENT_SECRET = os.environ.get("OAUTH_CLIENT_SECRET", "") OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co") -# In-memory session store (replace with proper session management in production) +# In-memory OAuth state store with expiry (5 min TTL) +_OAUTH_STATE_TTL = 300 oauth_states: dict[str, dict] = {} +def _cleanup_expired_states() -> None: + """Remove expired OAuth states to prevent memory growth.""" + now = time.time() + expired = [k for k, v in oauth_states.items() if now > v.get("expires_at", 0)] + for k in expired: + del oauth_states[k] + + def get_redirect_uri(request: Request) -> str: """Get the OAuth callback redirect URI.""" # In HF Spaces, use the SPACE_HOST if available @@ -38,17 +53,26 @@ async def oauth_login(request: Request) -> RedirectResponse: detail="OAuth not configured. Set OAUTH_CLIENT_ID environment variable.", ) + # Clean up expired states to prevent memory growth + _cleanup_expired_states() + # Generate state for CSRF protection state = secrets.token_urlsafe(32) - oauth_states[state] = {"redirect_uri": get_redirect_uri(request)} + oauth_states[state] = { + "redirect_uri": get_redirect_uri(request), + "expires_at": time.time() + _OAUTH_STATE_TTL, + } # Build authorization URL params = { "client_id": OAUTH_CLIENT_ID, "redirect_uri": get_redirect_uri(request), - "scope": "openid profile", + "scope": "openid profile read-repos write-repos contribute-repos manage-repos inference-api jobs write-discussions", "response_type": "code", "state": state, + "orgIds": os.environ.get( + "HF_OAUTH_ORG_ID", "698dbf55845d85df163175f1" + ), # ml-agent-explorers } auth_url = f"{OPENID_PROVIDER_URL}/oauth/authorize?{urlencode(params)}" @@ -91,58 +115,57 @@ async def oauth_callback( # Get user info access_token = token_data.get("access_token") - if access_token: - async with httpx.AsyncClient() as client: - try: - userinfo_response = await client.get( - f"{OPENID_PROVIDER_URL}/oauth/userinfo", - headers={"Authorization": f"Bearer {access_token}"}, - ) - userinfo_response.raise_for_status() - user_info = userinfo_response.json() - except httpx.HTTPError: - user_info = {} - else: - user_info = {} - - # For now, redirect to home with token in query params - # In production, use secure cookies or session storage - redirect_params = { - "access_token": access_token, - "username": user_info.get("preferred_username", ""), - } + if not access_token: + raise HTTPException( + status_code=500, + detail="Token exchange succeeded but no access_token was returned.", + ) - return RedirectResponse(url=f"/?{urlencode(redirect_params)}") + # Fetch user info (optional — failure is not fatal) + async with httpx.AsyncClient() as client: + try: + userinfo_response = await client.get( + f"{OPENID_PROVIDER_URL}/oauth/userinfo", + headers={"Authorization": f"Bearer {access_token}"}, + ) + userinfo_response.raise_for_status() + except httpx.HTTPError: + pass # user_info not required for auth flow + + # Set access token as HttpOnly cookie (not in URL — avoids leaks via + # Referrer headers, browser history, and server logs) + is_production = bool(os.environ.get("SPACE_HOST")) + response = RedirectResponse(url="/", status_code=302) + response.set_cookie( + key="hf_access_token", + value=access_token, + httponly=True, + secure=is_production, # Secure flag only in production (HTTPS) + samesite="lax", + max_age=3600 * 24, # 24 hours + path="/", + ) + return response @router.get("/logout") async def logout() -> RedirectResponse: - """Log out the user.""" - return RedirectResponse(url="/") + """Log out the user by clearing the auth cookie.""" + response = RedirectResponse(url="/") + response.delete_cookie(key="hf_access_token", path="/") + return response -@router.get("/me") -async def get_current_user(request: Request) -> dict: - """Get current user info from Authorization header.""" - auth_header = request.headers.get("Authorization", "") - if not auth_header.startswith("Bearer "): - return {"authenticated": False} +@router.get("/status") +async def auth_status() -> dict: + """Check if OAuth is enabled on this instance.""" + return {"auth_enabled": AUTH_ENABLED} - token = auth_header.split(" ")[1] - async with httpx.AsyncClient() as client: - try: - response = await client.get( - f"{OPENID_PROVIDER_URL}/oauth/userinfo", - headers={"Authorization": f"Bearer {token}"}, - ) - response.raise_for_status() - user_info = response.json() - return { - "authenticated": True, - "username": user_info.get("preferred_username"), - "name": user_info.get("name"), - "picture": user_info.get("picture"), - } - except httpx.HTTPError: - return {"authenticated": False} +@router.get("/me") +async def get_me(user: dict = Depends(get_current_user)) -> dict: + """Get current user info. Returns the authenticated user or dev user. + + Uses the shared auth dependency which handles cookie + Bearer token. + """ + return user diff --git a/backend/session_manager.py b/backend/session_manager.py index 058b376123ff9d00013ce4fa7e49e7bbee5585d0..03d9b2d9b8d706f1fa391f69e43e759d77246b86 100644 --- a/backend/session_manager.py +++ b/backend/session_manager.py @@ -48,11 +48,28 @@ class AgentSession: session: Session tool_router: ToolRouter submission_queue: asyncio.Queue + user_id: str = "dev" # Owner of this session + hf_token: str | None = None # User's HF OAuth token for tool execution task: asyncio.Task | None = None created_at: datetime = field(default_factory=datetime.utcnow) is_active: bool = True +class SessionCapacityError(Exception): + """Raised when no more sessions can be created.""" + + def __init__(self, message: str, error_type: str = "global") -> None: + super().__init__(message) + self.error_type = error_type # "global" or "per_user" + + +# ── Capacity limits ───────────────────────────────────────────────── +# Estimated for HF Spaces cpu-basic (2 vCPU, 16 GB RAM). +# Each session uses ~10-20 MB (context, tools, queues, task). +MAX_SESSIONS: int = 50 +MAX_SESSIONS_PER_USER: int = 10 + + class SessionManager: """Manages multiple concurrent agent sessions.""" @@ -61,19 +78,69 @@ class SessionManager: self.sessions: dict[str, AgentSession] = {} self._lock = asyncio.Lock() - async def create_session(self) -> str: - """Create a new agent session and return its ID.""" + def _count_user_sessions(self, user_id: str) -> int: + """Count active sessions owned by a specific user.""" + return sum( + 1 + for s in self.sessions.values() + if s.user_id == user_id and s.is_active + ) + + async def create_session(self, user_id: str = "dev", hf_token: str | None = None) -> str: + """Create a new agent session and return its ID. + + Session() and ToolRouter() constructors contain blocking I/O + (e.g. HfApi().whoami(), litellm.get_max_tokens()) so they are + executed in a thread pool to avoid freezing the async event loop. + + Args: + user_id: The ID of the user who owns this session. + + Raises: + SessionCapacityError: If the server or user has reached the + maximum number of concurrent sessions. + """ + # ── Capacity checks ────────────────────────────────────────── + async with self._lock: + active_count = self.active_session_count + if active_count >= MAX_SESSIONS: + raise SessionCapacityError( + f"Server is at capacity ({active_count}/{MAX_SESSIONS} sessions). " + "Please try again later.", + error_type="global", + ) + if user_id != "dev": + user_count = self._count_user_sessions(user_id) + if user_count >= MAX_SESSIONS_PER_USER: + raise SessionCapacityError( + f"You have reached the maximum of {MAX_SESSIONS_PER_USER} " + "concurrent sessions. Please close an existing session first.", + error_type="per_user", + ) + session_id = str(uuid.uuid4()) # Create queues for this session submission_queue: asyncio.Queue = asyncio.Queue() event_queue: asyncio.Queue = asyncio.Queue() - # Create tool router - tool_router = ToolRouter(self.config.mcpServers) + # Run blocking constructors in a thread to keep the event loop responsive. + # Without this, Session.__init__ → ContextManager → litellm.get_max_tokens() + # blocks all HTTP/WebSocket handling. + import time as _time + + def _create_session_sync(): + t0 = _time.monotonic() + tool_router = ToolRouter(self.config.mcpServers) + session = Session(event_queue, config=self.config, tool_router=tool_router) + t1 = _time.monotonic() + logger.info(f"Session initialized in {t1 - t0:.2f}s") + return tool_router, session - # Create the agent session - session = Session(event_queue, config=self.config, tool_router=tool_router) + tool_router, session = await asyncio.to_thread(_create_session_sync) + + # Store user's HF token on the session so tools can use it + session.hf_token = hf_token # Create wrapper agent_session = AgentSession( @@ -81,6 +148,8 @@ class SessionManager: session=session, tool_router=tool_router, submission_queue=submission_queue, + user_id=user_id, + hf_token=hf_token, ) async with self._lock: @@ -92,7 +161,7 @@ class SessionManager: ) agent_session.task = task - logger.info(f"Created session {session_id}") + logger.info(f"Created session {session_id} for user {user_id}") return session_id async def _run_session( @@ -245,6 +314,27 @@ class SessionManager: return True + def get_session_owner(self, session_id: str) -> str | None: + """Get the user_id that owns a session, or None if session doesn't exist.""" + agent_session = self.sessions.get(session_id) + if not agent_session: + return None + return agent_session.user_id + + def verify_session_access(self, session_id: str, user_id: str) -> bool: + """Check if a user has access to a session. + + Returns True if: + - The session exists AND the user owns it + - The user_id is "dev" (dev mode bypass) + """ + owner = self.get_session_owner(session_id) + if owner is None: + return False + if user_id == "dev" or owner == "dev": + return True + return owner == user_id + def get_session_info(self, session_id: str) -> dict[str, Any] | None: """Get information about a session.""" agent_session = self.sessions.get(session_id) @@ -256,15 +346,25 @@ class SessionManager: "created_at": agent_session.created_at.isoformat(), "is_active": agent_session.is_active, "message_count": len(agent_session.session.context_manager.items), + "user_id": agent_session.user_id, } - def list_sessions(self) -> list[dict[str, Any]]: - """List all sessions.""" - return [ - self.get_session_info(sid) - for sid in self.sessions - if self.get_session_info(sid) - ] + def list_sessions(self, user_id: str | None = None) -> list[dict[str, Any]]: + """List sessions, optionally filtered by user. + + Args: + user_id: If provided, only return sessions owned by this user. + If "dev", return all sessions (dev mode). + """ + results = [] + for sid in self.sessions: + info = self.get_session_info(sid) + if not info: + continue + if user_id and user_id != "dev" and info.get("user_id") != user_id: + continue + results.append(info) + return results @property def active_session_count(self) -> int: diff --git a/backend/websocket.py b/backend/websocket.py index 924d7a831273890db939d32a8094373ee0a69fd3..bc09ed747b164bbe99ddebd6d35a36ae6a2faad8 100644 --- a/backend/websocket.py +++ b/backend/websocket.py @@ -1,6 +1,5 @@ """WebSocket connection manager for real-time communication.""" -import asyncio import logging from typing import Any @@ -15,23 +14,18 @@ class ConnectionManager: def __init__(self) -> None: # session_id -> WebSocket self.active_connections: dict[str, WebSocket] = {} - # session_id -> asyncio.Queue for outgoing messages - self.message_queues: dict[str, asyncio.Queue] = {} async def connect(self, websocket: WebSocket, session_id: str) -> None: """Accept a WebSocket connection and register it.""" logger.info(f"Attempting to accept WebSocket for session {session_id}") await websocket.accept() self.active_connections[session_id] = websocket - self.message_queues[session_id] = asyncio.Queue() logger.info(f"WebSocket connected and registered for session {session_id}") def disconnect(self, session_id: str) -> None: """Remove a WebSocket connection.""" if session_id in self.active_connections: del self.active_connections[session_id] - if session_id in self.message_queues: - del self.message_queues[session_id] logger.info(f"WebSocket disconnected for session {session_id}") async def send_event( @@ -63,10 +57,6 @@ class ConnectionManager: """Check if a session has an active WebSocket connection.""" return session_id in self.active_connections - def get_queue(self, session_id: str) -> asyncio.Queue | None: - """Get the message queue for a session.""" - return self.message_queues.get(session_id) - # Global connection manager instance manager = ConnectionManager() diff --git a/configs/main_agent_config.json b/configs/main_agent_config.json index 18a414b3bfced18b47d2737579e3db9c9d137cd6..1ef25f251b053a73dc461cf4bdf617bb11c983d5 100644 --- a/configs/main_agent_config.json +++ b/configs/main_agent_config.json @@ -1,9 +1,9 @@ { - "model_name": "anthropic/claude-opus-4-5-20251101", + "model_name": "huggingface/novita/moonshotai/kimi-k2.5", "save_sessions": true, "session_dataset_repo": "akseljoonas/hf-agent-sessions", "yolo_mode": false, - "confirm_cpu_jobs": false, + "confirm_cpu_jobs": true, "auto_file_upload": true, "mcpServers": { "hf-mcp-server": { diff --git a/frontend/package-lock.json b/frontend/package-lock.json index a800dd3f254b2ff725890c4f250e34d7490bf52d..1e1a41bb0535aa7abb2b7a8ac165ab216c1bd384 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -8,10 +8,12 @@ "name": "hf-agent-frontend", "version": "1.0.0", "dependencies": { + "@ai-sdk/react": "^3.0.93", "@emotion/react": "^11.13.0", "@emotion/styled": "^11.13.0", "@mui/icons-material": "^6.1.0", "@mui/material": "^6.1.0", + "ai": "^6.0.91", "react": "^18.3.1", "react-dom": "^18.3.1", "react-markdown": "^9.0.1", @@ -34,6 +36,70 @@ "vite": "^5.4.10" } }, + "node_modules/@ai-sdk/gateway": { + "version": "3.0.50", + "resolved": "https://registry.npmjs.org/@ai-sdk/gateway/-/gateway-3.0.50.tgz", + "integrity": "sha512-Jdd1a8VgbD7l7r+COj0h5SuaYRfPvOJ/AO6l0OrmTPEcI2MUQPr3C4JttfpNkcheEN+gOdy0CtZWuG17bW2fjw==", + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "3.0.8", + "@ai-sdk/provider-utils": "4.0.15", + "@vercel/oidc": "3.1.0" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, + "node_modules/@ai-sdk/provider": { + "version": "3.0.8", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-3.0.8.tgz", + "integrity": "sha512-oGMAgGoQdBXbZqNG0Ze56CHjDZ1IDYOwGYxYjO5KLSlz5HiNQ9udIXsPZ61VWaHGZ5XW/jyjmr6t2xz2jGVwbQ==", + "license": "Apache-2.0", + "dependencies": { + "json-schema": "^0.4.0" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/@ai-sdk/provider-utils": { + "version": "4.0.15", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-4.0.15.tgz", + "integrity": "sha512-8XiKWbemmCbvNN0CLR9u3PQiet4gtEVIrX4zzLxnCj06AwsEDJwJVBbKrEI4t6qE8XRSIvU2irka0dcpziKW6w==", + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "3.0.8", + "@standard-schema/spec": "^1.1.0", + "eventsource-parser": "^3.0.6" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, + "node_modules/@ai-sdk/react": { + "version": "3.0.93", + "resolved": "https://registry.npmjs.org/@ai-sdk/react/-/react-3.0.93.tgz", + "integrity": "sha512-FY1HmeAfCpiAGLhIZh2QR8QFzHFZfhjMmkA9D5KC/O3eGqPeY7CwBABLkzRH+5Gkf+MfxXnEm4VF0MpmvDMjpg==", + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider-utils": "4.0.15", + "ai": "6.0.91", + "swr": "^2.2.5", + "throttleit": "2.1.0" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "react": "^18 || ~19.0.1 || ~19.1.2 || ^19.2.1" + } + }, "node_modules/@babel/code-frame": { "version": "7.28.6", "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.28.6.tgz", @@ -1348,6 +1414,15 @@ } } }, + "node_modules/@opentelemetry/api": { + "version": "1.9.0", + "resolved": "https://registry.npmjs.org/@opentelemetry/api/-/api-1.9.0.tgz", + "integrity": "sha512-3giAOQvZiH5F9bMlMiv8+GSPMeqg0dbaeo58/0SlA9sxSqZhnUtxzX9/2FzyhS9sWQf5S0GJE0AKBrFqjpeYcg==", + "license": "Apache-2.0", + "engines": { + "node": ">=8.0.0" + } + }, "node_modules/@popperjs/core": { "version": "2.11.8", "resolved": "https://registry.npmjs.org/@popperjs/core/-/core-2.11.8.tgz", @@ -1715,6 +1790,12 @@ "win32" ] }, + "node_modules/@standard-schema/spec": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/@standard-schema/spec/-/spec-1.1.0.tgz", + "integrity": "sha512-l2aFy5jALhniG5HgqrD6jXLi/rUWrKvqN/qJx6yoJsgKhblVd+iqqU4RCXavm/jPityDo5TCvKMnpjKnOriy0w==", + "license": "MIT" + }, "node_modules/@types/babel__core": { "version": "7.20.5", "resolved": "https://registry.npmjs.org/@types/babel__core/-/babel__core-7.20.5.tgz", @@ -2155,6 +2236,15 @@ "integrity": "sha512-WmoN8qaIAo7WTYWbAZuG8PYEhn5fkz7dZrqTBZ7dtt//lL2Gwms1IcnQ5yHqjDfX8Ft5j4YzDM23f87zBfDe9g==", "license": "ISC" }, + "node_modules/@vercel/oidc": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/@vercel/oidc/-/oidc-3.1.0.tgz", + "integrity": "sha512-Fw28YZpRnA3cAHHDlkt7xQHiJ0fcL+NRcIqsocZQUSmbzeIKRpwttJjik5ZGanXP+vlA4SbTg+AbA3bP363l+w==", + "license": "Apache-2.0", + "engines": { + "node": ">= 20" + } + }, "node_modules/@vitejs/plugin-react": { "version": "4.7.0", "resolved": "https://registry.npmjs.org/@vitejs/plugin-react/-/plugin-react-4.7.0.tgz", @@ -2200,6 +2290,24 @@ "acorn": "^6.0.0 || ^7.0.0 || ^8.0.0" } }, + "node_modules/ai": { + "version": "6.0.91", + "resolved": "https://registry.npmjs.org/ai/-/ai-6.0.91.tgz", + "integrity": "sha512-k1/8BusZMhYVxxLZt0BUZzm9HVDCCh117nyWfWUx5xjR2+tWisJbXgysL7EBMq2lgyHwgpA1jDR3tVjWSdWZXw==", + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/gateway": "3.0.50", + "@ai-sdk/provider": "3.0.8", + "@ai-sdk/provider-utils": "4.0.15", + "@opentelemetry/api": "1.9.0" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, "node_modules/ajv": { "version": "6.12.6", "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", @@ -2848,6 +2956,15 @@ "node": ">=0.10.0" } }, + "node_modules/eventsource-parser": { + "version": "3.0.6", + "resolved": "https://registry.npmjs.org/eventsource-parser/-/eventsource-parser-3.0.6.tgz", + "integrity": "sha512-Vo1ab+QXPzZ4tCa8SwIHJFaSzy4R6SHf7BY79rFBDf0idraZWAkYrDjDj8uWaSm3S2TK+hJ7/t1CEmZ7jXw+pg==", + "license": "MIT", + "engines": { + "node": ">=18.0.0" + } + }, "node_modules/extend": { "version": "3.0.2", "resolved": "https://registry.npmjs.org/extend/-/extend-3.0.2.tgz", @@ -3356,6 +3473,12 @@ "integrity": "sha512-xyFwyhro/JEof6Ghe2iz2NcXoj2sloNsWr/XsERDK/oiPCfaNhl5ONfp+jQdAZRQQ0IJWNzH9zIZF7li91kh2w==", "license": "MIT" }, + "node_modules/json-schema": { + "version": "0.4.0", + "resolved": "https://registry.npmjs.org/json-schema/-/json-schema-0.4.0.tgz", + "integrity": "sha512-es94M3nTIfsEPisRafak+HDLfHXnKBhV3vU5eqPcS3flIWqcxJWgXHXiey3YrpaNsanY5ei1VoYEbOzijuq9BA==", + "license": "(AFL-2.1 OR BSD-3-Clause)" + }, "node_modules/json-schema-traverse": { "version": "0.4.1", "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-0.4.1.tgz", @@ -5052,6 +5175,31 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/swr": { + "version": "2.4.0", + "resolved": "https://registry.npmjs.org/swr/-/swr-2.4.0.tgz", + "integrity": "sha512-sUlC20T8EOt1pHmDiqueUWMmRRX03W7w5YxovWX7VR2KHEPCTMly85x05vpkP5i6Bu4h44ePSMD9Tc+G2MItFw==", + "license": "MIT", + "dependencies": { + "dequal": "^2.0.3", + "use-sync-external-store": "^1.6.0" + }, + "peerDependencies": { + "react": "^16.11.0 || ^17.0.0 || ^18.0.0 || ^19.0.0" + } + }, + "node_modules/throttleit": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/throttleit/-/throttleit-2.1.0.tgz", + "integrity": "sha512-nt6AMGKW1p/70DF/hGBdJB57B8Tspmbp5gfJ8ilhLnt7kkr2ye7hzD6NVG8GGErk2HWF34igrL2CXmNIkzKqKw==", + "license": "MIT", + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, "node_modules/tinyglobby": { "version": "0.2.15", "resolved": "https://registry.npmjs.org/tinyglobby/-/tinyglobby-0.2.15.tgz", @@ -5282,6 +5430,16 @@ "punycode": "^2.1.0" } }, + "node_modules/use-sync-external-store": { + "version": "1.6.0", + "resolved": "https://registry.npmjs.org/use-sync-external-store/-/use-sync-external-store-1.6.0.tgz", + "integrity": "sha512-Pp6GSwGP/NrPIrxVFAIkOQeyw8lFenOHijQWkUTrDvrF4ALqylP2C/KCkeS9dpUM3KvYRQhna5vt7IL95+ZQ9w==", + "license": "MIT", + "peer": true, + "peerDependencies": { + "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0" + } + }, "node_modules/vfile": { "version": "6.0.3", "resolved": "https://registry.npmjs.org/vfile/-/vfile-6.0.3.tgz", @@ -5426,6 +5584,16 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/zod": { + "version": "4.3.6", + "resolved": "https://registry.npmjs.org/zod/-/zod-4.3.6.tgz", + "integrity": "sha512-rftlrkhHZOcjDwkGlnUtZZkvaPHCsDATp4pGpuOOMDaTdDDXF91wuVDJoWoPsKX/3YPQ5fHuF3STjcYyKr+Qhg==", + "license": "MIT", + "peer": true, + "funding": { + "url": "https://github.com/sponsors/colinhacks" + } + }, "node_modules/zustand": { "version": "5.0.10", "resolved": "https://registry.npmjs.org/zustand/-/zustand-5.0.10.tgz", diff --git a/frontend/package.json b/frontend/package.json index 553726bae62a96f8869c8bec29bf3fbad511bc0c..9efe3dced3118cbf0976e413f376f1050f1b2853 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -10,10 +10,12 @@ "preview": "vite preview" }, "dependencies": { + "@ai-sdk/react": "^3.0.93", "@emotion/react": "^11.13.0", "@emotion/styled": "^11.13.0", "@mui/icons-material": "^6.1.0", "@mui/material": "^6.1.0", + "ai": "^6.0.91", "react": "^18.3.1", "react-dom": "^18.3.1", "react-markdown": "^9.0.1", diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index e9aecc9ed99aee276dc509c7078d4c7404b50d89..de1f785734359130675174b487a4d30d1ca34f50 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -1,7 +1,12 @@ import { Box } from '@mui/material'; import AppLayout from '@/components/Layout/AppLayout'; +import { useAuth } from '@/hooks/useAuth'; function App() { + // Non-blocking auth check — fires in background, updates store when done. + // If auth fails later, apiFetch redirects to /auth/login. + useAuth(); + return ( diff --git a/frontend/src/components/ApprovalModal/ApprovalModal.tsx b/frontend/src/components/ApprovalModal/ApprovalModal.tsx deleted file mode 100644 index 98414524160bf8c81c7efdb3d6ce1adc2578435f..0000000000000000000000000000000000000000 --- a/frontend/src/components/ApprovalModal/ApprovalModal.tsx +++ /dev/null @@ -1,208 +0,0 @@ -import { useState, useCallback } from 'react'; -import { - Dialog, - DialogTitle, - DialogContent, - DialogActions, - Button, - Box, - Typography, - Checkbox, - FormControlLabel, - Accordion, - AccordionSummary, - AccordionDetails, - TextField, - Chip, -} from '@mui/material'; -import ExpandMoreIcon from '@mui/icons-material/ExpandMore'; -import WarningIcon from '@mui/icons-material/Warning'; -import { useAgentStore } from '@/store/agentStore'; - -interface ApprovalModalProps { - sessionId: string | null; -} - -interface ApprovalState { - [toolCallId: string]: { - approved: boolean; - feedback: string; - }; -} - -export default function ApprovalModal({ sessionId }: ApprovalModalProps) { - const { pendingApprovals, setPendingApprovals } = useAgentStore(); - const [approvalState, setApprovalState] = useState({}); - - const isOpen = pendingApprovals !== null && pendingApprovals.tools.length > 0; - - const handleApprovalChange = useCallback( - (toolCallId: string, approved: boolean) => { - setApprovalState((prev) => ({ - ...prev, - [toolCallId]: { - ...prev[toolCallId], - approved, - feedback: prev[toolCallId]?.feedback || '', - }, - })); - }, - [] - ); - - const handleFeedbackChange = useCallback( - (toolCallId: string, feedback: string) => { - setApprovalState((prev) => ({ - ...prev, - [toolCallId]: { - ...prev[toolCallId], - feedback, - }, - })); - }, - [] - ); - - const handleSubmit = useCallback(async () => { - if (!sessionId || !pendingApprovals) return; - - const approvals = pendingApprovals.tools.map((tool) => ({ - tool_call_id: tool.tool_call_id, - approved: approvalState[tool.tool_call_id]?.approved ?? false, - feedback: approvalState[tool.tool_call_id]?.feedback || null, - })); - - try { - await fetch('/api/approve', { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ - session_id: sessionId, - approvals, - }), - }); - setPendingApprovals(null); - setApprovalState({}); - } catch (e) { - console.error('Approval submission failed:', e); - } - }, [sessionId, pendingApprovals, approvalState, setPendingApprovals]); - - const handleApproveAll = useCallback(() => { - if (!pendingApprovals) return; - const newState: ApprovalState = {}; - pendingApprovals.tools.forEach((tool) => { - newState[tool.tool_call_id] = { approved: true, feedback: '' }; - }); - setApprovalState(newState); - }, [pendingApprovals]); - - const handleRejectAll = useCallback(() => { - if (!pendingApprovals) return; - const newState: ApprovalState = {}; - pendingApprovals.tools.forEach((tool) => { - newState[tool.tool_call_id] = { approved: false, feedback: '' }; - }); - setApprovalState(newState); - }, [pendingApprovals]); - - if (!isOpen || !pendingApprovals) return null; - - const approvedCount = Object.values(approvalState).filter((s) => s.approved).length; - - return ( - - - - Approval Required - 1 ? 's' : ''}`} - size="small" - sx={{ ml: 1 }} - /> - - - - The following tool calls require your approval before execution: - - {pendingApprovals.tools.map((tool, index) => ( - - }> - - { - e.stopPropagation(); - handleApprovalChange(tool.tool_call_id, e.target.checked); - }} - onClick={(e) => e.stopPropagation()} - /> - } - label="" - sx={{ m: 0 }} - /> - - - {approvalState[tool.tool_call_id]?.approved ? 'Approved' : 'Pending'} - - - - - - Arguments: - - - {JSON.stringify(tool.arguments, null, 2)} - - {!approvalState[tool.tool_call_id]?.approved && ( - handleFeedbackChange(tool.tool_call_id, e.target.value)} - sx={{ mt: 2 }} - /> - )} - - - ))} - - - - - - - {approvedCount} of {pendingApprovals.count} approved - - - - - ); -} diff --git a/frontend/src/components/Chat/ActivityStatusBar.tsx b/frontend/src/components/Chat/ActivityStatusBar.tsx new file mode 100644 index 0000000000000000000000000000000000000000..5aac135875b951d97fdb942eda9b6e1f4943e193 --- /dev/null +++ b/frontend/src/components/Chat/ActivityStatusBar.tsx @@ -0,0 +1,57 @@ +import { Box, Typography } from '@mui/material'; +import { keyframes } from '@mui/system'; +import { useAgentStore, type ActivityStatus } from '@/store/agentStore'; + +const shimmer = keyframes` + 0% { background-position: -100% center; } + 50% { background-position: 200% center; } + 100% { background-position: -100% center; } +`; + +const TOOL_LABELS: Record = { + hf_jobs: 'Running job', + hf_repo_files: 'Uploading file', + hf_repo_git: 'Git operation', + hf_inspect_dataset: 'Inspecting dataset', + hf_search: 'Searching', + plan_tool: 'Planning', +}; + +function statusLabel(status: ActivityStatus): string { + switch (status.type) { + case 'thinking': return 'Thinking'; + case 'streaming': return 'Writing'; + case 'tool': return TOOL_LABELS[status.toolName] || `Running ${status.toolName}`; + case 'waiting-approval': return 'Waiting for approval'; + default: return ''; + } +} + +export default function ActivityStatusBar() { + const activityStatus = useAgentStore(s => s.activityStatus); + + if (activityStatus.type === 'idle') return null; + + const label = statusLabel(activityStatus); + + return ( + + + {label}… + + + ); +} diff --git a/frontend/src/components/Chat/ApprovalFlow.tsx b/frontend/src/components/Chat/ApprovalFlow.tsx deleted file mode 100644 index 58c1d8e6520ac561ae341965c2f421c74112ad63..0000000000000000000000000000000000000000 --- a/frontend/src/components/Chat/ApprovalFlow.tsx +++ /dev/null @@ -1,515 +0,0 @@ -import { useState, useCallback, useEffect } from 'react'; -import { Box, Typography, Button, TextField, IconButton, Link } from '@mui/material'; -import SendIcon from '@mui/icons-material/Send'; -import OpenInNewIcon from '@mui/icons-material/OpenInNew'; -import CheckCircleIcon from '@mui/icons-material/CheckCircle'; -import CancelIcon from '@mui/icons-material/Cancel'; -import LaunchIcon from '@mui/icons-material/Launch'; -import { useAgentStore } from '@/store/agentStore'; -import { useLayoutStore } from '@/store/layoutStore'; -import { useSessionStore } from '@/store/sessionStore'; -import type { Message, ToolApproval } from '@/types/agent'; - -interface ApprovalFlowProps { - message: Message; -} - -export default function ApprovalFlow({ message }: ApprovalFlowProps) { - const { setPanelContent, setPanelTab, setActivePanelTab, clearPanelTabs, updateMessage } = useAgentStore(); - const { setRightPanelOpen, setLeftSidebarOpen } = useLayoutStore(); - const { activeSessionId } = useSessionStore(); - const [currentIndex, setCurrentIndex] = useState(0); - const [feedback, setFeedback] = useState(''); - const [decisions, setDecisions] = useState([]); - - const approvalData = message.approval; - - if (!approvalData) return null; - - const { batch, status } = approvalData; - - // Parse toolOutput to extract job info (URL, status, logs, errors) - let logsContent = ''; - let showLogsButton = false; - let jobUrl = ''; - let jobStatus = ''; - let jobFailed = false; - let errorMessage = ''; - - if (message.toolOutput) { - const output = message.toolOutput; - - // Extract job URL: **View at:** https://... - const urlMatch = output.match(/\*\*View at:\*\*\s*(https:\/\/[^\s\n]+)/); - if (urlMatch) { - jobUrl = urlMatch[1]; - } - - // Extract job status: **Final Status:** ... - const statusMatch = output.match(/\*\*Final Status:\*\*\s*([^\n]+)/); - if (statusMatch) { - jobStatus = statusMatch[1].trim(); - jobFailed = jobStatus.toLowerCase().includes('error') || jobStatus.toLowerCase().includes('failed'); - } - - // Extract logs - if (output.includes('**Logs:**')) { - const parts = output.split('**Logs:**'); - if (parts.length > 1) { - const logsPart = parts[1].trim(); - const codeBlockMatch = logsPart.match(/```([\s\S]*?)```/); - if (codeBlockMatch) { - logsContent = codeBlockMatch[1].trim(); - showLogsButton = true; - } - } - } - - // Detect errors - if output exists but doesn't have the expected job completion format - // This catches early failures (validation errors, API errors, etc.) - const isExpectedFormat = output.includes('**Job ID:**') || output.includes('**View at:**'); - const looksLikeError = output.toLowerCase().includes('error') || - output.toLowerCase().includes('failed') || - output.toLowerCase().includes('exception') || - output.includes('Traceback'); - - if (!isExpectedFormat || (looksLikeError && !logsContent)) { - // This is likely an error message - show it - errorMessage = output; - jobFailed = true; - } - } - - // Sync right panel with current tool - useEffect(() => { - if (!batch || currentIndex >= batch.tools.length) return; - - // Only auto-open panel if pending - if (status !== 'pending') return; - - const tool = batch.tools[currentIndex]; - const args = tool.arguments as any; - - if (tool.tool === 'hf_jobs' && (args.operation === 'run' || args.operation === 'scheduled run') && args.script) { - setPanelContent({ - title: 'Compute Job Script', - content: args.script, - language: 'python', - parameters: args - }); - // Don't auto-open if already resolved - } else if (tool.tool === 'hf_repo_files' && args.operation === 'upload' && args.content) { - setPanelContent({ - title: `File Upload: ${args.path || 'unnamed'}`, - content: args.content, - parameters: args - }); - } - }, [currentIndex, batch, status, setPanelContent]); - - const handleResolve = useCallback(async (approved: boolean) => { - if (!batch || !activeSessionId) return; - - const currentTool = batch.tools[currentIndex]; - const newDecisions = [ - ...decisions, - { - tool_call_id: currentTool.tool_call_id, - approved, - feedback: approved ? null : feedback || 'Rejected by user', - }, - ]; - - if (currentIndex < batch.tools.length - 1) { - setDecisions(newDecisions); - setCurrentIndex(currentIndex + 1); - setFeedback(''); - } else { - // All tools in batch resolved - try { - await fetch('/api/approve', { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ - session_id: activeSessionId, - approvals: newDecisions, - }), - }); - - // Update message status - updateMessage(activeSessionId, message.id, { - approval: { - ...approvalData!, - status: approved ? 'approved' : 'rejected', - decisions: newDecisions - } - }); - - } catch (e) { - console.error('Approval submission failed:', e); - } - } - }, [activeSessionId, message.id, batch, currentIndex, feedback, decisions, approvalData, updateMessage]); - - if (!batch || currentIndex >= batch.tools.length) return null; - - const currentTool = batch.tools[currentIndex]; - - // Check if script contains push_to_hub or upload_file - const args = currentTool.arguments as any; - const containsPushToHub = currentTool.tool === 'hf_jobs' && args.script && (args.script.includes('push_to_hub') || args.script.includes('upload_file')); - - const getToolDescription = (toolName: string, args: any) => { - if (toolName === 'hf_jobs') { - return ( - - - The agent wants to execute hf_jobs on{' '} - {args.hardware_flavor || 'default'} with a timeout of{' '} - {args.timeout || '30m'} - - - ); - } - return ( - - The agent wants to execute {toolName} - - ); - }; - - const showCode = () => { - const args = currentTool.arguments as any; - if (currentTool.tool === 'hf_jobs' && args.script) { - // Clear existing tabs and set up script tab (and logs if available) - clearPanelTabs(); - setPanelTab({ - id: 'script', - title: 'Script', - content: args.script, - language: 'python', - parameters: args - }); - // If logs are available (job completed), also add logs tab - if (logsContent) { - setPanelTab({ - id: 'logs', - title: 'Logs', - content: logsContent, - language: 'text' - }); - } - setActivePanelTab('script'); - setRightPanelOpen(true); - setLeftSidebarOpen(false); - } else { - setPanelContent({ - title: `Tool: ${currentTool.tool}`, - content: JSON.stringify(args, null, 2), - language: 'json', - parameters: args - }); - setRightPanelOpen(true); - setLeftSidebarOpen(false); - } - }; - - const handleViewLogs = (e: React.MouseEvent) => { - e.stopPropagation(); - const args = currentTool.arguments as any; - // Set up both tabs so user can switch between script and logs - clearPanelTabs(); - if (currentTool.tool === 'hf_jobs' && args.script) { - setPanelTab({ - id: 'script', - title: 'Script', - content: args.script, - language: 'python', - parameters: args - }); - } - setPanelTab({ - id: 'logs', - title: 'Logs', - content: logsContent, - language: 'text' - }); - setActivePanelTab('logs'); - setRightPanelOpen(true); - setLeftSidebarOpen(false); - }; - - return ( - - - - {status === 'pending' ? 'Approval Required' : status === 'approved' ? 'Approved' : 'Rejected'} - - - ({currentIndex + 1}/{batch.count}) - - {status === 'approved' && } - {status === 'rejected' && } - - - - {getToolDescription(currentTool.tool, currentTool.arguments)} - - - - {/* Script/Logs buttons for hf_jobs - always show when we have a script */} - {currentTool.tool === 'hf_jobs' && args.script && ( - - - - - - - {/* Job URL - only show when we have a specific URL */} - {jobUrl && ( - - - View Job on Hugging Face - - )} - - {/* Show job status if available */} - {jobStatus && ( - - Status: {jobStatus} - - )} - - )} - - {containsPushToHub && ( - - We've detected the result will be pushed to hub. - - )} - - {/* Show error message if job failed */} - {errorMessage && status !== 'pending' && ( - - - Error - - - {errorMessage.length > 500 ? errorMessage.substring(0, 500) + '...' : errorMessage} - - - )} - - - {status === 'pending' && ( - - - setFeedback(e.target.value)} - variant="outlined" - sx={{ - '& .MuiOutlinedInput-root': { - bgcolor: 'rgba(0,0,0,0.2)', - fontFamily: 'inherit', - fontSize: '0.9rem' - } - }} - /> - handleResolve(false)} - disabled={!feedback} - title="Reject with feedback" - sx={{ - color: 'var(--accent-red)', - border: '1px solid rgba(255,255,255,0.05)', - borderRadius: '8px', - width: 40, - height: 40, - '&:hover': { - bgcolor: 'rgba(224, 90, 79, 0.1)', - borderColor: 'var(--accent-red)', - }, - '&.Mui-disabled': { - color: 'rgba(255,255,255,0.1)', - borderColor: 'rgba(255,255,255,0.02)' - } - }} - > - - - - - - - - - - )} - - {status === 'rejected' && decisions.some(d => d.feedback) && ( - - Feedback: {decisions.find(d => d.feedback)?.feedback} - - )} - - ); -} \ No newline at end of file diff --git a/frontend/src/components/Chat/AssistantMessage.tsx b/frontend/src/components/Chat/AssistantMessage.tsx new file mode 100644 index 0000000000000000000000000000000000000000..83bd8cae505808781908a2292eaa8acc1242536b --- /dev/null +++ b/frontend/src/components/Chat/AssistantMessage.tsx @@ -0,0 +1,119 @@ +import { useMemo } from 'react'; +import { Box, Stack, Typography } from '@mui/material'; +import MarkdownContent from './MarkdownContent'; +import ToolCallGroup from './ToolCallGroup'; +import type { UIMessage } from 'ai'; +import type { MessageMeta } from '@/types/agent'; + +interface AssistantMessageProps { + message: UIMessage; + isStreaming?: boolean; + approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null }>) => Promise; +} + +/** + * Groups consecutive tool parts together so they render as a single + * ToolCallGroup (visually identical to the old segments approach). + */ +type DynamicToolPart = Extract; + +function groupParts(parts: UIMessage['parts']) { + const groups: Array< + | { kind: 'text'; text: string; idx: number } + | { kind: 'tools'; tools: DynamicToolPart[]; idx: number } + > = []; + + for (let i = 0; i < parts.length; i++) { + const part = parts[i]; + + if (part.type === 'text') { + groups.push({ kind: 'text', text: part.text, idx: i }); + } else if (part.type === 'dynamic-tool') { + const toolPart = part as DynamicToolPart; + const last = groups[groups.length - 1]; + if (last?.kind === 'tools') { + last.tools.push(toolPart); + } else { + groups.push({ kind: 'tools', tools: [toolPart], idx: i }); + } + } + // step-start, step-end, etc. are ignored visually + } + + return groups; +} + +export default function AssistantMessage({ message, isStreaming = false, approveTools }: AssistantMessageProps) { + const groups = useMemo(() => groupParts(message.parts), [message.parts]); + + // Find the last text group index for streaming cursor + let lastTextIdx = -1; + for (let i = groups.length - 1; i >= 0; i--) { + if (groups[i].kind === 'text') { lastTextIdx = i; break; } + } + + const meta = message.metadata as MessageMeta | undefined; + const timeStr = meta?.createdAt + ? new Date(meta.createdAt).toLocaleTimeString([], { hour: '2-digit', minute: '2-digit' }) + : null; + + if (groups.length === 0) return null; + + return ( + + + + Assistant + + {timeStr && ( + + {timeStr} + + )} + + + + {groups.map((group, i) => { + if (group.kind === 'text' && group.text) { + return ( + + ); + } + if (group.kind === 'tools' && group.tools.length > 0) { + return ( + + ); + } + return null; + })} + + + ); +} diff --git a/frontend/src/components/Chat/ChatInput.tsx b/frontend/src/components/Chat/ChatInput.tsx index a242eee54a373b6559ef565803e0be76d15df5c1..5fa7bd5f03c321a8d39245a95014b314221cd1b2 100644 --- a/frontend/src/components/Chat/ChatInput.tsx +++ b/frontend/src/components/Chat/ChatInput.tsx @@ -1,14 +1,103 @@ -import { useState, useCallback, KeyboardEvent } from 'react'; -import { Box, TextField, IconButton, CircularProgress, Typography } from '@mui/material'; +import { useState, useCallback, useEffect, useRef, KeyboardEvent } from 'react'; +import { Box, TextField, IconButton, CircularProgress, Typography, Menu, MenuItem, ListItemIcon, ListItemText, Chip } from '@mui/material'; import ArrowUpwardIcon from '@mui/icons-material/ArrowUpward'; +import ArrowDropDownIcon from '@mui/icons-material/ArrowDropDown'; +import { apiFetch } from '@/utils/api'; + +// Model configuration +interface ModelOption { + id: string; + name: string; + description: string; + modelPath: string; + avatarUrl: string; + recommended?: boolean; +} + +const getHfAvatarUrl = (modelId: string) => { + const org = modelId.split('/')[0]; + return `https://huggingface.co/api/avatars/${org}`; +}; + +const MODEL_OPTIONS: ModelOption[] = [ + { + id: 'minimax-m2.1', + name: 'MiniMax M2.1', + description: 'Via Novita', + modelPath: 'huggingface/novita/minimax/minimax-m2.1', + avatarUrl: getHfAvatarUrl('MiniMaxAI/MiniMax-M2.1'), + recommended: true, + }, + { + id: 'claude-opus', + name: 'Claude Opus 4.5', + description: 'Anthropic', + modelPath: 'anthropic/claude-opus-4-5-20251101', + avatarUrl: 'https://huggingface.co/api/avatars/Anthropic', + recommended: true, + }, + { + id: 'kimi-k2.5', + name: 'Kimi K2.5', + description: 'Via Novita', + modelPath: 'huggingface/novita/moonshotai/kimi-k2.5', + avatarUrl: getHfAvatarUrl('moonshotai/Kimi-K2.5'), + }, + { + id: 'glm-5', + name: 'GLM 5', + description: 'Via Novita', + modelPath: 'huggingface/novita/zai-org/glm-5', + avatarUrl: getHfAvatarUrl('zai-org/GLM-5'), + }, +]; + +const findModelByPath = (path: string): ModelOption | undefined => { + return MODEL_OPTIONS.find(m => m.modelPath === path || path?.includes(m.id)); +}; interface ChatInputProps { onSend: (text: string) => void; disabled?: boolean; + placeholder?: string; } -export default function ChatInput({ onSend, disabled = false }: ChatInputProps) { +export default function ChatInput({ onSend, disabled = false, placeholder = 'Ask anything...' }: ChatInputProps) { const [input, setInput] = useState(''); + const inputRef = useRef(null); + const [selectedModelId, setSelectedModelId] = useState(() => { + try { + const stored = localStorage.getItem('hf-agent-model'); + if (stored && MODEL_OPTIONS.some(m => m.id === stored)) return stored; + } catch { /* localStorage unavailable */ } + return MODEL_OPTIONS[0].id; + }); + const [modelAnchorEl, setModelAnchorEl] = useState(null); + + // Sync with backend on mount (backend is source of truth, localStorage is just a cache) + useEffect(() => { + fetch('/api/config/model') + .then((res) => (res.ok ? res.json() : null)) + .then((data) => { + if (data?.current) { + const model = findModelByPath(data.current); + if (model) { + setSelectedModelId(model.id); + try { localStorage.setItem('hf-agent-model', model.id); } catch { /* ignore */ } + } + } + }) + .catch(() => { /* ignore */ }); + }, []); + + const selectedModel = MODEL_OPTIONS.find(m => m.id === selectedModelId) || MODEL_OPTIONS[0]; + + // Auto-focus the textarea when the session becomes ready (disabled -> false) + useEffect(() => { + if (!disabled && inputRef.current) { + inputRef.current.focus(); + } + }, [disabled]); const handleSend = useCallback(() => { if (input.trim() && !disabled) { @@ -27,26 +116,48 @@ export default function ChatInput({ onSend, disabled = false }: ChatInputProps) [handleSend] ); + const handleModelClick = (event: React.MouseEvent) => { + setModelAnchorEl(event.currentTarget); + }; + + const handleModelClose = () => { + setModelAnchorEl(null); + }; + + const handleSelectModel = async (model: ModelOption) => { + handleModelClose(); + try { + const res = await apiFetch('/api/config/model', { + method: 'POST', + body: JSON.stringify({ model: model.modelPath }), + }); + if (res.ok) { + setSelectedModelId(model.id); + try { localStorage.setItem('hf-agent-model', model.id); } catch { /* ignore */ } + } + } catch { /* ignore */ } + }; + return ( - + setInput(e.target.value)} onKeyDown={handleKeyDown} - placeholder="Ask anything..." + placeholder={placeholder} disabled={disabled} variant="standard" + inputRef={inputRef} InputProps={{ disableUnderline: true, sx: { @@ -72,7 +184,7 @@ export default function ChatInput({ onSend, disabled = false }: ChatInputProps) fontFamily: 'inherit', padding: 0, lineHeight: 1.5, - minHeight: '56px', + minHeight: { xs: '44px', md: '56px' }, alignItems: 'flex-start', } }} @@ -99,7 +211,7 @@ export default function ChatInput({ onSend, disabled = false }: ChatInputProps) transition: 'all 0.2s', '&:hover': { color: 'var(--accent-yellow)', - bgcolor: 'rgba(255,255,255,0.05)', + bgcolor: 'var(--hover-bg)', }, '&.Mui-disabled': { opacity: 0.3, @@ -109,17 +221,108 @@ export default function ChatInput({ onSend, disabled = false }: ChatInputProps) {disabled ? : } - + {/* Powered By Badge */} - + powered by - Claude + {selectedModel.name} - claude-opus-4-5-20251101 + {selectedModel.name} + + + {/* Model Selection Menu */} + + {MODEL_OPTIONS.map((model) => ( + handleSelectModel(model)} + selected={selectedModelId === model.id} + sx={{ + py: 1.5, + '&.Mui-selected': { + bgcolor: 'rgba(255,255,255,0.05)', + } + }} + > + + {model.name} + + + {model.name} + {model.recommended && ( + + )} + + } + secondary={model.description} + secondaryTypographyProps={{ + sx: { fontSize: '12px', color: 'var(--muted-text)' } + }} + /> + + ))} + ); diff --git a/frontend/src/components/Chat/MarkdownContent.tsx b/frontend/src/components/Chat/MarkdownContent.tsx new file mode 100644 index 0000000000000000000000000000000000000000..beb682720bf2b4d846b67a86d45607bc4544044b --- /dev/null +++ b/frontend/src/components/Chat/MarkdownContent.tsx @@ -0,0 +1,160 @@ +import { useMemo, useRef, useState, useEffect } from 'react'; +import { Box } from '@mui/material'; +import ReactMarkdown from 'react-markdown'; +import remarkGfm from 'remark-gfm'; +import type { SxProps, Theme } from '@mui/material/styles'; + +interface MarkdownContentProps { + content: string; + sx?: SxProps; + /** When true, shows a blinking cursor and throttles renders. */ + isStreaming?: boolean; +} + +/** Shared markdown styles — adapts to light/dark via CSS variables. */ +const markdownSx: SxProps = { + fontSize: '0.925rem', + lineHeight: 1.7, + color: 'var(--text)', + wordBreak: 'break-word', + + '& p': { m: 0, mb: 1.5, '&:last-child': { mb: 0 } }, + + '& h1, & h2, & h3, & h4': { mt: 2.5, mb: 1, fontWeight: 600, lineHeight: 1.3 }, + '& h1': { fontSize: '1.35rem' }, + '& h2': { fontSize: '1.15rem' }, + '& h3': { fontSize: '1.05rem' }, + + '& pre': { + bgcolor: 'var(--code-bg)', + p: 2, + borderRadius: 2, + overflow: 'auto', + fontSize: '0.82rem', + lineHeight: 1.6, + border: '1px solid var(--tool-border)', + my: 2, + }, + '& code': { + bgcolor: 'var(--hover-bg)', + px: 0.75, + py: 0.25, + borderRadius: 0.5, + fontSize: '0.84rem', + fontFamily: '"JetBrains Mono", ui-monospace, SFMono-Regular, Menlo, Monaco, monospace', + }, + '& pre code': { bgcolor: 'transparent', p: 0 }, + + '& a': { + color: 'var(--accent-yellow)', + textDecoration: 'none', + fontWeight: 500, + '&:hover': { textDecoration: 'underline' }, + }, + + '& ul, & ol': { pl: 3, my: 1 }, + '& li': { mb: 0.5 }, + '& li::marker': { color: 'var(--muted-text)' }, + + '& blockquote': { + borderLeft: '3px solid var(--accent-yellow)', + pl: 2, + ml: 0, + my: 1.5, + color: 'var(--muted-text)', + fontStyle: 'italic', + }, + + '& table': { + borderCollapse: 'collapse', + width: '100%', + my: 2, + fontSize: '0.85rem', + }, + '& th': { + borderBottom: '2px solid var(--border-hover)', + textAlign: 'left', + p: 1, + fontWeight: 600, + }, + '& td': { + borderBottom: '1px solid var(--tool-border)', + p: 1, + }, + + '& hr': { + border: 'none', + borderTop: '1px solid var(--border)', + my: 2, + }, + + '& img': { + maxWidth: '100%', + borderRadius: 2, + }, +}; + +/** + * Throttled content for streaming: render the full markdown through + * ReactMarkdown but only re-parse every ~80ms to avoid layout thrashing. + * This is the Claude approach — always render as markdown, never split + * into raw text. The parser handles incomplete tables gracefully. + */ +function useThrottledValue(value: string, isStreaming: boolean, intervalMs = 80): string { + const [throttled, setThrottled] = useState(value); + const lastUpdate = useRef(0); + const pending = useRef | null>(null); + const latestValue = useRef(value); + latestValue.current = value; + + useEffect(() => { + if (!isStreaming) { + // Not streaming — always use latest value immediately + setThrottled(value); + return; + } + + const now = Date.now(); + const elapsed = now - lastUpdate.current; + + if (elapsed >= intervalMs) { + // Enough time passed — update immediately + setThrottled(value); + lastUpdate.current = now; + } else { + // Schedule an update for the remaining time + if (pending.current) clearTimeout(pending.current); + pending.current = setTimeout(() => { + setThrottled(latestValue.current); + lastUpdate.current = Date.now(); + pending.current = null; + }, intervalMs - elapsed); + } + + return () => { + if (pending.current) clearTimeout(pending.current); + }; + }, [value, isStreaming, intervalMs]); + + // When streaming ends, flush immediately + useEffect(() => { + if (!isStreaming) { + setThrottled(latestValue.current); + } + }, [isStreaming]); + + return throttled; +} + +export default function MarkdownContent({ content, sx, isStreaming = false }: MarkdownContentProps) { + // Throttle re-parses during streaming to ~12fps (every 80ms) + const displayContent = useThrottledValue(content, isStreaming); + + const remarkPlugins = useMemo(() => [remarkGfm], []); + + return ( + + {displayContent} + + ); +} diff --git a/frontend/src/components/Chat/MessageBubble.tsx b/frontend/src/components/Chat/MessageBubble.tsx index 5e7e0f197afc796beb443bf4b326d8b93f57023a..af5d5a49a1d426039b1bb65ef52abae3b601a104 100644 --- a/frontend/src/components/Chat/MessageBubble.tsx +++ b/frontend/src/components/Chat/MessageBubble.tsx @@ -1,215 +1,44 @@ -import { Box, Paper, Typography } from '@mui/material'; -import ReactMarkdown from 'react-markdown'; -import remarkGfm from 'remark-gfm'; -import ApprovalFlow from './ApprovalFlow'; -import type { Message, TraceLog } from '@/types/agent'; -import { useAgentStore } from '@/store/agentStore'; -import { useLayoutStore } from '@/store/layoutStore'; +import UserMessage from './UserMessage'; +import AssistantMessage from './AssistantMessage'; +import type { UIMessage } from 'ai'; interface MessageBubbleProps { - message: Message; + message: UIMessage; + isLastTurn?: boolean; + onUndoTurn?: () => void; + isProcessing?: boolean; + isStreaming?: boolean; + approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null }>) => Promise; } -// Render a tools segment with clickable tool calls -function ToolsSegment({ tools }: { tools: TraceLog[] }) { - const { showToolOutput } = useAgentStore(); - const { setRightPanelOpen } = useLayoutStore(); - - const handleToolClick = (log: TraceLog) => { - if (log.completed && log.output) { - showToolOutput(log); - setRightPanelOpen(true); - } - }; - - return ( - - - {tools.map((log) => { - const isClickable = log.completed && log.output; - return ( - handleToolClick(log)} - sx={{ - color: 'var(--muted-text)', - fontFamily: 'ui-monospace, SFMono-Regular, Menlo, Monaco, monospace', - fontSize: '0.75rem', - display: 'flex', - alignItems: 'center', - gap: 0.5, - cursor: isClickable ? 'pointer' : 'default', - borderRadius: 0.5, - px: 0.5, - mx: -0.5, - transition: 'background-color 0.15s ease', - '&:hover': isClickable ? { - bgcolor: 'rgba(255,255,255,0.05)', - } : {}, - }} - > - - {log.completed ? (log.success === false ? '✗' : '✓') : '•'} - - - {log.tool} - - {!log.completed && ...} - {isClickable && ( - - click to view - - )} - - ); - })} - - - ); -} - -// Markdown styles -const markdownStyles = { - '& p': { m: 0, mb: 1, '&:last-child': { mb: 0 } }, - '& pre': { - bgcolor: 'rgba(0,0,0,0.5)', - p: 1.5, - borderRadius: 1, - overflow: 'auto', - fontSize: '0.85rem', - border: '1px solid rgba(255,255,255,0.05)', - }, - '& code': { - bgcolor: 'rgba(255,255,255,0.05)', - px: 0.5, - py: 0.25, - borderRadius: 0.5, - fontSize: '0.85rem', - fontFamily: '"JetBrains Mono", monospace', - }, - '& pre code': { bgcolor: 'transparent', p: 0 }, - '& a': { - color: 'var(--accent-yellow)', - textDecoration: 'none', - '&:hover': { textDecoration: 'underline' }, - }, - '& ul, & ol': { pl: 2, my: 1 }, - '& table': { - borderCollapse: 'collapse', - width: '100%', - my: 2, - fontSize: '0.875rem', - }, - '& th': { - borderBottom: '1px solid rgba(255,255,255,0.1)', - textAlign: 'left', - p: 1, - bgcolor: 'rgba(255,255,255,0.02)', - }, - '& td': { - borderBottom: '1px solid rgba(255,255,255,0.05)', - p: 1, - }, -}; - -export default function MessageBubble({ message }: MessageBubbleProps) { - const isUser = message.role === 'user'; - const isAssistant = message.role === 'assistant'; - - if (message.approval) { +export default function MessageBubble({ + message, + isLastTurn = false, + onUndoTurn, + isProcessing = false, + isStreaming = false, + approveTools, +}: MessageBubbleProps) { + if (message.role === 'user') { return ( - - - + ); } - // Render segments chronologically if available, otherwise fall back to content - const renderContent = () => { - if (message.segments && message.segments.length > 0) { - return message.segments.map((segment, idx) => { - if (segment.type === 'text' && segment.content) { - return ( - - {segment.content} - - ); - } - if (segment.type === 'tools' && segment.tools && segment.tools.length > 0) { - return ; - } - return null; - }); - } - // Fallback: just render content + if (message.role === 'assistant') { return ( - - {message.content} - + ); - }; - - return ( - - - {renderContent()} + } - - {new Date(message.timestamp).toLocaleTimeString()} - - - - ); + return null; } diff --git a/frontend/src/components/Chat/MessageList.tsx b/frontend/src/components/Chat/MessageList.tsx index c54d4761c83681f1f1eebe7c7eb8619de3c5d962..21729fdc441dea8ebfcad1189686645ff1844e7b 100644 --- a/frontend/src/components/Chat/MessageList.tsx +++ b/frontend/src/components/Chat/MessageList.tsx @@ -1,100 +1,151 @@ -import { useEffect, useRef } from 'react'; -import { Box, Typography } from '@mui/material'; -import { useSessionStore } from '@/store/sessionStore'; +import { useCallback, useEffect, useRef, useMemo } from 'react'; +import { Box, Stack, Typography } from '@mui/material'; import MessageBubble from './MessageBubble'; -import type { Message } from '@/types/agent'; +import ActivityStatusBar from './ActivityStatusBar'; +import { useAgentStore } from '@/store/agentStore'; +import type { UIMessage } from 'ai'; interface MessageListProps { - messages: Message[]; + messages: UIMessage[]; isProcessing: boolean; + approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null }>) => Promise; + onUndoLastTurn: () => void | Promise; } -const TechnicalIndicator = () => ( - -); +function getGreeting(): string { + const h = new Date().getHours(); + if (h < 12) return 'Morning'; + if (h < 17) return 'Afternoon'; + return 'Evening'; +} + +function WelcomeGreeting() { + const { user } = useAgentStore(); + const firstName = user?.name?.split(' ')[0] || user?.username; + const greeting = firstName ? `${getGreeting()}, ${firstName}` : getGreeting(); -export default function MessageList({ messages, isProcessing }: MessageListProps) { - const bottomRef = useRef(null); - const { activeSessionId } = useSessionStore(); + return ( + + + {greeting} + + + Let's build something impressive? + + + ); +} + +export default function MessageList({ messages, isProcessing, approveTools, onUndoLastTurn }: MessageListProps) { + const scrollContainerRef = useRef(null); + const stickToBottom = useRef(true); + + const scrollToBottom = useCallback(() => { + const el = scrollContainerRef.current; + if (el) el.scrollTop = el.scrollHeight; + }, []); + + useEffect(() => { + const el = scrollContainerRef.current; + if (!el) return; + const onScroll = () => { + const distFromBottom = el.scrollHeight - el.scrollTop - el.clientHeight; + stickToBottom.current = distFromBottom < 80; + }; + el.addEventListener('scroll', onScroll, { passive: true }); + return () => el.removeEventListener('scroll', onScroll); + }, []); + + useEffect(() => { + if (stickToBottom.current) scrollToBottom(); + }, [messages, isProcessing, scrollToBottom]); - // Auto-scroll to bottom when new messages arrive useEffect(() => { - bottomRef.current?.scrollIntoView({ behavior: 'smooth' }); - }, [messages, isProcessing]); + const el = scrollContainerRef.current; + if (!el) return; + const observer = new MutationObserver(() => { + if (stickToBottom.current) el.scrollTop = el.scrollHeight; + }); + observer.observe(el, { childList: true, subtree: true, characterData: true }); + return () => observer.disconnect(); + }, []); + + const lastUserMsgId = useMemo(() => { + for (let i = messages.length - 1; i >= 0; i--) { + if (messages[i].role === 'user') return messages[i].id; + } + return null; + }, [messages]); + + // The last assistant message is "streaming" when we're processing + const lastAssistantId = useMemo(() => { + for (let i = messages.length - 1; i >= 0; i--) { + if (messages[i].role === 'assistant') return messages[i].id; + } + return null; + }, [messages]); return ( - + {messages.length === 0 && !isProcessing ? ( - - - Awaiting input… - - + ) : ( - messages.map((message) => ( - + messages.map((msg) => ( + )) )} - - {isProcessing && ( - - - - Thinking - - - - - )} - {activeSessionId && ( - // ApprovalFlow is now handled within messages - null - )} - -
- + + +
+ ); -} \ No newline at end of file +} diff --git a/frontend/src/components/Chat/ThinkingIndicator.tsx b/frontend/src/components/Chat/ThinkingIndicator.tsx new file mode 100644 index 0000000000000000000000000000000000000000..b8c37181f5de70eeb26a6f42049311393cc3d73e --- /dev/null +++ b/frontend/src/components/Chat/ThinkingIndicator.tsx @@ -0,0 +1,48 @@ +import { Box, Typography } from '@mui/material'; + +/** Pulsing dots shown while the agent is processing. */ +export default function ThinkingIndicator() { + return ( + + + Thinking + + + + + + + + ); +} diff --git a/frontend/src/components/Chat/ToolCallGroup.tsx b/frontend/src/components/Chat/ToolCallGroup.tsx new file mode 100644 index 0000000000000000000000000000000000000000..65bc752ba9020a3a78f036271e94adb3db62960c --- /dev/null +++ b/frontend/src/components/Chat/ToolCallGroup.tsx @@ -0,0 +1,655 @@ +import { useCallback, useMemo, useRef, useState } from 'react'; +import { Box, Stack, Typography, Chip, Button, TextField, IconButton, Link, CircularProgress } from '@mui/material'; +import CheckCircleOutlineIcon from '@mui/icons-material/CheckCircleOutline'; +import ErrorOutlineIcon from '@mui/icons-material/ErrorOutline'; +import OpenInNewIcon from '@mui/icons-material/OpenInNew'; +import HourglassEmptyIcon from '@mui/icons-material/HourglassEmpty'; +import LaunchIcon from '@mui/icons-material/Launch'; +import SendIcon from '@mui/icons-material/Send'; +import BlockIcon from '@mui/icons-material/Block'; +import { useAgentStore } from '@/store/agentStore'; +import { useLayoutStore } from '@/store/layoutStore'; +import { logger } from '@/utils/logger'; +import type { UIMessage } from 'ai'; + +// --------------------------------------------------------------------------- +// Type helpers — extract the dynamic-tool part type from UIMessage +// --------------------------------------------------------------------------- +type DynamicToolPart = Extract; + +type ToolPartState = DynamicToolPart['state']; + +interface ToolCallGroupProps { + tools: DynamicToolPart[]; + approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null; edited_script?: string | null }>) => Promise; +} + +// --------------------------------------------------------------------------- +// Visual helpers +// --------------------------------------------------------------------------- + +function StatusIcon({ state }: { state: ToolPartState }) { + switch (state) { + case 'approval-requested': + return ; + case 'output-available': + return ; + case 'output-error': + return ; + case 'output-denied': + return ; + case 'input-streaming': + case 'input-available': + default: + return ; + } +} + +function statusLabel(state: ToolPartState): string | null { + switch (state) { + case 'approval-requested': return 'awaiting approval'; + case 'input-streaming': + case 'input-available': return 'running'; + case 'output-denied': return 'denied'; + case 'output-error': return 'error'; + default: return null; + } +} + +function statusColor(state: ToolPartState): string { + switch (state) { + case 'approval-requested': return 'var(--accent-yellow)'; + case 'output-available': return 'var(--accent-green)'; + case 'output-error': return 'var(--accent-red)'; + case 'output-denied': return 'var(--muted-text)'; + default: return 'var(--accent-yellow)'; + } +} + +// --------------------------------------------------------------------------- +// Inline approval UI (per-tool) +// --------------------------------------------------------------------------- + +function InlineApproval({ + toolCallId, + toolName, + input, + scriptLabel, + onResolve, +}: { + toolCallId: string; + toolName: string; + input: unknown; + scriptLabel: string; + onResolve: (toolCallId: string, approved: boolean, feedback?: string) => void; +}) { + const [feedback, setFeedback] = useState(''); + const args = input as Record | undefined; + const { setPanel, getEditedScript } = useAgentStore(); + const { setRightPanelOpen, setLeftSidebarOpen } = useLayoutStore(); + const hasEditedScript = !!getEditedScript(toolCallId); + + const handleScriptClick = useCallback(() => { + if (toolName === 'hf_jobs' && args?.script) { + const scriptContent = getEditedScript(toolCallId) || String(args.script); + setPanel( + { title: scriptLabel, script: { content: scriptContent, language: 'python' }, parameters: { tool_call_id: toolCallId } }, + 'script', + true, + ); + setRightPanelOpen(true); + setLeftSidebarOpen(false); + } + }, [toolCallId, toolName, args, scriptLabel, setPanel, getEditedScript, setRightPanelOpen, setLeftSidebarOpen]); + + return ( + + {toolName === 'hf_jobs' && args && ( + + + Execute {scriptLabel.replace('Script', 'Job')} on{' '} + + {String(args.hardware_flavor || 'default')} + + {!!args.timeout && ( + <> with timeout + {String(args.timeout)} + + )} + + {typeof args.script === 'string' && args.script && ( + + + {String(args.script).trim()} + + + Click to view & edit + + + )} + + )} + + + setFeedback(e.target.value)} + variant="outlined" + sx={{ + '& .MuiOutlinedInput-root': { + bgcolor: 'var(--hover-bg)', + fontFamily: 'inherit', + fontSize: '0.8rem', + '& fieldset': { borderColor: 'var(--tool-border)' }, + '&:hover fieldset': { borderColor: 'var(--border-hover)' }, + '&.Mui-focused fieldset': { borderColor: 'var(--accent-yellow)' }, + }, + '& .MuiOutlinedInput-input': { + color: 'var(--text)', + '&::placeholder': { color: 'var(--muted-text)', opacity: 0.7 }, + }, + }} + /> + onResolve(toolCallId, false, feedback || 'Rejected by user')} + disabled={!feedback} + size="small" + sx={{ + color: 'var(--accent-red)', + border: '1px solid var(--tool-border)', + borderRadius: '6px', + '&:hover': { bgcolor: 'rgba(224,90,79,0.1)', borderColor: 'var(--accent-red)' }, + '&.Mui-disabled': { color: 'var(--muted-text)', opacity: 0.3 }, + }} + > + + + + + + + + + + ); +} + +// --------------------------------------------------------------------------- +// Main component +// --------------------------------------------------------------------------- + +export default function ToolCallGroup({ tools, approveTools }: ToolCallGroupProps) { + const { setPanel, lockPanel, getJobUrl, getEditedScript } = useAgentStore(); + const { setRightPanelOpen, setLeftSidebarOpen } = useLayoutStore(); + + // ── Batch approval state ────────────────────────────────────────── + const pendingTools = useMemo( + () => tools.filter(t => t.state === 'approval-requested'), + [tools], + ); + + const [decisions, setDecisions] = useState>({}); + const [isSubmitting, setIsSubmitting] = useState(false); + const submittingRef = useRef(false); + + const { scriptLabelMap, toolDisplayMap } = useMemo(() => { + const hfJobs = tools.filter(t => t.toolName === 'hf_jobs' && (t.input as Record)?.script); + const scriptMap: Record = {}; + const displayMap: Record = {}; + for (let i = 0; i < hfJobs.length; i++) { + const id = hfJobs[i].toolCallId; + if (hfJobs.length > 1) { + scriptMap[id] = `Script ${i + 1}`; + displayMap[id] = `hf_jobs #${i + 1}`; + } else { + scriptMap[id] = 'Script'; + displayMap[id] = 'hf_jobs'; + } + } + return { scriptLabelMap: scriptMap, toolDisplayMap: displayMap }; + }, [tools]); + + // ── Send all decisions as a single batch ────────────────────────── + const sendBatch = useCallback( + async (batch: Record) => { + if (submittingRef.current) return; + submittingRef.current = true; + setIsSubmitting(true); + + const approvals = Object.entries(batch).map(([toolCallId, d]) => { + const editedScript = d.approved ? (getEditedScript(toolCallId) ?? null) : null; + if (editedScript) { + logger.log(`Sending edited script for ${toolCallId} (${editedScript.length} chars)`); + } + return { + tool_call_id: toolCallId, + approved: d.approved, + feedback: d.approved ? null : (d.feedback || 'Rejected by user'), + edited_script: editedScript, + }; + }); + + const ok = await approveTools(approvals); + if (ok) { + lockPanel(); + } else { + logger.error('Batch approval failed'); + submittingRef.current = false; + setIsSubmitting(false); + } + }, + [approveTools, lockPanel, getEditedScript], + ); + + const handleApproveAll = useCallback(() => { + const batch: Record = {}; + for (const t of pendingTools) batch[t.toolCallId] = { approved: true }; + sendBatch(batch); + }, [pendingTools, sendBatch]); + + const handleRejectAll = useCallback(() => { + const batch: Record = {}; + for (const t of pendingTools) batch[t.toolCallId] = { approved: false }; + sendBatch(batch); + }, [pendingTools, sendBatch]); + + const handleIndividualDecision = useCallback( + (toolCallId: string, approved: boolean, feedback?: string) => { + setDecisions(prev => { + const next = { ...prev, [toolCallId]: { approved, feedback } }; + if (pendingTools.every(t => next[t.toolCallId])) { + queueMicrotask(() => sendBatch(next)); + } + return next; + }); + }, + [pendingTools, sendBatch], + ); + + const undoDecision = useCallback((toolCallId: string) => { + setDecisions(prev => { + const next = { ...prev }; + delete next[toolCallId]; + return next; + }); + }, []); + + // ── Panel click handler ─────────────────────────────────────────── + const handleClick = useCallback( + (tool: DynamicToolPart) => { + const args = tool.input as Record | undefined; + const displayName = toolDisplayMap[tool.toolCallId] || tool.toolName; + + if (tool.toolName === 'hf_jobs' && args?.script) { + const hasOutput = (tool.state === 'output-available' || tool.state === 'output-error') && tool.output; + const scriptContent = getEditedScript(tool.toolCallId) || String(args.script); + setPanel( + { + title: displayName, + script: { content: scriptContent, language: 'python' }, + ...(hasOutput ? { output: { content: String(tool.output), language: 'markdown' } } : {}), + parameters: { tool_call_id: tool.toolCallId }, + }, + hasOutput ? 'output' : 'script', + ); + setRightPanelOpen(true); + setLeftSidebarOpen(false); + return; + } + + if ((tool.state === 'output-available' || tool.state === 'output-error') && tool.output) { + let language = 'text'; + const content = String(tool.output); + if (content.trim().startsWith('{') || content.trim().startsWith('[')) language = 'json'; + else if (content.includes('```')) language = 'markdown'; + + setPanel({ title: displayName, output: { content, language } }, 'output'); + setRightPanelOpen(true); + } else if (args) { + const content = JSON.stringify(args, null, 2); + setPanel({ title: displayName, output: { content, language: 'json' } }, 'output'); + setRightPanelOpen(true); + } + }, + [toolDisplayMap, setPanel, getEditedScript, setRightPanelOpen, setLeftSidebarOpen], + ); + + // ── Parse hf_jobs metadata from output ──────────────────────────── + function parseJobMeta(output: unknown): { jobUrl?: string; jobStatus?: string } { + if (typeof output !== 'string') return {}; + const urlMatch = output.match(/\*\*View at:\*\*\s*(https:\/\/[^\s\n]+)/); + const statusMatch = output.match(/\*\*Final Status:\*\*\s*([^\n]+)/); + return { + jobUrl: urlMatch?.[1], + jobStatus: statusMatch?.[1]?.trim(), + }; + } + + // ── Render ──────────────────────────────────────────────────────── + const decidedCount = pendingTools.filter(t => decisions[t.toolCallId]).length; + + return ( + + {/* Batch approval header — hidden once user starts deciding individually */} + {pendingTools.length > 1 && !isSubmitting && decidedCount === 0 && ( + + + {`${pendingTools.length} tool${pendingTools.length > 1 ? 's' : ''} pending`} + + + + + )} + + {/* Tool list */} + }> + {tools.map((tool) => { + const state = tool.state; + const isPending = state === 'approval-requested'; + const clickable = + state === 'output-available' || + state === 'output-error' || + !!tool.input; + const localDecision = decisions[tool.toolCallId]; + + const displayState = isPending && localDecision + ? (localDecision.approved ? 'input-available' : 'output-denied') + : state; + const label = statusLabel(displayState as ToolPartState); + + // Parse job metadata from hf_jobs output and store + const jobUrlFromStore = tool.toolName === 'hf_jobs' ? getJobUrl(tool.toolCallId) : undefined; + const jobMetaFromOutput = tool.toolName === 'hf_jobs' && tool.state === 'output-available' + ? parseJobMeta(tool.output) + : {}; + + // Combine job URL from store (available immediately) with output metadata (available at completion) + const jobMeta = { + jobUrl: jobUrlFromStore || jobMetaFromOutput.jobUrl, + jobStatus: jobMetaFromOutput.jobStatus, + }; + + return ( + + {/* Main tool row */} + !isPending && handleClick(tool)} + sx={{ + px: 1.5, + py: 1, + cursor: isPending ? 'default' : clickable ? 'pointer' : 'default', + transition: 'background-color 0.15s', + '&:hover': clickable && !isPending ? { bgcolor: 'var(--hover-bg)' } : {}, + }} + > + + + + {toolDisplayMap[tool.toolCallId] || tool.toolName} + + + {/* Status chip (non hf_jobs, or hf_jobs without final status) */} + {label && !(tool.toolName === 'hf_jobs' && jobMeta.jobStatus) && ( + + )} + + {/* HF Jobs: final status chip from job metadata */} + {tool.toolName === 'hf_jobs' && jobMeta.jobStatus && ( + + )} + + {/* View on HF link — single place, shown whenever URL is available */} + {tool.toolName === 'hf_jobs' && jobMeta.jobUrl && ( + e.stopPropagation()} + sx={{ + display: 'inline-flex', + alignItems: 'center', + gap: 0.5, + color: 'var(--accent-yellow)', + fontSize: '0.68rem', + textDecoration: 'none', + ml: 0.5, + '&:hover': { textDecoration: 'underline' }, + }} + > + + View on HF + + )} + + {clickable && !isPending && ( + + )} + + + + {/* Per-tool approval: undecided */} + {isPending && !localDecision && !isSubmitting && ( + + )} + + {/* Per-tool approval: locally decided (undo available) */} + {isPending && localDecision && !isSubmitting && ( + + + {localDecision.approved + ? 'Marked for approval' + : `Marked for rejection${localDecision.feedback ? `: ${localDecision.feedback}` : ''}`} + + + + )} + + ); + })} + + + ); +} diff --git a/frontend/src/components/Chat/UserMessage.tsx b/frontend/src/components/Chat/UserMessage.tsx new file mode 100644 index 0000000000000000000000000000000000000000..4bcea89867df46d92e2630b12ac9978bc5da476f --- /dev/null +++ b/frontend/src/components/Chat/UserMessage.tsx @@ -0,0 +1,105 @@ +import { Box, Stack, Typography, IconButton, Tooltip } from '@mui/material'; +import CloseIcon from '@mui/icons-material/Close'; +import type { UIMessage } from 'ai'; +import type { MessageMeta } from '@/types/agent'; + +interface UserMessageProps { + message: UIMessage; + isLastTurn?: boolean; + onUndoTurn?: () => void; + isProcessing?: boolean; +} + +function extractText(message: UIMessage): string { + return message.parts + .filter((p): p is Extract => p.type === 'text') + .map(p => p.text) + .join(''); +} + +export default function UserMessage({ + message, + isLastTurn = false, + onUndoTurn, + isProcessing = false, +}: UserMessageProps) { + const showUndo = isLastTurn && !isProcessing && !!onUndoTurn; + const text = extractText(message); + const meta = message.metadata as MessageMeta | undefined; + const timeStr = meta?.createdAt + ? new Date(meta.createdAt).toLocaleTimeString([], { hour: '2-digit', minute: '2-digit' }) + : null; + return ( + + {showUndo && ( + + + + + + + + )} + + + + {text} + + + {timeStr && ( + + {timeStr} + + )} + + + ); +} diff --git a/frontend/src/components/CodePanel/CodePanel.tsx b/frontend/src/components/CodePanel/CodePanel.tsx index 3cda17dab23b7e2bbc4940b3ca171dc9b776e661..4b38fbf5f7bda3664bfacac65bb20d67d2e6a626 100644 --- a/frontend/src/components/CodePanel/CodePanel.tsx +++ b/frontend/src/components/CodePanel/CodePanel.tsx @@ -1,138 +1,463 @@ -import { useRef, useEffect, useMemo } from 'react'; -import { Box, Typography, IconButton } from '@mui/material'; +import { useRef, useEffect, useMemo, useState, useCallback } from 'react'; +import { Box, Stack, Typography, IconButton, Button, Tooltip } from '@mui/material'; import CloseIcon from '@mui/icons-material/Close'; import RadioButtonUncheckedIcon from '@mui/icons-material/RadioButtonUnchecked'; import CheckCircleIcon from '@mui/icons-material/CheckCircle'; import PlayCircleOutlineIcon from '@mui/icons-material/PlayCircleOutline'; import CodeIcon from '@mui/icons-material/Code'; -import TerminalIcon from '@mui/icons-material/Terminal'; import ArticleIcon from '@mui/icons-material/Article'; +import EditIcon from '@mui/icons-material/Edit'; +import UndoIcon from '@mui/icons-material/Undo'; +import ContentCopyIcon from '@mui/icons-material/ContentCopy'; +import CheckIcon from '@mui/icons-material/Check'; import { Prism as SyntaxHighlighter } from 'react-syntax-highlighter'; -import { vscDarkPlus } from 'react-syntax-highlighter/dist/esm/styles/prism'; +import { vscDarkPlus, vs } from 'react-syntax-highlighter/dist/esm/styles/prism'; import ReactMarkdown from 'react-markdown'; import remarkGfm from 'remark-gfm'; import { useAgentStore } from '@/store/agentStore'; import { useLayoutStore } from '@/store/layoutStore'; import { processLogs } from '@/utils/logProcessor'; +import type { PanelView } from '@/store/agentStore'; + +// ── Helpers ────────────────────────────────────────────────────── + +function PlanStatusIcon({ status }: { status: string }) { + if (status === 'completed') return ; + if (status === 'in_progress') return ; + return ; +} + +// ── Markdown styles (adapts via CSS vars) ──────────────────────── +const markdownSx = { + color: 'var(--text)', + fontSize: '13px', + lineHeight: 1.6, + '& p': { m: 0, mb: 1.5, '&:last-child': { mb: 0 } }, + '& pre': { + bgcolor: 'var(--code-bg)', + p: 1.5, + borderRadius: 1, + overflow: 'auto', + fontSize: '12px', + border: '1px solid var(--tool-border)', + }, + '& code': { + bgcolor: 'var(--hover-bg)', + px: 0.5, + py: 0.25, + borderRadius: 0.5, + fontSize: '12px', + fontFamily: 'ui-monospace, SFMono-Regular, Menlo, Monaco, monospace', + }, + '& pre code': { bgcolor: 'transparent', p: 0 }, + '& a': { + color: 'var(--accent-yellow)', + textDecoration: 'none', + '&:hover': { textDecoration: 'underline' }, + }, + '& ul, & ol': { pl: 2.5, my: 1 }, + '& li': { mb: 0.5 }, + '& table': { + borderCollapse: 'collapse', + width: '100%', + my: 2, + fontSize: '12px', + fontFamily: 'ui-monospace, SFMono-Regular, Menlo, Monaco, monospace', + }, + '& th': { + borderBottom: '2px solid var(--border-hover)', + textAlign: 'left', + p: 1, + fontWeight: 600, + }, + '& td': { + borderBottom: '1px solid var(--tool-border)', + p: 1, + }, + '& h1, & h2, & h3, & h4': { mt: 2, mb: 1, fontWeight: 600 }, + '& h1': { fontSize: '1.25rem' }, + '& h2': { fontSize: '1.1rem' }, + '& h3': { fontSize: '1rem' }, + '& blockquote': { + borderLeft: '3px solid var(--accent-yellow)', + pl: 2, + ml: 0, + color: 'var(--muted-text)', + }, +} as const; + +// ── View toggle button ────────────────────────────────────────── + +function ViewToggle({ view, icon, label, isActive, onClick }: { + view: PanelView; + icon: React.ReactNode; + label: string; + isActive: boolean; + onClick: (v: PanelView) => void; +}) { + return ( + onClick(view)} + sx={{ + display: 'flex', + alignItems: 'center', + gap: 0.5, + px: 1.5, + py: 0.75, + borderRadius: 1, + cursor: 'pointer', + fontSize: '0.7rem', + fontWeight: 600, + textTransform: 'uppercase', + letterSpacing: '0.05em', + whiteSpace: 'nowrap', + color: isActive ? 'var(--text)' : 'var(--muted-text)', + bgcolor: isActive ? 'var(--tab-active-bg)' : 'transparent', + border: '1px solid', + borderColor: isActive ? 'var(--tab-active-border)' : 'transparent', + transition: 'all 0.15s ease', + '&:hover': { bgcolor: 'var(--tab-hover-bg)' }, + }} + > + {icon} + {label} + + ); +} + +// ── Component ──────────────────────────────────────────────────── export default function CodePanel() { - const { panelContent, panelTabs, activePanelTab, setActivePanelTab, removePanelTab, plan } = useAgentStore(); - const { setRightPanelOpen } = useLayoutStore(); + const { panelData, panelView, panelEditable, setPanelView, updatePanelScript, setEditedScript, plan } = + useAgentStore(); + const { setRightPanelOpen, themeMode } = useLayoutStore(); const scrollRef = useRef(null); + const textareaRef = useRef(null); + const [isEditing, setIsEditing] = useState(false); + const [editedContent, setEditedContent] = useState(''); + const [originalContent, setOriginalContent] = useState(''); + const [copied, setCopied] = useState(false); + + const isDark = themeMode === 'dark'; + const syntaxTheme = isDark ? vscDarkPlus : vs; + + const activeSection = panelView === 'script' ? panelData?.script : panelData?.output; + const hasScript = !!panelData?.script; + const hasOutput = !!panelData?.output; + const hasBothViews = hasScript && hasOutput; + + const isEditableScript = panelView === 'script' && panelEditable; + const hasUnsavedChanges = isEditing && editedContent !== originalContent; + + // Sync edited content when panel data changes + useEffect(() => { + if (panelData?.script?.content && panelView === 'script' && panelEditable) { + setOriginalContent(panelData.script.content); + if (!isEditing) { + setEditedContent(panelData.script.content); + } + } + }, [panelData?.script?.content, panelView, panelEditable, isEditing]); + + // Exit editing when switching away from script view or losing editable + useEffect(() => { + if (!isEditableScript && isEditing) { + setIsEditing(false); + } + }, [isEditableScript, isEditing]); + + const handleStartEdit = useCallback(() => { + if (panelData?.script?.content) { + setEditedContent(panelData.script.content); + setOriginalContent(panelData.script.content); + setIsEditing(true); + setTimeout(() => textareaRef.current?.focus(), 0); + } + }, [panelData?.script?.content]); + + const handleCancelEdit = useCallback(() => { + setEditedContent(originalContent); + setIsEditing(false); + }, [originalContent]); + + const handleSaveEdit = useCallback(() => { + if (editedContent !== originalContent) { + updatePanelScript(editedContent); + const toolCallId = panelData?.parameters?.tool_call_id as string | undefined; + if (toolCallId) { + setEditedScript(toolCallId, editedContent); + } + setOriginalContent(editedContent); + } + setIsEditing(false); + }, [panelData?.parameters?.tool_call_id, editedContent, originalContent, updatePanelScript, setEditedScript]); - // Get the active tab content, or fall back to panelContent for backwards compatibility - const activeTab = panelTabs.find(t => t.id === activePanelTab); - const currentContent = activeTab || panelContent; + const handleCopy = useCallback(async () => { + const contentToCopy = isEditing ? editedContent : (activeSection?.content || ''); + if (contentToCopy) { + try { + await navigator.clipboard.writeText(contentToCopy); + setCopied(true); + setTimeout(() => setCopied(false), 2000); + } catch (err) { + console.error('Failed to copy:', err); + } + } + }, [isEditing, editedContent, activeSection?.content]); const displayContent = useMemo(() => { - if (!currentContent?.content) return ''; - // Apply log processing only for text/logs, not for code/json - if (!currentContent.language || currentContent.language === 'text') { - return processLogs(currentContent.content); + if (!activeSection?.content) return ''; + if (!activeSection.language || activeSection.language === 'text') { + return processLogs(activeSection.content); } - return currentContent.content; - }, [currentContent?.content, currentContent?.language]); + return activeSection.content; + }, [activeSection?.content, activeSection?.language]); useEffect(() => { - // Auto-scroll only for logs tab - if (scrollRef.current && activePanelTab === 'logs') { + if (scrollRef.current && panelView === 'output') { scrollRef.current.scrollTop = scrollRef.current.scrollHeight; } - }, [displayContent, activePanelTab]); + }, [displayContent, panelView]); + + // ── Syntax-highlighted code block (DRY) ──────────────────────── + const renderSyntaxBlock = (language: string) => ( + + {displayContent} + + ); + + // ── Content renderer ─────────────────────────────────────────── + const renderContent = () => { + if (!activeSection?.content) { + return ( + + NO CONTENT TO DISPLAY + + ); + } + + if (isEditing && isEditableScript) { + return ( + + + {editedContent || ' '} + +