""" UpdateQuery - Update operations with access control. Inherits from BaseQuery for shared filtering logic. """ import logging from typing import Type, TypeVar, Dict, Any from fastapi import HTTPException, status as http_status from sqlalchemy import update from services.db_service.base_query import BaseQuery logger = logging.getLogger(__name__) T = TypeVar('T') class UpdateQuery(BaseQuery): """ Handles UPDATE operations with automatic access control. Inherits filtering logic from BaseQuery: - User ownership checks (_is_admin) - Deleted record protection """ async def update(self, model_class: Type[T], values: Dict[str, Any], **filters) -> int: """ Update records matching filters with new values. Automatically applies ownership filter for non-admins. Args: model_class: Model class to update values: Dictionary of field=value pairs to update **filters: Field=value filters to match records Returns: Number of records updated Raises: HTTPException: 403 if user doesn't have permission ValueError: If invalid field names provided """ # Validate that fields exist for key in values.keys(): if not hasattr(model_class, key): raise ValueError(f"{model_class.__name__} has no attribute '{key}'") # Verify user has permission to update this model self._verify_operation_access(model_class, 'update') # Build update statement delete_col = getattr(model_class, self._config.soft_delete_column) stmt = update(model_class).where( delete_col == None # Don't update deleted records ) # Apply ownership filter (shared method from BaseQuery) stmt = self._apply_ownership_filter(stmt, model_class, 'updat') # Apply user's filters for key, value in filters.items(): if not hasattr(model_class, key): raise ValueError(f"{model_class.__name__} has no attribute '{key}'") stmt = stmt.where(getattr(model_class, key) == value) # Apply values stmt = stmt.values(**values) # Execute result = await self.db.execute(stmt) await self.db.commit() count = result.rowcount logger.info(f"Updated {count} {model_class.__name__} record(s)") return count async def update_one(self, instance: T, values: Dict[str, Any]) -> T: """ Update a single model instance with validation. Args: instance: Model instance to update values: Dictionary of field=value pairs to update Returns: Updated model instance Raises: HTTPException: 403 if user doesn't have permission ValueError: If invalid field names provided """ # Check if deleted delete_column = self._config.soft_delete_column if hasattr(instance, delete_column) and getattr(instance, delete_column) is not None: raise HTTPException( status_code=http_status.HTTP_400_BAD_REQUEST, detail="Cannot update deleted record" ) # Verify user has permission self._verify_operation_access(instance.__class__, 'update') # Check ownership for non-admins filter_column = self._config.user_filter_column if not self.is_admin and hasattr(instance, filter_column): if getattr(instance, filter_column) != self.user.id: raise HTTPException( status_code=http_status.HTTP_403_FORBIDDEN, detail="You do not have permission to update this record" ) # Validate fields for key in values.keys(): if not hasattr(instance, key): raise ValueError(f"{instance.__class__.__name__} has no attribute '{key}'") # Apply updates for key, value in values.items(): setattr(instance, key, value) # Commit await self.db.commit() await self.db.refresh(instance) logger.info(f"Updated {instance.__class__.__name__} instance") return instance async def increment(self, model_class: Type[T], field: str, amount: int = 1, **filters) -> int: """ Increment a numeric field by a specified amount. """ if not hasattr(model_class, field): raise ValueError(f"{model_class.__name__} has no attribute '{field}'") self._verify_operation_access(model_class, 'update') field_obj = getattr(model_class, field) delete_col = getattr(model_class, self._config.soft_delete_column) stmt = update(model_class).where(delete_col == None) stmt = self._apply_ownership_filter(stmt, model_class, 'updat') for key, value in filters.items(): if not hasattr(model_class, key): raise ValueError(f"{model_class.__name__} has no attribute '{key}'") stmt = stmt.where(getattr(model_class, key) == value) stmt = stmt.values({field: field_obj + amount}) result = await self.db.execute(stmt) await self.db.commit() count = result.rowcount logger.info(f"Incremented {field} by {amount} for {count} {model_class.__name__} record(s)") return count async def decrement(self, model_class: Type[T], field: str, amount: int = 1, **filters) -> int: """Decrement a numeric field by a specified amount.""" return await self.increment(model_class, field, -amount, **filters) async def toggle_boolean(self, model_class: Type[T], field: str, **filters) -> int: """Toggle a boolean field (True -> False, False -> True).""" if not hasattr(model_class, field): raise ValueError(f"{model_class.__name__} has no attribute '{field}'") self._verify_operation_access(model_class, 'update') field_obj = getattr(model_class, field) stmt = update(model_class).where(model_class.deleted_at == None) stmt = self._apply_ownership_filter(stmt, model_class, 'updat') for key, value in filters.items(): if not hasattr(model_class, key): raise ValueError(f"{model_class.__name__} has no attribute '{key}'") stmt = stmt.where(getattr(model_class, key) == value) stmt = stmt.values({field: ~field_obj}) result = await self.db.execute(stmt) await self.db.commit() count = result.rowcount logger.info(f"Toggled {field} for {count} {model_class.__name__} record(s)") return count