Spaces:
Sleeping
Sleeping
db services
Browse files- app.py +15 -4
- core/models.py +7 -0
- routers/blink.py +102 -94
- routers/gemini.py +38 -20
- services/db_service/__init__.py +59 -0
- services/db_service/base_query.py +233 -0
- services/db_service/config.py +135 -0
- services/db_service/db_init.py +74 -0
- services/db_service/delete_query.py +160 -0
- services/db_service/query_service.py +78 -0
- services/db_service/register_config.py +121 -0
- services/db_service/select_query.py +98 -0
- services/db_service/update_query.py +188 -0
- tests/README.md +78 -0
- tests/requirements-test.txt +3 -0
- tests/test_db_service.py +338 -0
app.py
CHANGED
|
@@ -11,9 +11,11 @@ from fastapi import FastAPI, Request
|
|
| 11 |
from fastapi.middleware.cors import CORSMiddleware
|
| 12 |
from fastapi.responses import JSONResponse
|
| 13 |
|
| 14 |
-
from core.database import
|
| 15 |
from routers import auth, blink, contact, credits, general, gemini, payments, schema
|
| 16 |
from services.drive_service import DriveService
|
|
|
|
|
|
|
| 17 |
|
| 18 |
# Configure logging
|
| 19 |
logging.basicConfig(
|
|
@@ -33,12 +35,20 @@ async def lifespan(app: FastAPI):
|
|
| 33 |
"""
|
| 34 |
logger.info("Starting up - initializing database...")
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
# Check for RESET_DB environment variable
|
| 37 |
if os.getenv("RESET_DB", "").lower() == "true":
|
| 38 |
logger.warning(f"RESET_DB is set to true. Skipping download and clearing local database ({DB_FILENAME}).")
|
| 39 |
if os.path.exists(DB_FILENAME):
|
| 40 |
os.remove(DB_FILENAME)
|
| 41 |
logger.info("Local database deleted.")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
else:
|
| 43 |
# Startup: Download DB from Drive ONLY if local file doesn't exist
|
| 44 |
if not os.path.exists(DB_FILENAME):
|
|
@@ -46,9 +56,10 @@ async def lifespan(app: FastAPI):
|
|
| 46 |
drive_service.download_db()
|
| 47 |
else:
|
| 48 |
logger.info("Startup: Local DB found. Skipping download to preserve data.")
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
|
|
|
| 52 |
|
| 53 |
# Start background job worker
|
| 54 |
from services.gemini_job_worker import start_worker, stop_worker
|
|
|
|
| 11 |
from fastapi.middleware.cors import CORSMiddleware
|
| 12 |
from fastapi.responses import JSONResponse
|
| 13 |
|
| 14 |
+
from core.database import engine, DB_FILENAME
|
| 15 |
from routers import auth, blink, contact, credits, general, gemini, payments, schema
|
| 16 |
from services.drive_service import DriveService
|
| 17 |
+
from services.db_service import init_database, reset_database
|
| 18 |
+
from services.db_service.register_config import register_db_service_config
|
| 19 |
|
| 20 |
# Configure logging
|
| 21 |
logging.basicConfig(
|
|
|
|
| 35 |
"""
|
| 36 |
logger.info("Starting up - initializing database...")
|
| 37 |
|
| 38 |
+
# Register DB Service configuration
|
| 39 |
+
register_db_service_config()
|
| 40 |
+
logger.info("✅ DB Service configured")
|
| 41 |
+
|
| 42 |
# Check for RESET_DB environment variable
|
| 43 |
if os.getenv("RESET_DB", "").lower() == "true":
|
| 44 |
logger.warning(f"RESET_DB is set to true. Skipping download and clearing local database ({DB_FILENAME}).")
|
| 45 |
if os.path.exists(DB_FILENAME):
|
| 46 |
os.remove(DB_FILENAME)
|
| 47 |
logger.info("Local database deleted.")
|
| 48 |
+
|
| 49 |
+
# Reset database (drop + create all tables)
|
| 50 |
+
await reset_database(engine)
|
| 51 |
+
logger.info("✅ Database reset complete")
|
| 52 |
else:
|
| 53 |
# Startup: Download DB from Drive ONLY if local file doesn't exist
|
| 54 |
if not os.path.exists(DB_FILENAME):
|
|
|
|
| 56 |
drive_service.download_db()
|
| 57 |
else:
|
| 58 |
logger.info("Startup: Local DB found. Skipping download to preserve data.")
|
| 59 |
+
|
| 60 |
+
# Initialize database (create tables if not exist)
|
| 61 |
+
await init_database(engine)
|
| 62 |
+
logger.info("✅ Database initialized")
|
| 63 |
|
| 64 |
# Start background job worker
|
| 65 |
from services.gemini_job_worker import start_worker, stop_worker
|
core/models.py
CHANGED
|
@@ -47,6 +47,7 @@ class User(Base):
|
|
| 47 |
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
| 48 |
last_used_at = Column(DateTime(timezone=True), nullable=True)
|
| 49 |
is_active = Column(Boolean, default=True)
|
|
|
|
| 50 |
|
| 51 |
# Relationships
|
| 52 |
client_users = relationship("ClientUser", back_populates="user", lazy="dynamic")
|
|
@@ -86,6 +87,7 @@ class ClientUser(Base):
|
|
| 86 |
# Timestamps
|
| 87 |
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
| 88 |
last_seen_at = Column(DateTime(timezone=True), nullable=True)
|
|
|
|
| 89 |
|
| 90 |
# Relationship
|
| 91 |
user = relationship("User", back_populates="client_users")
|
|
@@ -131,6 +133,7 @@ class AuditLog(Base):
|
|
| 131 |
|
| 132 |
# Timestamp
|
| 133 |
timestamp = Column(DateTime(timezone=True), server_default=func.now(), index=True)
|
|
|
|
| 134 |
|
| 135 |
# Relationship
|
| 136 |
user = relationship("User", back_populates="audit_logs")
|
|
@@ -174,6 +177,7 @@ class GeminiJob(Base):
|
|
| 174 |
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
| 175 |
started_at = Column(DateTime(timezone=True), nullable=True)
|
| 176 |
completed_at = Column(DateTime(timezone=True), nullable=True)
|
|
|
|
| 177 |
|
| 178 |
# Credit tracking for reservation pattern
|
| 179 |
credits_reserved = Column(Integer, default=0) # Credits reserved for this job
|
|
@@ -219,6 +223,7 @@ class PaymentTransaction(Base):
|
|
| 219 |
# Timestamps
|
| 220 |
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
| 221 |
paid_at = Column(DateTime(timezone=True), nullable=True)
|
|
|
|
| 222 |
|
| 223 |
# Metadata
|
| 224 |
razorpay_signature = Column(String(255), nullable=True) # For verification audit
|
|
@@ -250,6 +255,7 @@ class Contact(Base):
|
|
| 250 |
message = Column(Text, nullable=False)
|
| 251 |
ip_address = Column(String(45), nullable=True) # IPv6 can be up to 45 chars
|
| 252 |
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
|
|
|
| 253 |
|
| 254 |
# Relationship
|
| 255 |
user = relationship("User", back_populates="contacts")
|
|
@@ -270,6 +276,7 @@ class RateLimit(Base):
|
|
| 270 |
attempts = Column(Integer, default=0)
|
| 271 |
window_start = Column(DateTime(timezone=True), nullable=False)
|
| 272 |
expires_at = Column(DateTime(timezone=True), nullable=False)
|
|
|
|
| 273 |
|
| 274 |
|
| 275 |
class ApiKeyUsage(Base):
|
|
|
|
| 47 |
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
| 48 |
last_used_at = Column(DateTime(timezone=True), nullable=True)
|
| 49 |
is_active = Column(Boolean, default=True)
|
| 50 |
+
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True) # Soft delete timestamp
|
| 51 |
|
| 52 |
# Relationships
|
| 53 |
client_users = relationship("ClientUser", back_populates="user", lazy="dynamic")
|
|
|
|
| 87 |
# Timestamps
|
| 88 |
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
| 89 |
last_seen_at = Column(DateTime(timezone=True), nullable=True)
|
| 90 |
+
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True) # Soft delete timestamp
|
| 91 |
|
| 92 |
# Relationship
|
| 93 |
user = relationship("User", back_populates="client_users")
|
|
|
|
| 133 |
|
| 134 |
# Timestamp
|
| 135 |
timestamp = Column(DateTime(timezone=True), server_default=func.now(), index=True)
|
| 136 |
+
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True) # Soft delete timestamp
|
| 137 |
|
| 138 |
# Relationship
|
| 139 |
user = relationship("User", back_populates="audit_logs")
|
|
|
|
| 177 |
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
| 178 |
started_at = Column(DateTime(timezone=True), nullable=True)
|
| 179 |
completed_at = Column(DateTime(timezone=True), nullable=True)
|
| 180 |
+
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True) # Soft delete timestamp
|
| 181 |
|
| 182 |
# Credit tracking for reservation pattern
|
| 183 |
credits_reserved = Column(Integer, default=0) # Credits reserved for this job
|
|
|
|
| 223 |
# Timestamps
|
| 224 |
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
| 225 |
paid_at = Column(DateTime(timezone=True), nullable=True)
|
| 226 |
+
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True) # Soft delete timestamp
|
| 227 |
|
| 228 |
# Metadata
|
| 229 |
razorpay_signature = Column(String(255), nullable=True) # For verification audit
|
|
|
|
| 255 |
message = Column(Text, nullable=False)
|
| 256 |
ip_address = Column(String(45), nullable=True) # IPv6 can be up to 45 chars
|
| 257 |
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
| 258 |
+
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True) # Soft delete timestamp
|
| 259 |
|
| 260 |
# Relationship
|
| 261 |
user = relationship("User", back_populates="contacts")
|
|
|
|
| 276 |
attempts = Column(Integer, default=0)
|
| 277 |
window_start = Column(DateTime(timezone=True), nullable=False)
|
| 278 |
expires_at = Column(DateTime(timezone=True), nullable=False)
|
| 279 |
+
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True) # Soft delete timestamp
|
| 280 |
|
| 281 |
|
| 282 |
class ApiKeyUsage(Base):
|
routers/blink.py
CHANGED
|
@@ -11,7 +11,7 @@ import logging
|
|
| 11 |
from core.database import get_db
|
| 12 |
from core.models import User, AuditLog, GeminiJob, Contact, ClientUser
|
| 13 |
from services.encryption_service import decrypt_multiple_blocks
|
| 14 |
-
from dependencies import get_geolocation, get_optional_user
|
| 15 |
|
| 16 |
logger = logging.getLogger(__name__)
|
| 17 |
|
|
@@ -30,37 +30,36 @@ async def get_data(
|
|
| 30 |
page: int = Query(1, ge=1, description="Page number"),
|
| 31 |
limit: int = Query(100, ge=1, le=500, description="Items per page"),
|
| 32 |
log_type: str = Query(None, description="Filter by log type: client, server"),
|
|
|
|
| 33 |
db: AsyncSession = Depends(get_db)
|
| 34 |
):
|
| 35 |
"""
|
| 36 |
-
Get paginated audit log data.
|
| 37 |
-
|
| 38 |
"""
|
|
|
|
|
|
|
| 39 |
try:
|
|
|
|
| 40 |
offset = (page - 1) * limit
|
| 41 |
|
| 42 |
# Build query with optional filter
|
| 43 |
-
|
| 44 |
-
count_query = select(func.count(AuditLog.id))
|
| 45 |
-
|
| 46 |
if log_type:
|
| 47 |
-
|
| 48 |
-
count_query = count_query.where(AuditLog.log_type == log_type)
|
| 49 |
|
| 50 |
-
#
|
| 51 |
-
|
| 52 |
-
|
| 53 |
|
| 54 |
-
#
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
|
| 60 |
-
#
|
| 61 |
-
|
| 62 |
-
result = await db.execute(query)
|
| 63 |
-
items = result.scalars().all()
|
| 64 |
|
| 65 |
return {
|
| 66 |
"items": [
|
|
@@ -96,22 +95,29 @@ async def get_data(
|
|
| 96 |
async def get_users(
|
| 97 |
page: int = Query(1, ge=1, description="Page number"),
|
| 98 |
limit: int = Query(50, ge=1, le=500, description="Items per page"),
|
|
|
|
| 99 |
db: AsyncSession = Depends(get_db)
|
| 100 |
):
|
| 101 |
"""
|
| 102 |
-
Get
|
|
|
|
| 103 |
"""
|
|
|
|
|
|
|
| 104 |
try:
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
return {
|
| 117 |
"items": [
|
|
@@ -145,31 +151,23 @@ async def get_client_users(
|
|
| 145 |
page: int = Query(1, ge=1, description="Page number"),
|
| 146 |
limit: int = Query(50, ge=1, le=500, description="Items per page"),
|
| 147 |
user_id: str = Query(None, description="Filter by server user_id"),
|
|
|
|
| 148 |
db: AsyncSession = Depends(get_db)
|
| 149 |
):
|
| 150 |
"""
|
| 151 |
-
Get
|
| 152 |
-
|
| 153 |
"""
|
|
|
|
|
|
|
| 154 |
try:
|
|
|
|
| 155 |
offset = (page - 1) * limit
|
| 156 |
|
| 157 |
-
#
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
if user_id:
|
| 162 |
-
base_query = base_query.where(ClientUser.user_id == user_id)
|
| 163 |
-
count_query = count_query.where(ClientUser.user_id == user_id)
|
| 164 |
-
|
| 165 |
-
# Get total count
|
| 166 |
-
total_result = await db.execute(count_query)
|
| 167 |
-
total = total_result.scalar() or 0
|
| 168 |
-
|
| 169 |
-
# Get paginated items
|
| 170 |
-
query = base_query.order_by(ClientUser.id.desc()).offset(offset).limit(limit)
|
| 171 |
-
result = await db.execute(query)
|
| 172 |
-
items = result.scalars().all()
|
| 173 |
|
| 174 |
return {
|
| 175 |
"items": [
|
|
@@ -203,34 +201,37 @@ async def get_audit_logs(
|
|
| 203 |
limit: int = Query(50, ge=1, le=500, description="Items per page"),
|
| 204 |
log_type: str = Query(None, description="Filter by log type: client, server"),
|
| 205 |
action: str = Query(None, description="Filter by action"),
|
|
|
|
| 206 |
db: AsyncSession = Depends(get_db)
|
| 207 |
):
|
| 208 |
"""
|
| 209 |
-
Get
|
|
|
|
| 210 |
"""
|
|
|
|
|
|
|
| 211 |
try:
|
|
|
|
| 212 |
offset = (page - 1) * limit
|
| 213 |
|
| 214 |
# Build query with filters
|
| 215 |
-
|
| 216 |
-
|
|
|
|
|
|
|
|
|
|
| 217 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
if log_type:
|
| 219 |
-
base_query = base_query.where(AuditLog.log_type == log_type)
|
| 220 |
count_query = count_query.where(AuditLog.log_type == log_type)
|
| 221 |
-
|
| 222 |
if action:
|
| 223 |
-
base_query = base_query.where(AuditLog.action == action)
|
| 224 |
count_query = count_query.where(AuditLog.action == action)
|
| 225 |
-
|
| 226 |
-
# Get total count
|
| 227 |
-
total_result = await db.execute(count_query)
|
| 228 |
-
total = total_result.scalar() or 0
|
| 229 |
-
|
| 230 |
-
# Get paginated items
|
| 231 |
-
query = base_query.order_by(AuditLog.timestamp.desc()).offset(offset).limit(limit)
|
| 232 |
-
result = await db.execute(query)
|
| 233 |
-
items = result.scalars().all()
|
| 234 |
|
| 235 |
return {
|
| 236 |
"items": [
|
|
@@ -264,22 +265,25 @@ async def get_audit_logs(
|
|
| 264 |
async def get_gemini_jobs(
|
| 265 |
page: int = Query(1, ge=1, description="Page number"),
|
| 266 |
limit: int = Query(50, ge=1, le=500, description="Items per page"),
|
|
|
|
| 267 |
db: AsyncSession = Depends(get_db)
|
| 268 |
):
|
| 269 |
"""
|
| 270 |
-
Get
|
|
|
|
| 271 |
"""
|
|
|
|
|
|
|
| 272 |
try:
|
|
|
|
| 273 |
offset = (page - 1) * limit
|
| 274 |
|
| 275 |
-
#
|
| 276 |
-
total_result = await db.execute(select(func.count(GeminiJob.id)))
|
| 277 |
-
total = total_result.scalar() or 0
|
| 278 |
-
|
| 279 |
-
# Get paginated items
|
| 280 |
query = select(GeminiJob).order_by(GeminiJob.id.desc()).offset(offset).limit(limit)
|
| 281 |
-
|
| 282 |
-
|
|
|
|
|
|
|
| 283 |
|
| 284 |
return {
|
| 285 |
"items": [
|
|
@@ -311,32 +315,35 @@ async def get_gemini_jobs(
|
|
| 311 |
async def get_payment_transactions(
|
| 312 |
page: int = Query(1, ge=1, description="Page number"),
|
| 313 |
limit: int = Query(50, ge=1, le=500, description="Items per page"),
|
|
|
|
| 314 |
db: AsyncSession = Depends(get_db)
|
| 315 |
):
|
| 316 |
"""
|
| 317 |
-
Get
|
|
|
|
| 318 |
"""
|
| 319 |
from core.models import PaymentTransaction
|
|
|
|
| 320 |
|
| 321 |
try:
|
|
|
|
| 322 |
offset = (page - 1) * limit
|
| 323 |
|
| 324 |
-
#
|
| 325 |
-
|
| 326 |
-
total = total_result.scalar() or 0
|
| 327 |
|
| 328 |
-
# Get
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
PaymentTransaction.status == "paid"
|
| 332 |
-
)
|
| 333 |
)
|
| 334 |
-
|
|
|
|
| 335 |
|
| 336 |
-
# Get paginated items
|
| 337 |
-
query = select(PaymentTransaction).order_by(
|
| 338 |
-
|
| 339 |
-
|
|
|
|
| 340 |
|
| 341 |
return {
|
| 342 |
"items": [
|
|
@@ -378,22 +385,23 @@ async def get_payment_transactions(
|
|
| 378 |
async def get_contacts(
|
| 379 |
page: int = Query(1, ge=1, description="Page number"),
|
| 380 |
limit: int = Query(50, ge=1, le=500, description="Items per page"),
|
|
|
|
| 381 |
db: AsyncSession = Depends(get_db)
|
| 382 |
):
|
| 383 |
"""
|
| 384 |
-
Get
|
|
|
|
| 385 |
"""
|
|
|
|
|
|
|
| 386 |
try:
|
|
|
|
| 387 |
offset = (page - 1) * limit
|
| 388 |
|
| 389 |
-
#
|
| 390 |
-
total_result = await db.execute(select(func.count(Contact.id)))
|
| 391 |
-
total = total_result.scalar() or 0
|
| 392 |
-
|
| 393 |
-
# Get paginated items
|
| 394 |
query = select(Contact).order_by(Contact.id.desc()).offset(offset).limit(limit)
|
| 395 |
-
|
| 396 |
-
|
| 397 |
|
| 398 |
return {
|
| 399 |
"items": [
|
|
|
|
| 11 |
from core.database import get_db
|
| 12 |
from core.models import User, AuditLog, GeminiJob, Contact, ClientUser
|
| 13 |
from services.encryption_service import decrypt_multiple_blocks
|
| 14 |
+
from dependencies import get_geolocation, get_optional_user, get_current_user
|
| 15 |
|
| 16 |
logger = logging.getLogger(__name__)
|
| 17 |
|
|
|
|
| 30 |
page: int = Query(1, ge=1, description="Page number"),
|
| 31 |
limit: int = Query(100, ge=1, le=500, description="Items per page"),
|
| 32 |
log_type: str = Query(None, description="Filter by log type: client, server"),
|
| 33 |
+
user: User = Depends(get_current_user), # Auth required
|
| 34 |
db: AsyncSession = Depends(get_db)
|
| 35 |
):
|
| 36 |
"""
|
| 37 |
+
Get paginated audit log data for the authenticated user.
|
| 38 |
+
Admins see all logs from all users.
|
| 39 |
"""
|
| 40 |
+
from services.db_service import QueryService
|
| 41 |
+
|
| 42 |
try:
|
| 43 |
+
qs = QueryService(user, db)
|
| 44 |
offset = (page - 1) * limit
|
| 45 |
|
| 46 |
# Build query with optional filter
|
| 47 |
+
query = select(AuditLog)
|
|
|
|
|
|
|
| 48 |
if log_type:
|
| 49 |
+
query = query.where(AuditLog.log_type == log_type)
|
|
|
|
| 50 |
|
| 51 |
+
# Automatically filtered by user (unless admin)
|
| 52 |
+
query = query.order_by(AuditLog.timestamp.desc()).offset(offset).limit(limit)
|
| 53 |
+
items = await qs.select().execute(query)
|
| 54 |
|
| 55 |
+
# Count also automatically filtered
|
| 56 |
+
count_query = select(AuditLog)
|
| 57 |
+
if log_type:
|
| 58 |
+
count_query = count_query.where(AuditLog.log_type == log_type)
|
| 59 |
+
total = await qs.select().count(count_query)
|
| 60 |
|
| 61 |
+
# Unique users: 1 for regular users, multiple for admins
|
| 62 |
+
unique_users = 1 if not qs.select().is_admin else total
|
|
|
|
|
|
|
| 63 |
|
| 64 |
return {
|
| 65 |
"items": [
|
|
|
|
| 95 |
async def get_users(
|
| 96 |
page: int = Query(1, ge=1, description="Page number"),
|
| 97 |
limit: int = Query(50, ge=1, le=500, description="Items per page"),
|
| 98 |
+
user: User = Depends(get_current_user), # Auth required
|
| 99 |
db: AsyncSession = Depends(get_db)
|
| 100 |
):
|
| 101 |
"""
|
| 102 |
+
Get current user's profile data.
|
| 103 |
+
Admins see paginated list of all users.
|
| 104 |
"""
|
| 105 |
+
from services.db_service import QueryService
|
| 106 |
+
|
| 107 |
try:
|
| 108 |
+
qs = QueryService(user, db)
|
| 109 |
+
|
| 110 |
+
if qs.select().is_admin:
|
| 111 |
+
# Admins get paginated list of all users
|
| 112 |
+
offset = (page - 1) * limit
|
| 113 |
+
query = select(User).order_by(User.id.desc()).offset(offset).limit(limit)
|
| 114 |
+
items = await qs.select().execute(query) # No filter for admin
|
| 115 |
+
total = await qs.select().count(select(User))
|
| 116 |
+
else:
|
| 117 |
+
# Regular users only see their own profile
|
| 118 |
+
query = select(User)
|
| 119 |
+
items = await qs.select().execute(query) # Filtered to current user
|
| 120 |
+
total = 1
|
| 121 |
|
| 122 |
return {
|
| 123 |
"items": [
|
|
|
|
| 151 |
page: int = Query(1, ge=1, description="Page number"),
|
| 152 |
limit: int = Query(50, ge=1, le=500, description="Items per page"),
|
| 153 |
user_id: str = Query(None, description="Filter by server user_id"),
|
| 154 |
+
user: User = Depends(get_current_user), # Auth required
|
| 155 |
db: AsyncSession = Depends(get_db)
|
| 156 |
):
|
| 157 |
"""
|
| 158 |
+
Get current user's client mappings.
|
| 159 |
+
Admins see all client mappings from all users.
|
| 160 |
"""
|
| 161 |
+
from services.db_service import QueryService
|
| 162 |
+
|
| 163 |
try:
|
| 164 |
+
qs = QueryService(user, db)
|
| 165 |
offset = (page - 1) * limit
|
| 166 |
|
| 167 |
+
# Automatically filtered by user (unless admin)
|
| 168 |
+
query = select(ClientUser).order_by(ClientUser.id.desc()).offset(offset).limit(limit)
|
| 169 |
+
items = await qs.select().execute(query)
|
| 170 |
+
total = await qs.select().count(select(ClientUser))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
|
| 172 |
return {
|
| 173 |
"items": [
|
|
|
|
| 201 |
limit: int = Query(50, ge=1, le=500, description="Items per page"),
|
| 202 |
log_type: str = Query(None, description="Filter by log type: client, server"),
|
| 203 |
action: str = Query(None, description="Filter by action"),
|
| 204 |
+
user: User = Depends(get_current_user), # Auth required
|
| 205 |
db: AsyncSession = Depends(get_db)
|
| 206 |
):
|
| 207 |
"""
|
| 208 |
+
Get current user's audit logs with optional filters.
|
| 209 |
+
Admins see all logs from all users.
|
| 210 |
"""
|
| 211 |
+
from services.db_service import QueryService
|
| 212 |
+
|
| 213 |
try:
|
| 214 |
+
qs = QueryService(user, db)
|
| 215 |
offset = (page - 1) * limit
|
| 216 |
|
| 217 |
# Build query with filters
|
| 218 |
+
query = select(AuditLog)
|
| 219 |
+
if log_type:
|
| 220 |
+
query = query.where(AuditLog.log_type == log_type)
|
| 221 |
+
if action:
|
| 222 |
+
query = query.where(AuditLog.action == action)
|
| 223 |
|
| 224 |
+
# Automatically filtered by user (unless admin)
|
| 225 |
+
query = query.order_by(AuditLog.timestamp.desc()).offset(offset).limit(limit)
|
| 226 |
+
items = await qs.select().execute(query)
|
| 227 |
+
|
| 228 |
+
# Count with same filters
|
| 229 |
+
count_query = select(AuditLog)
|
| 230 |
if log_type:
|
|
|
|
| 231 |
count_query = count_query.where(AuditLog.log_type == log_type)
|
|
|
|
| 232 |
if action:
|
|
|
|
| 233 |
count_query = count_query.where(AuditLog.action == action)
|
| 234 |
+
total = await qs.select().count(count_query)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
|
| 236 |
return {
|
| 237 |
"items": [
|
|
|
|
| 265 |
async def get_gemini_jobs(
|
| 266 |
page: int = Query(1, ge=1, description="Page number"),
|
| 267 |
limit: int = Query(50, ge=1, le=500, description="Items per page"),
|
| 268 |
+
user: User = Depends(get_current_user), # Auth required
|
| 269 |
db: AsyncSession = Depends(get_db)
|
| 270 |
):
|
| 271 |
"""
|
| 272 |
+
Get current user's Gemini jobs.
|
| 273 |
+
Admins see all jobs from all users.
|
| 274 |
"""
|
| 275 |
+
from services.db_service import QueryService
|
| 276 |
+
|
| 277 |
try:
|
| 278 |
+
qs = QueryService(user, db)
|
| 279 |
offset = (page - 1) * limit
|
| 280 |
|
| 281 |
+
# Automatically filtered by user (unless admin)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
query = select(GeminiJob).order_by(GeminiJob.id.desc()).offset(offset).limit(limit)
|
| 283 |
+
items = await qs.select().execute(query)
|
| 284 |
+
|
| 285 |
+
# Count also automatically filtered
|
| 286 |
+
total = await qs.select().count(select(GeminiJob))
|
| 287 |
|
| 288 |
return {
|
| 289 |
"items": [
|
|
|
|
| 315 |
async def get_payment_transactions(
|
| 316 |
page: int = Query(1, ge=1, description="Page number"),
|
| 317 |
limit: int = Query(50, ge=1, le=500, description="Items per page"),
|
| 318 |
+
user: User = Depends(get_current_user), # Auth required
|
| 319 |
db: AsyncSession = Depends(get_db)
|
| 320 |
):
|
| 321 |
"""
|
| 322 |
+
Get current user's payment transactions.
|
| 323 |
+
Admins see all transactions from all users.
|
| 324 |
"""
|
| 325 |
from core.models import PaymentTransaction
|
| 326 |
+
from services.db_service import QueryService
|
| 327 |
|
| 328 |
try:
|
| 329 |
+
qs = QueryService(user, db)
|
| 330 |
offset = (page - 1) * limit
|
| 331 |
|
| 332 |
+
# Automatically filtered by user (unless admin)
|
| 333 |
+
total = await qs.select().count(select(PaymentTransaction))
|
|
|
|
| 334 |
|
| 335 |
+
# Get paid transactions revenue - automatically filtered
|
| 336 |
+
revenue_query = select(func.sum(PaymentTransaction.amount_paise)).where(
|
| 337 |
+
PaymentTransaction.status == "paid"
|
|
|
|
|
|
|
| 338 |
)
|
| 339 |
+
revenue_result = await qs.execute_one(revenue_query)
|
| 340 |
+
total_revenue_paise = revenue_result or 0
|
| 341 |
|
| 342 |
+
# Get paginated items - automatically filtered
|
| 343 |
+
query = select(PaymentTransaction).order_by(
|
| 344 |
+
PaymentTransaction.id.desc()
|
| 345 |
+
).offset(offset).limit(limit)
|
| 346 |
+
items = await qs.select().execute(query)
|
| 347 |
|
| 348 |
return {
|
| 349 |
"items": [
|
|
|
|
| 385 |
async def get_contacts(
|
| 386 |
page: int = Query(1, ge=1, description="Page number"),
|
| 387 |
limit: int = Query(50, ge=1, le=500, description="Items per page"),
|
| 388 |
+
user: User = Depends(get_current_user), # Auth required
|
| 389 |
db: AsyncSession = Depends(get_db)
|
| 390 |
):
|
| 391 |
"""
|
| 392 |
+
Get current user's contact form submissions.
|
| 393 |
+
Admins see all contact submissions from all users.
|
| 394 |
"""
|
| 395 |
+
from services.db_service import QueryService
|
| 396 |
+
|
| 397 |
try:
|
| 398 |
+
qs = QueryService(user, db)
|
| 399 |
offset = (page - 1) * limit
|
| 400 |
|
| 401 |
+
# Automatically filtered by user (unless admin)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 402 |
query = select(Contact).order_by(Contact.id.desc()).offset(offset).limit(limit)
|
| 403 |
+
items = await qs.select().execute(query)
|
| 404 |
+
total = await qs.select().count(select(Contact))
|
| 405 |
|
| 406 |
return {
|
| 407 |
"items": [
|
routers/gemini.py
CHANGED
|
@@ -516,16 +516,20 @@ async def delete_job(
|
|
| 516 |
db: AsyncSession = Depends(get_db)
|
| 517 |
):
|
| 518 |
"""
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
|
|
|
|
|
|
| 522 |
"""
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 526 |
)
|
| 527 |
-
result = await db.execute(query)
|
| 528 |
-
job = result.scalar_one_or_none()
|
| 529 |
|
| 530 |
if not job:
|
| 531 |
raise HTTPException(
|
|
@@ -548,23 +552,37 @@ async def delete_job(
|
|
| 548 |
# Job has third_party_id but is queued? (Unlikely for video, but maybe for others?)
|
| 549 |
# Or maybe it was reset to queued?
|
| 550 |
# Use existing logic: Refund 8 credits (10 - 2) for video
|
| 551 |
-
|
| 552 |
-
|
|
|
|
|
|
|
| 553 |
user.credits += refund_amount
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
message =
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 561 |
|
| 562 |
return {
|
| 563 |
"success": True,
|
| 564 |
-
"job_id": job_id,
|
| 565 |
"message": message,
|
| 566 |
-
"
|
| 567 |
-
"
|
| 568 |
}
|
| 569 |
|
| 570 |
|
|
|
|
| 516 |
db: AsyncSession = Depends(get_db)
|
| 517 |
):
|
| 518 |
"""
|
| 519 |
+
Soft delete a job with conditional credit refund.
|
| 520 |
+
|
| 521 |
+
Refund policy:
|
| 522 |
+
- If queued: Refund 8 credits (10 cost - 2 penalty), soft delete job.
|
| 523 |
+
- If processing/completed/failed: Soft delete job (no refund).
|
| 524 |
"""
|
| 525 |
+
from services.db_service import QueryService
|
| 526 |
+
|
| 527 |
+
qs = QueryService(user, db)
|
| 528 |
+
|
| 529 |
+
# Get job (automatically filtered to user's job)
|
| 530 |
+
job = await qs.select().execute_one(
|
| 531 |
+
select(GeminiJob).where(GeminiJob.job_id == job_id)
|
| 532 |
)
|
|
|
|
|
|
|
| 533 |
|
| 534 |
if not job:
|
| 535 |
raise HTTPException(
|
|
|
|
| 552 |
# Job has third_party_id but is queued? (Unlikely for video, but maybe for others?)
|
| 553 |
# Or maybe it was reset to queued?
|
| 554 |
# Use existing logic: Refund 8 credits (10 - 2) for video
|
| 555 |
+
penalty = 2 # Fixed penalty
|
| 556 |
+
refund_amount = max(0, job.credits_reserved - penalty)
|
| 557 |
+
|
| 558 |
+
if refund_amount > 0 and not job.credits_refunded:
|
| 559 |
user.credits += refund_amount
|
| 560 |
+
job.credits_refunded = True
|
| 561 |
+
message = f"Job deleted. {refund_amount} credits refunded (queued status)."
|
| 562 |
+
else:
|
| 563 |
+
message = "Job deleted (no refund - already refunded or 0 credit job)."
|
| 564 |
+
elif job.status in ["processing", "completed", "failed", "expired"]:
|
| 565 |
+
# No refund for jobs that started or completed
|
| 566 |
+
message = f"Job deleted (no refund for {job.status} jobs)."
|
| 567 |
+
|
| 568 |
+
# Commit credit refund if any
|
| 569 |
+
if refund_amount > 0:
|
| 570 |
+
await db.commit()
|
| 571 |
+
|
| 572 |
+
# Soft delete the job using DeleteQuery
|
| 573 |
+
deleted = await qs.delete().soft_delete_one(job)
|
| 574 |
+
|
| 575 |
+
if not deleted:
|
| 576 |
+
raise HTTPException(
|
| 577 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 578 |
+
detail="Failed to delete job"
|
| 579 |
+
)
|
| 580 |
|
| 581 |
return {
|
| 582 |
"success": True,
|
|
|
|
| 583 |
"message": message,
|
| 584 |
+
"refund_amount": refund_amount,
|
| 585 |
+
"new_credit_balance": user.credits
|
| 586 |
}
|
| 587 |
|
| 588 |
|
services/db_service/__init__.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DB Service - Database access layer with automatic filtering and access control.
|
| 3 |
+
|
| 4 |
+
Provides:
|
| 5 |
+
- DBServiceConfig: Configuration registration (call at startup)
|
| 6 |
+
- QueryService: Main entry point
|
| 7 |
+
- SelectQuery: Read operations
|
| 8 |
+
- UpdateQuery: Update operations
|
| 9 |
+
- DeleteQuery: Soft delete operations
|
| 10 |
+
- BaseQuery: Shared filtering logic
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
# 1. Register configuration at startup
|
| 14 |
+
from services.db_service import DBServiceConfig
|
| 15 |
+
from core.models import User, GeminiJob, ...
|
| 16 |
+
|
| 17 |
+
DBServiceConfig.register(
|
| 18 |
+
user_filter_column="user_id",
|
| 19 |
+
soft_delete_column="deleted_at",
|
| 20 |
+
user_read_scoped=[User, GeminiJob, ...],
|
| 21 |
+
...
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
# 2. Use in endpoints
|
| 25 |
+
from services.db_service import QueryService
|
| 26 |
+
|
| 27 |
+
qs = QueryService(user, db)
|
| 28 |
+
jobs = await qs.select().execute(select(GeminiJob))
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
from services.db_service.config import DBServiceConfig
|
| 32 |
+
from services.db_service.query_service import QueryService, get_query_service
|
| 33 |
+
from services.db_service.select_query import SelectQuery
|
| 34 |
+
from services.db_service.update_query import UpdateQuery
|
| 35 |
+
from services.db_service.delete_query import DeleteQuery
|
| 36 |
+
from services.db_service.base_query import BaseQuery
|
| 37 |
+
from services.db_service.db_init import (
|
| 38 |
+
init_database,
|
| 39 |
+
drop_database,
|
| 40 |
+
reset_database,
|
| 41 |
+
get_registered_models,
|
| 42 |
+
get_model_by_name,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
__all__ = [
|
| 46 |
+
'DBServiceConfig',
|
| 47 |
+
'QueryService',
|
| 48 |
+
'get_query_service',
|
| 49 |
+
'SelectQuery',
|
| 50 |
+
'UpdateQuery',
|
| 51 |
+
'DeleteQuery',
|
| 52 |
+
'BaseQuery',
|
| 53 |
+
# Database initialization
|
| 54 |
+
'init_database',
|
| 55 |
+
'drop_database',
|
| 56 |
+
'reset_database',
|
| 57 |
+
'get_registered_models',
|
| 58 |
+
'get_model_by_name',
|
| 59 |
+
]
|
services/db_service/base_query.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
BaseQuery - Base class for all query types with shared filtering logic.
|
| 3 |
+
|
| 4 |
+
Uses DBServiceConfig for plug-and-play configuration of models and columns.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import logging
|
| 9 |
+
from typing import Type
|
| 10 |
+
from fastapi import HTTPException, status as http_status
|
| 11 |
+
from sqlalchemy import Select
|
| 12 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 13 |
+
|
| 14 |
+
from core.models import User
|
| 15 |
+
from services.db_service.config import DBServiceConfig
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class BaseQuery:
|
| 21 |
+
"""
|
| 22 |
+
Base class for all query operations.
|
| 23 |
+
|
| 24 |
+
Uses DBServiceConfig for model scopes and column names.
|
| 25 |
+
Configuration must be registered at application startup.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self, user: User, db: AsyncSession, is_system: bool = False):
|
| 29 |
+
"""Initialize base query with user and database session."""
|
| 30 |
+
# Ensure config is registered
|
| 31 |
+
DBServiceConfig.assert_registered()
|
| 32 |
+
|
| 33 |
+
self.user = user
|
| 34 |
+
self.db = db
|
| 35 |
+
self._is_system = is_system
|
| 36 |
+
self._config = DBServiceConfig
|
| 37 |
+
|
| 38 |
+
@property
|
| 39 |
+
def is_admin(self) -> bool:
|
| 40 |
+
"""Check if current user is an admin."""
|
| 41 |
+
admin_emails_str = os.getenv("ADMIN_EMAILS", "")
|
| 42 |
+
if not admin_emails_str:
|
| 43 |
+
return False
|
| 44 |
+
|
| 45 |
+
admin_emails = [email.strip() for email in admin_emails_str.split(",")]
|
| 46 |
+
is_admin = self.user.email in admin_emails
|
| 47 |
+
|
| 48 |
+
if is_admin:
|
| 49 |
+
logger.info(f"Admin access granted for {self.user.email}")
|
| 50 |
+
|
| 51 |
+
return is_admin
|
| 52 |
+
|
| 53 |
+
def _verify_operation_access(self, model_class: Type, operation: str) -> None:
|
| 54 |
+
"""
|
| 55 |
+
Check if user has permission for this operation on this model.
|
| 56 |
+
|
| 57 |
+
Permission hierarchy: SYSTEM > ADMIN > USER
|
| 58 |
+
Uses DBServiceConfig for scope checking.
|
| 59 |
+
"""
|
| 60 |
+
# System operations have highest priority
|
| 61 |
+
if self._is_system:
|
| 62 |
+
logger.info(f"System operation: {operation} on {model_class.__name__}")
|
| 63 |
+
return
|
| 64 |
+
|
| 65 |
+
# Admins can do anything
|
| 66 |
+
if self.is_admin:
|
| 67 |
+
return
|
| 68 |
+
|
| 69 |
+
# Map operation to config scope sets
|
| 70 |
+
admin_only_sets = {
|
| 71 |
+
'read': self._config.admin_read_only,
|
| 72 |
+
'create': self._config.admin_create_only,
|
| 73 |
+
'update': self._config.admin_update_only,
|
| 74 |
+
'delete': self._config.admin_delete_only,
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
user_allowed_sets = {
|
| 78 |
+
'read': self._config.user_read_scoped,
|
| 79 |
+
'create': self._config.user_create_scoped,
|
| 80 |
+
'update': self._config.user_update_scoped,
|
| 81 |
+
'delete': self._config.user_delete_scoped,
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
system_only_sets = {
|
| 85 |
+
'read': self._config.system_read_scoped,
|
| 86 |
+
'create': self._config.system_create_scoped,
|
| 87 |
+
'update': self._config.system_update_scoped,
|
| 88 |
+
'delete': self._config.system_delete_scoped,
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
admin_only = admin_only_sets.get(operation, set())
|
| 92 |
+
user_allowed = user_allowed_sets.get(operation, set())
|
| 93 |
+
system_only = system_only_sets.get(operation, set())
|
| 94 |
+
|
| 95 |
+
# Check if this is system-only operation
|
| 96 |
+
if model_class in system_only and model_class not in user_allowed and model_class not in admin_only:
|
| 97 |
+
raise HTTPException(
|
| 98 |
+
status_code=http_status.HTTP_403_FORBIDDEN,
|
| 99 |
+
detail=f"Only system processes can {operation} {model_class.__name__}"
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# Check if admin-only
|
| 103 |
+
if model_class in admin_only:
|
| 104 |
+
raise HTTPException(
|
| 105 |
+
status_code=http_status.HTTP_403_FORBIDDEN,
|
| 106 |
+
detail=f"Only administrators can {operation} {model_class.__name__}"
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# Check if user is allowed
|
| 110 |
+
if model_class not in user_allowed:
|
| 111 |
+
raise HTTPException(
|
| 112 |
+
status_code=http_status.HTTP_403_FORBIDDEN,
|
| 113 |
+
detail=f"You do not have permission to {operation} {model_class.__name__}"
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
def _apply_ownership_filter(self, stmt, model_class: Type, operation: str):
|
| 117 |
+
"""
|
| 118 |
+
Shared method to apply ownership filter for UPDATE/DELETE operations.
|
| 119 |
+
Uses DBServiceConfig.user_filter_column for filtering.
|
| 120 |
+
"""
|
| 121 |
+
# Admins can modify all records
|
| 122 |
+
if self.is_admin:
|
| 123 |
+
logger.info(f"Admin {self.user.email} {operation}ing {model_class.__name__} records")
|
| 124 |
+
return stmt
|
| 125 |
+
|
| 126 |
+
# Non-admins can only modify their own records
|
| 127 |
+
filter_column = self._config.user_filter_column
|
| 128 |
+
if hasattr(model_class, filter_column):
|
| 129 |
+
user_id_col = getattr(model_class, filter_column)
|
| 130 |
+
stmt = stmt.where(user_id_col == self.user.id)
|
| 131 |
+
logger.info(f"User {self.user.email} {operation}ing own {model_class.__name__} records")
|
| 132 |
+
|
| 133 |
+
return stmt
|
| 134 |
+
|
| 135 |
+
def _verify_admin_access(self, query: Select) -> None:
|
| 136 |
+
"""
|
| 137 |
+
Check if query is for admin-only models (READ operation).
|
| 138 |
+
Uses DBServiceConfig.admin_read_only.
|
| 139 |
+
"""
|
| 140 |
+
if self.is_admin:
|
| 141 |
+
return
|
| 142 |
+
|
| 143 |
+
# Check if querying admin-only models
|
| 144 |
+
if hasattr(query, 'column_descriptions'):
|
| 145 |
+
for description in query.column_descriptions:
|
| 146 |
+
entity = description.get('entity') or description.get('type')
|
| 147 |
+
if entity in self._config.admin_read_only:
|
| 148 |
+
raise HTTPException(
|
| 149 |
+
status_code=http_status.HTTP_403_FORBIDDEN,
|
| 150 |
+
detail=f"Only administrators can read {entity.__name__}"
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# Check froms for admin-only models
|
| 154 |
+
for from_clause in query.froms:
|
| 155 |
+
table_class = from_clause.entity_namespace.get('__class__')
|
| 156 |
+
if table_class in self._config.admin_read_only:
|
| 157 |
+
raise HTTPException(
|
| 158 |
+
status_code=http_status.HTTP_403_FORBIDDEN,
|
| 159 |
+
detail=f"Only administrators can read {table_class.__name__}"
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
def _apply_user_filter(self, query: Select) -> Select:
|
| 163 |
+
"""
|
| 164 |
+
Automatically apply user_id filter to READ queries.
|
| 165 |
+
Uses DBServiceConfig for model scopes and filter column.
|
| 166 |
+
"""
|
| 167 |
+
# First check if this is an admin-only query
|
| 168 |
+
self._verify_admin_access(query)
|
| 169 |
+
|
| 170 |
+
# Admins see all data
|
| 171 |
+
if self.is_admin:
|
| 172 |
+
logger.debug(f"Admin query - no user filter applied")
|
| 173 |
+
return query
|
| 174 |
+
|
| 175 |
+
# Get filter column from config
|
| 176 |
+
filter_column = self._config.user_filter_column
|
| 177 |
+
special_user_model = self._config.special_user_model
|
| 178 |
+
user_id_column = self._config.user_id_column
|
| 179 |
+
|
| 180 |
+
# Detect which model is being queried
|
| 181 |
+
if hasattr(query, 'column_descriptions'):
|
| 182 |
+
for description in query.column_descriptions:
|
| 183 |
+
entity = description.get('entity') or description.get('type')
|
| 184 |
+
if entity in self._config.user_read_scoped:
|
| 185 |
+
logger.debug(f"Applying user filter to {entity.__name__} query")
|
| 186 |
+
# Special handling for User model (uses id instead of user_id)
|
| 187 |
+
if entity == special_user_model:
|
| 188 |
+
user_col = getattr(entity, user_id_column)
|
| 189 |
+
return query.where(user_col == self.user.id)
|
| 190 |
+
# Standard user_id filtering
|
| 191 |
+
if hasattr(entity, filter_column):
|
| 192 |
+
user_col = getattr(entity, filter_column)
|
| 193 |
+
return query.where(user_col == self.user.id)
|
| 194 |
+
|
| 195 |
+
# Check froms
|
| 196 |
+
for from_clause in query.froms:
|
| 197 |
+
table_class = from_clause.entity_namespace.get('__class__')
|
| 198 |
+
if table_class in self._config.user_read_scoped:
|
| 199 |
+
logger.debug(f"Applying user filter to {table_class.__name__} query")
|
| 200 |
+
# Special handling for User model
|
| 201 |
+
if table_class == special_user_model:
|
| 202 |
+
user_col = getattr(table_class, user_id_column)
|
| 203 |
+
return query.where(user_col == self.user.id)
|
| 204 |
+
# Standard user_id filtering
|
| 205 |
+
if hasattr(table_class, filter_column):
|
| 206 |
+
user_col = getattr(table_class, filter_column)
|
| 207 |
+
return query.where(user_col == self.user.id)
|
| 208 |
+
|
| 209 |
+
return query
|
| 210 |
+
|
| 211 |
+
def _filter_deleted(self, query: Select) -> Select:
|
| 212 |
+
"""
|
| 213 |
+
Add filter to exclude soft-deleted records.
|
| 214 |
+
Uses DBServiceConfig.soft_delete_column.
|
| 215 |
+
"""
|
| 216 |
+
delete_column = self._config.soft_delete_column
|
| 217 |
+
|
| 218 |
+
# Detect which model is being queried
|
| 219 |
+
if hasattr(query, 'column_descriptions'):
|
| 220 |
+
for description in query.column_descriptions:
|
| 221 |
+
entity = description.get('entity') or description.get('type')
|
| 222 |
+
if entity and hasattr(entity, delete_column):
|
| 223 |
+
deleted_at_col = getattr(entity, delete_column)
|
| 224 |
+
return query.where(deleted_at_col == None)
|
| 225 |
+
|
| 226 |
+
# Check froms
|
| 227 |
+
for from_clause in query.froms:
|
| 228 |
+
table_class = from_clause.entity_namespace.get('__class__')
|
| 229 |
+
if table_class and hasattr(table_class, delete_column):
|
| 230 |
+
deleted_at_col = getattr(table_class, delete_column)
|
| 231 |
+
return query.where(deleted_at_col == None)
|
| 232 |
+
|
| 233 |
+
return query
|
services/db_service/config.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DBServiceConfig - Configuration for DB Service permissions and filtering.
|
| 3 |
+
|
| 4 |
+
Allows application layer to register model scopes and column names,
|
| 5 |
+
making the DB service completely generic and plug-and-play.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
from typing import Type, Set, Optional
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class DBServiceConfig:
|
| 15 |
+
"""
|
| 16 |
+
Centralized configuration for DB Service.
|
| 17 |
+
|
| 18 |
+
Register your application's models and permissions at startup.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
# Configuration state
|
| 22 |
+
_registered = False
|
| 23 |
+
|
| 24 |
+
# Database metadata
|
| 25 |
+
db_base = None # SQLAlchemy declarative base
|
| 26 |
+
all_models: list = [] # All model classes
|
| 27 |
+
|
| 28 |
+
# Column names (generic)
|
| 29 |
+
user_filter_column: str = "user_id"
|
| 30 |
+
user_id_column: str = "id"
|
| 31 |
+
soft_delete_column: str = "deleted_at"
|
| 32 |
+
|
| 33 |
+
# Special models
|
| 34 |
+
special_user_model: Optional[Type] = None
|
| 35 |
+
|
| 36 |
+
# USER scopes
|
| 37 |
+
user_read_scoped: Set[Type] = set()
|
| 38 |
+
user_create_scoped: Set[Type] = set()
|
| 39 |
+
user_update_scoped: Set[Type] = set()
|
| 40 |
+
user_delete_scoped: Set[Type] = set()
|
| 41 |
+
|
| 42 |
+
# ADMIN scopes
|
| 43 |
+
admin_read_only: Set[Type] = set()
|
| 44 |
+
admin_create_only: Set[Type] = set()
|
| 45 |
+
admin_update_only: Set[Type] = set()
|
| 46 |
+
admin_delete_only: Set[Type] = set()
|
| 47 |
+
|
| 48 |
+
# SYSTEM scopes
|
| 49 |
+
system_read_scoped: Set[Type] = set()
|
| 50 |
+
system_create_scoped: Set[Type] = set()
|
| 51 |
+
system_update_scoped: Set[Type] = set()
|
| 52 |
+
system_delete_scoped: Set[Type] = set()
|
| 53 |
+
|
| 54 |
+
@classmethod
|
| 55 |
+
def register(
|
| 56 |
+
cls,
|
| 57 |
+
# Database metadata
|
| 58 |
+
db_base=None, # SQLAlchemy Base (for table creation)
|
| 59 |
+
all_models: list = None, # All model classes
|
| 60 |
+
# Column names
|
| 61 |
+
user_filter_column: str = "user_id",
|
| 62 |
+
user_id_column: str = "id",
|
| 63 |
+
soft_delete_column: str = "deleted_at",
|
| 64 |
+
# Special models
|
| 65 |
+
special_user_model: Optional[Type] = None,
|
| 66 |
+
# USER scopes
|
| 67 |
+
user_read_scoped: list = None,
|
| 68 |
+
user_create_scoped: list = None,
|
| 69 |
+
user_update_scoped: list = None,
|
| 70 |
+
user_delete_scoped: list = None,
|
| 71 |
+
# ADMIN scopes
|
| 72 |
+
admin_read_only: list = None,
|
| 73 |
+
admin_create_only: list = None,
|
| 74 |
+
admin_update_only: list = None,
|
| 75 |
+
admin_delete_only: list = None,
|
| 76 |
+
# SYSTEM scopes
|
| 77 |
+
system_read_scoped: list = None,
|
| 78 |
+
system_create_scoped: list = None,
|
| 79 |
+
system_update_scoped: list = None,
|
| 80 |
+
system_delete_scoped: list = None,
|
| 81 |
+
) -> None:
|
| 82 |
+
"""Register DB Service configuration at application startup."""
|
| 83 |
+
# Database metadata
|
| 84 |
+
cls.db_base = db_base
|
| 85 |
+
cls.all_models = all_models or []
|
| 86 |
+
|
| 87 |
+
# Column names
|
| 88 |
+
cls.user_filter_column = user_filter_column
|
| 89 |
+
cls.user_id_column = user_id_column
|
| 90 |
+
cls.soft_delete_column = soft_delete_column
|
| 91 |
+
|
| 92 |
+
# Special models
|
| 93 |
+
cls.special_user_model = special_user_model
|
| 94 |
+
|
| 95 |
+
# USER scopes
|
| 96 |
+
cls.user_read_scoped = set(user_read_scoped or [])
|
| 97 |
+
cls.user_create_scoped = set(user_create_scoped or [])
|
| 98 |
+
cls.user_update_scoped = set(user_update_scoped or [])
|
| 99 |
+
cls.user_delete_scoped = set(user_delete_scoped or [])
|
| 100 |
+
|
| 101 |
+
# ADMIN scopes
|
| 102 |
+
cls.admin_read_only = set(admin_read_only or [])
|
| 103 |
+
cls.admin_create_only = set(admin_create_only or [])
|
| 104 |
+
cls.admin_update_only = set(admin_update_only or [])
|
| 105 |
+
cls.admin_delete_only = set(admin_delete_only or [])
|
| 106 |
+
|
| 107 |
+
# SYSTEM scopes
|
| 108 |
+
cls.system_read_scoped = set(system_read_scoped or [])
|
| 109 |
+
cls.system_create_scoped = set(system_create_scoped or [])
|
| 110 |
+
cls.system_update_scoped = set(system_update_scoped or [])
|
| 111 |
+
cls.system_delete_scoped = set(system_delete_scoped or [])
|
| 112 |
+
|
| 113 |
+
cls._registered = True
|
| 114 |
+
|
| 115 |
+
logger.info("✅ DBServiceConfig registered successfully")
|
| 116 |
+
logger.info(f" Models registered: {len(cls.all_models)}")
|
| 117 |
+
logger.info(f" User filter column: {cls.user_filter_column}")
|
| 118 |
+
logger.info(f" Soft delete column: {cls.soft_delete_column}")
|
| 119 |
+
logger.info(f" USER scopes: {len(cls.user_read_scoped)} read, {len(cls.user_create_scoped)} create")
|
| 120 |
+
logger.info(f" ADMIN scopes: {len(cls.admin_read_only)} read, {len(cls.admin_create_only)} create")
|
| 121 |
+
logger.info(f" SYSTEM scopes: {len(cls.system_read_scoped)} read, {len(cls.system_create_scoped)} create")
|
| 122 |
+
|
| 123 |
+
@classmethod
|
| 124 |
+
def is_registered(cls) -> bool:
|
| 125 |
+
"""Check if configuration has been registered."""
|
| 126 |
+
return cls._registered
|
| 127 |
+
|
| 128 |
+
@classmethod
|
| 129 |
+
def assert_registered(cls) -> None:
|
| 130 |
+
"""Assert that configuration has been registered, raise if not."""
|
| 131 |
+
if not cls._registered:
|
| 132 |
+
raise RuntimeError(
|
| 133 |
+
"DBServiceConfig not registered! "
|
| 134 |
+
"Call DBServiceConfig.register() at application startup."
|
| 135 |
+
)
|
services/db_service/db_init.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DBService Database Initialization
|
| 3 |
+
|
| 4 |
+
Provides utilities for creating and managing database tables.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
from sqlalchemy.ext.asyncio import AsyncEngine
|
| 9 |
+
|
| 10 |
+
from services.db_service.config import DBServiceConfig
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
async def init_database(engine: AsyncEngine) -> None:
|
| 16 |
+
"""
|
| 17 |
+
Initialize database tables based on registered models.
|
| 18 |
+
|
| 19 |
+
Creates all tables defined in DBServiceConfig.all_models.
|
| 20 |
+
"""
|
| 21 |
+
DBServiceConfig.assert_registered()
|
| 22 |
+
|
| 23 |
+
if not DBServiceConfig.db_base:
|
| 24 |
+
raise RuntimeError(
|
| 25 |
+
"No database base registered! "
|
| 26 |
+
"Pass db_base parameter to DBServiceConfig.register()"
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
logger.info("Creating database tables...")
|
| 30 |
+
|
| 31 |
+
async with engine.begin() as conn:
|
| 32 |
+
await conn.run_sync(DBServiceConfig.db_base.metadata.create_all)
|
| 33 |
+
|
| 34 |
+
model_count = len(DBServiceConfig.all_models)
|
| 35 |
+
logger.info(f"✅ Database initialized with {model_count} models")
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
async def drop_database(engine: AsyncEngine) -> None:
|
| 39 |
+
"""Drop all database tables. WARNING: Deletes all data!"""
|
| 40 |
+
DBServiceConfig.assert_registered()
|
| 41 |
+
|
| 42 |
+
if not DBServiceConfig.db_base:
|
| 43 |
+
raise RuntimeError("No database base registered!")
|
| 44 |
+
|
| 45 |
+
logger.warning("⚠️ Dropping all database tables...")
|
| 46 |
+
|
| 47 |
+
async with engine.begin() as conn:
|
| 48 |
+
await conn.run_sync(DBServiceConfig.db_base.metadata.drop_all)
|
| 49 |
+
|
| 50 |
+
logger.info("✅ All tables dropped")
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
async def reset_database(engine: AsyncEngine) -> None:
|
| 54 |
+
"""Reset database (drop + create). WARNING: Deletes all data!"""
|
| 55 |
+
await drop_database(engine)
|
| 56 |
+
await init_database(engine)
|
| 57 |
+
logger.info("✅ Database reset complete")
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def get_registered_models() -> list:
|
| 61 |
+
"""Get list of all registered models."""
|
| 62 |
+
DBServiceConfig.assert_registered()
|
| 63 |
+
return DBServiceConfig.all_models
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_model_by_name(model_name: str):
|
| 67 |
+
"""Get model class by name."""
|
| 68 |
+
DBServiceConfig.assert_registered()
|
| 69 |
+
|
| 70 |
+
for model in DBServiceConfig.all_models:
|
| 71 |
+
if model.__name__ == model_name:
|
| 72 |
+
return model
|
| 73 |
+
|
| 74 |
+
return None
|
services/db_service/delete_query.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DeleteQuery - Soft delete operations with access control.
|
| 3 |
+
|
| 4 |
+
Inherits from BaseQuery for shared filtering logic.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
from typing import Type, TypeVar, List
|
| 9 |
+
from fastapi import HTTPException, status as http_status
|
| 10 |
+
from sqlalchemy import update, select
|
| 11 |
+
from sqlalchemy.sql import func
|
| 12 |
+
|
| 13 |
+
from services.db_service.base_query import BaseQuery
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
T = TypeVar('T')
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class DeleteQuery(BaseQuery):
|
| 21 |
+
"""
|
| 22 |
+
Handles soft DELETE operations.
|
| 23 |
+
|
| 24 |
+
Inherits filtering logic from BaseQuery:
|
| 25 |
+
- User ownership checks (_is_admin)
|
| 26 |
+
- Records are marked as deleted (deleted_at = NOW()) instead of being physically removed
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
async def soft_delete(self, model_class: Type[T], **filters) -> int:
|
| 30 |
+
"""
|
| 31 |
+
Soft delete records matching filters.
|
| 32 |
+
|
| 33 |
+
Sets deleted_at = NOW() instead of physically deleting.
|
| 34 |
+
"""
|
| 35 |
+
# Verify user has permission to delete this model
|
| 36 |
+
self._verify_operation_access(model_class, 'delete')
|
| 37 |
+
|
| 38 |
+
# Build update statement to set deleted_at
|
| 39 |
+
delete_col = getattr(model_class, self._config.soft_delete_column)
|
| 40 |
+
stmt = update(model_class).where(
|
| 41 |
+
delete_col == None # Only soft-delete non-deleted records
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# Apply ownership filter (shared method from BaseQuery)
|
| 45 |
+
stmt = self._apply_ownership_filter(stmt, model_class, 'delet')
|
| 46 |
+
|
| 47 |
+
# Apply user's filters
|
| 48 |
+
for key, value in filters.items():
|
| 49 |
+
if not hasattr(model_class, key):
|
| 50 |
+
raise ValueError(f"{model_class.__name__} has no attribute '{key}'")
|
| 51 |
+
stmt = stmt.where(getattr(model_class, key) == value)
|
| 52 |
+
|
| 53 |
+
# Set deleted_at timestamp
|
| 54 |
+
stmt = stmt.values({self._config.soft_delete_column: func.now()})
|
| 55 |
+
|
| 56 |
+
result = await self.db.execute(stmt)
|
| 57 |
+
await self.db.commit()
|
| 58 |
+
|
| 59 |
+
count = result.rowcount
|
| 60 |
+
logger.info(f"Soft-deleted {count} {model_class.__name__} record(s)")
|
| 61 |
+
return count
|
| 62 |
+
|
| 63 |
+
async def soft_delete_one(self, instance: T) -> bool:
|
| 64 |
+
"""Soft delete a single model instance."""
|
| 65 |
+
# Check if already deleted
|
| 66 |
+
delete_column = self._config.soft_delete_column
|
| 67 |
+
if hasattr(instance, delete_column) and getattr(instance, delete_column) is not None:
|
| 68 |
+
logger.warning(f"Attempted to delete already-deleted {instance.__class__.__name__}")
|
| 69 |
+
return False
|
| 70 |
+
|
| 71 |
+
# Verify user has permission
|
| 72 |
+
self._verify_operation_access(instance.__class__, 'delete')
|
| 73 |
+
|
| 74 |
+
# Check ownership for non-admins
|
| 75 |
+
filter_column = self._config.user_filter_column
|
| 76 |
+
if not self.is_admin and hasattr(instance, filter_column):
|
| 77 |
+
if getattr(instance, filter_column) != self.user.id:
|
| 78 |
+
raise HTTPException(
|
| 79 |
+
status_code=http_status.HTTP_403_FORBIDDEN,
|
| 80 |
+
detail="You do not have permission to delete this record"
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# Set deleted_at
|
| 84 |
+
setattr(instance, self._config.soft_delete_column, func.now())
|
| 85 |
+
await self.db.commit()
|
| 86 |
+
await self.db.refresh(instance)
|
| 87 |
+
|
| 88 |
+
logger.info(f"Soft-deleted {instance.__class__.__name__} instance")
|
| 89 |
+
return True
|
| 90 |
+
|
| 91 |
+
async def restore(self, model_class: Type[T], **filters) -> int:
|
| 92 |
+
"""
|
| 93 |
+
Restore soft-deleted records.
|
| 94 |
+
Only admins can restore.
|
| 95 |
+
"""
|
| 96 |
+
if not self.is_admin:
|
| 97 |
+
raise HTTPException(
|
| 98 |
+
status_code=http_status.HTTP_403_FORBIDDEN,
|
| 99 |
+
detail="Only administrators can restore deleted records"
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# Build update to clear deleted_at
|
| 103 |
+
delete_col = getattr(model_class, self._config.soft_delete_column)
|
| 104 |
+
stmt = update(model_class).where(
|
| 105 |
+
delete_col != None # Only restore deleted records
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Apply filters
|
| 109 |
+
for key, value in filters.items():
|
| 110 |
+
if not hasattr(model_class, key):
|
| 111 |
+
raise ValueError(f"{model_class.__name__} has no attribute '{key}'")
|
| 112 |
+
stmt = stmt.where(getattr(model_class, key) == value)
|
| 113 |
+
|
| 114 |
+
# Clear deleted_at
|
| 115 |
+
stmt = stmt.values({self._config.soft_delete_column: None})
|
| 116 |
+
|
| 117 |
+
result = await self.db.execute(stmt)
|
| 118 |
+
await self.db.commit()
|
| 119 |
+
|
| 120 |
+
count = result.rowcount
|
| 121 |
+
logger.info(f"Admin {self.user.email} restored {count} {model_class.__name__} record(s)")
|
| 122 |
+
return count
|
| 123 |
+
|
| 124 |
+
async def restore_one(self, instance: T) -> bool:
|
| 125 |
+
"""Restore a single soft-deleted model instance. Admin only."""
|
| 126 |
+
if not self.is_admin:
|
| 127 |
+
raise HTTPException(
|
| 128 |
+
status_code=http_status.HTTP_403_FORBIDDEN,
|
| 129 |
+
detail="Only administrators can restore deleted records"
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Check if actually deleted
|
| 133 |
+
delete_column = self._config.soft_delete_column
|
| 134 |
+
if not hasattr(instance, delete_column) or getattr(instance, delete_column) is None:
|
| 135 |
+
logger.warning(f"Attempted to restore non-deleted {instance.__class__.__name__}")
|
| 136 |
+
return False
|
| 137 |
+
|
| 138 |
+
# Clear deleted_at
|
| 139 |
+
setattr(instance, self._config.soft_delete_column, None)
|
| 140 |
+
await self.db.commit()
|
| 141 |
+
await self.db.refresh(instance)
|
| 142 |
+
|
| 143 |
+
logger.info(f"Admin restored {instance.__class__.__name__} instance")
|
| 144 |
+
return True
|
| 145 |
+
|
| 146 |
+
async def list_deleted(self, model_class: Type[T], limit: int = 100) -> List[T]:
|
| 147 |
+
"""List soft-deleted records. Admin only."""
|
| 148 |
+
if not self.is_admin:
|
| 149 |
+
raise HTTPException(
|
| 150 |
+
status_code=http_status.HTTP_403_FORBIDDEN,
|
| 151 |
+
detail="Only administrators can view deleted records"
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
delete_col = getattr(model_class, self._config.soft_delete_column)
|
| 155 |
+
query = select(model_class).where(
|
| 156 |
+
delete_col != None
|
| 157 |
+
).order_by(delete_col.desc()).limit(limit)
|
| 158 |
+
|
| 159 |
+
result = await self.db.execute(query)
|
| 160 |
+
return result.scalars().all()
|
services/db_service/query_service.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
QueryService - Main entry point for database operations.
|
| 3 |
+
|
| 4 |
+
Factory class that provides access to SelectQuery, UpdateQuery, and DeleteQuery.
|
| 5 |
+
All query types inherit from BaseQuery for shared filtering logic.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 10 |
+
|
| 11 |
+
from core.models import User
|
| 12 |
+
from services.db_service.select_query import SelectQuery
|
| 13 |
+
from services.db_service.update_query import UpdateQuery
|
| 14 |
+
from services.db_service.delete_query import DeleteQuery
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class QueryService:
|
| 18 |
+
"""
|
| 19 |
+
Main query service factory.
|
| 20 |
+
|
| 21 |
+
Provides access to:
|
| 22 |
+
- SelectQuery: Read operations (SELECT)
|
| 23 |
+
- UpdateQuery: Update operations (UPDATE)
|
| 24 |
+
- DeleteQuery: Soft delete operations (soft DELETE)
|
| 25 |
+
|
| 26 |
+
Usage:
|
| 27 |
+
qs = QueryService(user, db)
|
| 28 |
+
|
| 29 |
+
# SELECT
|
| 30 |
+
jobs = await qs.select().execute(select(GeminiJob))
|
| 31 |
+
|
| 32 |
+
# UPDATE
|
| 33 |
+
await qs.update().update(Job, {...}, id=123)
|
| 34 |
+
|
| 35 |
+
# DELETE
|
| 36 |
+
await qs.delete().soft_delete(Job, id=123)
|
| 37 |
+
|
| 38 |
+
# Check admin status
|
| 39 |
+
if qs.is_admin():
|
| 40 |
+
# Admin-only logic
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(self, user: User, db: AsyncSession, is_system: bool = False):
|
| 44 |
+
"""
|
| 45 |
+
Initialize QueryService.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
user: Current authenticated user (or system user for background ops)
|
| 49 |
+
db: Database session
|
| 50 |
+
is_system: True if this is a system/background operation (OAuth, webhooks, workers)
|
| 51 |
+
"""
|
| 52 |
+
self.user = user
|
| 53 |
+
self.db = db
|
| 54 |
+
self.is_system = is_system
|
| 55 |
+
|
| 56 |
+
def select(self) -> SelectQuery:
|
| 57 |
+
"""Get SelectQuery instance for read operations."""
|
| 58 |
+
return SelectQuery(self.user, self.db, self.is_system)
|
| 59 |
+
|
| 60 |
+
def update(self) -> UpdateQuery:
|
| 61 |
+
"""Get UpdateQuery instance for update operations."""
|
| 62 |
+
return UpdateQuery(self.user, self.db, self.is_system)
|
| 63 |
+
|
| 64 |
+
def delete(self) -> DeleteQuery:
|
| 65 |
+
"""Get DeleteQuery instance for soft delete operations."""
|
| 66 |
+
return DeleteQuery(self.user, self.db, self.is_system)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def get_query_service(user: User, db: AsyncSession) -> QueryService:
|
| 70 |
+
"""
|
| 71 |
+
Factory function to create QueryService.
|
| 72 |
+
|
| 73 |
+
Can be used as a FastAPI dependency:
|
| 74 |
+
from services.db_service import get_query_service
|
| 75 |
+
|
| 76 |
+
qs: QueryService = Depends(lambda user=Depends(get_current_user), db=Depends(get_db): get_query_service(user, db))
|
| 77 |
+
"""
|
| 78 |
+
return QueryService(user, db)
|
services/db_service/register_config.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DB Service Configuration Registration
|
| 3 |
+
|
| 4 |
+
Add this to your main.py or app initialization:
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from services.db_service import DBServiceConfig
|
| 8 |
+
from core.models import (
|
| 9 |
+
Base, # SQLAlchemy declarative base
|
| 10 |
+
User, GeminiJob, PaymentTransaction, Contact,
|
| 11 |
+
RateLimit, ApiKeyUsage, ClientUser, AuditLog
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def register_db_service_config():
|
| 16 |
+
"""Register DB Service configuration at application startup."""
|
| 17 |
+
DBServiceConfig.register(
|
| 18 |
+
# Database metadata (for table creation)
|
| 19 |
+
db_base=Base,
|
| 20 |
+
all_models=[
|
| 21 |
+
User, GeminiJob, PaymentTransaction, Contact,
|
| 22 |
+
RateLimit, ApiKeyUsage, ClientUser, AuditLog
|
| 23 |
+
],
|
| 24 |
+
|
| 25 |
+
# Column names
|
| 26 |
+
user_filter_column="user_id", # Column name for user ownership
|
| 27 |
+
user_id_column="id", # Column name for user ID
|
| 28 |
+
soft_delete_column="deleted_at", # Column name for soft delete tracking
|
| 29 |
+
|
| 30 |
+
# Special models
|
| 31 |
+
special_user_model=User, # Model that uses 'id' instead of 'user_id'
|
| 32 |
+
|
| 33 |
+
# ================================================================
|
| 34 |
+
# USER SCOPES (Regular authenticated users)
|
| 35 |
+
# ================================================================
|
| 36 |
+
user_read_scoped=[
|
| 37 |
+
User, # Users can read own profile
|
| 38 |
+
GeminiJob, # Users can read own jobs
|
| 39 |
+
PaymentTransaction, # Users can read own payments
|
| 40 |
+
Contact, # Users can read own contacts
|
| 41 |
+
],
|
| 42 |
+
|
| 43 |
+
user_create_scoped=[
|
| 44 |
+
GeminiJob, # Users can create jobs
|
| 45 |
+
PaymentTransaction, # Users can create payments (via API)
|
| 46 |
+
Contact, # Users can submit contact forms
|
| 47 |
+
],
|
| 48 |
+
|
| 49 |
+
user_update_scoped=[
|
| 50 |
+
User, # Users can update own profile
|
| 51 |
+
GeminiJob, # Users can update own jobs
|
| 52 |
+
],
|
| 53 |
+
|
| 54 |
+
user_delete_scoped=[
|
| 55 |
+
GeminiJob, # Users can delete own jobs
|
| 56 |
+
Contact, # Users can delete own contacts
|
| 57 |
+
],
|
| 58 |
+
|
| 59 |
+
# ================================================================
|
| 60 |
+
# ADMIN SCOPES (Administrators only - via ADMIN_EMAILS env var)
|
| 61 |
+
# ================================================================
|
| 62 |
+
admin_read_only=[
|
| 63 |
+
RateLimit, # Only admins view rate limits
|
| 64 |
+
ApiKeyUsage, # Only admins view API usage
|
| 65 |
+
ClientUser, # Only admins view client mappings
|
| 66 |
+
AuditLog, # Only admins view audit logs
|
| 67 |
+
],
|
| 68 |
+
|
| 69 |
+
admin_create_only=[
|
| 70 |
+
RateLimit, # Only admins create rate limits
|
| 71 |
+
ApiKeyUsage, # Only admins create API usage entries
|
| 72 |
+
ClientUser, # Only admins create client mappings
|
| 73 |
+
AuditLog, # Only admins create audit entries
|
| 74 |
+
],
|
| 75 |
+
|
| 76 |
+
admin_update_only=[
|
| 77 |
+
RateLimit, # Only admins update rate limits
|
| 78 |
+
ApiKeyUsage, # Only admins update API settings
|
| 79 |
+
ClientUser, # Only admins modify client mappings
|
| 80 |
+
PaymentTransaction,# Only admins refund/adjust payments
|
| 81 |
+
],
|
| 82 |
+
|
| 83 |
+
admin_delete_only=[
|
| 84 |
+
RateLimit, # Only admins delete rate limits
|
| 85 |
+
ApiKeyUsage, # Only admins remove API tracking
|
| 86 |
+
User, # Only admins delete user accounts
|
| 87 |
+
],
|
| 88 |
+
|
| 89 |
+
# ================================================================
|
| 90 |
+
# SYSTEM SCOPES (Background processes - OAuth, webhooks, workers)
|
| 91 |
+
# ================================================================
|
| 92 |
+
system_read_scoped=[
|
| 93 |
+
User, GeminiJob, PaymentTransaction, RateLimit,
|
| 94 |
+
ApiKeyUsage, ClientUser, AuditLog,
|
| 95 |
+
],
|
| 96 |
+
|
| 97 |
+
system_create_scoped=[
|
| 98 |
+
User, # OAuth creates users
|
| 99 |
+
ClientUser, # Middleware creates client mappings
|
| 100 |
+
AuditLog, # System creates audit entries
|
| 101 |
+
PaymentTransaction,# Webhooks create transactions
|
| 102 |
+
ApiKeyUsage, # Middleware tracks API usage
|
| 103 |
+
GeminiJob, # System creates jobs
|
| 104 |
+
RateLimit, # Middleware creates rate limits
|
| 105 |
+
],
|
| 106 |
+
|
| 107 |
+
system_update_scoped=[
|
| 108 |
+
User, # OAuth updates profiles
|
| 109 |
+
GeminiJob, # Workers update job status
|
| 110 |
+
PaymentTransaction,# Webhooks update verification
|
| 111 |
+
ApiKeyUsage, # System updates API stats
|
| 112 |
+
RateLimit, # System updates rate counters
|
| 113 |
+
ClientUser, # System updates mappings
|
| 114 |
+
],
|
| 115 |
+
|
| 116 |
+
system_delete_scoped=[
|
| 117 |
+
GeminiJob, # Cleanup deletes expired jobs
|
| 118 |
+
RateLimit, # Cleanup deletes old limits
|
| 119 |
+
ApiKeyUsage, # Cleanup deletes old usage
|
| 120 |
+
],
|
| 121 |
+
)
|
services/db_service/select_query.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SelectQuery - Read operations with access control.
|
| 3 |
+
|
| 4 |
+
Inherits from BaseQuery for shared filtering logic.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
from typing import TypeVar, Optional, List, Type
|
| 9 |
+
from fastapi import HTTPException, status as http_status
|
| 10 |
+
from sqlalchemy import Select, select, func
|
| 11 |
+
|
| 12 |
+
from services.db_service.base_query import BaseQuery
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
T = TypeVar('T')
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class SelectQuery(BaseQuery):
|
| 20 |
+
"""
|
| 21 |
+
Handles SELECT operations with automatic filtering.
|
| 22 |
+
|
| 23 |
+
Inherits filtering logic from BaseQuery:
|
| 24 |
+
- User filtering (_apply_user_filter)
|
| 25 |
+
- Deleted record filtering (_filter_deleted)
|
| 26 |
+
- Admin checks (_check_admin, _is_admin)
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
async def execute(self, query: Select) -> List[T]:
|
| 30 |
+
"""
|
| 31 |
+
Execute a query with automatic filtering.
|
| 32 |
+
|
| 33 |
+
Filtering is applied automatically based on:
|
| 34 |
+
- Model type (USER_SCOPED, ADMIN_ONLY)
|
| 35 |
+
- User's admin status
|
| 36 |
+
- Deleted records (always excluded)
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
List of results
|
| 40 |
+
"""
|
| 41 |
+
query = self._apply_user_filter(query)
|
| 42 |
+
query = self._filter_deleted(query)
|
| 43 |
+
|
| 44 |
+
result = await self.db.execute(query)
|
| 45 |
+
return result.scalars().all()
|
| 46 |
+
|
| 47 |
+
async def execute_one(self, query: Select) -> Optional[T]:
|
| 48 |
+
"""
|
| 49 |
+
Execute a query expecting a single result.
|
| 50 |
+
Automatic filtering applied.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
Single result or None
|
| 54 |
+
"""
|
| 55 |
+
query = self._apply_user_filter(query)
|
| 56 |
+
query = self._filter_deleted(query)
|
| 57 |
+
|
| 58 |
+
result = await self.db.execute(query)
|
| 59 |
+
return result.scalar_one_or_none()
|
| 60 |
+
|
| 61 |
+
async def count(self, query: Select) -> int:
|
| 62 |
+
"""
|
| 63 |
+
Count query results with automatic filtering.
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
Count of results
|
| 67 |
+
"""
|
| 68 |
+
query = self._apply_user_filter(query)
|
| 69 |
+
query = self._filter_deleted(query)
|
| 70 |
+
|
| 71 |
+
# Convert to count query
|
| 72 |
+
count_query = select(func.count()).select_from(query.alias())
|
| 73 |
+
result = await self.db.execute(count_query)
|
| 74 |
+
return result.scalar() or 0
|
| 75 |
+
|
| 76 |
+
async def count_deleted(self, model_class: Type[T]) -> int:
|
| 77 |
+
"""
|
| 78 |
+
Count soft-deleted records for a model.
|
| 79 |
+
Only admins can access this.
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
Count of deleted records
|
| 83 |
+
|
| 84 |
+
Raises:
|
| 85 |
+
HTTPException: 403 if non-admin tries to access
|
| 86 |
+
"""
|
| 87 |
+
if not self.is_admin:
|
| 88 |
+
raise HTTPException(
|
| 89 |
+
status_code=http_status.HTTP_403_FORBIDDEN,
|
| 90 |
+
detail="Only administrators can view deleted records"
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
delete_col = getattr(model_class, self._config.soft_delete_column)
|
| 94 |
+
query = select(func.count()).select_from(model_class).where(
|
| 95 |
+
delete_col != None
|
| 96 |
+
)
|
| 97 |
+
result = await self.db.execute(query)
|
| 98 |
+
return result.scalar() or 0
|
services/db_service/update_query.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
UpdateQuery - Update operations with access control.
|
| 3 |
+
|
| 4 |
+
Inherits from BaseQuery for shared filtering logic.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
from typing import Type, TypeVar, Dict, Any
|
| 9 |
+
from fastapi import HTTPException, status as http_status
|
| 10 |
+
from sqlalchemy import update
|
| 11 |
+
|
| 12 |
+
from services.db_service.base_query import BaseQuery
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
T = TypeVar('T')
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class UpdateQuery(BaseQuery):
|
| 20 |
+
"""
|
| 21 |
+
Handles UPDATE operations with automatic access control.
|
| 22 |
+
|
| 23 |
+
Inherits filtering logic from BaseQuery:
|
| 24 |
+
- User ownership checks (_is_admin)
|
| 25 |
+
- Deleted record protection
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
async def update(self, model_class: Type[T], values: Dict[str, Any], **filters) -> int:
|
| 29 |
+
"""
|
| 30 |
+
Update records matching filters with new values.
|
| 31 |
+
|
| 32 |
+
Automatically applies ownership filter for non-admins.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
model_class: Model class to update
|
| 36 |
+
values: Dictionary of field=value pairs to update
|
| 37 |
+
**filters: Field=value filters to match records
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Number of records updated
|
| 41 |
+
|
| 42 |
+
Raises:
|
| 43 |
+
HTTPException: 403 if user doesn't have permission
|
| 44 |
+
ValueError: If invalid field names provided
|
| 45 |
+
"""
|
| 46 |
+
# Validate that fields exist
|
| 47 |
+
for key in values.keys():
|
| 48 |
+
if not hasattr(model_class, key):
|
| 49 |
+
raise ValueError(f"{model_class.__name__} has no attribute '{key}'")
|
| 50 |
+
|
| 51 |
+
# Verify user has permission to update this model
|
| 52 |
+
self._verify_operation_access(model_class, 'update')
|
| 53 |
+
|
| 54 |
+
# Build update statement
|
| 55 |
+
delete_col = getattr(model_class, self._config.soft_delete_column)
|
| 56 |
+
stmt = update(model_class).where(
|
| 57 |
+
delete_col == None # Don't update deleted records
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# Apply ownership filter (shared method from BaseQuery)
|
| 61 |
+
stmt = self._apply_ownership_filter(stmt, model_class, 'updat')
|
| 62 |
+
|
| 63 |
+
# Apply user's filters
|
| 64 |
+
for key, value in filters.items():
|
| 65 |
+
if not hasattr(model_class, key):
|
| 66 |
+
raise ValueError(f"{model_class.__name__} has no attribute '{key}'")
|
| 67 |
+
stmt = stmt.where(getattr(model_class, key) == value)
|
| 68 |
+
|
| 69 |
+
# Apply values
|
| 70 |
+
stmt = stmt.values(**values)
|
| 71 |
+
|
| 72 |
+
# Execute
|
| 73 |
+
result = await self.db.execute(stmt)
|
| 74 |
+
await self.db.commit()
|
| 75 |
+
|
| 76 |
+
count = result.rowcount
|
| 77 |
+
logger.info(f"Updated {count} {model_class.__name__} record(s)")
|
| 78 |
+
return count
|
| 79 |
+
|
| 80 |
+
async def update_one(self, instance: T, values: Dict[str, Any]) -> T:
|
| 81 |
+
"""
|
| 82 |
+
Update a single model instance with validation.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
instance: Model instance to update
|
| 86 |
+
values: Dictionary of field=value pairs to update
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
Updated model instance
|
| 90 |
+
|
| 91 |
+
Raises:
|
| 92 |
+
HTTPException: 403 if user doesn't have permission
|
| 93 |
+
ValueError: If invalid field names provided
|
| 94 |
+
"""
|
| 95 |
+
# Check if deleted
|
| 96 |
+
delete_column = self._config.soft_delete_column
|
| 97 |
+
if hasattr(instance, delete_column) and getattr(instance, delete_column) is not None:
|
| 98 |
+
raise HTTPException(
|
| 99 |
+
status_code=http_status.HTTP_400_BAD_REQUEST,
|
| 100 |
+
detail="Cannot update deleted record"
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# Verify user has permission
|
| 104 |
+
self._verify_operation_access(instance.__class__, 'update')
|
| 105 |
+
|
| 106 |
+
# Check ownership for non-admins
|
| 107 |
+
filter_column = self._config.user_filter_column
|
| 108 |
+
if not self.is_admin and hasattr(instance, filter_column):
|
| 109 |
+
if getattr(instance, filter_column) != self.user.id:
|
| 110 |
+
raise HTTPException(
|
| 111 |
+
status_code=http_status.HTTP_403_FORBIDDEN,
|
| 112 |
+
detail="You do not have permission to update this record"
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# Validate fields
|
| 116 |
+
for key in values.keys():
|
| 117 |
+
if not hasattr(instance, key):
|
| 118 |
+
raise ValueError(f"{instance.__class__.__name__} has no attribute '{key}'")
|
| 119 |
+
|
| 120 |
+
# Apply updates
|
| 121 |
+
for key, value in values.items():
|
| 122 |
+
setattr(instance, key, value)
|
| 123 |
+
|
| 124 |
+
# Commit
|
| 125 |
+
await self.db.commit()
|
| 126 |
+
await self.db.refresh(instance)
|
| 127 |
+
|
| 128 |
+
logger.info(f"Updated {instance.__class__.__name__} instance")
|
| 129 |
+
return instance
|
| 130 |
+
|
| 131 |
+
async def increment(self, model_class: Type[T], field: str, amount: int = 1, **filters) -> int:
|
| 132 |
+
"""
|
| 133 |
+
Increment a numeric field by a specified amount.
|
| 134 |
+
"""
|
| 135 |
+
if not hasattr(model_class, field):
|
| 136 |
+
raise ValueError(f"{model_class.__name__} has no attribute '{field}'")
|
| 137 |
+
|
| 138 |
+
self._verify_operation_access(model_class, 'update')
|
| 139 |
+
|
| 140 |
+
field_obj = getattr(model_class, field)
|
| 141 |
+
|
| 142 |
+
delete_col = getattr(model_class, self._config.soft_delete_column)
|
| 143 |
+
stmt = update(model_class).where(delete_col == None)
|
| 144 |
+
stmt = self._apply_ownership_filter(stmt, model_class, 'updat')
|
| 145 |
+
|
| 146 |
+
for key, value in filters.items():
|
| 147 |
+
if not hasattr(model_class, key):
|
| 148 |
+
raise ValueError(f"{model_class.__name__} has no attribute '{key}'")
|
| 149 |
+
stmt = stmt.where(getattr(model_class, key) == value)
|
| 150 |
+
|
| 151 |
+
stmt = stmt.values({field: field_obj + amount})
|
| 152 |
+
|
| 153 |
+
result = await self.db.execute(stmt)
|
| 154 |
+
await self.db.commit()
|
| 155 |
+
|
| 156 |
+
count = result.rowcount
|
| 157 |
+
logger.info(f"Incremented {field} by {amount} for {count} {model_class.__name__} record(s)")
|
| 158 |
+
return count
|
| 159 |
+
|
| 160 |
+
async def decrement(self, model_class: Type[T], field: str, amount: int = 1, **filters) -> int:
|
| 161 |
+
"""Decrement a numeric field by a specified amount."""
|
| 162 |
+
return await self.increment(model_class, field, -amount, **filters)
|
| 163 |
+
|
| 164 |
+
async def toggle_boolean(self, model_class: Type[T], field: str, **filters) -> int:
|
| 165 |
+
"""Toggle a boolean field (True -> False, False -> True)."""
|
| 166 |
+
if not hasattr(model_class, field):
|
| 167 |
+
raise ValueError(f"{model_class.__name__} has no attribute '{field}'")
|
| 168 |
+
|
| 169 |
+
self._verify_operation_access(model_class, 'update')
|
| 170 |
+
|
| 171 |
+
field_obj = getattr(model_class, field)
|
| 172 |
+
|
| 173 |
+
stmt = update(model_class).where(model_class.deleted_at == None)
|
| 174 |
+
stmt = self._apply_ownership_filter(stmt, model_class, 'updat')
|
| 175 |
+
|
| 176 |
+
for key, value in filters.items():
|
| 177 |
+
if not hasattr(model_class, key):
|
| 178 |
+
raise ValueError(f"{model_class.__name__} has no attribute '{key}'")
|
| 179 |
+
stmt = stmt.where(getattr(model_class, key) == value)
|
| 180 |
+
|
| 181 |
+
stmt = stmt.values({field: ~field_obj})
|
| 182 |
+
|
| 183 |
+
result = await self.db.execute(stmt)
|
| 184 |
+
await self.db.commit()
|
| 185 |
+
|
| 186 |
+
count = result.rowcount
|
| 187 |
+
logger.info(f"Toggled {field} for {count} {model_class.__name__} record(s)")
|
| 188 |
+
return count
|
tests/README.md
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DB Service - Quick Start
|
| 2 |
+
|
| 3 |
+
## Run Tests
|
| 4 |
+
|
| 5 |
+
```bash
|
| 6 |
+
# Install test dependencies
|
| 7 |
+
pip install pytest pytest-asyncio httpx
|
| 8 |
+
|
| 9 |
+
# Run all tests
|
| 10 |
+
ADMIN_EMAILS="admin@example.com" pytest tests/test_db_service.py -v
|
| 11 |
+
|
| 12 |
+
# Run specific test
|
| 13 |
+
pytest tests/test_db_service.py::TestPermissions::test_user_can_read_own_data -v
|
| 14 |
+
|
| 15 |
+
# Run with coverage
|
| 16 |
+
pytest tests/test_db_service.py --cov=services/db_service --cov-report=html -v
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
## Start Application
|
| 20 |
+
|
| 21 |
+
```bash
|
| 22 |
+
# Development (reset DB on startup)
|
| 23 |
+
RESET_DB=true ADMIN_EMAILS="your@email.com" python app.py
|
| 24 |
+
|
| 25 |
+
# Production (preserve DB)
|
| 26 |
+
ADMIN_EMAILS="admin@example.com" python app.py
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
## Verify Integration
|
| 30 |
+
|
| 31 |
+
Check startup logs for:
|
| 32 |
+
```
|
| 33 |
+
✅ DB Service configured
|
| 34 |
+
✅ Database initialized
|
| 35 |
+
✅ Database reset complete (if RESET_DB=true)
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
## Test Endpoints
|
| 39 |
+
|
| 40 |
+
```bash
|
| 41 |
+
# As regular user
|
| 42 |
+
curl -X GET http://localhost:8000/gemini/jobs \
|
| 43 |
+
-H "Authorization: Bearer <user_token>"
|
| 44 |
+
|
| 45 |
+
# As admin
|
| 46 |
+
curl -X GET http://localhost:8000/blink/audit-logs \
|
| 47 |
+
-H "Authorization: Bearer <admin_token>"
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
## Troubleshooting
|
| 51 |
+
|
| 52 |
+
**Config not registered:**
|
| 53 |
+
```
|
| 54 |
+
RuntimeError: DBServiceConfig not registered!
|
| 55 |
+
```
|
| 56 |
+
→ Ensure `register_db_service_config()` is called in app startup
|
| 57 |
+
|
| 58 |
+
**Module not found:**
|
| 59 |
+
```
|
| 60 |
+
ModuleNotFoundError: No module named 'services.db_service'
|
| 61 |
+
```
|
| 62 |
+
→ Verify `__init__.py` exists in `services/db_service/`
|
| 63 |
+
|
| 64 |
+
**Permission denied:**
|
| 65 |
+
```
|
| 66 |
+
403 Forbidden: Only administrators can...
|
| 67 |
+
```
|
| 68 |
+
→ Check `ADMIN_EMAILS` environment variable
|
| 69 |
+
|
| 70 |
+
## Success Indicators
|
| 71 |
+
|
| 72 |
+
✅ All files compile without errors
|
| 73 |
+
✅ Application starts successfully
|
| 74 |
+
✅ Database tables created
|
| 75 |
+
✅ Tests pass
|
| 76 |
+
✅ Endpoints return data correctly
|
| 77 |
+
✅ User isolation working (users see only own data)
|
| 78 |
+
✅ Admin access working (admins see all data)
|
tests/requirements-test.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pytest>=7.4.0
|
| 2 |
+
pytest-asyncio>=0.21.0
|
| 3 |
+
httpx>=0.24.0
|
tests/test_db_service.py
ADDED
|
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test Suite for DB Service
|
| 3 |
+
|
| 4 |
+
Comprehensive tests for the plug-and-play DB Service including:
|
| 5 |
+
- Configuration
|
| 6 |
+
- Permissions (USER/ADMIN/SYSTEM)
|
| 7 |
+
- Filtering (user ownership, soft deletes)
|
| 8 |
+
- CRUD operations
|
| 9 |
+
- Database initialization
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import pytest
|
| 13 |
+
import os
|
| 14 |
+
from datetime import datetime
|
| 15 |
+
from sqlalchemy import create_async_engine, select
|
| 16 |
+
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
| 17 |
+
from sqlalchemy.orm import declarative_base
|
| 18 |
+
|
| 19 |
+
from services.db_service import (
|
| 20 |
+
DBServiceConfig,
|
| 21 |
+
QueryService,
|
| 22 |
+
init_database,
|
| 23 |
+
reset_database,
|
| 24 |
+
get_registered_models,
|
| 25 |
+
)
|
| 26 |
+
from core.models import (
|
| 27 |
+
Base, User, GeminiJob, PaymentTransaction, Contact,
|
| 28 |
+
RateLimit, ApiKeyUsage, ClientUser, AuditLog
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# Test database URL
|
| 33 |
+
TEST_DB_URL = "sqlite+aiosqlite:///:memory:"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@pytest.fixture
|
| 37 |
+
async def engine():
|
| 38 |
+
"""Create test database engine."""
|
| 39 |
+
engine = create_async_engine(TEST_DB_URL, echo=False)
|
| 40 |
+
yield engine
|
| 41 |
+
await engine.dispose()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@pytest.fixture
|
| 45 |
+
async def session(engine):
|
| 46 |
+
"""Create test database session."""
|
| 47 |
+
async_session = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
| 48 |
+
|
| 49 |
+
async with async_session() as session:
|
| 50 |
+
yield session
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@pytest.fixture(autouse=True)
|
| 54 |
+
async def setup_db(engine):
|
| 55 |
+
"""Setup test database with configuration."""
|
| 56 |
+
# Register configuration
|
| 57 |
+
DBServiceConfig.register(
|
| 58 |
+
db_base=Base,
|
| 59 |
+
all_models=[User, GeminiJob, PaymentTransaction, Contact,
|
| 60 |
+
RateLimit, ApiKeyUsage, ClientUser, AuditLog],
|
| 61 |
+
user_filter_column="user_id",
|
| 62 |
+
user_id_column="id",
|
| 63 |
+
soft_delete_column="deleted_at",
|
| 64 |
+
special_user_model=User,
|
| 65 |
+
user_read_scoped=[User, GeminiJob, PaymentTransaction, Contact],
|
| 66 |
+
user_create_scoped=[GeminiJob, PaymentTransaction, Contact],
|
| 67 |
+
user_update_scoped=[User, GeminiJob],
|
| 68 |
+
user_delete_scoped=[GeminiJob, Contact],
|
| 69 |
+
admin_read_only=[RateLimit, ApiKeyUsage, ClientUser, AuditLog],
|
| 70 |
+
admin_create_only=[RateLimit, ApiKeyUsage, ClientUser, AuditLog],
|
| 71 |
+
admin_update_only=[RateLimit, ApiKeyUsage, ClientUser, PaymentTransaction],
|
| 72 |
+
admin_delete_only=[RateLimit, ApiKeyUsage, User],
|
| 73 |
+
system_read_scoped=[User, GeminiJob, PaymentTransaction, RateLimit,
|
| 74 |
+
ApiKeyUsage, ClientUser, AuditLog],
|
| 75 |
+
system_create_scoped=[User, ClientUser, AuditLog, PaymentTransaction,
|
| 76 |
+
ApiKeyUsage, GeminiJob, RateLimit],
|
| 77 |
+
system_update_scoped=[User, GeminiJob, PaymentTransaction, ApiKeyUsage,
|
| 78 |
+
RateLimit, ClientUser],
|
| 79 |
+
system_delete_scoped=[GeminiJob, RateLimit, ApiKeyUsage],
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Initialize database
|
| 83 |
+
await init_database(engine)
|
| 84 |
+
|
| 85 |
+
yield
|
| 86 |
+
|
| 87 |
+
# Cleanup
|
| 88 |
+
await reset_database(engine)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@pytest.fixture
|
| 92 |
+
async def regular_user(session):
|
| 93 |
+
"""Create a regular test user."""
|
| 94 |
+
user = User(
|
| 95 |
+
email="user@example.com",
|
| 96 |
+
name="Test User",
|
| 97 |
+
credits=100
|
| 98 |
+
)
|
| 99 |
+
session.add(user)
|
| 100 |
+
await session.commit()
|
| 101 |
+
await session.refresh(user)
|
| 102 |
+
return user
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@pytest.fixture
|
| 106 |
+
async def admin_user(session):
|
| 107 |
+
"""Create an admin test user."""
|
| 108 |
+
user = User(
|
| 109 |
+
email=os.getenv("ADMIN_EMAILS", "admin@example.com").split(",")[0],
|
| 110 |
+
name="Admin User",
|
| 111 |
+
credits=1000
|
| 112 |
+
)
|
| 113 |
+
session.add(user)
|
| 114 |
+
await session.commit()
|
| 115 |
+
await session.refresh(user)
|
| 116 |
+
return user
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
@pytest.fixture
|
| 120 |
+
async def other_user(session):
|
| 121 |
+
"""Create another test user."""
|
| 122 |
+
user = User(
|
| 123 |
+
email="other@example.com",
|
| 124 |
+
name="Other User",
|
| 125 |
+
credits=50
|
| 126 |
+
)
|
| 127 |
+
session.add(user)
|
| 128 |
+
await session.commit()
|
| 129 |
+
await session.refresh(user)
|
| 130 |
+
return user
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
# ============================================================================
|
| 134 |
+
# Configuration Tests
|
| 135 |
+
# ============================================================================
|
| 136 |
+
|
| 137 |
+
class TestConfiguration:
|
| 138 |
+
"""Test DB Service configuration."""
|
| 139 |
+
|
| 140 |
+
def test_config_registered(self):
|
| 141 |
+
"""Test that configuration is registered."""
|
| 142 |
+
assert DBServiceConfig.is_registered()
|
| 143 |
+
assert DBServiceConfig.db_base == Base
|
| 144 |
+
assert len(DBServiceConfig.all_models) == 8
|
| 145 |
+
|
| 146 |
+
def test_get_registered_models(self):
|
| 147 |
+
"""Test getting registered models."""
|
| 148 |
+
models = get_registered_models()
|
| 149 |
+
assert len(models) == 8
|
| 150 |
+
assert User in models
|
| 151 |
+
assert GeminiJob in models
|
| 152 |
+
|
| 153 |
+
def test_column_names(self):
|
| 154 |
+
"""Test configured column names."""
|
| 155 |
+
assert DBServiceConfig.user_filter_column == "user_id"
|
| 156 |
+
assert DBServiceConfig.soft_delete_column == "deleted_at"
|
| 157 |
+
assert DBServiceConfig.special_user_model == User
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
# ============================================================================
|
| 161 |
+
# Permission Tests
|
| 162 |
+
# ============================================================================
|
| 163 |
+
|
| 164 |
+
class TestPermissions:
|
| 165 |
+
"""Test USER/ADMIN/SYSTEM permission hierarchy."""
|
| 166 |
+
|
| 167 |
+
async def test_user_can_read_own_data(self, session, regular_user):
|
| 168 |
+
"""Test that users can read their own data."""
|
| 169 |
+
job = GeminiJob(user_id=regular_user.id, prompt="Test", status="queued")
|
| 170 |
+
session.add(job)
|
| 171 |
+
await session.commit()
|
| 172 |
+
|
| 173 |
+
qs = QueryService(regular_user, session)
|
| 174 |
+
jobs = await qs.select().execute(select(GeminiJob))
|
| 175 |
+
|
| 176 |
+
assert len(jobs) == 1
|
| 177 |
+
assert jobs[0].id == job.id
|
| 178 |
+
|
| 179 |
+
async def test_user_cannot_read_others_data(self, session, regular_user, other_user):
|
| 180 |
+
"""Test that users cannot read other users' data."""
|
| 181 |
+
# Create job for other user
|
| 182 |
+
job = GeminiJob(user_id=other_user.id, prompt="Other", status="queued")
|
| 183 |
+
session.add(job)
|
| 184 |
+
await session.commit()
|
| 185 |
+
|
| 186 |
+
# Regular user tries to read
|
| 187 |
+
qs = QueryService(regular_user, session)
|
| 188 |
+
jobs = await qs.select().execute(select(GeminiJob))
|
| 189 |
+
|
| 190 |
+
assert len(jobs) == 0 # Should not see other user's jobs
|
| 191 |
+
|
| 192 |
+
async def test_admin_can_read_all_data(self, session, admin_user, regular_user):
|
| 193 |
+
"""Test that admins can read all users' data."""
|
| 194 |
+
# Create jobs for different users
|
| 195 |
+
job1 = GeminiJob(user_id=regular_user.id, prompt="User Job", status="queued")
|
| 196 |
+
job2 = GeminiJob(user_id=admin_user.id, prompt="Admin Job", status="queued")
|
| 197 |
+
session.add_all([job1, job2])
|
| 198 |
+
await session.commit()
|
| 199 |
+
|
| 200 |
+
qs = QueryService(admin_user, session)
|
| 201 |
+
jobs = await qs.select().execute(select(GeminiJob))
|
| 202 |
+
|
| 203 |
+
assert len(jobs) == 2 # Admin sees all jobs
|
| 204 |
+
|
| 205 |
+
async def test_user_cannot_access_admin_only_models(self, session, regular_user):
|
| 206 |
+
"""Test that regular users cannot access admin-only models."""
|
| 207 |
+
qs = QueryService(regular_user, session)
|
| 208 |
+
|
| 209 |
+
with pytest.raises(Exception) as exc_info:
|
| 210 |
+
await qs.select().execute(select(RateLimit))
|
| 211 |
+
|
| 212 |
+
assert "403" in str(exc_info.value) or "administrator" in str(exc_info.value).lower()
|
| 213 |
+
|
| 214 |
+
async def test_admin_can_access_admin_only_models(self, session, admin_user):
|
| 215 |
+
"""Test that admins can access admin-only models."""
|
| 216 |
+
rate_limit = RateLimit(identifier="test", endpoint="/api/test", request_count=10)
|
| 217 |
+
session.add(rate_limit)
|
| 218 |
+
await session.commit()
|
| 219 |
+
|
| 220 |
+
qs = QueryService(admin_user, session)
|
| 221 |
+
limits = await qs.select().execute(select(RateLimit))
|
| 222 |
+
|
| 223 |
+
assert len(limits) == 1
|
| 224 |
+
|
| 225 |
+
async def test_system_can_create_user(self, session, regular_user):
|
| 226 |
+
"""Test that system operations can create users."""
|
| 227 |
+
qs = QueryService(regular_user, session, is_system=True)
|
| 228 |
+
|
| 229 |
+
# System should be able to bypass permissions
|
| 230 |
+
# (actual create would use direct SQLAlchemy, but permission check passes)
|
| 231 |
+
assert qs._is_system is True
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
# ============================================================================
|
| 235 |
+
# Soft Delete Tests
|
| 236 |
+
# ============================================================================
|
| 237 |
+
|
| 238 |
+
class TestSoftDeletes:
|
| 239 |
+
"""Test soft delete functionality."""
|
| 240 |
+
|
| 241 |
+
async def test_soft_delete_marks_record(self, session, regular_user):
|
| 242 |
+
"""Test that soft delete sets deleted_at."""
|
| 243 |
+
job = GeminiJob(user_id=regular_user.id, prompt="Delete Me", status="queued")
|
| 244 |
+
session.add(job)
|
| 245 |
+
await session.commit()
|
| 246 |
+
|
| 247 |
+
qs = QueryService(regular_user, session)
|
| 248 |
+
await qs.delete().soft_delete_one(job)
|
| 249 |
+
|
| 250 |
+
assert job.deleted_at is not None
|
| 251 |
+
|
| 252 |
+
async def test_soft_deleted_not_in_query(self, session, regular_user):
|
| 253 |
+
"""Test that soft-deleted records don't appear in queries."""
|
| 254 |
+
job = GeminiJob(user_id=regular_user.id, prompt="Delete Me", status="queued")
|
| 255 |
+
session.add(job)
|
| 256 |
+
await session.commit()
|
| 257 |
+
|
| 258 |
+
qs = QueryService(regular_user, session)
|
| 259 |
+
|
| 260 |
+
# Before delete
|
| 261 |
+
jobs = await qs.select().execute(select(GeminiJob))
|
| 262 |
+
assert len(jobs) == 1
|
| 263 |
+
|
| 264 |
+
# After delete
|
| 265 |
+
await qs.delete().soft_delete_one(job)
|
| 266 |
+
jobs = await qs.select().execute(select(GeminiJob))
|
| 267 |
+
assert len(jobs) == 0 # Should not appear
|
| 268 |
+
|
| 269 |
+
async def test_admin_can_restore(self, session, admin_user, regular_user):
|
| 270 |
+
"""Test that admins can restore deleted records."""
|
| 271 |
+
job = GeminiJob(user_id=regular_user.id, prompt="Restore Me", status="queued")
|
| 272 |
+
session.add(job)
|
| 273 |
+
await session.commit()
|
| 274 |
+
job_id = job.id
|
| 275 |
+
|
| 276 |
+
qs = QueryService(admin_user, session)
|
| 277 |
+
|
| 278 |
+
# Delete
|
| 279 |
+
await qs.delete().soft_delete_one(job)
|
| 280 |
+
assert job.deleted_at is not None
|
| 281 |
+
|
| 282 |
+
# Restore
|
| 283 |
+
await qs.delete().restore_one(job)
|
| 284 |
+
assert job.deleted_at is None
|
| 285 |
+
|
| 286 |
+
async def test_user_cannot_restore(self, session, regular_user):
|
| 287 |
+
"""Test that regular users cannot restore records."""
|
| 288 |
+
job = GeminiJob(user_id=regular_user.id, prompt="Deleted", status="queued")
|
| 289 |
+
session.add(job)
|
| 290 |
+
await session.commit()
|
| 291 |
+
|
| 292 |
+
qs = QueryService(regular_user, session)
|
| 293 |
+
await qs.delete().soft_delete_one(job)
|
| 294 |
+
|
| 295 |
+
with pytest.raises(Exception) as exc_info:
|
| 296 |
+
await qs.delete().restore_one(job)
|
| 297 |
+
|
| 298 |
+
assert "403" in str(exc_info.value) or "administrator" in str(exc_info.value).lower()
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
# ============================================================================
|
| 302 |
+
# Database Initialization Tests
|
| 303 |
+
# ============================================================================
|
| 304 |
+
|
| 305 |
+
class TestDatabaseInitialization:
|
| 306 |
+
"""Test database initialization utilities."""
|
| 307 |
+
|
| 308 |
+
async def test_init_database_creates_tables(self, engine):
|
| 309 |
+
"""Test that init_database creates all tables."""
|
| 310 |
+
await init_database(engine)
|
| 311 |
+
|
| 312 |
+
# Verify tables exist by querying
|
| 313 |
+
async with AsyncSession(engine) as session:
|
| 314 |
+
result = await session.execute(select(User))
|
| 315 |
+
assert result.scalars().all() == [] # Empty but table exists
|
| 316 |
+
|
| 317 |
+
async def test_reset_database_clears_data(self, engine, session, regular_user):
|
| 318 |
+
"""Test that reset_database clears all data."""
|
| 319 |
+
# Add some data
|
| 320 |
+
user = User(email="test@example.com", name="Test", credits=10)
|
| 321 |
+
session.add(user)
|
| 322 |
+
await session.commit()
|
| 323 |
+
|
| 324 |
+
# Reset
|
| 325 |
+
await reset_database(engine)
|
| 326 |
+
|
| 327 |
+
# Verify data cleared
|
| 328 |
+
async with AsyncSession(engine) as new_session:
|
| 329 |
+
result = await new_session.execute(select(User))
|
| 330 |
+
assert len(result.scalars().all()) == 0
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
# ============================================================================
|
| 334 |
+
# Run Tests
|
| 335 |
+
# ============================================================================
|
| 336 |
+
|
| 337 |
+
if __name__ == "__main__":
|
| 338 |
+
pytest.main([__file__, "-v", "--tb=short"])
|