"""Generation API routes — submit single and batch image generation jobs.""" from __future__ import annotations import asyncio import logging import uuid from fastapi import APIRouter, File, Form, HTTPException, UploadFile from content_engine.models.schemas import ( BatchRequest, BatchStatusResponse, GenerationRequest, GenerationResponse, ) logger = logging.getLogger(__name__) router = APIRouter(prefix="/api", tags=["generation"]) # These are injected at startup from main.py _local_worker = None _template_engine = None _variation_engine = None _character_profiles = None _wavespeed_provider = None _runpod_provider = None _catalog = None _comfyui_client = None # In-memory batch tracking (v1 — move to DB for production) _batch_tracker: dict[str, dict] = {} # Job status tracking for cloud generations _job_tracker: dict[str, dict] = {} def init_routes(local_worker, template_engine, variation_engine, character_profiles, wavespeed_provider=None, catalog=None, comfyui_client=None): """Initialize route dependencies. Called from main.py on startup.""" global _local_worker, _template_engine, _variation_engine, _character_profiles global _wavespeed_provider, _catalog, _comfyui_client _local_worker = local_worker _template_engine = template_engine _variation_engine = variation_engine _character_profiles = character_profiles _wavespeed_provider = wavespeed_provider _catalog = catalog _comfyui_client = comfyui_client def set_runpod_provider(provider): """Set RunPod generation provider. Called from main.py after init_routes.""" global _runpod_provider _runpod_provider = provider @router.post("/generate", response_model=GenerationResponse) async def generate_single(request: GenerationRequest): """Submit a single image generation job. The job runs asynchronously — returns immediately with a job ID. """ if _local_worker is None: raise HTTPException(503, "Worker not initialized") job_id = str(uuid.uuid4()) # Fire and forget — run in background asyncio.create_task( _run_generation( job_id=job_id, character_id=request.character_id, template_id=request.template_id, content_rating=request.content_rating, positive_prompt=request.positive_prompt, negative_prompt=request.negative_prompt, checkpoint=request.checkpoint, loras=[l.model_dump() for l in request.loras] if request.loras else None, seed=request.seed or -1, steps=request.steps, cfg=request.cfg, sampler=request.sampler, scheduler=request.scheduler, width=request.width, height=request.height, variables=request.variables, ) ) return GenerationResponse(job_id=job_id, status="queued", backend="local") @router.post("/batch", response_model=GenerationResponse) async def generate_batch(request: BatchRequest): """Submit a batch of variation-based generation jobs. Uses the variation engine to generate multiple images with different poses, outfits, emotions, etc. """ if _local_worker is None or _variation_engine is None: raise HTTPException(503, "Services not initialized") if _character_profiles is None: raise HTTPException(503, "No character profiles loaded") character = _character_profiles.get(request.character_id) if character is None: raise HTTPException(404, f"Character not found: {request.character_id}") # Generate variation jobs jobs = _variation_engine.generate_batch( template_id=request.template_id, character=character, content_rating=request.content_rating, count=request.count, variation_mode=request.variation_mode, pin=request.pin, seed_strategy=request.seed_strategy, ) batch_id = jobs[0].batch_id if jobs else str(uuid.uuid4()) _batch_tracker[batch_id] = { "total": len(jobs), "completed": 0, "failed": 0, "pending": len(jobs), "running": 0, } # Fire all jobs in background for job in jobs: asyncio.create_task( _run_batch_job(batch_id, job) ) logger.info("Batch %s: %d jobs queued", batch_id, len(jobs)) return GenerationResponse( job_id=batch_id, batch_id=batch_id, status="queued", backend="local" ) @router.get("/batch/{batch_id}/status", response_model=BatchStatusResponse) async def get_batch_status(batch_id: str): """Get the status of a batch generation.""" if batch_id not in _batch_tracker: raise HTTPException(404, f"Batch not found: {batch_id}") tracker = _batch_tracker[batch_id] return BatchStatusResponse( batch_id=batch_id, total_jobs=tracker["total"], completed=tracker["completed"], failed=tracker["failed"], pending=tracker["pending"], running=tracker["running"], ) @router.post("/generate/cloud", response_model=GenerationResponse) async def generate_cloud(request: GenerationRequest): """Generate an image using WaveSpeed cloud API (NanoBanana, SeeDream). Supported models via the 'checkpoint' field: - nano-banana, nano-banana-pro - seedream-3, seedream-3.1, seedream-4, seedream-4.5 """ if _wavespeed_provider is None: raise HTTPException(503, "WaveSpeed cloud provider not configured. Set WAVESPEED_API_KEY in .env") job_id = str(uuid.uuid4()) lora_path = request.loras[0].name if request.loras else None lora_strength = request.loras[0].strength_model if request.loras else 0.85 asyncio.create_task( _run_cloud_generation( job_id=job_id, positive_prompt=request.positive_prompt or "", negative_prompt=request.negative_prompt or "", model=request.checkpoint, width=request.width or 1024, height=request.height or 1024, seed=request.seed or -1, content_rating=request.content_rating, character_id=request.character_id, template_id=request.template_id, variables=request.variables, lora_path=lora_path, lora_strength=lora_strength, ) ) return GenerationResponse(job_id=job_id, status="queued", backend="wavespeed") @router.get("/cloud/models") async def list_cloud_models(): """List available cloud models (WaveSpeed and RunPod).""" return { "wavespeed": { "available": _wavespeed_provider is not None, "models": [ {"id": "nano-banana", "name": "NanoBanana", "provider": "Google", "type": "txt2img"}, {"id": "nano-banana-pro", "name": "NanoBanana Pro", "provider": "Google", "type": "txt2img"}, {"id": "seedream-3", "name": "SeeDream v3", "provider": "ByteDance", "type": "txt2img"}, {"id": "seedream-3.1", "name": "SeeDream v3.1", "provider": "ByteDance", "type": "txt2img"}, {"id": "seedream-4", "name": "SeeDream v4", "provider": "ByteDance", "type": "txt2img"}, {"id": "seedream-4.5", "name": "SeeDream v4.5", "provider": "ByteDance", "type": "txt2img"}, ], "edit_models": [ {"id": "seedream-4.5-edit", "name": "SeeDream v4.5 Edit", "provider": "ByteDance", "type": "img2img", "price": "$0.04/img"}, {"id": "seedream-4-edit", "name": "SeeDream v4 Edit", "provider": "ByteDance", "type": "img2img", "price": "$0.04/img"}, {"id": "nano-banana-edit", "name": "NanoBanana Edit", "provider": "Google", "type": "img2img", "price": "$0.038/img"}, {"id": "nano-banana-pro-edit", "name": "NanoBanana Pro Edit", "provider": "Google", "type": "img2img", "price": "$0.14/img"}, ], }, "runpod": { "available": _runpod_provider is not None, "description": "Pay-per-second serverless GPU. Uses your deployed endpoint.", "pricing": "~$0.00025/sec (RTX 4090)", }, } @router.post("/generate/runpod", response_model=GenerationResponse) async def generate_runpod(request: GenerationRequest): """Generate an image using RunPod serverless GPU. Uses your deployed RunPod endpoint. Pay per second of GPU time. Requires RUNPOD_API_KEY and RUNPOD_ENDPOINT_ID in .env. """ if _runpod_provider is None: raise HTTPException( 503, "RunPod not configured. Set RUNPOD_API_KEY and RUNPOD_ENDPOINT_ID in .env" ) job_id = str(uuid.uuid4()) asyncio.create_task( _run_runpod_generation( job_id=job_id, positive_prompt=request.positive_prompt or "", negative_prompt=request.negative_prompt or "", checkpoint=request.checkpoint, loras=request.loras, seed=request.seed or -1, steps=request.steps or 28, cfg=request.cfg or 7.0, width=request.width or 832, height=request.height or 1216, character_id=request.character_id, template_id=request.template_id, content_rating=request.content_rating, ) ) return GenerationResponse(job_id=job_id, status="queued", backend="runpod") @router.get("/generate/jobs/{job_id}") async def get_job_status(job_id: str): """Get the status of a generation job.""" job = _job_tracker.get(job_id) if not job: return {"job_id": job_id, "status": "unknown", "message": "Job not found or completed"} return job @router.get("/generate/jobs") async def list_jobs(): """List recent generation jobs with their status.""" # Return most recent 20 jobs jobs = list(_job_tracker.values()) jobs.sort(key=lambda x: x.get("started_at", 0), reverse=True) return jobs[:20] @router.post("/generate/jobs/{job_id}/cancel") async def cancel_job(job_id: str): """Cancel a running generation job.""" job = _job_tracker.get(job_id) if not job: raise HTTPException(404, "Job not found") if job["status"] not in ["running", "queued"]: return {"job_id": job_id, "status": job["status"], "message": "Job already finished"} _job_tracker[job_id]["status"] = "cancelled" _job_tracker[job_id]["message"] = "Cancelled by user" logger.info("Job %s cancelled by user", job_id) return {"job_id": job_id, "status": "cancelled", "message": "Job cancelled"} @router.post("/generate/img2img", response_model=GenerationResponse) async def generate_img2img( image: UploadFile = File(...), image2: UploadFile | None = File(default=None), positive_prompt: str = Form(""), negative_prompt: str = Form(""), character_id: str | None = Form(None), template_id: str | None = Form(None), variables_json: str = Form("{}"), content_rating: str = Form("sfw"), checkpoint: str | None = Form(None), seed: int = Form(-1), steps: int = Form(28), cfg: float = Form(7.0), denoise: float = Form(0.65), width: int | None = Form(None), height: int | None = Form(None), backend: str = Form("local"), ): """Generate an image using a reference image (img2img). Supports both local (ComfyUI) and cloud (WaveSpeed edit) backends. - Local: denoise-based img2img via ComfyUI - Cloud: prompt-guided editing via SeeDream/NanoBanana Edit APIs Multi-reference: Pass a second image (pose/style reference) for models that support it. """ import json as json_module job_id = str(uuid.uuid4()) image_bytes = await image.read() # Read second reference image if provided (for multi-ref models) image_bytes_2 = None if image2 is not None: image_bytes_2 = await image2.read() # Parse template variables try: variables = json_module.loads(variables_json) if variables_json else {} except json_module.JSONDecodeError: variables = {} if backend == "cloud": # Cloud img2img via WaveSpeed Edit API if _wavespeed_provider is None: raise HTTPException(503, "WaveSpeed cloud provider not configured. Set WAVESPEED_API_KEY in .env") asyncio.create_task( _run_cloud_img2img( job_id=job_id, image_bytes=image_bytes, image_bytes_2=image_bytes_2, positive_prompt=positive_prompt, model=checkpoint, content_rating=content_rating, character_id=character_id, template_id=template_id, variables=variables, width=width, height=height, ) ) return GenerationResponse(job_id=job_id, status="queued", backend="wavespeed") # Local img2img via ComfyUI if _local_worker is None or _comfyui_client is None: raise HTTPException(503, "Worker not initialized") ref_filename = f"ref_{job_id[:8]}.png" try: uploaded_name = await _comfyui_client.upload_image(image_bytes, ref_filename) except Exception as e: raise HTTPException(500, f"Failed to upload reference image to ComfyUI: {e}") asyncio.create_task( _run_generation( job_id=job_id, character_id=character_id, template_id=template_id, variables=variables, content_rating=content_rating, positive_prompt=positive_prompt, negative_prompt=negative_prompt, checkpoint=checkpoint, seed=seed, steps=steps, cfg=cfg, width=width, height=height, denoise=denoise, reference_image=uploaded_name, mode="img2img", ) ) return GenerationResponse(job_id=job_id, status="queued", backend="local") async def _run_cloud_generation( *, job_id: str, positive_prompt: str, negative_prompt: str, model: str | None, width: int, height: int, seed: int, content_rating: str, character_id: str | None, template_id: str | None, variables: dict | None, lora_path: str | None = None, lora_strength: float = 0.85, ): """Background task to run a WaveSpeed cloud generation.""" import time _job_tracker[job_id] = { "job_id": job_id, "status": "running", "type": "txt2img", "model": model, "started_at": time.time(), "message": "Preparing prompt...", } try: # Check if cancelled if _job_tracker.get(job_id, {}).get("status") == "cancelled": return # Apply template rendering if a template is selected final_positive = positive_prompt final_negative = negative_prompt if template_id and _template_engine: try: rendered = _template_engine.render(template_id, variables or {}) # Template prompt becomes the base; user prompt is appended if provided final_positive = rendered.positive_prompt if positive_prompt: final_positive = f"{final_positive}, {positive_prompt}" final_negative = rendered.negative_prompt if negative_prompt: final_negative = f"{final_negative}, {negative_prompt}" # Use template dimensions if user didn't override if rendered.template.width: width = rendered.template.width if rendered.template.height: height = rendered.template.height logger.info("Cloud gen: applied template '%s'", template_id) except Exception: logger.warning("Failed to render template '%s', using raw prompt", template_id, exc_info=True) _job_tracker[job_id]["message"] = f"Calling WaveSpeed API ({model or 'seedream-4.5'})..." # Check if cancelled before API call if _job_tracker.get(job_id, {}).get("status") == "cancelled": return result = await _wavespeed_provider.generate( positive_prompt=final_positive, negative_prompt=final_negative, model=model, width=width, height=height, seed=seed, lora_name=lora_path, lora_strength=lora_strength, ) # Check if cancelled after API call if _job_tracker.get(job_id, {}).get("status") == "cancelled": return _job_tracker[job_id]["message"] = "Saving image..." if _catalog: # Save image to disk output_path = _catalog.resolve_output_path( character_id=character_id or "cloud", content_rating=content_rating, filename=f"wavespeed_{job_id[:8]}.png", ) output_path.write_bytes(result.image_bytes) # Record in catalog await _catalog.insert_image( file_path=str(output_path), image_bytes=result.image_bytes, character_id=character_id, template_id=template_id, content_rating=content_rating, positive_prompt=positive_prompt, negative_prompt=negative_prompt, checkpoint=model or "seedream-4.5", seed=seed if seed >= 0 else None, width=width, height=height, generation_backend="wavespeed", generation_time_seconds=result.generation_time_seconds, variables=variables, ) _job_tracker[job_id]["status"] = "completed" _job_tracker[job_id]["message"] = f"Saved: {output_path.name}" _job_tracker[job_id]["completed_at"] = time.time() logger.info("Cloud generation saved: %s", output_path) except Exception as e: _job_tracker[job_id]["status"] = "failed" _job_tracker[job_id]["message"] = str(e)[:200] logger.error("Cloud generation failed for job %s: %s", job_id, e, exc_info=True) async def _run_cloud_img2img( *, job_id: str, image_bytes: bytes, image_bytes_2: bytes | None, positive_prompt: str, model: str | None, content_rating: str, character_id: str | None, template_id: str | None, variables: dict | None, width: int | None, height: int | None, ): """Background task to run a WaveSpeed cloud image edit (img2img).""" import time _job_tracker[job_id] = { "job_id": job_id, "status": "running", "type": "img2img", "model": model, "started_at": time.time(), "message": "Uploading image to cloud...", } try: # Apply template rendering if a template is selected final_prompt = positive_prompt if template_id and _template_engine: try: rendered = _template_engine.render(template_id, variables or {}) final_prompt = rendered.positive_prompt if positive_prompt: final_prompt = f"{final_prompt}, {positive_prompt}" logger.info("Cloud img2img: applied template '%s'", template_id) except Exception: logger.warning("Failed to render template '%s', using raw prompt", template_id, exc_info=True) # Clean up prompt — remove empty Jinja2 artifacts and leading/trailing commas final_prompt = ", ".join(p.strip() for p in final_prompt.split(",") if p.strip()) if not final_prompt: _job_tracker[job_id]["status"] = "failed" _job_tracker[job_id]["message"] = "Empty prompt - nothing to generate" logger.error("Cloud img2img: empty prompt after template rendering, cannot proceed") return _job_tracker[job_id]["message"] = f"Calling WaveSpeed API ({model or 'seedream-4.5-edit'})..." # Build size string if dimensions provided # WaveSpeed edit API requires output size >= 3686400 pixels (~1920x1920) # If dimensions are too small, omit size to let API use input image dimensions size = None if width and height and (width * height) >= 3686400: size = f"{width}x{height}" result = await _wavespeed_provider.edit_image( prompt=final_prompt, image_bytes=image_bytes, image_bytes_2=image_bytes_2, model=model, size=size, ) _job_tracker[job_id]["message"] = "Saving image..." if _catalog: output_path = _catalog.resolve_output_path( character_id=character_id or "cloud", content_rating=content_rating, filename=f"wavespeed_edit_{job_id[:8]}.png", ) output_path.write_bytes(result.image_bytes) await _catalog.insert_image( file_path=str(output_path), image_bytes=result.image_bytes, character_id=character_id, template_id=template_id, content_rating=content_rating, positive_prompt=final_prompt, negative_prompt="", checkpoint=model or "seedream-4.5-edit", width=width or 0, height=height or 0, generation_backend="wavespeed-edit", generation_time_seconds=result.generation_time_seconds, variables=variables, ) _job_tracker[job_id]["status"] = "completed" _job_tracker[job_id]["message"] = f"Saved: {output_path.name}" _job_tracker[job_id]["completed_at"] = time.time() logger.info("Cloud img2img saved: %s", output_path) except Exception as e: _job_tracker[job_id]["status"] = "failed" _job_tracker[job_id]["message"] = str(e)[:200] logger.error("Cloud img2img failed for job %s: %s", job_id, e, exc_info=True) async def _run_runpod_generation( *, job_id: str, positive_prompt: str, negative_prompt: str, checkpoint: str | None, loras: list | None, seed: int, steps: int, cfg: float, width: int, height: int, character_id: str | None, template_id: str | None, content_rating: str, ): """Background task to run a generation on RunPod serverless.""" try: # Resolve character/template prompts if provided final_prompt = positive_prompt final_negative = negative_prompt if character_id and _character_profiles: character = _character_profiles.get(character_id) if character: final_prompt = f"{character.trigger_word}, {positive_prompt}" # Submit to RunPod runpod_job_id = await _runpod_provider.submit_generation( positive_prompt=final_prompt, negative_prompt=final_negative, checkpoint=checkpoint or "realisticVisionV51_v51VAE", lora_name=loras[0].name if loras else None, lora_strength=loras[0].strength if loras else 0.85, seed=seed, steps=steps, cfg=cfg, width=width, height=height, ) # Wait for completion and get result result = await _runpod_provider.wait_for_completion(runpod_job_id) # Save to catalog if _catalog: from pathlib import Path output_path = await _catalog.insert_image( image_bytes=result.image_bytes, character_id=character_id or "unknown", content_rating=content_rating, job_id=job_id, positive_prompt=final_prompt, negative_prompt=final_negative, checkpoint=checkpoint, seed=seed, steps=steps, cfg=cfg, width=width, height=height, generation_backend="runpod", generation_time_seconds=result.generation_time_seconds, ) logger.info("RunPod generation saved: %s (%.1fs)", output_path, result.generation_time_seconds) except Exception: logger.error("RunPod generation failed for job %s", job_id, exc_info=True) async def _run_generation(**kwargs): """Background task to run a single local generation.""" try: # Remove mode param — it's used by the router, not the worker kwargs.pop("mode", None) await _local_worker.process_job(**kwargs) except Exception: logger.error("Generation failed for job %s", kwargs.get("job_id"), exc_info=True) async def _run_batch_job(batch_id: str, job): """Background task to run a single job within a batch.""" tracker = _batch_tracker.get(batch_id) if tracker: tracker["pending"] -= 1 tracker["running"] += 1 try: await _local_worker.process_job( job_id=job.job_id, batch_id=job.batch_id, character_id=job.character.id, template_id=job.template_id, content_rating=job.content_rating, loras=[l for l in job.loras], seed=job.seed, variables=job.variables, ) if tracker: tracker["completed"] += 1 except Exception: logger.error("Batch job %s failed", job.job_id, exc_info=True) if tracker: tracker["failed"] += 1 finally: if tracker: tracker["running"] -= 1