| """DeerFlowClient — Embedded Python client for DeerFlow agent system. |
| |
| Provides direct programmatic access to DeerFlow's agent capabilities |
| without requiring LangGraph Server or Gateway API processes. |
| |
| Usage: |
| from src.client import DeerFlowClient |
| |
| client = DeerFlowClient() |
| response = client.chat("Analyze this paper for me", thread_id="my-thread") |
| print(response) |
| |
| # Streaming |
| for event in client.stream("hello"): |
| print(event) |
| """ |
|
|
| import asyncio |
| import json |
| import logging |
| import mimetypes |
| import re |
| import shutil |
| import tempfile |
| import uuid |
| import zipfile |
| from collections.abc import Generator |
| from dataclasses import dataclass, field |
| from pathlib import Path |
| from typing import Any |
|
|
| from langchain.agents import create_agent |
| from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage |
| from langchain_core.runnables import RunnableConfig |
|
|
| from src.agents.lead_agent.agent import _build_middlewares |
| from src.agents.lead_agent.prompt import apply_prompt_template |
| from src.agents.thread_state import ThreadState |
| from src.config.app_config import get_app_config, reload_app_config |
| from src.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config |
| from src.config.paths import get_paths |
| from src.models import create_chat_model |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class StreamEvent: |
| """A single event from the streaming agent response. |
| |
| Event types align with the LangGraph SSE protocol: |
| - ``"values"``: Full state snapshot (title, messages, artifacts). |
| - ``"messages-tuple"``: Per-message update (AI text, tool calls, tool results). |
| - ``"end"``: Stream finished. |
| |
| Attributes: |
| type: Event type. |
| data: Event payload. Contents vary by type. |
| """ |
|
|
| type: str |
| data: dict[str, Any] = field(default_factory=dict) |
|
|
|
|
| class DeerFlowClient: |
| """Embedded Python client for DeerFlow agent system. |
| |
| Provides direct programmatic access to DeerFlow's agent capabilities |
| without requiring LangGraph Server or Gateway API processes. |
| |
| Note: |
| Multi-turn conversations require a ``checkpointer``. Without one, |
| each ``stream()`` / ``chat()`` call is stateless — ``thread_id`` |
| is only used for file isolation (uploads / artifacts). |
| |
| The system prompt (including date, memory, and skills context) is |
| generated when the internal agent is first created and cached until |
| the configuration key changes. Call :meth:`reset_agent` to force |
| a refresh in long-running processes. |
| |
| Example:: |
| |
| from src.client import DeerFlowClient |
| |
| client = DeerFlowClient() |
| |
| # Simple one-shot |
| print(client.chat("hello")) |
| |
| # Streaming |
| for event in client.stream("hello"): |
| print(event.type, event.data) |
| |
| # Configuration queries |
| print(client.list_models()) |
| print(client.list_skills()) |
| """ |
|
|
| def __init__( |
| self, |
| config_path: str | None = None, |
| checkpointer=None, |
| *, |
| model_name: str | None = None, |
| thinking_enabled: bool = True, |
| subagent_enabled: bool = False, |
| plan_mode: bool = False, |
| ): |
| """Initialize the client. |
| |
| Loads configuration but defers agent creation to first use. |
| |
| Args: |
| config_path: Path to config.yaml. Uses default resolution if None. |
| checkpointer: LangGraph checkpointer instance for state persistence. |
| Required for multi-turn conversations on the same thread_id. |
| Without a checkpointer, each call is stateless. |
| model_name: Override the default model name from config. |
| thinking_enabled: Enable model's extended thinking. |
| subagent_enabled: Enable subagent delegation. |
| plan_mode: Enable TodoList middleware for plan mode. |
| """ |
| if config_path is not None: |
| reload_app_config(config_path) |
| self._app_config = get_app_config() |
|
|
| self._checkpointer = checkpointer |
| self._model_name = model_name |
| self._thinking_enabled = thinking_enabled |
| self._subagent_enabled = subagent_enabled |
| self._plan_mode = plan_mode |
|
|
| |
| self._agent = None |
| self._agent_config_key: tuple | None = None |
|
|
| def reset_agent(self) -> None: |
| """Force the internal agent to be recreated on the next call. |
| |
| Use this after external changes (e.g. memory updates, skill |
| installations) that should be reflected in the system prompt |
| or tool set. |
| """ |
| self._agent = None |
| self._agent_config_key = None |
|
|
| |
| |
| |
|
|
| @staticmethod |
| def _atomic_write_json(path: Path, data: dict) -> None: |
| """Write JSON to *path* atomically (temp file + replace).""" |
| fd = tempfile.NamedTemporaryFile( |
| mode="w", dir=path.parent, suffix=".tmp", delete=False, |
| ) |
| try: |
| json.dump(data, fd, indent=2) |
| fd.close() |
| Path(fd.name).replace(path) |
| except BaseException: |
| fd.close() |
| Path(fd.name).unlink(missing_ok=True) |
| raise |
|
|
| def _get_runnable_config(self, thread_id: str, **overrides) -> RunnableConfig: |
| """Build a RunnableConfig for agent invocation.""" |
| configurable = { |
| "thread_id": thread_id, |
| "model_name": overrides.get("model_name", self._model_name), |
| "thinking_enabled": overrides.get("thinking_enabled", self._thinking_enabled), |
| "is_plan_mode": overrides.get("plan_mode", self._plan_mode), |
| "subagent_enabled": overrides.get("subagent_enabled", self._subagent_enabled), |
| } |
| return RunnableConfig( |
| configurable=configurable, |
| recursion_limit=overrides.get("recursion_limit", 100), |
| ) |
|
|
| def _ensure_agent(self, config: RunnableConfig): |
| """Create (or recreate) the agent when config-dependent params change.""" |
| cfg = config.get("configurable", {}) |
| key = ( |
| cfg.get("model_name"), |
| cfg.get("thinking_enabled"), |
| cfg.get("is_plan_mode"), |
| cfg.get("subagent_enabled"), |
| ) |
|
|
| if self._agent is not None and self._agent_config_key == key: |
| return |
|
|
| thinking_enabled = cfg.get("thinking_enabled", True) |
| model_name = cfg.get("model_name") |
| subagent_enabled = cfg.get("subagent_enabled", False) |
| max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3) |
|
|
| kwargs: dict[str, Any] = { |
| "model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled), |
| "tools": self._get_tools(model_name=model_name, subagent_enabled=subagent_enabled), |
| "middleware": _build_middlewares(config, model_name=model_name), |
| "system_prompt": apply_prompt_template( |
| subagent_enabled=subagent_enabled, |
| max_concurrent_subagents=max_concurrent_subagents, |
| ), |
| "state_schema": ThreadState, |
| } |
| if self._checkpointer is not None: |
| kwargs["checkpointer"] = self._checkpointer |
|
|
| self._agent = create_agent(**kwargs) |
| self._agent_config_key = key |
| logger.info("Agent created: model=%s, thinking=%s", model_name, thinking_enabled) |
|
|
| @staticmethod |
| def _get_tools(*, model_name: str | None, subagent_enabled: bool): |
| """Lazy import to avoid circular dependency at module level.""" |
| from src.tools import get_available_tools |
|
|
| return get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled) |
|
|
| @staticmethod |
| def _serialize_message(msg) -> dict: |
| """Serialize a LangChain message to a plain dict for values events.""" |
| if isinstance(msg, AIMessage): |
| d: dict[str, Any] = {"type": "ai", "content": msg.content, "id": getattr(msg, "id", None)} |
| if msg.tool_calls: |
| d["tool_calls"] = [{"name": tc["name"], "args": tc["args"], "id": tc.get("id")} for tc in msg.tool_calls] |
| return d |
| if isinstance(msg, ToolMessage): |
| return { |
| "type": "tool", |
| "content": msg.content if isinstance(msg.content, str) else str(msg.content), |
| "name": getattr(msg, "name", None), |
| "tool_call_id": getattr(msg, "tool_call_id", None), |
| "id": getattr(msg, "id", None), |
| } |
| if isinstance(msg, HumanMessage): |
| return {"type": "human", "content": msg.content, "id": getattr(msg, "id", None)} |
| if isinstance(msg, SystemMessage): |
| return {"type": "system", "content": msg.content, "id": getattr(msg, "id", None)} |
| return {"type": "unknown", "content": str(msg), "id": getattr(msg, "id", None)} |
|
|
| @staticmethod |
| def _extract_text(content) -> str: |
| """Extract plain text from AIMessage content (str or list of blocks).""" |
| if isinstance(content, str): |
| return content |
| if isinstance(content, list): |
| parts = [] |
| for block in content: |
| if isinstance(block, str): |
| parts.append(block) |
| elif isinstance(block, dict) and block.get("type") == "text": |
| parts.append(block["text"]) |
| return "\n".join(parts) if parts else "" |
| return str(content) |
|
|
| |
| |
| |
|
|
| def stream( |
| self, |
| message: str, |
| *, |
| thread_id: str | None = None, |
| **kwargs, |
| ) -> Generator[StreamEvent, None, None]: |
| """Stream a conversation turn, yielding events incrementally. |
| |
| Each call sends one user message and yields events until the agent |
| finishes its turn. A ``checkpointer`` must be provided at init time |
| for multi-turn context to be preserved across calls. |
| |
| Event types align with the LangGraph SSE protocol so that |
| consumers can switch between HTTP streaming and embedded mode |
| without changing their event-handling logic. |
| |
| Args: |
| message: User message text. |
| thread_id: Thread ID for conversation context. Auto-generated if None. |
| **kwargs: Override client defaults (model_name, thinking_enabled, |
| plan_mode, subagent_enabled, recursion_limit). |
| |
| Yields: |
| StreamEvent with one of: |
| - type="values" data={"title": str|None, "messages": [...], "artifacts": [...]} |
| - type="messages-tuple" data={"type": "ai", "content": str, "id": str} |
| - type="messages-tuple" data={"type": "ai", "content": "", "id": str, "tool_calls": [...]} |
| - type="messages-tuple" data={"type": "tool", "content": str, "name": str, "tool_call_id": str, "id": str} |
| - type="end" data={} |
| """ |
| if thread_id is None: |
| thread_id = str(uuid.uuid4()) |
|
|
| config = self._get_runnable_config(thread_id, **kwargs) |
| self._ensure_agent(config) |
|
|
| state: dict[str, Any] = {"messages": [HumanMessage(content=message)]} |
| context = {"thread_id": thread_id} |
|
|
| seen_ids: set[str] = set() |
|
|
| for chunk in self._agent.stream(state, config=config, context=context, stream_mode="values"): |
| messages = chunk.get("messages", []) |
|
|
| for msg in messages: |
| msg_id = getattr(msg, "id", None) |
| if msg_id and msg_id in seen_ids: |
| continue |
| if msg_id: |
| seen_ids.add(msg_id) |
|
|
| if isinstance(msg, AIMessage): |
| if msg.tool_calls: |
| yield StreamEvent( |
| type="messages-tuple", |
| data={ |
| "type": "ai", |
| "content": "", |
| "id": msg_id, |
| "tool_calls": [ |
| {"name": tc["name"], "args": tc["args"], "id": tc.get("id")} |
| for tc in msg.tool_calls |
| ], |
| }, |
| ) |
|
|
| text = self._extract_text(msg.content) |
| if text: |
| yield StreamEvent( |
| type="messages-tuple", |
| data={"type": "ai", "content": text, "id": msg_id}, |
| ) |
|
|
| elif isinstance(msg, ToolMessage): |
| yield StreamEvent( |
| type="messages-tuple", |
| data={ |
| "type": "tool", |
| "content": msg.content if isinstance(msg.content, str) else str(msg.content), |
| "name": getattr(msg, "name", None), |
| "tool_call_id": getattr(msg, "tool_call_id", None), |
| "id": msg_id, |
| }, |
| ) |
|
|
| |
| yield StreamEvent( |
| type="values", |
| data={ |
| "title": chunk.get("title"), |
| "messages": [self._serialize_message(m) for m in messages], |
| "artifacts": chunk.get("artifacts", []), |
| }, |
| ) |
|
|
| yield StreamEvent(type="end", data={}) |
|
|
| def chat(self, message: str, *, thread_id: str | None = None, **kwargs) -> str: |
| """Send a message and return the final text response. |
| |
| Convenience wrapper around :meth:`stream` that returns only the |
| **last** AI text from ``messages-tuple`` events. If the agent emits |
| multiple text segments in one turn, intermediate segments are |
| discarded. Use :meth:`stream` directly to capture all events. |
| |
| Args: |
| message: User message text. |
| thread_id: Thread ID for conversation context. Auto-generated if None. |
| **kwargs: Override client defaults (same as stream()). |
| |
| Returns: |
| The last AI message text, or empty string if no response. |
| """ |
| last_text = "" |
| for event in self.stream(message, thread_id=thread_id, **kwargs): |
| if event.type == "messages-tuple" and event.data.get("type") == "ai": |
| content = event.data.get("content", "") |
| if content: |
| last_text = content |
| return last_text |
|
|
| |
| |
| |
|
|
| def list_models(self) -> dict: |
| """List available models from configuration. |
| |
| Returns: |
| Dict with "models" key containing list of model info dicts, |
| matching the Gateway API ``ModelsListResponse`` schema. |
| """ |
| return { |
| "models": [ |
| { |
| "name": model.name, |
| "display_name": getattr(model, "display_name", None), |
| "description": getattr(model, "description", None), |
| "supports_thinking": getattr(model, "supports_thinking", False), |
| } |
| for model in self._app_config.models |
| ] |
| } |
|
|
| def list_skills(self, enabled_only: bool = False) -> dict: |
| """List available skills. |
| |
| Args: |
| enabled_only: If True, only return enabled skills. |
| |
| Returns: |
| Dict with "skills" key containing list of skill info dicts, |
| matching the Gateway API ``SkillsListResponse`` schema. |
| """ |
| from src.skills.loader import load_skills |
|
|
| return { |
| "skills": [ |
| { |
| "name": s.name, |
| "description": s.description, |
| "license": s.license, |
| "category": s.category, |
| "enabled": s.enabled, |
| } |
| for s in load_skills(enabled_only=enabled_only) |
| ] |
| } |
|
|
| def get_memory(self) -> dict: |
| """Get current memory data. |
| |
| Returns: |
| Memory data dict (see src/agents/memory/updater.py for structure). |
| """ |
| from src.agents.memory.updater import get_memory_data |
|
|
| return get_memory_data() |
|
|
| def get_model(self, name: str) -> dict | None: |
| """Get a specific model's configuration by name. |
| |
| Args: |
| name: Model name. |
| |
| Returns: |
| Model info dict matching the Gateway API ``ModelResponse`` |
| schema, or None if not found. |
| """ |
| model = self._app_config.get_model_config(name) |
| if model is None: |
| return None |
| return { |
| "name": model.name, |
| "display_name": getattr(model, "display_name", None), |
| "description": getattr(model, "description", None), |
| "supports_thinking": getattr(model, "supports_thinking", False), |
| } |
|
|
| |
| |
| |
|
|
| def get_mcp_config(self) -> dict: |
| """Get MCP server configurations. |
| |
| Returns: |
| Dict with "mcp_servers" key mapping server name to config, |
| matching the Gateway API ``McpConfigResponse`` schema. |
| """ |
| config = get_extensions_config() |
| return {"mcp_servers": {name: server.model_dump() for name, server in config.mcp_servers.items()}} |
|
|
| def update_mcp_config(self, mcp_servers: dict[str, dict]) -> dict: |
| """Update MCP server configurations. |
| |
| Writes to extensions_config.json and reloads the cache. |
| |
| Args: |
| mcp_servers: Dict mapping server name to config dict. |
| Each value should contain keys like enabled, type, command, args, env, url, etc. |
| |
| Returns: |
| Dict with "mcp_servers" key, matching the Gateway API |
| ``McpConfigResponse`` schema. |
| |
| Raises: |
| OSError: If the config file cannot be written. |
| """ |
| config_path = ExtensionsConfig.resolve_config_path() |
| if config_path is None: |
| raise FileNotFoundError( |
| "Cannot locate extensions_config.json. " |
| "Set DEER_FLOW_EXTENSIONS_CONFIG_PATH or ensure it exists in the project root." |
| ) |
|
|
| current_config = get_extensions_config() |
|
|
| config_data = { |
| "mcpServers": mcp_servers, |
| "skills": {name: {"enabled": skill.enabled} for name, skill in current_config.skills.items()}, |
| } |
|
|
| self._atomic_write_json(config_path, config_data) |
|
|
| self._agent = None |
| reloaded = reload_extensions_config() |
| return {"mcp_servers": {name: server.model_dump() for name, server in reloaded.mcp_servers.items()}} |
|
|
| |
| |
| |
|
|
| def get_skill(self, name: str) -> dict | None: |
| """Get a specific skill by name. |
| |
| Args: |
| name: Skill name. |
| |
| Returns: |
| Skill info dict, or None if not found. |
| """ |
| from src.skills.loader import load_skills |
|
|
| skill = next((s for s in load_skills(enabled_only=False) if s.name == name), None) |
| if skill is None: |
| return None |
| return { |
| "name": skill.name, |
| "description": skill.description, |
| "license": skill.license, |
| "category": skill.category, |
| "enabled": skill.enabled, |
| } |
|
|
| def update_skill(self, name: str, *, enabled: bool) -> dict: |
| """Update a skill's enabled status. |
| |
| Args: |
| name: Skill name. |
| enabled: New enabled status. |
| |
| Returns: |
| Updated skill info dict. |
| |
| Raises: |
| ValueError: If the skill is not found. |
| OSError: If the config file cannot be written. |
| """ |
| from src.skills.loader import load_skills |
|
|
| skills = load_skills(enabled_only=False) |
| skill = next((s for s in skills if s.name == name), None) |
| if skill is None: |
| raise ValueError(f"Skill '{name}' not found") |
|
|
| config_path = ExtensionsConfig.resolve_config_path() |
| if config_path is None: |
| raise FileNotFoundError( |
| "Cannot locate extensions_config.json. " |
| "Set DEER_FLOW_EXTENSIONS_CONFIG_PATH or ensure it exists in the project root." |
| ) |
|
|
| extensions_config = get_extensions_config() |
| extensions_config.skills[name] = SkillStateConfig(enabled=enabled) |
|
|
| config_data = { |
| "mcpServers": {n: s.model_dump() for n, s in extensions_config.mcp_servers.items()}, |
| "skills": {n: {"enabled": sc.enabled} for n, sc in extensions_config.skills.items()}, |
| } |
|
|
| self._atomic_write_json(config_path, config_data) |
|
|
| self._agent = None |
| reload_extensions_config() |
|
|
| updated = next((s for s in load_skills(enabled_only=False) if s.name == name), None) |
| if updated is None: |
| raise RuntimeError(f"Skill '{name}' disappeared after update") |
| return { |
| "name": updated.name, |
| "description": updated.description, |
| "license": updated.license, |
| "category": updated.category, |
| "enabled": updated.enabled, |
| } |
|
|
| def install_skill(self, skill_path: str | Path) -> dict: |
| """Install a skill from a .skill archive (ZIP). |
| |
| Args: |
| skill_path: Path to the .skill file. |
| |
| Returns: |
| Dict with success, skill_name, message. |
| |
| Raises: |
| FileNotFoundError: If the file does not exist. |
| ValueError: If the file is invalid. |
| """ |
| from src.gateway.routers.skills import _validate_skill_frontmatter |
| from src.skills.loader import get_skills_root_path |
|
|
| path = Path(skill_path) |
| if not path.exists(): |
| raise FileNotFoundError(f"Skill file not found: {skill_path}") |
| if not path.is_file(): |
| raise ValueError(f"Path is not a file: {skill_path}") |
| if path.suffix != ".skill": |
| raise ValueError("File must have .skill extension") |
| if not zipfile.is_zipfile(path): |
| raise ValueError("File is not a valid ZIP archive") |
|
|
| skills_root = get_skills_root_path() |
| custom_dir = skills_root / "custom" |
| custom_dir.mkdir(parents=True, exist_ok=True) |
|
|
| with tempfile.TemporaryDirectory() as tmp: |
| tmp_path = Path(tmp) |
| with zipfile.ZipFile(path, "r") as zf: |
| total_size = sum(info.file_size for info in zf.infolist()) |
| if total_size > 100 * 1024 * 1024: |
| raise ValueError("Skill archive too large when extracted (>100MB)") |
| for info in zf.infolist(): |
| if Path(info.filename).is_absolute() or ".." in Path(info.filename).parts: |
| raise ValueError(f"Unsafe path in archive: {info.filename}") |
| zf.extractall(tmp_path) |
| for p in tmp_path.rglob("*"): |
| if p.is_symlink(): |
| p.unlink() |
|
|
| items = list(tmp_path.iterdir()) |
| if not items: |
| raise ValueError("Skill archive is empty") |
|
|
| skill_dir = items[0] if len(items) == 1 and items[0].is_dir() else tmp_path |
|
|
| is_valid, message, skill_name = _validate_skill_frontmatter(skill_dir) |
| if not is_valid: |
| raise ValueError(f"Invalid skill: {message}") |
| if not re.fullmatch(r"[a-zA-Z0-9_-]+", skill_name): |
| raise ValueError(f"Invalid skill name: {skill_name}") |
|
|
| target = custom_dir / skill_name |
| if target.exists(): |
| raise ValueError(f"Skill '{skill_name}' already exists") |
|
|
| shutil.copytree(skill_dir, target) |
|
|
| return {"success": True, "skill_name": skill_name, "message": f"Skill '{skill_name}' installed successfully"} |
|
|
| |
| |
| |
|
|
| def reload_memory(self) -> dict: |
| """Reload memory data from file, forcing cache invalidation. |
| |
| Returns: |
| The reloaded memory data dict. |
| """ |
| from src.agents.memory.updater import reload_memory_data |
|
|
| return reload_memory_data() |
|
|
| def get_memory_config(self) -> dict: |
| """Get memory system configuration. |
| |
| Returns: |
| Memory config dict. |
| """ |
| from src.config.memory_config import get_memory_config |
|
|
| config = get_memory_config() |
| return { |
| "enabled": config.enabled, |
| "storage_path": config.storage_path, |
| "debounce_seconds": config.debounce_seconds, |
| "max_facts": config.max_facts, |
| "fact_confidence_threshold": config.fact_confidence_threshold, |
| "injection_enabled": config.injection_enabled, |
| "max_injection_tokens": config.max_injection_tokens, |
| } |
|
|
| def get_memory_status(self) -> dict: |
| """Get memory status: config + current data. |
| |
| Returns: |
| Dict with "config" and "data" keys. |
| """ |
| return { |
| "config": self.get_memory_config(), |
| "data": self.get_memory(), |
| } |
|
|
| |
| |
| |
|
|
| @staticmethod |
| def _get_uploads_dir(thread_id: str) -> Path: |
| """Get (and create) the uploads directory for a thread.""" |
| base = get_paths().sandbox_uploads_dir(thread_id) |
| base.mkdir(parents=True, exist_ok=True) |
| return base |
|
|
| def upload_files(self, thread_id: str, files: list[str | Path]) -> dict: |
| """Upload local files into a thread's uploads directory. |
| |
| For PDF, PPT, Excel, and Word files, they are also converted to Markdown. |
| |
| Args: |
| thread_id: Target thread ID. |
| files: List of local file paths to upload. |
| |
| Returns: |
| Dict with success, files, message — matching the Gateway API |
| ``UploadResponse`` schema. |
| |
| Raises: |
| FileNotFoundError: If any file does not exist. |
| """ |
| from src.gateway.routers.uploads import CONVERTIBLE_EXTENSIONS, convert_file_to_markdown |
|
|
| |
| resolved_files = [] |
| for f in files: |
| p = Path(f) |
| if not p.exists(): |
| raise FileNotFoundError(f"File not found: {f}") |
| resolved_files.append(p) |
|
|
| uploads_dir = self._get_uploads_dir(thread_id) |
| uploaded_files: list[dict] = [] |
|
|
| for src_path in resolved_files: |
|
|
| dest = uploads_dir / src_path.name |
| shutil.copy2(src_path, dest) |
|
|
| info: dict[str, Any] = { |
| "filename": src_path.name, |
| "size": str(dest.stat().st_size), |
| "path": str(dest), |
| "virtual_path": f"/mnt/user-data/uploads/{src_path.name}", |
| "artifact_url": f"/api/threads/{thread_id}/artifacts/mnt/user-data/uploads/{src_path.name}", |
| } |
|
|
| if src_path.suffix.lower() in CONVERTIBLE_EXTENSIONS: |
| try: |
| try: |
| asyncio.get_running_loop() |
| import concurrent.futures |
| with concurrent.futures.ThreadPoolExecutor() as pool: |
| md_path = pool.submit(lambda: asyncio.run(convert_file_to_markdown(dest))).result() |
| except RuntimeError: |
| md_path = asyncio.run(convert_file_to_markdown(dest)) |
| except Exception: |
| logger.warning("Failed to convert %s to markdown", src_path.name, exc_info=True) |
| md_path = None |
|
|
| if md_path is not None: |
| info["markdown_file"] = md_path.name |
| info["markdown_virtual_path"] = f"/mnt/user-data/uploads/{md_path.name}" |
| info["markdown_artifact_url"] = f"/api/threads/{thread_id}/artifacts/mnt/user-data/uploads/{md_path.name}" |
|
|
| uploaded_files.append(info) |
|
|
| return { |
| "success": True, |
| "files": uploaded_files, |
| "message": f"Successfully uploaded {len(uploaded_files)} file(s)", |
| } |
|
|
| def list_uploads(self, thread_id: str) -> dict: |
| """List files in a thread's uploads directory. |
| |
| Args: |
| thread_id: Thread ID. |
| |
| Returns: |
| Dict with "files" and "count" keys, matching the Gateway API |
| ``list_uploaded_files`` response. |
| """ |
| uploads_dir = self._get_uploads_dir(thread_id) |
| if not uploads_dir.exists(): |
| return {"files": [], "count": 0} |
|
|
| files = [] |
| for fp in sorted(uploads_dir.iterdir()): |
| if fp.is_file(): |
| stat = fp.stat() |
| files.append({ |
| "filename": fp.name, |
| "size": str(stat.st_size), |
| "path": str(fp), |
| "virtual_path": f"/mnt/user-data/uploads/{fp.name}", |
| "artifact_url": f"/api/threads/{thread_id}/artifacts/mnt/user-data/uploads/{fp.name}", |
| "extension": fp.suffix, |
| "modified": stat.st_mtime, |
| }) |
| return {"files": files, "count": len(files)} |
|
|
| def delete_upload(self, thread_id: str, filename: str) -> dict: |
| """Delete a file from a thread's uploads directory. |
| |
| Args: |
| thread_id: Thread ID. |
| filename: Filename to delete. |
| |
| Returns: |
| Dict with success and message, matching the Gateway API |
| ``delete_uploaded_file`` response. |
| |
| Raises: |
| FileNotFoundError: If the file does not exist. |
| PermissionError: If path traversal is detected. |
| """ |
| uploads_dir = self._get_uploads_dir(thread_id) |
| file_path = (uploads_dir / filename).resolve() |
|
|
| try: |
| file_path.relative_to(uploads_dir.resolve()) |
| except ValueError as exc: |
| raise PermissionError("Access denied: path traversal detected") from exc |
|
|
| if not file_path.is_file(): |
| raise FileNotFoundError(f"File not found: {filename}") |
|
|
| file_path.unlink() |
| return {"success": True, "message": f"Deleted {filename}"} |
|
|
| |
| |
| |
|
|
| def get_artifact(self, thread_id: str, path: str) -> tuple[bytes, str]: |
| """Read an artifact file produced by the agent. |
| |
| Args: |
| thread_id: Thread ID. |
| path: Virtual path (e.g. "mnt/user-data/outputs/file.txt"). |
| |
| Returns: |
| Tuple of (file_bytes, mime_type). |
| |
| Raises: |
| FileNotFoundError: If the artifact does not exist. |
| ValueError: If the path is invalid. |
| """ |
| virtual_prefix = "mnt/user-data" |
| clean_path = path.lstrip("/") |
| if not clean_path.startswith(virtual_prefix): |
| raise ValueError(f"Path must start with /{virtual_prefix}") |
|
|
| relative = clean_path[len(virtual_prefix):].lstrip("/") |
| base_dir = get_paths().sandbox_user_data_dir(thread_id) |
| actual = (base_dir / relative).resolve() |
|
|
| try: |
| actual.relative_to(base_dir.resolve()) |
| except ValueError as exc: |
| raise PermissionError("Access denied: path traversal detected") from exc |
| if not actual.exists(): |
| raise FileNotFoundError(f"Artifact not found: {path}") |
| if not actual.is_file(): |
| raise ValueError(f"Path is not a file: {path}") |
|
|
| mime_type, _ = mimetypes.guess_type(actual) |
| return actual.read_bytes(), mime_type or "application/octet-stream" |
|
|