CharacterForgePro / src /backend_router.py
ghmk's picture
Deploy full Character Sheet Pro with HF auth
da23dfe
"""
Backend Router
==============
Unified router for selecting between different image generation backends:
- Gemini (Flash/Pro) - Cloud API
- FLUX.2 klein 4B/9B - Local model
- Z-Image Turbo (Tongyi-MAI) - Local model, 6B, 9 steps, 16GB VRAM
- Qwen-Image-Edit-2511 - Local model
"""
import logging
from typing import Optional, Protocol, Union
from enum import Enum, auto
from PIL import Image
from .models import GenerationRequest, GenerationResult
logger = logging.getLogger(__name__)
class BackendType(Enum):
"""Available backend types."""
GEMINI_FLASH = "gemini_flash"
GEMINI_PRO = "gemini_pro"
FLUX_KLEIN = "flux_klein" # 4B model (~13GB VRAM)
FLUX_KLEIN_9B_FP8 = "flux_klein_9b_fp8" # 9B FP8 model (~20GB VRAM, best quality)
ZIMAGE_TURBO = "zimage_turbo" # Z-Image Turbo 6B (9 steps, 16GB VRAM)
ZIMAGE_BASE = "zimage_base" # Z-Image Base 6B (50 steps, CFG support) - NEW!
LONGCAT_EDIT = "longcat_edit" # LongCat-Image-Edit (instruction-following, 18GB)
QWEN_IMAGE_EDIT = "qwen_image_edit" # Direct diffusers (slow, high VRAM)
QWEN_COMFYUI = "qwen_comfyui" # Via ComfyUI with FP8 quantization
class ImageClient(Protocol):
"""Protocol for image generation clients."""
def generate(self, request: GenerationRequest, **kwargs) -> GenerationResult:
"""Generate an image from request."""
...
def is_healthy(self) -> bool:
"""Check if client is ready."""
...
class BackendRouter:
"""
Router for selecting between image generation backends.
Supports lazy loading of local models to save memory.
"""
BACKEND_NAMES = {
BackendType.GEMINI_FLASH: "Gemini Flash",
BackendType.GEMINI_PRO: "Gemini Pro",
BackendType.FLUX_KLEIN: "FLUX.2 klein 4B",
BackendType.FLUX_KLEIN_9B_FP8: "FLUX.2 klein 9B-FP8",
BackendType.ZIMAGE_TURBO: "Z-Image Turbo 6B",
BackendType.ZIMAGE_BASE: "Z-Image Base 6B",
BackendType.LONGCAT_EDIT: "LongCat-Image-Edit",
BackendType.QWEN_IMAGE_EDIT: "Qwen-Image-Edit-2511",
BackendType.QWEN_COMFYUI: "Qwen-Image-Edit-2511-FP8 (ComfyUI)",
}
def __init__(
self,
gemini_api_key: Optional[str] = None,
default_backend: BackendType = BackendType.GEMINI_FLASH
):
"""
Initialize backend router.
Args:
gemini_api_key: API key for Gemini backends
default_backend: Default backend to use
"""
self.gemini_api_key = gemini_api_key
self.default_backend = default_backend
self._clients: dict = {}
self._active_backend: Optional[BackendType] = None
logger.info(f"BackendRouter initialized (default: {default_backend.value})")
def get_client(self, backend: Optional[BackendType] = None) -> ImageClient:
"""
Get or create client for specified backend.
Args:
backend: Backend type (uses default if None)
Returns:
ImageClient instance
"""
if backend is None:
backend = self.default_backend
# Return cached client if available
if backend in self._clients:
self._active_backend = backend
return self._clients[backend]
# Create new client
client = self._create_client(backend)
self._clients[backend] = client
self._active_backend = backend
return client
def _create_client(self, backend: BackendType) -> ImageClient:
"""Create client for specified backend."""
logger.info(f"Creating client for {backend.value}...")
if backend == BackendType.GEMINI_FLASH:
from .gemini_client import GeminiClient
if not self.gemini_api_key:
raise ValueError("Gemini API key required for Gemini backends")
return GeminiClient(api_key=self.gemini_api_key, use_pro_model=False)
elif backend == BackendType.GEMINI_PRO:
from .gemini_client import GeminiClient
if not self.gemini_api_key:
raise ValueError("Gemini API key required for Gemini backends")
return GeminiClient(api_key=self.gemini_api_key, use_pro_model=True)
elif backend == BackendType.FLUX_KLEIN:
from .flux_klein_client import FluxKleinClient
# 4B model (~13GB VRAM) - fast
client = FluxKleinClient(
model_variant="4b",
enable_cpu_offload=False
)
if not client.load_model():
raise RuntimeError("Failed to load FLUX.2 klein 4B model")
return client
elif backend == BackendType.FLUX_KLEIN_9B_FP8:
from .flux_klein_client import FluxKleinClient
# 9B model (~29GB VRAM with CPU offload) - best quality
client = FluxKleinClient(
model_variant="9b",
enable_cpu_offload=True # Required for 24GB VRAM
)
if not client.load_model():
raise RuntimeError("Failed to load FLUX.2 klein 9B model")
return client
elif backend == BackendType.ZIMAGE_TURBO:
from .zimage_client import ZImageClient
# Z-Image Turbo 6B - fast (9 steps), fits 16GB VRAM
client = ZImageClient(
model_variant="turbo",
enable_cpu_offload=True
)
if not client.load_model():
raise RuntimeError("Failed to load Z-Image Turbo model")
return client
elif backend == BackendType.ZIMAGE_BASE:
from .zimage_client import ZImageClient
# Z-Image Base 6B - quality (50 steps), CFG support, negative prompts
client = ZImageClient(
model_variant="base",
enable_cpu_offload=True
)
if not client.load_model():
raise RuntimeError("Failed to load Z-Image Base model")
return client
elif backend == BackendType.LONGCAT_EDIT:
from .longcat_edit_client import LongCatEditClient
# LongCat-Image-Edit - instruction-following editing (~18GB VRAM)
client = LongCatEditClient(
enable_cpu_offload=True
)
if not client.load_model():
raise RuntimeError("Failed to load LongCat-Image-Edit model")
return client
elif backend == BackendType.QWEN_IMAGE_EDIT:
from .qwen_image_edit_client import QwenImageEditClient
client = QwenImageEditClient(enable_cpu_offload=False)
if not client.load_model():
raise RuntimeError("Failed to load Qwen-Image-Edit model")
return client
elif backend == BackendType.QWEN_COMFYUI:
from .comfyui_client import ComfyUIClient
client = ComfyUIClient()
if not client.is_healthy():
raise RuntimeError(
"ComfyUI is not running. Please start ComfyUI first:\n"
" cd comfyui && python main.py"
)
return client
else:
raise ValueError(f"Unknown backend: {backend}")
def generate(
self,
request: GenerationRequest,
backend: Optional[BackendType] = None,
**kwargs
) -> GenerationResult:
"""
Generate image using specified backend.
Args:
request: Generation request
backend: Backend to use (default if None)
**kwargs: Backend-specific parameters
Returns:
GenerationResult
"""
try:
client = self.get_client(backend)
return client.generate(request, **kwargs)
except Exception as e:
logger.error(f"Generation failed with {backend}: {e}", exc_info=True)
return GenerationResult.error_result(f"Backend error: {str(e)}")
def unload_local_models(self):
"""Unload all local models to free memory."""
local_backends = (BackendType.FLUX_KLEIN, BackendType.FLUX_KLEIN_9B_FP8, BackendType.ZIMAGE_TURBO, BackendType.ZIMAGE_BASE, BackendType.LONGCAT_EDIT, BackendType.QWEN_IMAGE_EDIT, BackendType.QWEN_COMFYUI)
for backend, client in list(self._clients.items()):
if backend in local_backends:
if hasattr(client, 'unload_model'):
client.unload_model()
del self._clients[backend]
logger.info(f"Unloaded {backend.value}")
def switch_backend(self, backend: BackendType) -> bool:
"""
Switch to a different backend.
For local models, this will load the new model and optionally
unload the previous one to save memory.
Args:
backend: Backend to switch to
Returns:
True if switch successful
"""
try:
local_backends = {BackendType.FLUX_KLEIN, BackendType.FLUX_KLEIN_9B_FP8, BackendType.ZIMAGE_TURBO, BackendType.ZIMAGE_BASE, BackendType.LONGCAT_EDIT, BackendType.QWEN_IMAGE_EDIT, BackendType.QWEN_COMFYUI}
# Unload other local models first to save memory
if backend in local_backends:
for other_local in local_backends - {backend}:
if other_local in self._clients:
if hasattr(self._clients[other_local], 'unload_model'):
self._clients[other_local].unload_model()
del self._clients[other_local]
# Get/create the new client
self.get_client(backend)
self.default_backend = backend
logger.info(f"Switched to {backend.value}")
return True
except Exception as e:
logger.error(f"Failed to switch to {backend}: {e}", exc_info=True)
return False
def get_active_backend_name(self) -> str:
"""Get human-readable name of active backend."""
if self._active_backend:
return self.BACKEND_NAMES.get(self._active_backend, str(self._active_backend))
return "None"
def is_local_backend(self, backend: Optional[BackendType] = None) -> bool:
"""Check if backend is a local model."""
if backend is None:
backend = self._active_backend
return backend in (BackendType.FLUX_KLEIN, BackendType.FLUX_KLEIN_9B_FP8, BackendType.ZIMAGE_TURBO, BackendType.ZIMAGE_BASE, BackendType.LONGCAT_EDIT, BackendType.QWEN_IMAGE_EDIT, BackendType.QWEN_COMFYUI)
@staticmethod
def get_supported_aspect_ratios(backend: BackendType) -> dict:
"""
Get supported aspect ratios for a backend.
Returns dict mapping ratio strings to (width, height) tuples.
"""
# Import clients to get their ASPECT_RATIOS
if backend in (BackendType.FLUX_KLEIN, BackendType.FLUX_KLEIN_9B_FP8):
from .flux_klein_client import FluxKleinClient
return FluxKleinClient.ASPECT_RATIOS
elif backend in (BackendType.ZIMAGE_TURBO, BackendType.ZIMAGE_BASE):
from .zimage_client import ZImageClient
return ZImageClient.ASPECT_RATIOS
elif backend == BackendType.LONGCAT_EDIT:
from .longcat_edit_client import LongCatEditClient
return LongCatEditClient.ASPECT_RATIOS
elif backend in (BackendType.GEMINI_FLASH, BackendType.GEMINI_PRO):
from .gemini_client import GeminiClient
return GeminiClient.ASPECT_RATIOS
elif backend == BackendType.QWEN_IMAGE_EDIT:
from .qwen_image_edit_client import QwenImageEditClient
return QwenImageEditClient.ASPECT_RATIOS
elif backend == BackendType.QWEN_COMFYUI:
from .comfyui_client import ComfyUIClient
return ComfyUIClient.ASPECT_RATIOS
else:
# Default fallback
return {
"1:1": (1024, 1024),
"16:9": (1344, 768),
"9:16": (768, 1344),
}
@staticmethod
def get_aspect_ratio_choices(backend: BackendType) -> list:
"""
Get aspect ratio choices for UI dropdowns.
Returns list of (label, value) tuples.
"""
ratios = BackendRouter.get_supported_aspect_ratios(backend)
choices = []
for ratio, (w, h) in ratios.items():
label = f"{ratio} ({w}x{h})"
choices.append((label, ratio))
return choices
def get_available_backends(self) -> list:
"""Get list of available backends."""
available = []
# Gemini backends require API key
if self.gemini_api_key:
available.extend([BackendType.GEMINI_FLASH, BackendType.GEMINI_PRO])
# Local backends always available (if dependencies installed)
try:
from diffusers import Flux2KleinPipeline
available.append(BackendType.FLUX_KLEIN)
except ImportError:
pass
try:
from diffusers import ZImagePipeline
available.append(BackendType.ZIMAGE_TURBO)
available.append(BackendType.ZIMAGE_BASE)
except ImportError:
pass
try:
from diffusers import LongCatImageEditPipeline
available.append(BackendType.LONGCAT_EDIT)
except ImportError:
pass
try:
from diffusers import QwenImageEditPlusPipeline
available.append(BackendType.QWEN_IMAGE_EDIT)
except ImportError:
pass
# ComfyUI backend - check if ComfyUI client works
try:
from .comfyui_client import ComfyUIClient
client = ComfyUIClient()
if client.is_healthy():
available.append(BackendType.QWEN_COMFYUI)
except Exception:
pass
return available
@staticmethod
def get_backend_choices() -> list:
"""Get list of backend choices for UI dropdowns."""
return [
("Gemini Flash (Cloud)", BackendType.GEMINI_FLASH.value),
("Gemini Pro (Cloud)", BackendType.GEMINI_PRO.value),
("FLUX.2 klein 4B (Local)", BackendType.FLUX_KLEIN.value),
("Z-Image Turbo 6B (Fast, 9 steps, 16GB)", BackendType.ZIMAGE_TURBO.value),
("Z-Image Base 6B (Quality, 50 steps, CFG)", BackendType.ZIMAGE_BASE.value),
("LongCat-Image-Edit (Instruction Editing, 18GB)", BackendType.LONGCAT_EDIT.value),
("Qwen-Image-Edit-2511 (Local, High VRAM)", BackendType.QWEN_IMAGE_EDIT.value),
("Qwen-Image-Edit-2511-FP8 (ComfyUI)", BackendType.QWEN_COMFYUI.value),
]
@staticmethod
def backend_from_string(value: str) -> BackendType:
"""Convert string to BackendType."""
for bt in BackendType:
if bt.value == value:
return bt
raise ValueError(f"Unknown backend: {value}")