Spaces:
Sleeping
Sleeping
File size: 7,040 Bytes
50c20bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
"""
UpdateQuery - Update operations with access control.
Inherits from BaseQuery for shared filtering logic.
"""
import logging
from typing import Type, TypeVar, Dict, Any
from fastapi import HTTPException, status as http_status
from sqlalchemy import update
from services.db_service.base_query import BaseQuery
logger = logging.getLogger(__name__)
T = TypeVar('T')
class UpdateQuery(BaseQuery):
"""
Handles UPDATE operations with automatic access control.
Inherits filtering logic from BaseQuery:
- User ownership checks (_is_admin)
- Deleted record protection
"""
async def update(self, model_class: Type[T], values: Dict[str, Any], **filters) -> int:
"""
Update records matching filters with new values.
Automatically applies ownership filter for non-admins.
Args:
model_class: Model class to update
values: Dictionary of field=value pairs to update
**filters: Field=value filters to match records
Returns:
Number of records updated
Raises:
HTTPException: 403 if user doesn't have permission
ValueError: If invalid field names provided
"""
# Validate that fields exist
for key in values.keys():
if not hasattr(model_class, key):
raise ValueError(f"{model_class.__name__} has no attribute '{key}'")
# Verify user has permission to update this model
self._verify_operation_access(model_class, 'update')
# Build update statement
delete_col = getattr(model_class, self._config.soft_delete_column)
stmt = update(model_class).where(
delete_col == None # Don't update deleted records
)
# Apply ownership filter (shared method from BaseQuery)
stmt = self._apply_ownership_filter(stmt, model_class, 'updat')
# Apply user's filters
for key, value in filters.items():
if not hasattr(model_class, key):
raise ValueError(f"{model_class.__name__} has no attribute '{key}'")
stmt = stmt.where(getattr(model_class, key) == value)
# Apply values
stmt = stmt.values(**values)
# Execute
result = await self.db.execute(stmt)
await self.db.commit()
count = result.rowcount
logger.info(f"Updated {count} {model_class.__name__} record(s)")
return count
async def update_one(self, instance: T, values: Dict[str, Any]) -> T:
"""
Update a single model instance with validation.
Args:
instance: Model instance to update
values: Dictionary of field=value pairs to update
Returns:
Updated model instance
Raises:
HTTPException: 403 if user doesn't have permission
ValueError: If invalid field names provided
"""
# Check if deleted
delete_column = self._config.soft_delete_column
if hasattr(instance, delete_column) and getattr(instance, delete_column) is not None:
raise HTTPException(
status_code=http_status.HTTP_400_BAD_REQUEST,
detail="Cannot update deleted record"
)
# Verify user has permission
self._verify_operation_access(instance.__class__, 'update')
# Check ownership for non-admins
filter_column = self._config.user_filter_column
if not self.is_admin and hasattr(instance, filter_column):
if getattr(instance, filter_column) != self.user.id:
raise HTTPException(
status_code=http_status.HTTP_403_FORBIDDEN,
detail="You do not have permission to update this record"
)
# Validate fields
for key in values.keys():
if not hasattr(instance, key):
raise ValueError(f"{instance.__class__.__name__} has no attribute '{key}'")
# Apply updates
for key, value in values.items():
setattr(instance, key, value)
# Commit
await self.db.commit()
await self.db.refresh(instance)
logger.info(f"Updated {instance.__class__.__name__} instance")
return instance
async def increment(self, model_class: Type[T], field: str, amount: int = 1, **filters) -> int:
"""
Increment a numeric field by a specified amount.
"""
if not hasattr(model_class, field):
raise ValueError(f"{model_class.__name__} has no attribute '{field}'")
self._verify_operation_access(model_class, 'update')
field_obj = getattr(model_class, field)
delete_col = getattr(model_class, self._config.soft_delete_column)
stmt = update(model_class).where(delete_col == None)
stmt = self._apply_ownership_filter(stmt, model_class, 'updat')
for key, value in filters.items():
if not hasattr(model_class, key):
raise ValueError(f"{model_class.__name__} has no attribute '{key}'")
stmt = stmt.where(getattr(model_class, key) == value)
stmt = stmt.values({field: field_obj + amount})
result = await self.db.execute(stmt)
await self.db.commit()
count = result.rowcount
logger.info(f"Incremented {field} by {amount} for {count} {model_class.__name__} record(s)")
return count
async def decrement(self, model_class: Type[T], field: str, amount: int = 1, **filters) -> int:
"""Decrement a numeric field by a specified amount."""
return await self.increment(model_class, field, -amount, **filters)
async def toggle_boolean(self, model_class: Type[T], field: str, **filters) -> int:
"""Toggle a boolean field (True -> False, False -> True)."""
if not hasattr(model_class, field):
raise ValueError(f"{model_class.__name__} has no attribute '{field}'")
self._verify_operation_access(model_class, 'update')
field_obj = getattr(model_class, field)
stmt = update(model_class).where(model_class.deleted_at == None)
stmt = self._apply_ownership_filter(stmt, model_class, 'updat')
for key, value in filters.items():
if not hasattr(model_class, key):
raise ValueError(f"{model_class.__name__} has no attribute '{key}'")
stmt = stmt.where(getattr(model_class, key) == value)
stmt = stmt.values({field: ~field_obj})
result = await self.db.execute(stmt)
await self.db.commit()
count = result.rowcount
logger.info(f"Toggled {field} for {count} {model_class.__name__} record(s)")
return count
|