apigateway / services /db_service /update_query.py
jebin2's picture
db services
50c20bf
"""
UpdateQuery - Update operations with access control.
Inherits from BaseQuery for shared filtering logic.
"""
import logging
from typing import Type, TypeVar, Dict, Any
from fastapi import HTTPException, status as http_status
from sqlalchemy import update
from services.db_service.base_query import BaseQuery
logger = logging.getLogger(__name__)
T = TypeVar('T')
class UpdateQuery(BaseQuery):
"""
Handles UPDATE operations with automatic access control.
Inherits filtering logic from BaseQuery:
- User ownership checks (_is_admin)
- Deleted record protection
"""
async def update(self, model_class: Type[T], values: Dict[str, Any], **filters) -> int:
"""
Update records matching filters with new values.
Automatically applies ownership filter for non-admins.
Args:
model_class: Model class to update
values: Dictionary of field=value pairs to update
**filters: Field=value filters to match records
Returns:
Number of records updated
Raises:
HTTPException: 403 if user doesn't have permission
ValueError: If invalid field names provided
"""
# Validate that fields exist
for key in values.keys():
if not hasattr(model_class, key):
raise ValueError(f"{model_class.__name__} has no attribute '{key}'")
# Verify user has permission to update this model
self._verify_operation_access(model_class, 'update')
# Build update statement
delete_col = getattr(model_class, self._config.soft_delete_column)
stmt = update(model_class).where(
delete_col == None # Don't update deleted records
)
# Apply ownership filter (shared method from BaseQuery)
stmt = self._apply_ownership_filter(stmt, model_class, 'updat')
# Apply user's filters
for key, value in filters.items():
if not hasattr(model_class, key):
raise ValueError(f"{model_class.__name__} has no attribute '{key}'")
stmt = stmt.where(getattr(model_class, key) == value)
# Apply values
stmt = stmt.values(**values)
# Execute
result = await self.db.execute(stmt)
await self.db.commit()
count = result.rowcount
logger.info(f"Updated {count} {model_class.__name__} record(s)")
return count
async def update_one(self, instance: T, values: Dict[str, Any]) -> T:
"""
Update a single model instance with validation.
Args:
instance: Model instance to update
values: Dictionary of field=value pairs to update
Returns:
Updated model instance
Raises:
HTTPException: 403 if user doesn't have permission
ValueError: If invalid field names provided
"""
# Check if deleted
delete_column = self._config.soft_delete_column
if hasattr(instance, delete_column) and getattr(instance, delete_column) is not None:
raise HTTPException(
status_code=http_status.HTTP_400_BAD_REQUEST,
detail="Cannot update deleted record"
)
# Verify user has permission
self._verify_operation_access(instance.__class__, 'update')
# Check ownership for non-admins
filter_column = self._config.user_filter_column
if not self.is_admin and hasattr(instance, filter_column):
if getattr(instance, filter_column) != self.user.id:
raise HTTPException(
status_code=http_status.HTTP_403_FORBIDDEN,
detail="You do not have permission to update this record"
)
# Validate fields
for key in values.keys():
if not hasattr(instance, key):
raise ValueError(f"{instance.__class__.__name__} has no attribute '{key}'")
# Apply updates
for key, value in values.items():
setattr(instance, key, value)
# Commit
await self.db.commit()
await self.db.refresh(instance)
logger.info(f"Updated {instance.__class__.__name__} instance")
return instance
async def increment(self, model_class: Type[T], field: str, amount: int = 1, **filters) -> int:
"""
Increment a numeric field by a specified amount.
"""
if not hasattr(model_class, field):
raise ValueError(f"{model_class.__name__} has no attribute '{field}'")
self._verify_operation_access(model_class, 'update')
field_obj = getattr(model_class, field)
delete_col = getattr(model_class, self._config.soft_delete_column)
stmt = update(model_class).where(delete_col == None)
stmt = self._apply_ownership_filter(stmt, model_class, 'updat')
for key, value in filters.items():
if not hasattr(model_class, key):
raise ValueError(f"{model_class.__name__} has no attribute '{key}'")
stmt = stmt.where(getattr(model_class, key) == value)
stmt = stmt.values({field: field_obj + amount})
result = await self.db.execute(stmt)
await self.db.commit()
count = result.rowcount
logger.info(f"Incremented {field} by {amount} for {count} {model_class.__name__} record(s)")
return count
async def decrement(self, model_class: Type[T], field: str, amount: int = 1, **filters) -> int:
"""Decrement a numeric field by a specified amount."""
return await self.increment(model_class, field, -amount, **filters)
async def toggle_boolean(self, model_class: Type[T], field: str, **filters) -> int:
"""Toggle a boolean field (True -> False, False -> True)."""
if not hasattr(model_class, field):
raise ValueError(f"{model_class.__name__} has no attribute '{field}'")
self._verify_operation_access(model_class, 'update')
field_obj = getattr(model_class, field)
stmt = update(model_class).where(model_class.deleted_at == None)
stmt = self._apply_ownership_filter(stmt, model_class, 'updat')
for key, value in filters.items():
if not hasattr(model_class, key):
raise ValueError(f"{model_class.__name__} has no attribute '{key}'")
stmt = stmt.where(getattr(model_class, key) == value)
stmt = stmt.values({field: ~field_obj})
result = await self.db.execute(stmt)
await self.db.commit()
count = result.rowcount
logger.info(f"Toggled {field} for {count} {model_class.__name__} record(s)")
return count