|
|
""" |
|
|
Update Task MCP Tool |
|
|
|
|
|
MCP tool for updating task properties via natural language. |
|
|
Supports task identification by ID or title and updating multiple fields. |
|
|
Implements user context injection for security. |
|
|
""" |
|
|
|
|
|
import logging |
|
|
from typing import Union, Optional, Dict, Any |
|
|
from datetime import datetime |
|
|
from sqlmodel import Session, select |
|
|
|
|
|
from ...models.task import Task |
|
|
from ...core.database import get_session |
|
|
from ..tool_registry import ToolExecutionResult |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
async def update_task( |
|
|
task_identifier: Union[int, str], |
|
|
updates: Dict[str, Any], |
|
|
user_id: int |
|
|
) -> ToolExecutionResult: |
|
|
""" |
|
|
Update an existing task's properties. |
|
|
|
|
|
SECURITY: user_id is injected by the backend via MCPToolRegistry. |
|
|
The LLM cannot specify or modify the user_id. |
|
|
|
|
|
Args: |
|
|
task_identifier: Task ID (integer) or task title (string) |
|
|
updates: Dictionary of fields to update (title, description, due_date, priority, completed) |
|
|
user_id: User ID (injected by backend for security) |
|
|
|
|
|
Returns: |
|
|
ToolExecutionResult with success status and updated task data |
|
|
""" |
|
|
try: |
|
|
|
|
|
if not updates or len(updates) == 0: |
|
|
logger.warning("update_task called with empty updates dictionary") |
|
|
return ToolExecutionResult( |
|
|
success=False, |
|
|
error="No updates provided. Please specify at least one field to update." |
|
|
) |
|
|
|
|
|
|
|
|
allowed_fields = ["title", "description", "due_date", "priority", "completed"] |
|
|
invalid_fields = [field for field in updates.keys() if field not in allowed_fields] |
|
|
if invalid_fields: |
|
|
logger.warning(f"update_task called with invalid fields: {invalid_fields}") |
|
|
return ToolExecutionResult( |
|
|
success=False, |
|
|
error=f"Invalid fields: {', '.join(invalid_fields)}. Allowed fields: {', '.join(allowed_fields)}" |
|
|
) |
|
|
|
|
|
|
|
|
db: Session = next(get_session()) |
|
|
try: |
|
|
|
|
|
if isinstance(task_identifier, int): |
|
|
|
|
|
statement = select(Task).where( |
|
|
Task.id == task_identifier, |
|
|
Task.user_id == user_id |
|
|
) |
|
|
identifier_type = "ID" |
|
|
else: |
|
|
|
|
|
statement = select(Task).where( |
|
|
Task.title == task_identifier, |
|
|
Task.user_id == user_id |
|
|
) |
|
|
identifier_type = "title" |
|
|
|
|
|
task = db.exec(statement).first() |
|
|
|
|
|
|
|
|
if not task: |
|
|
logger.warning(f"Task not found for update: {identifier_type}={task_identifier}, user_id={user_id}") |
|
|
return ToolExecutionResult( |
|
|
success=False, |
|
|
error=f"Task not found with {identifier_type}: {task_identifier}" |
|
|
) |
|
|
|
|
|
|
|
|
updated_fields = [] |
|
|
|
|
|
|
|
|
if "title" in updates: |
|
|
new_title = updates["title"] |
|
|
if not new_title or not new_title.strip(): |
|
|
return ToolExecutionResult( |
|
|
success=False, |
|
|
error="Task title cannot be empty" |
|
|
) |
|
|
if len(new_title) > 200: |
|
|
return ToolExecutionResult( |
|
|
success=False, |
|
|
error="Task title cannot exceed 200 characters" |
|
|
) |
|
|
task.title = new_title.strip() |
|
|
updated_fields.append("title") |
|
|
|
|
|
|
|
|
if "description" in updates: |
|
|
new_description = updates["description"] |
|
|
if new_description and len(new_description) > 1000: |
|
|
return ToolExecutionResult( |
|
|
success=False, |
|
|
error="Task description cannot exceed 1000 characters" |
|
|
) |
|
|
task.description = new_description.strip() if new_description else None |
|
|
updated_fields.append("description") |
|
|
|
|
|
|
|
|
if "due_date" in updates: |
|
|
new_due_date = updates["due_date"] |
|
|
if new_due_date: |
|
|
try: |
|
|
task.due_date = datetime.fromisoformat(new_due_date).date() |
|
|
except ValueError: |
|
|
return ToolExecutionResult( |
|
|
success=False, |
|
|
error="Due date must be in ISO 8601 format (YYYY-MM-DD)" |
|
|
) |
|
|
else: |
|
|
task.due_date = None |
|
|
updated_fields.append("due_date") |
|
|
|
|
|
|
|
|
if "priority" in updates: |
|
|
new_priority = updates["priority"] |
|
|
valid_priorities = ["low", "medium", "high"] |
|
|
if new_priority and new_priority.lower() not in valid_priorities: |
|
|
return ToolExecutionResult( |
|
|
success=False, |
|
|
error=f"Priority must be one of: {', '.join(valid_priorities)}" |
|
|
) |
|
|
task.priority = new_priority.lower() if new_priority else "medium" |
|
|
updated_fields.append("priority") |
|
|
|
|
|
|
|
|
if "completed" in updates: |
|
|
task.completed = bool(updates["completed"]) |
|
|
updated_fields.append("completed") |
|
|
|
|
|
|
|
|
task.updated_at = datetime.utcnow() |
|
|
|
|
|
|
|
|
db.add(task) |
|
|
db.commit() |
|
|
db.refresh(task) |
|
|
|
|
|
logger.info(f"Task updated successfully: id={task.id}, user_id={user_id}, fields={updated_fields}") |
|
|
|
|
|
return ToolExecutionResult( |
|
|
success=True, |
|
|
data={ |
|
|
"id": task.id, |
|
|
"title": task.title, |
|
|
"description": task.description, |
|
|
"due_date": task.due_date.isoformat() if task.due_date else None, |
|
|
"priority": task.priority, |
|
|
"completed": task.completed, |
|
|
"updated_at": task.updated_at.isoformat() |
|
|
}, |
|
|
message=f"Task '{task.title}' updated successfully. Updated fields: {', '.join(updated_fields)}" |
|
|
) |
|
|
|
|
|
finally: |
|
|
db.close() |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error updating task: {str(e)}") |
|
|
return ToolExecutionResult( |
|
|
success=False, |
|
|
error=f"Failed to update task: {str(e)}" |
|
|
) |
|
|
|