File size: 15,065 Bytes
da23dfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
"""
Backend Router
==============

Unified router for selecting between different image generation backends:
- Gemini (Flash/Pro) - Cloud API
- FLUX.2 klein 4B/9B - Local model
- Z-Image Turbo (Tongyi-MAI) - Local model, 6B, 9 steps, 16GB VRAM
- Qwen-Image-Edit-2511 - Local model
"""

import logging
from typing import Optional, Protocol, Union
from enum import Enum, auto
from PIL import Image

from .models import GenerationRequest, GenerationResult


logger = logging.getLogger(__name__)


class BackendType(Enum):
    """Available backend types."""
    GEMINI_FLASH = "gemini_flash"
    GEMINI_PRO = "gemini_pro"
    FLUX_KLEIN = "flux_klein"              # 4B model (~13GB VRAM)
    FLUX_KLEIN_9B_FP8 = "flux_klein_9b_fp8"  # 9B FP8 model (~20GB VRAM, best quality)
    ZIMAGE_TURBO = "zimage_turbo"          # Z-Image Turbo 6B (9 steps, 16GB VRAM)
    ZIMAGE_BASE = "zimage_base"            # Z-Image Base 6B (50 steps, CFG support) - NEW!
    LONGCAT_EDIT = "longcat_edit"          # LongCat-Image-Edit (instruction-following, 18GB)
    QWEN_IMAGE_EDIT = "qwen_image_edit"    # Direct diffusers (slow, high VRAM)
    QWEN_COMFYUI = "qwen_comfyui"           # Via ComfyUI with FP8 quantization


class ImageClient(Protocol):
    """Protocol for image generation clients."""

    def generate(self, request: GenerationRequest, **kwargs) -> GenerationResult:
        """Generate an image from request."""
        ...

    def is_healthy(self) -> bool:
        """Check if client is ready."""
        ...


