content-engine / src /content_engine /api /routes_generation.py
dippoo's picture
Sync all local changes: video routes, pod management, wavespeed, UI updates
e808ae1
raw
history blame
25.7 kB
"""Generation API routes — submit single and batch image generation jobs."""
from __future__ import annotations
import asyncio
import logging
import uuid
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
from content_engine.models.schemas import (
BatchRequest,
BatchStatusResponse,
GenerationRequest,
GenerationResponse,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api", tags=["generation"])
# These are injected at startup from main.py
_local_worker = None
_template_engine = None
_variation_engine = None
_character_profiles = None
_wavespeed_provider = None
_runpod_provider = None
_catalog = None
_comfyui_client = None
# In-memory batch tracking (v1 — move to DB for production)
_batch_tracker: dict[str, dict] = {}
# Job status tracking for cloud generations
_job_tracker: dict[str, dict] = {}
def init_routes(local_worker, template_engine, variation_engine, character_profiles,
wavespeed_provider=None, catalog=None, comfyui_client=None):
"""Initialize route dependencies. Called from main.py on startup."""
global _local_worker, _template_engine, _variation_engine, _character_profiles
global _wavespeed_provider, _catalog, _comfyui_client
_local_worker = local_worker
_template_engine = template_engine
_variation_engine = variation_engine
_character_profiles = character_profiles
_wavespeed_provider = wavespeed_provider
_catalog = catalog
_comfyui_client = comfyui_client
def set_runpod_provider(provider):
"""Set RunPod generation provider. Called from main.py after init_routes."""
global _runpod_provider
_runpod_provider = provider
@router.post("/generate", response_model=GenerationResponse)
async def generate_single(request: GenerationRequest):
"""Submit a single image generation job.
The job runs asynchronously — returns immediately with a job ID.
"""
if _local_worker is None:
raise HTTPException(503, "Worker not initialized")
job_id = str(uuid.uuid4())
# Fire and forget — run in background
asyncio.create_task(
_run_generation(
job_id=job_id,
character_id=request.character_id,
template_id=request.template_id,
content_rating=request.content_rating,
positive_prompt=request.positive_prompt,
negative_prompt=request.negative_prompt,
checkpoint=request.checkpoint,
loras=[l.model_dump() for l in request.loras] if request.loras else None,
seed=request.seed or -1,
steps=request.steps,
cfg=request.cfg,
sampler=request.sampler,
scheduler=request.scheduler,
width=request.width,
height=request.height,
variables=request.variables,
)
)
return GenerationResponse(job_id=job_id, status="queued", backend="local")
@router.post("/batch", response_model=GenerationResponse)
async def generate_batch(request: BatchRequest):
"""Submit a batch of variation-based generation jobs.
Uses the variation engine to generate multiple images with
different poses, outfits, emotions, etc.
"""
if _local_worker is None or _variation_engine is None:
raise HTTPException(503, "Services not initialized")
if _character_profiles is None:
raise HTTPException(503, "No character profiles loaded")
character = _character_profiles.get(request.character_id)
if character is None:
raise HTTPException(404, f"Character not found: {request.character_id}")
# Generate variation jobs
jobs = _variation_engine.generate_batch(
template_id=request.template_id,
character=character,
content_rating=request.content_rating,
count=request.count,
variation_mode=request.variation_mode,
pin=request.pin,
seed_strategy=request.seed_strategy,
)
batch_id = jobs[0].batch_id if jobs else str(uuid.uuid4())
_batch_tracker[batch_id] = {
"total": len(jobs),
"completed": 0,
"failed": 0,
"pending": len(jobs),
"running": 0,
}
# Fire all jobs in background
for job in jobs:
asyncio.create_task(
_run_batch_job(batch_id, job)
)
logger.info("Batch %s: %d jobs queued", batch_id, len(jobs))
return GenerationResponse(
job_id=batch_id, batch_id=batch_id, status="queued", backend="local"
)
@router.get("/batch/{batch_id}/status", response_model=BatchStatusResponse)
async def get_batch_status(batch_id: str):
"""Get the status of a batch generation."""
if batch_id not in _batch_tracker:
raise HTTPException(404, f"Batch not found: {batch_id}")
tracker = _batch_tracker[batch_id]
return BatchStatusResponse(
batch_id=batch_id,
total_jobs=tracker["total"],
completed=tracker["completed"],
failed=tracker["failed"],
pending=tracker["pending"],
running=tracker["running"],
)
@router.post("/generate/cloud", response_model=GenerationResponse)
async def generate_cloud(request: GenerationRequest):
"""Generate an image using WaveSpeed cloud API (NanoBanana, SeeDream).
Supported models via the 'checkpoint' field:
- nano-banana, nano-banana-pro
- seedream-3, seedream-3.1, seedream-4, seedream-4.5
"""
if _wavespeed_provider is None:
raise HTTPException(503, "WaveSpeed cloud provider not configured. Set WAVESPEED_API_KEY in .env")
job_id = str(uuid.uuid4())
lora_path = request.loras[0].name if request.loras else None
lora_strength = request.loras[0].strength_model if request.loras else 0.85
asyncio.create_task(
_run_cloud_generation(
job_id=job_id,
positive_prompt=request.positive_prompt or "",
negative_prompt=request.negative_prompt or "",
model=request.checkpoint,
width=request.width or 1024,
height=request.height or 1024,
seed=request.seed or -1,
content_rating=request.content_rating,
character_id=request.character_id,
template_id=request.template_id,
variables=request.variables,
lora_path=lora_path,
lora_strength=lora_strength,
)
)
return GenerationResponse(job_id=job_id, status="queued", backend="wavespeed")
@router.get("/cloud/models")
async def list_cloud_models():
"""List available cloud models (WaveSpeed and RunPod)."""
return {
"wavespeed": {
"available": _wavespeed_provider is not None,
"models": [
{"id": "nano-banana", "name": "NanoBanana", "provider": "Google", "type": "txt2img"},
{"id": "nano-banana-pro", "name": "NanoBanana Pro", "provider": "Google", "type": "txt2img"},
{"id": "seedream-3", "name": "SeeDream v3", "provider": "ByteDance", "type": "txt2img"},
{"id": "seedream-3.1", "name": "SeeDream v3.1", "provider": "ByteDance", "type": "txt2img"},
{"id": "seedream-4", "name": "SeeDream v4", "provider": "ByteDance", "type": "txt2img"},
{"id": "seedream-4.5", "name": "SeeDream v4.5", "provider": "ByteDance", "type": "txt2img"},
],
"edit_models": [
{"id": "seedream-4.5-edit", "name": "SeeDream v4.5 Edit", "provider": "ByteDance", "type": "img2img", "price": "$0.04/img"},
{"id": "seedream-4-edit", "name": "SeeDream v4 Edit", "provider": "ByteDance", "type": "img2img", "price": "$0.04/img"},
{"id": "nano-banana-edit", "name": "NanoBanana Edit", "provider": "Google", "type": "img2img", "price": "$0.038/img"},
{"id": "nano-banana-pro-edit", "name": "NanoBanana Pro Edit", "provider": "Google", "type": "img2img", "price": "$0.14/img"},
],
},
"runpod": {
"available": _runpod_provider is not None,
"description": "Pay-per-second serverless GPU. Uses your deployed endpoint.",
"pricing": "~$0.00025/sec (RTX 4090)",
},
}
@router.post("/generate/runpod", response_model=GenerationResponse)
async def generate_runpod(request: GenerationRequest):
"""Generate an image using RunPod serverless GPU.
Uses your deployed RunPod endpoint. Pay per second of GPU time.
Requires RUNPOD_API_KEY and RUNPOD_ENDPOINT_ID in .env.
"""
if _runpod_provider is None:
raise HTTPException(
503,
"RunPod not configured. Set RUNPOD_API_KEY and RUNPOD_ENDPOINT_ID in .env"
)
job_id = str(uuid.uuid4())
asyncio.create_task(
_run_runpod_generation(
job_id=job_id,
positive_prompt=request.positive_prompt or "",
negative_prompt=request.negative_prompt or "",
checkpoint=request.checkpoint,
loras=request.loras,
seed=request.seed or -1,
steps=request.steps or 28,
cfg=request.cfg or 7.0,
width=request.width or 832,
height=request.height or 1216,
character_id=request.character_id,
template_id=request.template_id,
content_rating=request.content_rating,
)
)
return GenerationResponse(job_id=job_id, status="queued", backend="runpod")
@router.get("/generate/jobs/{job_id}")
async def get_job_status(job_id: str):
"""Get the status of a generation job."""
job = _job_tracker.get(job_id)
if not job:
return {"job_id": job_id, "status": "unknown", "message": "Job not found or completed"}
return job
@router.get("/generate/jobs")
async def list_jobs():
"""List recent generation jobs with their status."""
# Return most recent 20 jobs
jobs = list(_job_tracker.values())
jobs.sort(key=lambda x: x.get("started_at", 0), reverse=True)
return jobs[:20]
@router.post("/generate/jobs/{job_id}/cancel")
async def cancel_job(job_id: str):
"""Cancel a running generation job."""
job = _job_tracker.get(job_id)
if not job:
raise HTTPException(404, "Job not found")
if job["status"] not in ["running", "queued"]:
return {"job_id": job_id, "status": job["status"], "message": "Job already finished"}
_job_tracker[job_id]["status"] = "cancelled"
_job_tracker[job_id]["message"] = "Cancelled by user"
logger.info("Job %s cancelled by user", job_id)
return {"job_id": job_id, "status": "cancelled", "message": "Job cancelled"}
@router.post("/generate/img2img", response_model=GenerationResponse)
async def generate_img2img(
image: UploadFile = File(...),
image2: UploadFile | None = File(default=None),
positive_prompt: str = Form(""),
negative_prompt: str = Form(""),
character_id: str | None = Form(None),
template_id: str | None = Form(None),
variables_json: str = Form("{}"),
content_rating: str = Form("sfw"),
checkpoint: str | None = Form(None),
seed: int = Form(-1),
steps: int = Form(28),
cfg: float = Form(7.0),
denoise: float = Form(0.65),
width: int | None = Form(None),
height: int | None = Form(None),
backend: str = Form("local"),
):
"""Generate an image using a reference image (img2img).
Supports both local (ComfyUI) and cloud (WaveSpeed edit) backends.
- Local: denoise-based img2img via ComfyUI
- Cloud: prompt-guided editing via SeeDream/NanoBanana Edit APIs
Multi-reference: Pass a second image (pose/style reference) for models that support it.
"""
import json as json_module
job_id = str(uuid.uuid4())
image_bytes = await image.read()
# Read second reference image if provided (for multi-ref models)
image_bytes_2 = None
if image2 is not None:
image_bytes_2 = await image2.read()
# Parse template variables
try:
variables = json_module.loads(variables_json) if variables_json else {}
except json_module.JSONDecodeError:
variables = {}
if backend == "cloud":
# Cloud img2img via WaveSpeed Edit API
if _wavespeed_provider is None:
raise HTTPException(503, "WaveSpeed cloud provider not configured. Set WAVESPEED_API_KEY in .env")
asyncio.create_task(
_run_cloud_img2img(
job_id=job_id,
image_bytes=image_bytes,
image_bytes_2=image_bytes_2,
positive_prompt=positive_prompt,
model=checkpoint,
content_rating=content_rating,
character_id=character_id,
template_id=template_id,
variables=variables,
width=width,
height=height,
)
)
return GenerationResponse(job_id=job_id, status="queued", backend="wavespeed")
# Local img2img via ComfyUI
if _local_worker is None or _comfyui_client is None:
raise HTTPException(503, "Worker not initialized")
ref_filename = f"ref_{job_id[:8]}.png"
try:
uploaded_name = await _comfyui_client.upload_image(image_bytes, ref_filename)
except Exception as e:
raise HTTPException(500, f"Failed to upload reference image to ComfyUI: {e}")
asyncio.create_task(
_run_generation(
job_id=job_id,
character_id=character_id,
template_id=template_id,
variables=variables,
content_rating=content_rating,
positive_prompt=positive_prompt,
negative_prompt=negative_prompt,
checkpoint=checkpoint,
seed=seed,
steps=steps,
cfg=cfg,
width=width,
height=height,
denoise=denoise,
reference_image=uploaded_name,
mode="img2img",
)
)
return GenerationResponse(job_id=job_id, status="queued", backend="local")
async def _run_cloud_generation(
*,
job_id: str,
positive_prompt: str,
negative_prompt: str,
model: str | None,
width: int,
height: int,
seed: int,
content_rating: str,
character_id: str | None,
template_id: str | None,
variables: dict | None,
lora_path: str | None = None,
lora_strength: float = 0.85,
):
"""Background task to run a WaveSpeed cloud generation."""
import time
_job_tracker[job_id] = {
"job_id": job_id,
"status": "running",
"type": "txt2img",
"model": model,
"started_at": time.time(),
"message": "Preparing prompt...",
}
try:
# Check if cancelled
if _job_tracker.get(job_id, {}).get("status") == "cancelled":
return
# Apply template rendering if a template is selected
final_positive = positive_prompt
final_negative = negative_prompt
if template_id and _template_engine:
try:
rendered = _template_engine.render(template_id, variables or {})
# Template prompt becomes the base; user prompt is appended if provided
final_positive = rendered.positive_prompt
if positive_prompt:
final_positive = f"{final_positive}, {positive_prompt}"
final_negative = rendered.negative_prompt
if negative_prompt:
final_negative = f"{final_negative}, {negative_prompt}"
# Use template dimensions if user didn't override
if rendered.template.width:
width = rendered.template.width
if rendered.template.height:
height = rendered.template.height
logger.info("Cloud gen: applied template '%s'", template_id)
except Exception:
logger.warning("Failed to render template '%s', using raw prompt", template_id, exc_info=True)
_job_tracker[job_id]["message"] = f"Calling WaveSpeed API ({model or 'seedream-4.5'})..."
# Check if cancelled before API call
if _job_tracker.get(job_id, {}).get("status") == "cancelled":
return
result = await _wavespeed_provider.generate(
positive_prompt=final_positive,
negative_prompt=final_negative,
model=model,
width=width,
height=height,
seed=seed,
lora_name=lora_path,
lora_strength=lora_strength,
)
# Check if cancelled after API call
if _job_tracker.get(job_id, {}).get("status") == "cancelled":
return
_job_tracker[job_id]["message"] = "Saving image..."
if _catalog:
# Save image to disk
output_path = _catalog.resolve_output_path(
character_id=character_id or "cloud",
content_rating=content_rating,
filename=f"wavespeed_{job_id[:8]}.png",
)
output_path.write_bytes(result.image_bytes)
# Record in catalog
await _catalog.insert_image(
file_path=str(output_path),
image_bytes=result.image_bytes,
character_id=character_id,
template_id=template_id,
content_rating=content_rating,
positive_prompt=positive_prompt,
negative_prompt=negative_prompt,
checkpoint=model or "seedream-4.5",
seed=seed if seed >= 0 else None,
width=width,
height=height,
generation_backend="wavespeed",
generation_time_seconds=result.generation_time_seconds,
variables=variables,
)
_job_tracker[job_id]["status"] = "completed"
_job_tracker[job_id]["message"] = f"Saved: {output_path.name}"
_job_tracker[job_id]["completed_at"] = time.time()
logger.info("Cloud generation saved: %s", output_path)
except Exception as e:
_job_tracker[job_id]["status"] = "failed"
_job_tracker[job_id]["message"] = str(e)[:200]
logger.error("Cloud generation failed for job %s: %s", job_id, e, exc_info=True)
async def _run_cloud_img2img(
*,
job_id: str,
image_bytes: bytes,
image_bytes_2: bytes | None,
positive_prompt: str,
model: str | None,
content_rating: str,
character_id: str | None,
template_id: str | None,
variables: dict | None,
width: int | None,
height: int | None,
):
"""Background task to run a WaveSpeed cloud image edit (img2img)."""
import time
_job_tracker[job_id] = {
"job_id": job_id,
"status": "running",
"type": "img2img",
"model": model,
"started_at": time.time(),
"message": "Uploading image to cloud...",
}
try:
# Apply template rendering if a template is selected
final_prompt = positive_prompt
if template_id and _template_engine:
try:
rendered = _template_engine.render(template_id, variables or {})
final_prompt = rendered.positive_prompt
if positive_prompt:
final_prompt = f"{final_prompt}, {positive_prompt}"
logger.info("Cloud img2img: applied template '%s'", template_id)
except Exception:
logger.warning("Failed to render template '%s', using raw prompt", template_id, exc_info=True)
# Clean up prompt — remove empty Jinja2 artifacts and leading/trailing commas
final_prompt = ", ".join(p.strip() for p in final_prompt.split(",") if p.strip())
if not final_prompt:
_job_tracker[job_id]["status"] = "failed"
_job_tracker[job_id]["message"] = "Empty prompt - nothing to generate"
logger.error("Cloud img2img: empty prompt after template rendering, cannot proceed")
return
_job_tracker[job_id]["message"] = f"Calling WaveSpeed API ({model or 'seedream-4.5-edit'})..."
# Build size string if dimensions provided
# WaveSpeed edit API requires output size >= 3686400 pixels (~1920x1920)
# If dimensions are too small, omit size to let API use input image dimensions
size = None
if width and height and (width * height) >= 3686400:
size = f"{width}x{height}"
result = await _wavespeed_provider.edit_image(
prompt=final_prompt,
image_bytes=image_bytes,
image_bytes_2=image_bytes_2,
model=model,
size=size,
)
_job_tracker[job_id]["message"] = "Saving image..."
if _catalog:
output_path = _catalog.resolve_output_path(
character_id=character_id or "cloud",
content_rating=content_rating,
filename=f"wavespeed_edit_{job_id[:8]}.png",
)
output_path.write_bytes(result.image_bytes)
await _catalog.insert_image(
file_path=str(output_path),
image_bytes=result.image_bytes,
character_id=character_id,
template_id=template_id,
content_rating=content_rating,
positive_prompt=final_prompt,
negative_prompt="",
checkpoint=model or "seedream-4.5-edit",
width=width or 0,
height=height or 0,
generation_backend="wavespeed-edit",
generation_time_seconds=result.generation_time_seconds,
variables=variables,
)
_job_tracker[job_id]["status"] = "completed"
_job_tracker[job_id]["message"] = f"Saved: {output_path.name}"
_job_tracker[job_id]["completed_at"] = time.time()
logger.info("Cloud img2img saved: %s", output_path)
except Exception as e:
_job_tracker[job_id]["status"] = "failed"
_job_tracker[job_id]["message"] = str(e)[:200]
logger.error("Cloud img2img failed for job %s: %s", job_id, e, exc_info=True)
async def _run_runpod_generation(
*,
job_id: str,
positive_prompt: str,
negative_prompt: str,
checkpoint: str | None,
loras: list | None,
seed: int,
steps: int,
cfg: float,
width: int,
height: int,
character_id: str | None,
template_id: str | None,
content_rating: str,
):
"""Background task to run a generation on RunPod serverless."""
try:
# Resolve character/template prompts if provided
final_prompt = positive_prompt
final_negative = negative_prompt
if character_id and _character_profiles:
character = _character_profiles.get(character_id)
if character:
final_prompt = f"{character.trigger_word}, {positive_prompt}"
# Submit to RunPod
runpod_job_id = await _runpod_provider.submit_generation(
positive_prompt=final_prompt,
negative_prompt=final_negative,
checkpoint=checkpoint or "realisticVisionV51_v51VAE",
lora_name=loras[0].name if loras else None,
lora_strength=loras[0].strength if loras else 0.85,
seed=seed,
steps=steps,
cfg=cfg,
width=width,
height=height,
)
# Wait for completion and get result
result = await _runpod_provider.wait_for_completion(runpod_job_id)
# Save to catalog
if _catalog:
from pathlib import Path
output_path = await _catalog.insert_image(
image_bytes=result.image_bytes,
character_id=character_id or "unknown",
content_rating=content_rating,
job_id=job_id,
positive_prompt=final_prompt,
negative_prompt=final_negative,
checkpoint=checkpoint,
seed=seed,
steps=steps,
cfg=cfg,
width=width,
height=height,
generation_backend="runpod",
generation_time_seconds=result.generation_time_seconds,
)
logger.info("RunPod generation saved: %s (%.1fs)", output_path, result.generation_time_seconds)
except Exception:
logger.error("RunPod generation failed for job %s", job_id, exc_info=True)
async def _run_generation(**kwargs):
"""Background task to run a single local generation."""
try:
# Remove mode param — it's used by the router, not the worker
kwargs.pop("mode", None)
await _local_worker.process_job(**kwargs)
except Exception:
logger.error("Generation failed for job %s", kwargs.get("job_id"), exc_info=True)
async def _run_batch_job(batch_id: str, job):
"""Background task to run a single job within a batch."""
tracker = _batch_tracker.get(batch_id)
if tracker:
tracker["pending"] -= 1
tracker["running"] += 1
try:
await _local_worker.process_job(
job_id=job.job_id,
batch_id=job.batch_id,
character_id=job.character.id,
template_id=job.template_id,
content_rating=job.content_rating,
loras=[l for l in job.loras],
seed=job.seed,
variables=job.variables,
)
if tracker:
tracker["completed"] += 1
except Exception:
logger.error("Batch job %s failed", job.job_id, exc_info=True)
if tracker:
tracker["failed"] += 1
finally:
if tracker:
tracker["running"] -= 1