Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Backend Router | |
| ============== | |
| Unified router for selecting between different image generation backends: | |
| - Gemini (Flash/Pro) - Cloud API | |
| - FLUX.2 klein 4B/9B - Local model | |
| - Z-Image Turbo (Tongyi-MAI) - Local model, 6B, 9 steps, 16GB VRAM | |
| - Qwen-Image-Edit-2511 - Local model | |
| """ | |
| import logging | |
| from typing import Optional, Protocol, Union | |
| from enum import Enum, auto | |
| from PIL import Image | |
| from .models import GenerationRequest, GenerationResult | |
| logger = logging.getLogger(__name__) | |
| class BackendType(Enum): | |
| """Available backend types.""" | |
| GEMINI_FLASH = "gemini_flash" | |
| GEMINI_PRO = "gemini_pro" | |
| FLUX_KLEIN = "flux_klein" # 4B model (~13GB VRAM) | |
| FLUX_KLEIN_9B_FP8 = "flux_klein_9b_fp8" # 9B FP8 model (~20GB VRAM, best quality) | |
| ZIMAGE_TURBO = "zimage_turbo" # Z-Image Turbo 6B (9 steps, 16GB VRAM) | |
| ZIMAGE_BASE = "zimage_base" # Z-Image Base 6B (50 steps, CFG support) - NEW! | |
| LONGCAT_EDIT = "longcat_edit" # LongCat-Image-Edit (instruction-following, 18GB) | |
| QWEN_IMAGE_EDIT = "qwen_image_edit" # Direct diffusers (slow, high VRAM) | |
| QWEN_COMFYUI = "qwen_comfyui" # Via ComfyUI with FP8 quantization | |
| class ImageClient(Protocol): | |
| """Protocol for image generation clients.""" | |
| def generate(self, request: GenerationRequest, **kwargs) -> GenerationResult: | |
| """Generate an image from request.""" | |
| ... | |
| def is_healthy(self) -> bool: | |
| """Check if client is ready.""" | |
| ... | |
| class BackendRouter: | |
| """ | |
| Router for selecting between image generation backends. | |
| Supports lazy loading of local models to save memory. | |
| """ | |
| BACKEND_NAMES = { | |
| BackendType.GEMINI_FLASH: "Gemini Flash", | |
| BackendType.GEMINI_PRO: "Gemini Pro", | |
| BackendType.FLUX_KLEIN: "FLUX.2 klein 4B", | |
| BackendType.FLUX_KLEIN_9B_FP8: "FLUX.2 klein 9B-FP8", | |
| BackendType.ZIMAGE_TURBO: "Z-Image Turbo 6B", | |
| BackendType.ZIMAGE_BASE: "Z-Image Base 6B", | |
| BackendType.LONGCAT_EDIT: "LongCat-Image-Edit", | |
| BackendType.QWEN_IMAGE_EDIT: "Qwen-Image-Edit-2511", | |
| BackendType.QWEN_COMFYUI: "Qwen-Image-Edit-2511-FP8 (ComfyUI)", | |
| } | |
| def __init__( | |
| self, | |
| gemini_api_key: Optional[str] = None, | |
| default_backend: BackendType = BackendType.GEMINI_FLASH | |
| ): | |
| """ | |
| Initialize backend router. | |
| Args: | |
| gemini_api_key: API key for Gemini backends | |
| default_backend: Default backend to use | |
| """ | |
| self.gemini_api_key = gemini_api_key | |
| self.default_backend = default_backend | |
| self._clients: dict = {} | |
| self._active_backend: Optional[BackendType] = None | |
| logger.info(f"BackendRouter initialized (default: {default_backend.value})") | |
| def get_client(self, backend: Optional[BackendType] = None) -> ImageClient: | |
| """ | |
| Get or create client for specified backend. | |
| Args: | |
| backend: Backend type (uses default if None) | |
| Returns: | |
| ImageClient instance | |
| """ | |
| if backend is None: | |
| backend = self.default_backend | |
| # Return cached client if available | |
| if backend in self._clients: | |
| self._active_backend = backend | |
| return self._clients[backend] | |
| # Create new client | |
| client = self._create_client(backend) | |
| self._clients[backend] = client | |
| self._active_backend = backend | |
| return client | |
| def _create_client(self, backend: BackendType) -> ImageClient: | |
| """Create client for specified backend.""" | |
| logger.info(f"Creating client for {backend.value}...") | |
| if backend == BackendType.GEMINI_FLASH: | |
| from .gemini_client import GeminiClient | |
| if not self.gemini_api_key: | |
| raise ValueError("Gemini API key required for Gemini backends") | |
| return GeminiClient(api_key=self.gemini_api_key, use_pro_model=False) | |
| elif backend == BackendType.GEMINI_PRO: | |
| from .gemini_client import GeminiClient | |
| if not self.gemini_api_key: | |
| raise ValueError("Gemini API key required for Gemini backends") | |
| return GeminiClient(api_key=self.gemini_api_key, use_pro_model=True) | |
| elif backend == BackendType.FLUX_KLEIN: | |
| from .flux_klein_client import FluxKleinClient | |
| # 4B model (~13GB VRAM) - fast | |
| client = FluxKleinClient( | |
| model_variant="4b", | |
| enable_cpu_offload=False | |
| ) | |
| if not client.load_model(): | |
| raise RuntimeError("Failed to load FLUX.2 klein 4B model") | |
| return client | |
| elif backend == BackendType.FLUX_KLEIN_9B_FP8: | |
| from .flux_klein_client import FluxKleinClient | |
| # 9B model (~29GB VRAM with CPU offload) - best quality | |
| client = FluxKleinClient( | |
| model_variant="9b", | |
| enable_cpu_offload=True # Required for 24GB VRAM | |
| ) | |
| if not client.load_model(): | |
| raise RuntimeError("Failed to load FLUX.2 klein 9B model") | |
| return client | |
| elif backend == BackendType.ZIMAGE_TURBO: | |
| from .zimage_client import ZImageClient | |
| # Z-Image Turbo 6B - fast (9 steps), fits 16GB VRAM | |
| client = ZImageClient( | |
| model_variant="turbo", | |
| enable_cpu_offload=True | |
| ) | |
| if not client.load_model(): | |
| raise RuntimeError("Failed to load Z-Image Turbo model") | |
| return client | |
| elif backend == BackendType.ZIMAGE_BASE: | |
| from .zimage_client import ZImageClient | |
| # Z-Image Base 6B - quality (50 steps), CFG support, negative prompts | |
| client = ZImageClient( | |
| model_variant="base", | |
| enable_cpu_offload=True | |
| ) | |
| if not client.load_model(): | |
| raise RuntimeError("Failed to load Z-Image Base model") | |
| return client | |
| elif backend == BackendType.LONGCAT_EDIT: | |
| from .longcat_edit_client import LongCatEditClient | |
| # LongCat-Image-Edit - instruction-following editing (~18GB VRAM) | |
| client = LongCatEditClient( | |
| enable_cpu_offload=True | |
| ) | |
| if not client.load_model(): | |
| raise RuntimeError("Failed to load LongCat-Image-Edit model") | |
| return client | |
| elif backend == BackendType.QWEN_IMAGE_EDIT: | |
| from .qwen_image_edit_client import QwenImageEditClient | |
| client = QwenImageEditClient(enable_cpu_offload=False) | |
| if not client.load_model(): | |
| raise RuntimeError("Failed to load Qwen-Image-Edit model") | |
| return client | |
| elif backend == BackendType.QWEN_COMFYUI: | |
| from .comfyui_client import ComfyUIClient | |
| client = ComfyUIClient() | |
| if not client.is_healthy(): | |
| raise RuntimeError( | |
| "ComfyUI is not running. Please start ComfyUI first:\n" | |
| " cd comfyui && python main.py" | |
| ) | |
| return client | |
| else: | |
| raise ValueError(f"Unknown backend: {backend}") | |
| def generate( | |
| self, | |
| request: GenerationRequest, | |
| backend: Optional[BackendType] = None, | |
| **kwargs | |
| ) -> GenerationResult: | |
| """ | |
| Generate image using specified backend. | |
| Args: | |
| request: Generation request | |
| backend: Backend to use (default if None) | |
| **kwargs: Backend-specific parameters | |
| Returns: | |
| GenerationResult | |
| """ | |
| try: | |
| client = self.get_client(backend) | |
| return client.generate(request, **kwargs) | |
| except Exception as e: | |
| logger.error(f"Generation failed with {backend}: {e}", exc_info=True) | |
| return GenerationResult.error_result(f"Backend error: {str(e)}") | |
| def unload_local_models(self): | |
| """Unload all local models to free memory.""" | |
| local_backends = (BackendType.FLUX_KLEIN, BackendType.FLUX_KLEIN_9B_FP8, BackendType.ZIMAGE_TURBO, BackendType.ZIMAGE_BASE, BackendType.LONGCAT_EDIT, BackendType.QWEN_IMAGE_EDIT, BackendType.QWEN_COMFYUI) | |
| for backend, client in list(self._clients.items()): | |
| if backend in local_backends: | |
| if hasattr(client, 'unload_model'): | |
| client.unload_model() | |
| del self._clients[backend] | |
| logger.info(f"Unloaded {backend.value}") | |
| def switch_backend(self, backend: BackendType) -> bool: | |
| """ | |
| Switch to a different backend. | |
| For local models, this will load the new model and optionally | |
| unload the previous one to save memory. | |
| Args: | |
| backend: Backend to switch to | |
| Returns: | |
| True if switch successful | |
| """ | |
| try: | |
| local_backends = {BackendType.FLUX_KLEIN, BackendType.FLUX_KLEIN_9B_FP8, BackendType.ZIMAGE_TURBO, BackendType.ZIMAGE_BASE, BackendType.LONGCAT_EDIT, BackendType.QWEN_IMAGE_EDIT, BackendType.QWEN_COMFYUI} | |
| # Unload other local models first to save memory | |
| if backend in local_backends: | |
| for other_local in local_backends - {backend}: | |
| if other_local in self._clients: | |
| if hasattr(self._clients[other_local], 'unload_model'): | |
| self._clients[other_local].unload_model() | |
| del self._clients[other_local] | |
| # Get/create the new client | |
| self.get_client(backend) | |
| self.default_backend = backend | |
| logger.info(f"Switched to {backend.value}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to switch to {backend}: {e}", exc_info=True) | |
| return False | |
| def get_active_backend_name(self) -> str: | |
| """Get human-readable name of active backend.""" | |
| if self._active_backend: | |
| return self.BACKEND_NAMES.get(self._active_backend, str(self._active_backend)) | |
| return "None" | |
| def is_local_backend(self, backend: Optional[BackendType] = None) -> bool: | |
| """Check if backend is a local model.""" | |
| if backend is None: | |
| backend = self._active_backend | |
| return backend in (BackendType.FLUX_KLEIN, BackendType.FLUX_KLEIN_9B_FP8, BackendType.ZIMAGE_TURBO, BackendType.ZIMAGE_BASE, BackendType.LONGCAT_EDIT, BackendType.QWEN_IMAGE_EDIT, BackendType.QWEN_COMFYUI) | |
| def get_supported_aspect_ratios(backend: BackendType) -> dict: | |
| """ | |
| Get supported aspect ratios for a backend. | |
| Returns dict mapping ratio strings to (width, height) tuples. | |
| """ | |
| # Import clients to get their ASPECT_RATIOS | |
| if backend in (BackendType.FLUX_KLEIN, BackendType.FLUX_KLEIN_9B_FP8): | |
| from .flux_klein_client import FluxKleinClient | |
| return FluxKleinClient.ASPECT_RATIOS | |
| elif backend in (BackendType.ZIMAGE_TURBO, BackendType.ZIMAGE_BASE): | |
| from .zimage_client import ZImageClient | |
| return ZImageClient.ASPECT_RATIOS | |
| elif backend == BackendType.LONGCAT_EDIT: | |
| from .longcat_edit_client import LongCatEditClient | |
| return LongCatEditClient.ASPECT_RATIOS | |
| elif backend in (BackendType.GEMINI_FLASH, BackendType.GEMINI_PRO): | |
| from .gemini_client import GeminiClient | |
| return GeminiClient.ASPECT_RATIOS | |
| elif backend == BackendType.QWEN_IMAGE_EDIT: | |
| from .qwen_image_edit_client import QwenImageEditClient | |
| return QwenImageEditClient.ASPECT_RATIOS | |
| elif backend == BackendType.QWEN_COMFYUI: | |
| from .comfyui_client import ComfyUIClient | |
| return ComfyUIClient.ASPECT_RATIOS | |
| else: | |
| # Default fallback | |
| return { | |
| "1:1": (1024, 1024), | |
| "16:9": (1344, 768), | |
| "9:16": (768, 1344), | |
| } | |
| def get_aspect_ratio_choices(backend: BackendType) -> list: | |
| """ | |
| Get aspect ratio choices for UI dropdowns. | |
| Returns list of (label, value) tuples. | |
| """ | |
| ratios = BackendRouter.get_supported_aspect_ratios(backend) | |
| choices = [] | |
| for ratio, (w, h) in ratios.items(): | |
| label = f"{ratio} ({w}x{h})" | |
| choices.append((label, ratio)) | |
| return choices | |
| def get_available_backends(self) -> list: | |
| """Get list of available backends.""" | |
| available = [] | |
| # Gemini backends require API key | |
| if self.gemini_api_key: | |
| available.extend([BackendType.GEMINI_FLASH, BackendType.GEMINI_PRO]) | |
| # Local backends always available (if dependencies installed) | |
| try: | |
| from diffusers import Flux2KleinPipeline | |
| available.append(BackendType.FLUX_KLEIN) | |
| except ImportError: | |
| pass | |
| try: | |
| from diffusers import ZImagePipeline | |
| available.append(BackendType.ZIMAGE_TURBO) | |
| available.append(BackendType.ZIMAGE_BASE) | |
| except ImportError: | |
| pass | |
| try: | |
| from diffusers import LongCatImageEditPipeline | |
| available.append(BackendType.LONGCAT_EDIT) | |
| except ImportError: | |
| pass | |
| try: | |
| from diffusers import QwenImageEditPlusPipeline | |
| available.append(BackendType.QWEN_IMAGE_EDIT) | |
| except ImportError: | |
| pass | |
| # ComfyUI backend - check if ComfyUI client works | |
| try: | |
| from .comfyui_client import ComfyUIClient | |
| client = ComfyUIClient() | |
| if client.is_healthy(): | |
| available.append(BackendType.QWEN_COMFYUI) | |
| except Exception: | |
| pass | |
| return available | |
| def get_backend_choices() -> list: | |
| """Get list of backend choices for UI dropdowns.""" | |
| return [ | |
| ("Gemini Flash (Cloud)", BackendType.GEMINI_FLASH.value), | |
| ("Gemini Pro (Cloud)", BackendType.GEMINI_PRO.value), | |
| ("FLUX.2 klein 4B (Local)", BackendType.FLUX_KLEIN.value), | |
| ("Z-Image Turbo 6B (Fast, 9 steps, 16GB)", BackendType.ZIMAGE_TURBO.value), | |
| ("Z-Image Base 6B (Quality, 50 steps, CFG)", BackendType.ZIMAGE_BASE.value), | |
| ("LongCat-Image-Edit (Instruction Editing, 18GB)", BackendType.LONGCAT_EDIT.value), | |
| ("Qwen-Image-Edit-2511 (Local, High VRAM)", BackendType.QWEN_IMAGE_EDIT.value), | |
| ("Qwen-Image-Edit-2511-FP8 (ComfyUI)", BackendType.QWEN_COMFYUI.value), | |
| ] | |
| def backend_from_string(value: str) -> BackendType: | |
| """Convert string to BackendType.""" | |
| for bt in BackendType: | |
| if bt.value == value: | |
| return bt | |
| raise ValueError(f"Unknown backend: {value}") | |