Spaces:
Sleeping
Sleeping
| import asyncio | |
| import json | |
| import logging | |
| from pathlib import Path | |
| from fastapi import APIRouter, Depends, HTTPException | |
| from fastapi.responses import StreamingResponse | |
| from sqlalchemy import select | |
| from sqlalchemy.ext.asyncio import AsyncSession | |
| from app.database import get_db | |
| from app.models.user import User | |
| from app.models.project import Project | |
| from app.models.generation_job import GenerationJob | |
| from app.schemas.generation import GenerationStart, GenerationJobResponse | |
| from app.services.auth import get_current_user | |
| from app.pipeline.orchestrator import run_pipeline | |
| logger = logging.getLogger(__name__) | |
| router = APIRouter(prefix="/api/projects/{project_id}", tags=["generation"]) | |
| # Keep references to running pipeline tasks so they don't get garbage-collected | |
| _running_tasks: set[asyncio.Task] = set() | |
| def _launch_pipeline(job_id: int, resume: bool = False, chapter_ids: list[int] | None = None): | |
| """Launch pipeline as background task with error logging.""" | |
| async def _wrapper(): | |
| try: | |
| await run_pipeline(job_id, resume=resume, chapter_ids=chapter_ids) | |
| except Exception: | |
| logger.exception("Pipeline failed for job %s", job_id) | |
| task = asyncio.create_task(_wrapper()) | |
| _running_tasks.add(task) | |
| task.add_done_callback(_running_tasks.discard) | |
| # ── Asset endpoints ────────────────────────────────────────────────── | |
| async def list_images( | |
| project_id: int, | |
| user: User = Depends(get_current_user), | |
| db: AsyncSession = Depends(get_db), | |
| ): | |
| """List generated images for a project.""" | |
| result = await db.execute( | |
| select(Project).where(Project.id == project_id, Project.user_id == user.id) | |
| ) | |
| if not result.scalar_one_or_none(): | |
| raise HTTPException(status_code=404, detail="Project not found") | |
| img_dir = Path("workdir/projects") / str(project_id) / "images" | |
| if not img_dir.exists(): | |
| return [] | |
| files = sorted(f.name for f in img_dir.iterdir() if f.suffix.lower() in (".png", ".jpg", ".jpeg", ".webp")) | |
| return files | |
| # ── Generation endpoints ───────────────────────────────────────────── | |
| async def start_generation( | |
| project_id: int, | |
| data: GenerationStart, | |
| user: User = Depends(get_current_user), | |
| db: AsyncSession = Depends(get_db), | |
| ): | |
| result = await db.execute( | |
| select(Project).where(Project.id == project_id, Project.user_id == user.id) | |
| ) | |
| project = result.scalar_one_or_none() | |
| if not project: | |
| raise HTTPException(status_code=404, detail="Project not found") | |
| job = GenerationJob( | |
| project_id=project_id, | |
| user_id=user.id, | |
| episode_id=data.episode_id, | |
| chapter_ids_json=data.chapter_ids, | |
| status="queued", | |
| current_stage="ingest", | |
| progress_pct=0.0, | |
| ) | |
| db.add(job) | |
| await db.commit() | |
| await db.refresh(job) | |
| # Launch pipeline in background with error handling | |
| _launch_pipeline(job.id, chapter_ids=data.chapter_ids) | |
| return job | |
| async def list_jobs( | |
| project_id: int, | |
| user: User = Depends(get_current_user), | |
| db: AsyncSession = Depends(get_db), | |
| ): | |
| result = await db.execute( | |
| select(GenerationJob) | |
| .where(GenerationJob.project_id == project_id, GenerationJob.user_id == user.id) | |
| .order_by(GenerationJob.created_at.desc()) | |
| ) | |
| return list(result.scalars().all()) | |
| async def get_job( | |
| project_id: int, | |
| job_id: int, | |
| user: User = Depends(get_current_user), | |
| db: AsyncSession = Depends(get_db), | |
| ): | |
| result = await db.execute( | |
| select(GenerationJob).where( | |
| GenerationJob.id == job_id, | |
| GenerationJob.project_id == project_id, | |
| GenerationJob.user_id == user.id, | |
| ) | |
| ) | |
| job = result.scalar_one_or_none() | |
| if not job: | |
| raise HTTPException(status_code=404, detail="Job not found") | |
| return job | |
| async def stream_progress( | |
| project_id: int, | |
| job_id: int, | |
| user: User = Depends(get_current_user), | |
| db: AsyncSession = Depends(get_db), | |
| ): | |
| """SSE endpoint for real-time generation progress.""" | |
| result = await db.execute( | |
| select(GenerationJob).where( | |
| GenerationJob.id == job_id, | |
| GenerationJob.project_id == project_id, | |
| GenerationJob.user_id == user.id, | |
| ) | |
| ) | |
| job = result.scalar_one_or_none() | |
| if not job: | |
| raise HTTPException(status_code=404, detail="Job not found") | |
| from app.database import async_session as session_factory | |
| async def event_stream(): | |
| last_data = None | |
| while True: | |
| try: | |
| async with session_factory() as session: | |
| result = await session.execute( | |
| select(GenerationJob).where(GenerationJob.id == job_id) | |
| ) | |
| current_job = result.scalar_one_or_none() | |
| if not current_job: | |
| yield f"data: {json.dumps({'status': 'not_found'})}\n\n" | |
| break | |
| data = { | |
| "status": current_job.status or "queued", | |
| "stage": current_job.current_stage or "ingest", | |
| "progress": current_job.progress_pct or 0, | |
| "detail": (current_job.progress_detail or "")[:200], | |
| "error": (current_job.error_message or "")[:500] if current_job.status == "failed" else None, | |
| } | |
| # Always send (even if same) so frontend knows connection is alive | |
| yield f"data: {json.dumps(data)}\n\n" | |
| last_data = data | |
| if current_job.status in ("completed", "failed", "cancelled"): | |
| break | |
| except Exception: | |
| logger.exception("SSE stream error for job %s", job_id) | |
| yield f"data: {json.dumps({'status': 'error', 'detail': 'Server error'})}\n\n" | |
| break | |
| await asyncio.sleep(1) | |
| return StreamingResponse( | |
| event_stream(), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no", | |
| }, | |
| ) | |
| async def retry_generation( | |
| project_id: int, | |
| user: User = Depends(get_current_user), | |
| db: AsyncSession = Depends(get_db), | |
| ): | |
| """Retry the last failed job for a project, resuming from where it left off.""" | |
| result = await db.execute( | |
| select(Project).where(Project.id == project_id, Project.user_id == user.id) | |
| ) | |
| project = result.scalar_one_or_none() | |
| if not project: | |
| raise HTTPException(status_code=404, detail="Project not found") | |
| # Find the last failed job | |
| result = await db.execute( | |
| select(GenerationJob) | |
| .where( | |
| GenerationJob.project_id == project_id, | |
| GenerationJob.user_id == user.id, | |
| GenerationJob.status == "failed", | |
| ) | |
| .order_by(GenerationJob.created_at.desc()) | |
| .limit(1) | |
| ) | |
| failed_job = result.scalar_one_or_none() | |
| if not failed_job: | |
| raise HTTPException(status_code=404, detail="No failed job to retry") | |
| # Create new job, carrying over chapter_ids from failed job | |
| job = GenerationJob( | |
| project_id=project_id, | |
| user_id=user.id, | |
| episode_id=failed_job.episode_id, | |
| chapter_ids_json=failed_job.chapter_ids_json, | |
| status="queued", | |
| current_stage="ingest", | |
| progress_pct=0.0, | |
| ) | |
| db.add(job) | |
| await db.commit() | |
| await db.refresh(job) | |
| _launch_pipeline(job.id, resume=True, chapter_ids=failed_job.chapter_ids_json) | |
| return job | |