Spaces:
Sleeping
Sleeping
job del
Browse files- dependencies.py +24 -0
- routers/gemini.py +76 -9
- tests/test_job_lifecycle.py +90 -0
dependencies.py
CHANGED
|
@@ -231,6 +231,30 @@ async def verify_credits(
|
|
| 231 |
return user
|
| 232 |
|
| 233 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
async def get_geolocation(ip_address: str) -> Tuple[Optional[str], Optional[str]]:
|
| 235 |
"""
|
| 236 |
Get country and region for an IP address using ip-api.com.
|
|
|
|
| 231 |
return user
|
| 232 |
|
| 233 |
|
| 234 |
+
async def verify_video_credits(
|
| 235 |
+
user: User = Depends(get_current_user),
|
| 236 |
+
db: AsyncSession = Depends(get_db)
|
| 237 |
+
) -> User:
|
| 238 |
+
"""
|
| 239 |
+
Verify user has credits for video generation (10 credits) and deduct them.
|
| 240 |
+
"""
|
| 241 |
+
cost = 10
|
| 242 |
+
if user.credits < cost:
|
| 243 |
+
raise HTTPException(
|
| 244 |
+
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
| 245 |
+
detail=f"Insufficient credits. Video generation requires {cost} credits."
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
# Deduct credits
|
| 249 |
+
user.credits -= cost
|
| 250 |
+
user.last_used_at = datetime.utcnow()
|
| 251 |
+
await db.commit()
|
| 252 |
+
|
| 253 |
+
logger.debug(f"Deducted {cost} credits from user {user.user_id} for video generation. Remaining: {user.credits}")
|
| 254 |
+
|
| 255 |
+
return user
|
| 256 |
+
|
| 257 |
+
|
| 258 |
async def get_geolocation(ip_address: str) -> Tuple[Optional[str], Optional[str]]:
|
| 259 |
"""
|
| 260 |
Get country and region for an IP address using ip-api.com.
|
routers/gemini.py
CHANGED
|
@@ -15,7 +15,7 @@ from sqlalchemy import select, func
|
|
| 15 |
from core.database import get_db
|
| 16 |
from core.models import User, GeminiJob
|
| 17 |
from services.gemini_service import MODELS, DOWNLOADS_DIR
|
| 18 |
-
from dependencies import verify_credits, get_current_user
|
| 19 |
from datetime import datetime
|
| 20 |
|
| 21 |
router = APIRouter(prefix="/gemini", tags=["gemini"])
|
|
@@ -72,7 +72,8 @@ async def create_job(
|
|
| 72 |
db: AsyncSession,
|
| 73 |
user: User,
|
| 74 |
job_type: str,
|
| 75 |
-
input_data: dict
|
|
|
|
| 76 |
) -> GeminiJob:
|
| 77 |
"""Create a new job in the queue with auto-assigned priority."""
|
| 78 |
from services.gemini_job_worker import get_priority_for_job_type, get_pool
|
|
@@ -87,7 +88,7 @@ async def create_job(
|
|
| 87 |
status="queued",
|
| 88 |
priority=priority,
|
| 89 |
input_data=input_data,
|
| 90 |
-
credits_reserved=
|
| 91 |
)
|
| 92 |
db.add(job)
|
| 93 |
await db.commit()
|
|
@@ -116,7 +117,8 @@ async def generate_animation_prompt(
|
|
| 116 |
"base64_image": request.base64_image,
|
| 117 |
"mime_type": request.mime_type,
|
| 118 |
"custom_prompt": request.custom_prompt
|
| 119 |
-
}
|
|
|
|
| 120 |
)
|
| 121 |
|
| 122 |
position = await get_queue_position(db, job.job_id)
|
|
@@ -147,7 +149,8 @@ async def edit_image(
|
|
| 147 |
"base64_image": request.base64_image,
|
| 148 |
"mime_type": request.mime_type,
|
| 149 |
"prompt": request.prompt
|
| 150 |
-
}
|
|
|
|
| 151 |
)
|
| 152 |
|
| 153 |
position = await get_queue_position(db, job.job_id)
|
|
@@ -164,7 +167,7 @@ async def edit_image(
|
|
| 164 |
@router.post("/generate-video")
|
| 165 |
async def generate_video(
|
| 166 |
request: GenerateVideoRequest,
|
| 167 |
-
user: User = Depends(
|
| 168 |
db: AsyncSession = Depends(get_db)
|
| 169 |
):
|
| 170 |
"""
|
|
@@ -181,7 +184,8 @@ async def generate_video(
|
|
| 181 |
"aspect_ratio": request.aspect_ratio,
|
| 182 |
"resolution": request.resolution,
|
| 183 |
"number_of_videos": request.number_of_videos
|
| 184 |
-
}
|
|
|
|
| 185 |
)
|
| 186 |
|
| 187 |
position = await get_queue_position(db, job.job_id)
|
|
@@ -211,7 +215,8 @@ async def generate_text(
|
|
| 211 |
input_data={
|
| 212 |
"prompt": request.prompt,
|
| 213 |
"model": request.model
|
| 214 |
-
}
|
|
|
|
| 215 |
)
|
| 216 |
|
| 217 |
position = await get_queue_position(db, job.job_id)
|
|
@@ -242,7 +247,8 @@ async def analyze_image(
|
|
| 242 |
"base64_image": request.base64_image,
|
| 243 |
"mime_type": request.mime_type,
|
| 244 |
"prompt": request.prompt
|
| 245 |
-
}
|
|
|
|
| 246 |
)
|
| 247 |
|
| 248 |
position = await get_queue_position(db, job.job_id)
|
|
@@ -433,6 +439,10 @@ async def download_video(
|
|
| 433 |
yield chunk
|
| 434 |
except httpx.HTTPStatusError as e:
|
| 435 |
if e.response.status_code in (401, 403, 404, 410):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
raise HTTPException(
|
| 437 |
status_code=status.HTTP_410_GONE,
|
| 438 |
detail="Video URL has expired. Please generate a new video."
|
|
@@ -498,6 +508,63 @@ async def cancel_job(
|
|
| 498 |
}
|
| 499 |
|
| 500 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 501 |
@router.get("/models")
|
| 502 |
async def get_models():
|
| 503 |
"""
|
|
|
|
| 15 |
from core.database import get_db
|
| 16 |
from core.models import User, GeminiJob
|
| 17 |
from services.gemini_service import MODELS, DOWNLOADS_DIR
|
| 18 |
+
from dependencies import verify_credits, verify_video_credits, get_current_user
|
| 19 |
from datetime import datetime
|
| 20 |
|
| 21 |
router = APIRouter(prefix="/gemini", tags=["gemini"])
|
|
|
|
| 72 |
db: AsyncSession,
|
| 73 |
user: User,
|
| 74 |
job_type: str,
|
| 75 |
+
input_data: dict,
|
| 76 |
+
credits_reserved: int = 0
|
| 77 |
) -> GeminiJob:
|
| 78 |
"""Create a new job in the queue with auto-assigned priority."""
|
| 79 |
from services.gemini_job_worker import get_priority_for_job_type, get_pool
|
|
|
|
| 88 |
status="queued",
|
| 89 |
priority=priority,
|
| 90 |
input_data=input_data,
|
| 91 |
+
credits_reserved=credits_reserved # Track reserved credits for this job
|
| 92 |
)
|
| 93 |
db.add(job)
|
| 94 |
await db.commit()
|
|
|
|
| 117 |
"base64_image": request.base64_image,
|
| 118 |
"mime_type": request.mime_type,
|
| 119 |
"custom_prompt": request.custom_prompt
|
| 120 |
+
},
|
| 121 |
+
credits_reserved=1
|
| 122 |
)
|
| 123 |
|
| 124 |
position = await get_queue_position(db, job.job_id)
|
|
|
|
| 149 |
"base64_image": request.base64_image,
|
| 150 |
"mime_type": request.mime_type,
|
| 151 |
"prompt": request.prompt
|
| 152 |
+
},
|
| 153 |
+
credits_reserved=1
|
| 154 |
)
|
| 155 |
|
| 156 |
position = await get_queue_position(db, job.job_id)
|
|
|
|
| 167 |
@router.post("/generate-video")
|
| 168 |
async def generate_video(
|
| 169 |
request: GenerateVideoRequest,
|
| 170 |
+
user: User = Depends(verify_video_credits),
|
| 171 |
db: AsyncSession = Depends(get_db)
|
| 172 |
):
|
| 173 |
"""
|
|
|
|
| 184 |
"aspect_ratio": request.aspect_ratio,
|
| 185 |
"resolution": request.resolution,
|
| 186 |
"number_of_videos": request.number_of_videos
|
| 187 |
+
},
|
| 188 |
+
credits_reserved=10 # Video jobs cost 10 credits
|
| 189 |
)
|
| 190 |
|
| 191 |
position = await get_queue_position(db, job.job_id)
|
|
|
|
| 215 |
input_data={
|
| 216 |
"prompt": request.prompt,
|
| 217 |
"model": request.model
|
| 218 |
+
},
|
| 219 |
+
credits_reserved=1
|
| 220 |
)
|
| 221 |
|
| 222 |
position = await get_queue_position(db, job.job_id)
|
|
|
|
| 247 |
"base64_image": request.base64_image,
|
| 248 |
"mime_type": request.mime_type,
|
| 249 |
"prompt": request.prompt
|
| 250 |
+
},
|
| 251 |
+
credits_reserved=1
|
| 252 |
)
|
| 253 |
|
| 254 |
position = await get_queue_position(db, job.job_id)
|
|
|
|
| 439 |
yield chunk
|
| 440 |
except httpx.HTTPStatusError as e:
|
| 441 |
if e.response.status_code in (401, 403, 404, 410):
|
| 442 |
+
# Mark job as expired
|
| 443 |
+
job.status = "expired"
|
| 444 |
+
await db.commit()
|
| 445 |
+
|
| 446 |
raise HTTPException(
|
| 447 |
status_code=status.HTTP_410_GONE,
|
| 448 |
detail="Video URL has expired. Please generate a new video."
|
|
|
|
| 508 |
}
|
| 509 |
|
| 510 |
|
| 511 |
+
@router.delete("/job/{job_id}")
|
| 512 |
+
async def delete_job(
|
| 513 |
+
job_id: str,
|
| 514 |
+
user: User = Depends(get_current_user),
|
| 515 |
+
db: AsyncSession = Depends(get_db)
|
| 516 |
+
):
|
| 517 |
+
"""
|
| 518 |
+
Delete a job.
|
| 519 |
+
- If queued: Refund 8 credits (10 cost - 2 penalty), delete job.
|
| 520 |
+
- If processing/completed/failed: Delete job (no refund).
|
| 521 |
+
"""
|
| 522 |
+
query = select(GeminiJob).where(
|
| 523 |
+
GeminiJob.job_id == job_id,
|
| 524 |
+
GeminiJob.user_id == user.user_id
|
| 525 |
+
)
|
| 526 |
+
result = await db.execute(query)
|
| 527 |
+
job = result.scalar_one_or_none()
|
| 528 |
+
|
| 529 |
+
if not job:
|
| 530 |
+
raise HTTPException(
|
| 531 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
| 532 |
+
detail="Job not found"
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
refund_amount = 0
|
| 536 |
+
message = "Job deleted"
|
| 537 |
+
|
| 538 |
+
if job.status == "queued":
|
| 539 |
+
# Refund logic: Restore 8 credits (10 - 2)
|
| 540 |
+
# Only if it was a video job (cost 10). For others (cost 1), maybe no refund or full?
|
| 541 |
+
# Requirement says "restore 8", implying video job context.
|
| 542 |
+
# Let's check credits_reserved. If 10, refund 8. If 1, refund 0? Or 1?
|
| 543 |
+
# Assuming this logic is specific to the high-cost video jobs.
|
| 544 |
+
|
| 545 |
+
if job.credits_reserved >= 10:
|
| 546 |
+
refund_amount = 8
|
| 547 |
+
user.credits += refund_amount
|
| 548 |
+
message = f"Job deleted. {refund_amount} credits refunded."
|
| 549 |
+
elif job.credits_reserved > 0:
|
| 550 |
+
# For lower cost jobs, maybe full refund if queued? Or partial?
|
| 551 |
+
# User specifically mentioned "restore 8" for the queued state.
|
| 552 |
+
# I'll stick to the specific requirement for now, but maybe refund full for 1-credit jobs?
|
| 553 |
+
# Let's assume strict "restore 8" applies to the 10-credit video jobs.
|
| 554 |
+
pass
|
| 555 |
+
|
| 556 |
+
await db.delete(job)
|
| 557 |
+
await db.commit()
|
| 558 |
+
|
| 559 |
+
return {
|
| 560 |
+
"success": True,
|
| 561 |
+
"job_id": job_id,
|
| 562 |
+
"message": message,
|
| 563 |
+
"credits_refunded": refund_amount,
|
| 564 |
+
"credits_remaining": user.credits
|
| 565 |
+
}
|
| 566 |
+
|
| 567 |
+
|
| 568 |
@router.get("/models")
|
| 569 |
async def get_models():
|
| 570 |
"""
|
tests/test_job_lifecycle.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
|
| 6 |
+
# Add project root to path
|
| 7 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 8 |
+
|
| 9 |
+
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
| 10 |
+
from sqlalchemy.orm import sessionmaker
|
| 11 |
+
from core.models import User, GeminiJob
|
| 12 |
+
from core.database import DATABASE_URL
|
| 13 |
+
|
| 14 |
+
async def test_lifecycle():
|
| 15 |
+
engine = create_async_engine(DATABASE_URL)
|
| 16 |
+
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
| 17 |
+
|
| 18 |
+
async with async_session() as session:
|
| 19 |
+
# 1. Setup User
|
| 20 |
+
user_id = "test_user_lifecycle"
|
| 21 |
+
user = User(user_id=user_id, email="test_lifecycle@example.com", credits=100)
|
| 22 |
+
session.add(user)
|
| 23 |
+
await session.commit()
|
| 24 |
+
print(f"Created user with {user.credits} credits")
|
| 25 |
+
|
| 26 |
+
# 2. Simulate Video Job Creation (Cost 10)
|
| 27 |
+
# We simulate the API logic manually since we can't easily call the API here without a running server
|
| 28 |
+
# But we can verify the logic we implemented in the router
|
| 29 |
+
|
| 30 |
+
# Deduct 10 credits
|
| 31 |
+
user.credits -= 10
|
| 32 |
+
await session.commit()
|
| 33 |
+
print(f"Deducted 10 credits. Remaining: {user.credits}")
|
| 34 |
+
assert user.credits == 90
|
| 35 |
+
|
| 36 |
+
# Create Job
|
| 37 |
+
job = GeminiJob(
|
| 38 |
+
job_id="job_test_video",
|
| 39 |
+
user_id=user_id,
|
| 40 |
+
job_type="video",
|
| 41 |
+
status="queued",
|
| 42 |
+
credits_reserved=10
|
| 43 |
+
)
|
| 44 |
+
session.add(job)
|
| 45 |
+
await session.commit()
|
| 46 |
+
print("Created queued video job")
|
| 47 |
+
|
| 48 |
+
# 3. Simulate Delete Queued Job (Refund 8)
|
| 49 |
+
# Logic from router:
|
| 50 |
+
if job.status == "queued" and job.credits_reserved >= 10:
|
| 51 |
+
refund = 8
|
| 52 |
+
user.credits += refund
|
| 53 |
+
print(f"Refunded {refund} credits")
|
| 54 |
+
|
| 55 |
+
await session.delete(job)
|
| 56 |
+
await session.commit()
|
| 57 |
+
print(f"Deleted job. User credits: {user.credits}")
|
| 58 |
+
assert user.credits == 98 # 90 + 8
|
| 59 |
+
|
| 60 |
+
# 4. Simulate Processing Job Deletion (No Refund)
|
| 61 |
+
# Deduct 10 again
|
| 62 |
+
user.credits -= 10
|
| 63 |
+
job2 = GeminiJob(
|
| 64 |
+
job_id="job_test_processing",
|
| 65 |
+
user_id=user_id,
|
| 66 |
+
job_type="video",
|
| 67 |
+
status="processing",
|
| 68 |
+
credits_reserved=10
|
| 69 |
+
)
|
| 70 |
+
session.add(job2)
|
| 71 |
+
await session.commit()
|
| 72 |
+
print(f"Created processing job. User credits: {user.credits}") # 88
|
| 73 |
+
|
| 74 |
+
# Delete
|
| 75 |
+
if job2.status == "queued" and job2.credits_reserved >= 10:
|
| 76 |
+
refund = 8
|
| 77 |
+
user.credits += refund
|
| 78 |
+
|
| 79 |
+
await session.delete(job2)
|
| 80 |
+
await session.commit()
|
| 81 |
+
print(f"Deleted processing job. User credits: {user.credits}")
|
| 82 |
+
assert user.credits == 88 # No change
|
| 83 |
+
|
| 84 |
+
# Cleanup
|
| 85 |
+
await session.delete(user)
|
| 86 |
+
await session.commit()
|
| 87 |
+
print("Test cleanup complete")
|
| 88 |
+
|
| 89 |
+
if __name__ == "__main__":
|
| 90 |
+
asyncio.run(test_lifecycle())
|