""" Backend Router ============== Unified router for selecting between different image generation backends: - Gemini (Flash/Pro) - Cloud API - FLUX.2 klein 4B/9B - Local model - Z-Image Turbo (Tongyi-MAI) - Local model, 6B, 9 steps, 16GB VRAM - Qwen-Image-Edit-2511 - Local model """ import logging from typing import Optional, Protocol, Union from enum import Enum, auto from PIL import Image from .models import GenerationRequest, GenerationResult logger = logging.getLogger(__name__) class BackendType(Enum): """Available backend types.""" GEMINI_FLASH = "gemini_flash" GEMINI_PRO = "gemini_pro" FLUX_KLEIN = "flux_klein" # 4B model (~13GB VRAM) FLUX_KLEIN_9B_FP8 = "flux_klein_9b_fp8" # 9B FP8 model (~20GB VRAM, best quality) ZIMAGE_TURBO = "zimage_turbo" # Z-Image Turbo 6B (9 steps, 16GB VRAM) ZIMAGE_BASE = "zimage_base" # Z-Image Base 6B (50 steps, CFG support) - NEW! LONGCAT_EDIT = "longcat_edit" # LongCat-Image-Edit (instruction-following, 18GB) QWEN_IMAGE_EDIT = "qwen_image_edit" # Direct diffusers (slow, high VRAM) QWEN_COMFYUI = "qwen_comfyui" # Via ComfyUI with FP8 quantization class ImageClient(Protocol): """Protocol for image generation clients.""" def generate(self, request: GenerationRequest, **kwargs) -> GenerationResult: """Generate an image from request.""" ... def is_healthy(self) -> bool: """Check if client is ready.""" ... class BackendRouter: """ Router for selecting between image generation backends. Supports lazy loading of local models to save memory. """ BACKEND_NAMES = { BackendType.GEMINI_FLASH: "Gemini Flash", BackendType.GEMINI_PRO: "Gemini Pro", BackendType.FLUX_KLEIN: "FLUX.2 klein 4B", BackendType.FLUX_KLEIN_9B_FP8: "FLUX.2 klein 9B-FP8", BackendType.ZIMAGE_TURBO: "Z-Image Turbo 6B", BackendType.ZIMAGE_BASE: "Z-Image Base 6B", BackendType.LONGCAT_EDIT: "LongCat-Image-Edit", BackendType.QWEN_IMAGE_EDIT: "Qwen-Image-Edit-2511", BackendType.QWEN_COMFYUI: "Qwen-Image-Edit-2511-FP8 (ComfyUI)", } def __init__( self, gemini_api_key: Optional[str] = None, default_backend: BackendType = BackendType.GEMINI_FLASH ): """ Initialize backend router. Args: gemini_api_key: API key for Gemini backends default_backend: Default backend to use """ self.gemini_api_key = gemini_api_key self.default_backend = default_backend self._clients: dict = {} self._active_backend: Optional[BackendType] = None logger.info(f"BackendRouter initialized (default: {default_backend.value})") def get_client(self, backend: Optional[BackendType] = None) -> ImageClient: """ Get or create client for specified backend. Args: backend: Backend type (uses default if None) Returns: ImageClient instance """ if backend is None: backend = self.default_backend # Return cached client if available if backend in self._clients: self._active_backend = backend return self._clients[backend] # Create new client client = self._create_client(backend) self._clients[backend] = client self._active_backend = backend return client def _create_client(self, backend: BackendType) -> ImageClient: """Create client for specified backend.""" logger.info(f"Creating client for {backend.value}...") if backend == BackendType.GEMINI_FLASH: from .gemini_client import GeminiClient if not self.gemini_api_key: raise ValueError("Gemini API key required for Gemini backends") return GeminiClient(api_key=self.gemini_api_key, use_pro_model=False) elif backend == BackendType.GEMINI_PRO: from .gemini_client import GeminiClient if not self.gemini_api_key: raise ValueError("Gemini API key required for Gemini backends") return GeminiClient(api_key=self.gemini_api_key, use_pro_model=True) elif backend == BackendType.FLUX_KLEIN: from .flux_klein_client import FluxKleinClient # 4B model (~13GB VRAM) - fast client = FluxKleinClient( model_variant="4b", enable_cpu_offload=False ) if not client.load_model(): raise RuntimeError("Failed to load FLUX.2 klein 4B model") return client elif backend == BackendType.FLUX_KLEIN_9B_FP8: from .flux_klein_client import FluxKleinClient # 9B model (~29GB VRAM with CPU offload) - best quality client = FluxKleinClient( model_variant="9b", enable_cpu_offload=True # Required for 24GB VRAM ) if not client.load_model(): raise RuntimeError("Failed to load FLUX.2 klein 9B model") return client elif backend == BackendType.ZIMAGE_TURBO: from .zimage_client import ZImageClient # Z-Image Turbo 6B - fast (9 steps), fits 16GB VRAM client = ZImageClient( model_variant="turbo", enable_cpu_offload=True ) if not client.load_model(): raise RuntimeError("Failed to load Z-Image Turbo model") return client elif backend == BackendType.ZIMAGE_BASE: from .zimage_client import ZImageClient # Z-Image Base 6B - quality (50 steps), CFG support, negative prompts client = ZImageClient( model_variant="base", enable_cpu_offload=True ) if not client.load_model(): raise RuntimeError("Failed to load Z-Image Base model") return client elif backend == BackendType.LONGCAT_EDIT: from .longcat_edit_client import LongCatEditClient # LongCat-Image-Edit - instruction-following editing (~18GB VRAM) client = LongCatEditClient( enable_cpu_offload=True ) if not client.load_model(): raise RuntimeError("Failed to load LongCat-Image-Edit model") return client elif backend == BackendType.QWEN_IMAGE_EDIT: from .qwen_image_edit_client import QwenImageEditClient client = QwenImageEditClient(enable_cpu_offload=False) if not client.load_model(): raise RuntimeError("Failed to load Qwen-Image-Edit model") return client elif backend == BackendType.QWEN_COMFYUI: from .comfyui_client import ComfyUIClient client = ComfyUIClient() if not client.is_healthy(): raise RuntimeError( "ComfyUI is not running. Please start ComfyUI first:\n" " cd comfyui && python main.py" ) return client else: raise ValueError(f"Unknown backend: {backend}") def generate( self, request: GenerationRequest, backend: Optional[BackendType] = None, **kwargs ) -> GenerationResult: """ Generate image using specified backend. Args: request: Generation request backend: Backend to use (default if None) **kwargs: Backend-specific parameters Returns: GenerationResult """ try: client = self.get_client(backend) return client.generate(request, **kwargs) except Exception as e: logger.error(f"Generation failed with {backend}: {e}", exc_info=True) return GenerationResult.error_result(f"Backend error: {str(e)}") def unload_local_models(self): """Unload all local models to free memory.""" local_backends = (BackendType.FLUX_KLEIN, BackendType.FLUX_KLEIN_9B_FP8, BackendType.ZIMAGE_TURBO, BackendType.ZIMAGE_BASE, BackendType.LONGCAT_EDIT, BackendType.QWEN_IMAGE_EDIT, BackendType.QWEN_COMFYUI) for backend, client in list(self._clients.items()): if backend in local_backends: if hasattr(client, 'unload_model'): client.unload_model() del self._clients[backend] logger.info(f"Unloaded {backend.value}") def switch_backend(self, backend: BackendType) -> bool: """ Switch to a different backend. For local models, this will load the new model and optionally unload the previous one to save memory. Args: backend: Backend to switch to Returns: True if switch successful """ try: local_backends = {BackendType.FLUX_KLEIN, BackendType.FLUX_KLEIN_9B_FP8, BackendType.ZIMAGE_TURBO, BackendType.ZIMAGE_BASE, BackendType.LONGCAT_EDIT, BackendType.QWEN_IMAGE_EDIT, BackendType.QWEN_COMFYUI} # Unload other local models first to save memory if backend in local_backends: for other_local in local_backends - {backend}: if other_local in self._clients: if hasattr(self._clients[other_local], 'unload_model'): self._clients[other_local].unload_model() del self._clients[other_local] # Get/create the new client self.get_client(backend) self.default_backend = backend logger.info(f"Switched to {backend.value}") return True except Exception as e: logger.error(f"Failed to switch to {backend}: {e}", exc_info=True) return False def get_active_backend_name(self) -> str: """Get human-readable name of active backend.""" if self._active_backend: return self.BACKEND_NAMES.get(self._active_backend, str(self._active_backend)) return "None" def is_local_backend(self, backend: Optional[BackendType] = None) -> bool: """Check if backend is a local model.""" if backend is None: backend = self._active_backend return backend in (BackendType.FLUX_KLEIN, BackendType.FLUX_KLEIN_9B_FP8, BackendType.ZIMAGE_TURBO, BackendType.ZIMAGE_BASE, BackendType.LONGCAT_EDIT, BackendType.QWEN_IMAGE_EDIT, BackendType.QWEN_COMFYUI) @staticmethod def get_supported_aspect_ratios(backend: BackendType) -> dict: """ Get supported aspect ratios for a backend. Returns dict mapping ratio strings to (width, height) tuples. """ # Import clients to get their ASPECT_RATIOS if backend in (BackendType.FLUX_KLEIN, BackendType.FLUX_KLEIN_9B_FP8): from .flux_klein_client import FluxKleinClient return FluxKleinClient.ASPECT_RATIOS elif backend in (BackendType.ZIMAGE_TURBO, BackendType.ZIMAGE_BASE): from .zimage_client import ZImageClient return ZImageClient.ASPECT_RATIOS elif backend == BackendType.LONGCAT_EDIT: from .longcat_edit_client import LongCatEditClient return LongCatEditClient.ASPECT_RATIOS elif backend in (BackendType.GEMINI_FLASH, BackendType.GEMINI_PRO): from .gemini_client import GeminiClient return GeminiClient.ASPECT_RATIOS elif backend == BackendType.QWEN_IMAGE_EDIT: from .qwen_image_edit_client import QwenImageEditClient return QwenImageEditClient.ASPECT_RATIOS elif backend == BackendType.QWEN_COMFYUI: from .comfyui_client import ComfyUIClient return ComfyUIClient.ASPECT_RATIOS else: # Default fallback return { "1:1": (1024, 1024), "16:9": (1344, 768), "9:16": (768, 1344), } @staticmethod def get_aspect_ratio_choices(backend: BackendType) -> list: """ Get aspect ratio choices for UI dropdowns. Returns list of (label, value) tuples. """ ratios = BackendRouter.get_supported_aspect_ratios(backend) choices = [] for ratio, (w, h) in ratios.items(): label = f"{ratio} ({w}x{h})" choices.append((label, ratio)) return choices def get_available_backends(self) -> list: """Get list of available backends.""" available = [] # Gemini backends require API key if self.gemini_api_key: available.extend([BackendType.GEMINI_FLASH, BackendType.GEMINI_PRO]) # Local backends always available (if dependencies installed) try: from diffusers import Flux2KleinPipeline available.append(BackendType.FLUX_KLEIN) except ImportError: pass try: from diffusers import ZImagePipeline available.append(BackendType.ZIMAGE_TURBO) available.append(BackendType.ZIMAGE_BASE) except ImportError: pass try: from diffusers import LongCatImageEditPipeline available.append(BackendType.LONGCAT_EDIT) except ImportError: pass try: from diffusers import QwenImageEditPlusPipeline available.append(BackendType.QWEN_IMAGE_EDIT) except ImportError: pass # ComfyUI backend - check if ComfyUI client works try: from .comfyui_client import ComfyUIClient client = ComfyUIClient() if client.is_healthy(): available.append(BackendType.QWEN_COMFYUI) except Exception: pass return available @staticmethod def get_backend_choices() -> list: """Get list of backend choices for UI dropdowns.""" return [ ("Gemini Flash (Cloud)", BackendType.GEMINI_FLASH.value), ("Gemini Pro (Cloud)", BackendType.GEMINI_PRO.value), ("FLUX.2 klein 4B (Local)", BackendType.FLUX_KLEIN.value), ("Z-Image Turbo 6B (Fast, 9 steps, 16GB)", BackendType.ZIMAGE_TURBO.value), ("Z-Image Base 6B (Quality, 50 steps, CFG)", BackendType.ZIMAGE_BASE.value), ("LongCat-Image-Edit (Instruction Editing, 18GB)", BackendType.LONGCAT_EDIT.value), ("Qwen-Image-Edit-2511 (Local, High VRAM)", BackendType.QWEN_IMAGE_EDIT.value), ("Qwen-Image-Edit-2511-FP8 (ComfyUI)", BackendType.QWEN_COMFYUI.value), ] @staticmethod def backend_from_string(value: str) -> BackendType: """Convert string to BackendType.""" for bt in BackendType: if bt.value == value: return bt raise ValueError(f"Unknown backend: {value}")