Spaces:
Sleeping
Sleeping
| """ | |
| UpdateQuery - Update operations with access control. | |
| Inherits from BaseQuery for shared filtering logic. | |
| """ | |
| import logging | |
| from typing import Type, TypeVar, Dict, Any | |
| from fastapi import HTTPException, status as http_status | |
| from sqlalchemy import update | |
| from services.db_service.base_query import BaseQuery | |
| logger = logging.getLogger(__name__) | |
| T = TypeVar('T') | |
| class UpdateQuery(BaseQuery): | |
| """ | |
| Handles UPDATE operations with automatic access control. | |
| Inherits filtering logic from BaseQuery: | |
| - User ownership checks (_is_admin) | |
| - Deleted record protection | |
| """ | |
| async def update(self, model_class: Type[T], values: Dict[str, Any], **filters) -> int: | |
| """ | |
| Update records matching filters with new values. | |
| Automatically applies ownership filter for non-admins. | |
| Args: | |
| model_class: Model class to update | |
| values: Dictionary of field=value pairs to update | |
| **filters: Field=value filters to match records | |
| Returns: | |
| Number of records updated | |
| Raises: | |
| HTTPException: 403 if user doesn't have permission | |
| ValueError: If invalid field names provided | |
| """ | |
| # Validate that fields exist | |
| for key in values.keys(): | |
| if not hasattr(model_class, key): | |
| raise ValueError(f"{model_class.__name__} has no attribute '{key}'") | |
| # Verify user has permission to update this model | |
| self._verify_operation_access(model_class, 'update') | |
| # Build update statement | |
| delete_col = getattr(model_class, self._config.soft_delete_column) | |
| stmt = update(model_class).where( | |
| delete_col == None # Don't update deleted records | |
| ) | |
| # Apply ownership filter (shared method from BaseQuery) | |
| stmt = self._apply_ownership_filter(stmt, model_class, 'updat') | |
| # Apply user's filters | |
| for key, value in filters.items(): | |
| if not hasattr(model_class, key): | |
| raise ValueError(f"{model_class.__name__} has no attribute '{key}'") | |
| stmt = stmt.where(getattr(model_class, key) == value) | |
| # Apply values | |
| stmt = stmt.values(**values) | |
| # Execute | |
| result = await self.db.execute(stmt) | |
| await self.db.commit() | |
| count = result.rowcount | |
| logger.info(f"Updated {count} {model_class.__name__} record(s)") | |
| return count | |
| async def update_one(self, instance: T, values: Dict[str, Any]) -> T: | |
| """ | |
| Update a single model instance with validation. | |
| Args: | |
| instance: Model instance to update | |
| values: Dictionary of field=value pairs to update | |
| Returns: | |
| Updated model instance | |
| Raises: | |
| HTTPException: 403 if user doesn't have permission | |
| ValueError: If invalid field names provided | |
| """ | |
| # Check if deleted | |
| delete_column = self._config.soft_delete_column | |
| if hasattr(instance, delete_column) and getattr(instance, delete_column) is not None: | |
| raise HTTPException( | |
| status_code=http_status.HTTP_400_BAD_REQUEST, | |
| detail="Cannot update deleted record" | |
| ) | |
| # Verify user has permission | |
| self._verify_operation_access(instance.__class__, 'update') | |
| # Check ownership for non-admins | |
| filter_column = self._config.user_filter_column | |
| if not self.is_admin and hasattr(instance, filter_column): | |
| if getattr(instance, filter_column) != self.user.id: | |
| raise HTTPException( | |
| status_code=http_status.HTTP_403_FORBIDDEN, | |
| detail="You do not have permission to update this record" | |
| ) | |
| # Validate fields | |
| for key in values.keys(): | |
| if not hasattr(instance, key): | |
| raise ValueError(f"{instance.__class__.__name__} has no attribute '{key}'") | |
| # Apply updates | |
| for key, value in values.items(): | |
| setattr(instance, key, value) | |
| # Commit | |
| await self.db.commit() | |
| await self.db.refresh(instance) | |
| logger.info(f"Updated {instance.__class__.__name__} instance") | |
| return instance | |
| async def increment(self, model_class: Type[T], field: str, amount: int = 1, **filters) -> int: | |
| """ | |
| Increment a numeric field by a specified amount. | |
| """ | |
| if not hasattr(model_class, field): | |
| raise ValueError(f"{model_class.__name__} has no attribute '{field}'") | |
| self._verify_operation_access(model_class, 'update') | |
| field_obj = getattr(model_class, field) | |
| delete_col = getattr(model_class, self._config.soft_delete_column) | |
| stmt = update(model_class).where(delete_col == None) | |
| stmt = self._apply_ownership_filter(stmt, model_class, 'updat') | |
| for key, value in filters.items(): | |
| if not hasattr(model_class, key): | |
| raise ValueError(f"{model_class.__name__} has no attribute '{key}'") | |
| stmt = stmt.where(getattr(model_class, key) == value) | |
| stmt = stmt.values({field: field_obj + amount}) | |
| result = await self.db.execute(stmt) | |
| await self.db.commit() | |
| count = result.rowcount | |
| logger.info(f"Incremented {field} by {amount} for {count} {model_class.__name__} record(s)") | |
| return count | |
| async def decrement(self, model_class: Type[T], field: str, amount: int = 1, **filters) -> int: | |
| """Decrement a numeric field by a specified amount.""" | |
| return await self.increment(model_class, field, -amount, **filters) | |
| async def toggle_boolean(self, model_class: Type[T], field: str, **filters) -> int: | |
| """Toggle a boolean field (True -> False, False -> True).""" | |
| if not hasattr(model_class, field): | |
| raise ValueError(f"{model_class.__name__} has no attribute '{field}'") | |
| self._verify_operation_access(model_class, 'update') | |
| field_obj = getattr(model_class, field) | |
| stmt = update(model_class).where(model_class.deleted_at == None) | |
| stmt = self._apply_ownership_filter(stmt, model_class, 'updat') | |
| for key, value in filters.items(): | |
| if not hasattr(model_class, key): | |
| raise ValueError(f"{model_class.__name__} has no attribute '{key}'") | |
| stmt = stmt.where(getattr(model_class, key) == value) | |
| stmt = stmt.values({field: ~field_obj}) | |
| result = await self.db.execute(stmt) | |
| await self.db.commit() | |
| count = result.rowcount | |
| logger.info(f"Toggled {field} for {count} {model_class.__name__} record(s)") | |
| return count | |