class BackendRouter:
    """
    Router for selecting between image generation backends.

    Supports lazy loading of local models to save memory.
    """

    BACKEND_NAMES = {
        BackendType.GEMINI_FLASH: "Gemini Flash",
        BackendType.GEMINI_PRO: "Gemini Pro",
        BackendType.FLUX_KLEIN: "FLUX.2 klein 4B",
        BackendType.FLUX_KLEIN_9B_FP8: "FLUX.2 klein 9B-FP8",
        BackendType.ZIMAGE_TURBO: "Z-Image Turbo 6B",
        BackendType.ZIMAGE_BASE: "Z-Image Base 6B",
        BackendType.LONGCAT_EDIT: "LongCat-Image-Edit",
        BackendType.QWEN_IMAGE_EDIT: "Qwen-Image-Edit-2511",
        BackendType.QWEN_COMFYUI: "Qwen-Image-Edit-2511-FP8 (ComfyUI)",
    }

    def __init__(
        self,
        gemini_api_key: Optional[str] = None,
        default_backend: BackendType = BackendType.GEMINI_FLASH
    ):
        """
        Initialize backend router.

        Args:
            gemini_api_key: API key for Gemini backends
            default_backend: Default backend to use
        """
        self.gemini_api_key = gemini_api_key
        self.default_backend = default_backend
        self._clients: dict = {}
        self._active_backend: Optional[BackendType] = None

        logger.info(f"BackendRouter initialized (default: {default_backend.value})")

    def get_client(self, backend: Optional[BackendType] = None) -> ImageClient:
        """
        Get or create client for specified backend.

        Args:
            backend: Backend type (uses default if None)

        Returns:
            ImageClient instance
        """
        if backend is None:
            backend = self.default_backend

        # Return cached client if available
        if backend in self._clients:
            self._active_backend = backend
            return self._clients[backend]

        # Create new client
        client = self._create_client(backend)
        self._clients[backend] = client
        self._active_backend = backend

        return client

    def _create_client(self, backend: BackendType) -> ImageClient:
        """Create client for specified backend."""
        logger.info(f"Creating client for {backend.value}...")

        if backend == BackendType.GEMINI_FLASH:
            from .gemini_client import GeminiClient
            if not self.gemini_api_key:
                raise ValueError("Gemini API key required for Gemini backends")
            return GeminiClient(api_key=self.gemini_api_key, use_pro_model=False)

        elif backend == BackendType.GEMINI_PRO:
            from .gemini_client import GeminiClient
            if not self.gemini_api_key:
                raise ValueError("Gemini API key required for Gemini backends")
            return GeminiClient(api_key=self.gemini_api_key, use_pro_model=True)

        elif backend == BackendType.FLUX_KLEIN:
            from .flux_klein_client import FluxKleinClient
            # 4B model (~13GB VRAM) - fast
            client = FluxKleinClient(
                model_variant="4b",
                enable_cpu_offload=False
            )
            if not client.load_model():
                raise RuntimeError("Failed to load FLUX.2 klein 4B model")
            return client

        elif backend == BackendType.FLUX_KLEIN_9B_FP8:
            from .flux_klein_client import FluxKleinClient
            # 9B model (~29GB VRAM with CPU offload) - best quality
            client = FluxKleinClient(
                model_variant="9b",
                enable_cpu_offload=True  # Required for 24GB VRAM
            )
            if not client.load_model():
                raise RuntimeError("Failed to load FLUX.2 klein 9B model")
            return client

        elif backend == BackendType.ZIMAGE_TURBO:
            from .zimage_client import ZImageClient
            # Z-Image Turbo 6B - fast (9 steps), fits 16GB VRAM
            client = ZImageClient(
                model_variant="turbo",
                enable_cpu_offload=True
            )
            if not client.load_model():
                raise RuntimeError("Failed to load Z-Image Turbo model")
            return client

        elif backend == BackendType.ZIMAGE_BASE:
            from .zimage_client import ZImageClient
            # Z-Image Base 6B - quality (50 steps), CFG support, negative prompts
            client = ZImageClient(
                model_variant="base",
                enable_cpu_offload=True
            )
            if not client.load_model():
                raise RuntimeError("Failed to load Z-Image Base model")
            return client

        elif backend == BackendType.LONGCAT_EDIT:
            from .longcat_edit_client import LongCatEditClient
            # LongCat-Image-Edit - instruction-following editing (~18GB VRAM)
            client = LongCatEditClient(
                enable_cpu_offload=True
            )
            if not client.load_model():
                raise RuntimeError("Failed to load LongCat-Image-Edit model")
            return client

        elif backend == BackendType.QWEN_IMAGE_EDIT:
            from .qwen_image_edit_client import QwenImageEditClient
            client = QwenImageEditClient(enable_cpu_offload=False)
            if not client.load_model():
                raise RuntimeError("Failed to load Qwen-Image-Edit model")
            return client

        elif backend == BackendType.QWEN_COMFYUI:
            from .comfyui_client import ComfyUIClient
            client = ComfyUIClient()
            if not client.is_healthy():
                raise RuntimeError(
                    "ComfyUI is not running. Please start ComfyUI first:\n"
                    "  cd comfyui && python main.py"
                )
            return client

        else:
            raise ValueError(f"Unknown backend: {backend}")

    def generate(
        self,
        request: GenerationRequest,
        backend: Optional[BackendType] = None,
        **kwargs
    ) -> GenerationResult:
        """
        Generate image using specified backend.

        Args:
            request: Generation request
            backend: Backend to use (default if None)
            **kwargs: Backend-specific parameters

        Returns:
            GenerationResult
        """
        try:
            client = self.get_client(backend)
            return client.generate(request, **kwargs)
        except Exception as e:
            logger.error(f"Generation failed with {backend}: {e}", exc_info=True)
            return GenerationResult.error_result(f"Backend error: {str(e)}")

    def unload_local_models(self):
        """Unload all local models to free memory."""
        local_backends = (BackendType.FLUX_KLEIN, BackendType.FLUX_KLEIN_9B_FP8, BackendType.ZIMAGE_TURBO, BackendType.ZIMAGE_BASE, BackendType.LONGCAT_EDIT, BackendType.QWEN_IMAGE_EDIT, BackendType.QWEN_COMFYUI)
        for backend, client in list(self._clients.items()):
            if backend in local_backends:
                if hasattr(client, 'unload_model'):
                    client.unload_model()
                del self._clients[backend]
                logger.info(f"Unloaded {backend.value}")

    def switch_backend(self, backend: BackendType) -> bool:
        """
        Switch to a different backend.

        For local models, this will load the new model and optionally
        unload the previous one to save memory.

        Args:
            backend: Backend to switch to

        Returns:
            True if switch successful
        """
        try:
            local_backends = {BackendType.FLUX_KLEIN, BackendType.FLUX_KLEIN_9B_FP8, BackendType.ZIMAGE_TURBO, BackendType.ZIMAGE_BASE, BackendType.LONGCAT_EDIT, BackendType.QWEN_IMAGE_EDIT, BackendType.QWEN_COMFYUI}

            # Unload other local models first to save memory
            if backend in local_backends:
                for other_local in local_backends - {backend}:
                    if other_local in self._clients:
                        if hasattr(self._clients[other_local], 'unload_model'):
                            self._clients[other_local].unload_model()
                        del self._clients[other_local]

            # Get/create the new client
            self.get_client(backend)
            self.default_backend = backend

            logger.info(f"Switched to {backend.value}")
            return True

        except Exception as e:
            logger.error(f"Failed to switch to {backend}: {e}", exc_info=True)
            return False

    def get_active_backend_name(self) -> str:
        """Get human-readable name of active backend."""
        if self._active_backend:
            return self.BACKEND_NAMES.get(self._active_backend, str(self._active_backend))
        return "None"

    def is_local_backend(self, backend: Optional[BackendType] = None) -> bool:
        """Check if backend is a local model."""
        if backend is None:
            backend = self._active_backend
        return backend in (BackendType.FLUX_KLEIN, BackendType.FLUX_KLEIN_9B_FP8, BackendType.ZIMAGE_TURBO, BackendType.ZIMAGE_BASE, BackendType.LONGCAT_EDIT, BackendType.QWEN_IMAGE_EDIT, BackendType.QWEN_COMFYUI)

    @staticmethod
    def get_supported_aspect_ratios(backend: BackendType) -> dict:
        """
        Get supported aspect ratios for a backend.

        Returns dict mapping ratio strings to (width, height) tuples.
        """
        # Import clients to get their ASPECT_RATIOS
        if backend in (BackendType.FLUX_KLEIN, BackendType.FLUX_KLEIN_9B_FP8):
            from .flux_klein_client import FluxKleinClient
            return FluxKleinClient.ASPECT_RATIOS

        elif backend in (BackendType.ZIMAGE_TURBO, BackendType.ZIMAGE_BASE):
            from .zimage_client import ZImageClient
            return ZImageClient.ASPECT_RATIOS

        elif backend == BackendType.LONGCAT_EDIT:
            from .longcat_edit_client import LongCatEditClient
            return LongCatEditClient.ASPECT_RATIOS

        elif backend in (BackendType.GEMINI_FLASH, BackendType.GEMINI_PRO):
            from .gemini_client import GeminiClient
            return GeminiClient.ASPECT_RATIOS

        elif backend == BackendType.QWEN_IMAGE_EDIT:
            from .qwen_image_edit_client import QwenImageEditClient
            return QwenImageEditClient.ASPECT_RATIOS

        elif backend == BackendType.QWEN_COMFYUI:
            from .comfyui_client import ComfyUIClient
            return ComfyUIClient.ASPECT_RATIOS

        else:
            # Default fallback
            return {
                "1:1": (1024, 1024),
                "16:9": (1344, 768),
                "9:16": (768, 1344),
            }

    @staticmethod
    def get_aspect_ratio_choices(backend: BackendType) -> list:
        """
        Get aspect ratio choices for UI dropdowns.

        Returns list of (label, value) tuples.
        """
        ratios = BackendRouter.get_supported_aspect_ratios(backend)
        choices = []
        for ratio, (w, h) in ratios.items():
            label = f"{ratio} ({w}x{h})"
            choices.append((label, ratio))
        return choices

    def get_available_backends(self) -> list:
        """Get list of available backends."""
        available = []

        # Gemini backends require API key
        if self.gemini_api_key:
            available.extend([BackendType.GEMINI_FLASH, BackendType.GEMINI_PRO])

        # Local backends always available (if dependencies installed)
        try:
            from diffusers import Flux2KleinPipeline
            available.append(BackendType.FLUX_KLEIN)
        except ImportError:
            pass

        try:
            from diffusers import ZImagePipeline
            available.append(BackendType.ZIMAGE_TURBO)
            available.append(BackendType.ZIMAGE_BASE)
        except ImportError:
            pass

        try:
            from diffusers import LongCatImageEditPipeline
            available.append(BackendType.LONGCAT_EDIT)
        except ImportError:
            pass

        try:
            from diffusers import QwenImageEditPlusPipeline
            available.append(BackendType.QWEN_IMAGE_EDIT)
        except ImportError:
            pass

        # ComfyUI backend - check if ComfyUI client works
        try:
            from .comfyui_client import ComfyUIClient
            client = ComfyUIClient()
            if client.is_healthy():
                available.append(BackendType.QWEN_COMFYUI)
        except Exception:
            pass

        return available

    @staticmethod
    def get_backend_choices() -> list:
        """Get list of backend choices for UI dropdowns."""
        return [
            ("Gemini Flash (Cloud)", BackendType.GEMINI_FLASH.value),
            ("Gemini Pro (Cloud)", BackendType.GEMINI_PRO.value),
            ("FLUX.2 klein 4B (Local)", BackendType.FLUX_KLEIN.value),
            ("Z-Image Turbo 6B (Fast, 9 steps, 16GB)", BackendType.ZIMAGE_TURBO.value),
            ("Z-Image Base 6B (Quality, 50 steps, CFG)", BackendType.ZIMAGE_BASE.value),
            ("LongCat-Image-Edit (Instruction Editing, 18GB)", BackendType.LONGCAT_EDIT.value),
            ("Qwen-Image-Edit-2511 (Local, High VRAM)", BackendType.QWEN_IMAGE_EDIT.value),
            ("Qwen-Image-Edit-2511-FP8 (ComfyUI)", BackendType.QWEN_COMFYUI.value),
        ]

    @staticmethod
    def backend_from_string(value: str) -> BackendType:
        """Convert string to BackendType."""
        for bt in BackendType:
            if bt.value == value:
                return bt
        raise ValueError(f"Unknown backend: {value}")