Spaces:
Sleeping
Sleeping
File size: 2,857 Bytes
f2b5c2a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 | """
Security guard for user authorization checks.
This module ensures that all tool operations are authorized
and that users can only access their own data.
"""
import logging
from typing import Optional
from sqlmodel import Session, select
from src.models.task import Task
logger = logging.getLogger(__name__)
class SecurityGuard:
"""
Security guard for validating user authorization.
All MCP tools must use this guard to verify that:
1. User is authenticated (user_id is valid)
2. User owns the resources they're trying to access
3. No cross-user data access is possible
"""
@staticmethod
async def validate_task_ownership(
db: Session,
task_id: int,
user_id: int
) -> Optional[Task]:
"""
Validate that a task belongs to the authenticated user.
Args:
db: Database session
task_id: Task ID to validate
user_id: Authenticated user ID
Returns:
Task object if valid, None if not found or unauthorized
Raises:
ValueError: If task doesn't belong to user
"""
statement = select(Task).where(
Task.id == task_id,
Task.user_id == user_id
)
task = db.exec(statement).first()
if not task:
logger.warning(f"Task {task_id} not found or unauthorized for user {user_id}")
raise ValueError(f"Task not found or access denied")
return task
@staticmethod
def validate_user_id(user_id: int) -> None:
"""
Validate that user_id is provided and valid.
Args:
user_id: User ID to validate
Raises:
ValueError: If user_id is invalid
"""
if not user_id or user_id <= 0:
logger.error(f"Invalid user_id: {user_id}")
raise ValueError("Invalid user_id")
@staticmethod
async def validate_conversation_ownership(
db: Session,
conversation_id: int,
user_id: int
) -> bool:
"""
Validate that a conversation belongs to the authenticated user.
Args:
db: Database session
conversation_id: Conversation ID to validate
user_id: Authenticated user ID
Returns:
True if valid, False otherwise
"""
from src.models.conversation import Conversation
statement = select(Conversation).where(
Conversation.id == conversation_id,
Conversation.user_id == user_id
)
conversation = db.exec(statement).first()
if not conversation:
logger.warning(f"Conversation {conversation_id} not found or unauthorized for user {user_id}")
return False
return True
# Singleton instance
security_guard = SecurityGuard()
|