File size: 19,910 Bytes
da23dfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
"""
Qwen-Image-Edit Client
======================

Client for Qwen-Image-Edit-2511 local image editing.
Supports multi-image editing with improved consistency.

GPU loading strategies (benchmarked on A6000 + A5000):
  Pinned 2-GPU:        169.9s (4.25s/step) - 1.36x vs baseline
  Balanced single-GPU: 184.4s (4.61s/step) - 1.25x vs baseline
  CPU offload:         231.5s (5.79s/step) - baseline
"""

import logging
import time
import types
from typing import Optional, List
from PIL import Image

import torch

from .models import GenerationRequest, GenerationResult


logger = logging.getLogger(__name__)


class QwenImageEditClient:
    """
    Client for Qwen-Image-Edit-2511 model.

    Supports:
    - Multi-image editing (up to multiple reference images)
    - Precise text editing
    - Improved character consistency
    - LoRA integration
    """

    # Model variants
    MODELS = {
        "full": "Qwen/Qwen-Image-Edit",           # Official Qwen model
    }

    # Legacy compatibility
    MODEL_ID = MODELS["full"]

    # Aspect ratio to dimensions mapping (target output sizes)
    ASPECT_RATIOS = {
        "1:1": (1328, 1328),
        "16:9": (1664, 928),
        "9:16": (928, 1664),
        "21:9": (1680, 720),    # Cinematic ultra-wide
        "3:2": (1584, 1056),
        "2:3": (1056, 1584),
        "3:4": (1104, 1472),
        "4:3": (1472, 1104),
        "4:5": (1056, 1320),
        "5:4": (1320, 1056),
    }

    # Proven native generation resolution.  Tested resolutions:
    #   1104x1472 (3:4) β†’ CLEAN output (face views in v1 test)
    #   928x1664  (9:16) β†’ VAE tiling noise / garbage
    #   1328x1328 (1:1)  β†’ VAE tiling noise / garbage
    #   896x1184  (auto) β†’ garbage
    # Always generate at 1104x1472, then crop+resize to target.
    NATIVE_RESOLUTION = (1104, 1472)

    # VRAM thresholds for loading strategies
    # Qwen-Image-Edit components: transformer ~40.9GB, text_encoder ~16.6GB, VAE ~0.25GB
    BALANCED_VRAM_THRESHOLD_GB = 45    # Single GPU balanced (needs ~42GB + headroom)
    MAIN_GPU_MIN_VRAM_GB = 42          # Transformer + VAE minimum
    ENCODER_GPU_MIN_VRAM_GB = 17       # Text encoder minimum

    def __init__(
        self,
        model_variant: str = "full",  # Use full model (~50GB)
        device: str = "cuda",
        dtype: torch.dtype = torch.bfloat16,
        enable_cpu_offload: bool = True,
        encoder_device: Optional[str] = None,
    ):
        """
        Initialize Qwen-Image-Edit client.

        Args:
            model_variant: Model variant ("full" for ~50GB)
            device: Device to use for transformer+VAE (cuda or cuda:N)
            dtype: Data type for model weights
            enable_cpu_offload: Enable CPU offload to save VRAM
            encoder_device: Explicit device for text_encoder (e.g. "cuda:3").
                           If None, auto-detected from available GPUs.
        """
        self.model_variant = model_variant
        self.device = device
        self.dtype = dtype
        self.enable_cpu_offload = enable_cpu_offload
        self.encoder_device = encoder_device
        self.pipe = None
        self._loaded = False
        self._loading_strategy = None

        logger.info(f"QwenImageEditClient initialized (variant: {model_variant})")

    @staticmethod
    def _get_gpu_vram_gb(device_idx: int) -> float:
        """Get total VRAM in GB for a specific GPU."""
        if not torch.cuda.is_available():
            return 0.0
        if device_idx >= torch.cuda.device_count():
            return 0.0
        return torch.cuda.get_device_properties(device_idx).total_memory / 1e9

    def _get_vram_gb(self) -> float:
        """Get available VRAM in GB for the main target device."""
        device_idx = self._parse_device_idx(self.device)
        return self._get_gpu_vram_gb(device_idx)

    @staticmethod
    def _parse_device_idx(device: str) -> int:
        """Parse CUDA device index from device string."""
        if device.startswith("cuda:"):
            try:
                return int(device.split(":")[1])
            except (ValueError, IndexError):
                pass
        return 0

    def _find_encoder_gpu(self, main_idx: int) -> Optional[int]:
        """Find a secondary GPU suitable for text_encoder (>= 17GB VRAM).

        Prefers GPUs with more VRAM. Skips the main GPU.
        """
        if not torch.cuda.is_available():
            return None

        candidates = []
        for i in range(torch.cuda.device_count()):
            if i == main_idx:
                continue
            vram = self._get_gpu_vram_gb(i)
            if vram >= self.ENCODER_GPU_MIN_VRAM_GB:
                name = torch.cuda.get_device_name(i)
                candidates.append((i, vram, name))

        if not candidates:
            return None

        # Pick the GPU with the most VRAM
        candidates.sort(key=lambda x: x[1], reverse=True)
        best = candidates[0]
        logger.info(f"Found encoder GPU: cuda:{best[0]} ({best[2]}, {best[1]:.1f} GB)")
        return best[0]

    @staticmethod
    def _patched_get_qwen_prompt_embeds(self, prompt, image=None, device=None, dtype=None):
        """Patched prompt encoding that routes inputs to text_encoder's device.

        The original _get_qwen_prompt_embeds sends model_inputs to
        execution_device (main GPU), then calls text_encoder on a different
        GPU, causing a device mismatch. This patch:
        1. Sends model_inputs to text_encoder's device for encoding
        2. Moves outputs back to execution_device for the transformer
        """
        te_device = next(self.text_encoder.parameters()).device
        execution_device = device or self._execution_device
        dtype = dtype or self.text_encoder.dtype

        prompt = [prompt] if isinstance(prompt, str) else prompt

        template = self.prompt_template_encode
        drop_idx = self.prompt_template_encode_start_idx
        txt = [template.format(e) for e in prompt]

        # Route to text_encoder's device, NOT execution_device
        model_inputs = self.processor(
            text=txt, images=image, padding=True, return_tensors="pt"
        ).to(te_device)

        outputs = self.text_encoder(
            input_ids=model_inputs.input_ids,
            attention_mask=model_inputs.attention_mask,
            pixel_values=model_inputs.pixel_values,
            image_grid_thw=model_inputs.image_grid_thw,
            output_hidden_states=True,
        )

        hidden_states = outputs.hidden_states[-1]
        split_hidden_states = self._extract_masked_hidden(
            hidden_states, model_inputs.attention_mask)
        split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
        attn_mask_list = [
            torch.ones(e.size(0), dtype=torch.long, device=e.device)
            for e in split_hidden_states
        ]
        max_seq_len = max([e.size(0) for e in split_hidden_states])
        prompt_embeds = torch.stack([
            torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))])
            for u in split_hidden_states
        ])
        encoder_attention_mask = torch.stack([
            torch.cat([u, u.new_zeros(max_seq_len - u.size(0))])
            for u in attn_mask_list
        ])

        # Move outputs to execution_device for transformer
        prompt_embeds = prompt_embeds.to(dtype=dtype, device=execution_device)
        encoder_attention_mask = encoder_attention_mask.to(device=execution_device)

        return prompt_embeds, encoder_attention_mask

    def _load_pinned_multi_gpu(self, model_id: str, main_idx: int, encoder_idx: int) -> bool:
        """Load with pinned multi-GPU: transformer+VAE on main, text_encoder on secondary.

        Benchmarked at 169.9s (4.25s/step) - 1.36x faster than cpu_offload baseline.
        """
        from diffusers import QwenImageEditPipeline
        from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
        from diffusers.models.transformers.transformer_qwenimage import QwenImageTransformer2DModel
        from diffusers.models.autoencoders.autoencoder_kl_qwenimage import AutoencoderKLQwenImage
        from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor

        main_dev = f"cuda:{main_idx}"
        enc_dev = f"cuda:{encoder_idx}"

        logger.info(f"Loading pinned 2-GPU: transformer+VAE β†’ {main_dev}, text_encoder β†’ {enc_dev}")

        scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
            model_id, subfolder="scheduler")
        tokenizer = Qwen2Tokenizer.from_pretrained(
            model_id, subfolder="tokenizer")
        processor = Qwen2VLProcessor.from_pretrained(
            model_id, subfolder="processor")

        text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            model_id, subfolder="text_encoder", torch_dtype=self.dtype,
        ).to(enc_dev)
        logger.info(f"  text_encoder loaded on {enc_dev}")

        transformer = QwenImageTransformer2DModel.from_pretrained(
            model_id, subfolder="transformer", torch_dtype=self.dtype,
        ).to(main_dev)
        logger.info(f"  transformer loaded on {main_dev}")

        vae = AutoencoderKLQwenImage.from_pretrained(
            model_id, subfolder="vae", torch_dtype=self.dtype,
        ).to(main_dev)
        vae.enable_tiling()
        logger.info(f"  VAE loaded on {main_dev}")

        self.pipe = QwenImageEditPipeline(
            scheduler=scheduler, vae=vae, text_encoder=text_encoder,
            tokenizer=tokenizer, processor=processor, transformer=transformer,
        )

        # Fix 1: Override _execution_device to force main GPU
        # Without this, pipeline returns text_encoder's device, causing VAE
        # to receive tensors on the wrong GPU
        main_device = torch.device(main_dev)
        QwenImageEditPipeline._execution_device = property(lambda self: main_device)

        # Fix 2: Monkey-patch prompt encoding to route inputs to text_encoder's device
        self.pipe._get_qwen_prompt_embeds = types.MethodType(
            self._patched_get_qwen_prompt_embeds, self.pipe)

        self._loading_strategy = "pinned_multi_gpu"
        logger.info(f"Pinned 2-GPU pipeline ready")
        return True

    def load_model(self) -> bool:
        """Load the model with the best available strategy.

        Strategy priority (GPU strategies always attempted first):
        1. Pinned 2-GPU: transformer+VAE on large GPU, text_encoder on secondary
           (requires main GPU >= 42GB, secondary >= 17GB)
           Benchmark: 169.9s (4.25s/step) - 1.36x
        2. Balanced single-GPU: device_map="balanced" on single large GPU
           (requires GPU >= 45GB)
           Benchmark: 184.4s (4.61s/step) - 1.25x
        3. CPU offload: model components shuttle between CPU and GPU
           (requires enable_cpu_offload=True)
           Benchmark: 231.5s (5.79s/step) - 1.0x baseline
        4. Direct load: entire model on single GPU (may OOM)
        """
        if self._loaded:
            return True

        try:
            from diffusers import QwenImageEditPipeline

            model_id = self.MODELS.get(self.model_variant, self.MODELS["full"])
            main_idx = self._parse_device_idx(self.device)
            main_vram = self._get_gpu_vram_gb(main_idx)
            logger.info(f"Loading Qwen-Image-Edit ({self.model_variant}) from {model_id}...")
            logger.info(f"Main GPU cuda:{main_idx}: {main_vram:.1f} GB VRAM")

            start_time = time.time()
            loaded = False

            # Strategy 1: Pinned 2-GPU (always try first if main GPU is large enough)
            if not loaded and main_vram >= self.MAIN_GPU_MIN_VRAM_GB:
                encoder_idx = None
                if self.encoder_device:
                    encoder_idx = self._parse_device_idx(self.encoder_device)
                    enc_vram = self._get_gpu_vram_gb(encoder_idx)
                    if enc_vram < self.ENCODER_GPU_MIN_VRAM_GB:
                        logger.warning(
                            f"Specified encoder device cuda:{encoder_idx} has "
                            f"{enc_vram:.1f} GB, need {self.ENCODER_GPU_MIN_VRAM_GB} GB. "
                            f"Falling back to auto-detect.")
                        encoder_idx = None

                if encoder_idx is None:
                    encoder_idx = self._find_encoder_gpu(main_idx)

                if encoder_idx is not None:
                    self._load_pinned_multi_gpu(model_id, main_idx, encoder_idx)
                    loaded = True

            # Strategy 2: Balanced single-GPU
            if not loaded and main_vram >= self.BALANCED_VRAM_THRESHOLD_GB:
                max_mem_gb = int(main_vram - 4)
                self.pipe = QwenImageEditPipeline.from_pretrained(
                    model_id, torch_dtype=self.dtype,
                    device_map="balanced",
                    max_memory={main_idx: f"{max_mem_gb}GiB"},
                )
                self._loading_strategy = "balanced_single"
                logger.info(f"Loaded with device_map='balanced', max_memory={max_mem_gb}GiB")
                loaded = True

            # Strategy 3: CPU offload (only if allowed)
            if not loaded and self.enable_cpu_offload:
                self.pipe = QwenImageEditPipeline.from_pretrained(
                    model_id, torch_dtype=self.dtype)
                self.pipe.enable_model_cpu_offload()
                self._loading_strategy = "cpu_offload"
                logger.info("Loaded with enable_model_cpu_offload()")
                loaded = True

            # Strategy 4: Direct load (last resort, may OOM)
            if not loaded:
                self.pipe = QwenImageEditPipeline.from_pretrained(
                    model_id, torch_dtype=self.dtype)
                self.pipe.to(self.device)
                self._loading_strategy = "direct"
                logger.info(f"Loaded directly to {self.device}")

            self.pipe.set_progress_bar_config(disable=None)

            load_time = time.time() - start_time
            logger.info(f"Qwen-Image-Edit loaded in {load_time:.1f}s (strategy: {self._loading_strategy})")

            self._loaded = True
            return True

        except Exception as e:
            logger.error(f"Failed to load Qwen-Image-Edit: {e}", exc_info=True)
            return False

    def unload_model(self):
        """Unload model from memory."""
        if self.pipe is not None:
            del self.pipe
            self.pipe = None
            self._loaded = False

            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            logger.info("Qwen-Image-Edit-2511 unloaded")

    def generate(
        self,
        request: GenerationRequest,
        num_inference_steps: int = 40,
        guidance_scale: float = 1.0,
        true_cfg_scale: float = 4.0
    ) -> GenerationResult:
        """
        Generate/edit image using Qwen-Image-Edit-2511.

        Args:
            request: GenerationRequest object
            num_inference_steps: Number of denoising steps
            guidance_scale: Classifier-free guidance scale
            true_cfg_scale: True CFG scale for better control

        Returns:
            GenerationResult object
        """
        if not self._loaded:
            if not self.load_model():
                return GenerationResult.error_result("Failed to load Qwen-Image-Edit-2511 model")

        try:
            start_time = time.time()

            # Target dimensions for post-processing crop+resize
            target_w, target_h = self._get_dimensions(request.aspect_ratio)

            # Build input images list
            input_images = []
            if request.has_input_images:
                input_images = [img for img in request.input_images if img is not None]

            # Always generate at the proven native resolution (1104x1472).
            # Other resolutions cause VAE tiling artifacts.
            native_w, native_h = self.NATIVE_RESOLUTION
            gen_kwargs = {
                "prompt": request.prompt,
                "negative_prompt": request.negative_prompt or " ",
                "height": native_h,
                "width": native_w,
                "num_inference_steps": num_inference_steps,
                "guidance_scale": guidance_scale,
                "true_cfg_scale": true_cfg_scale,
                "num_images_per_prompt": 1,
                "generator": torch.manual_seed(42),
            }

            # Qwen-Image-Edit is a single-image editor: use only the first image.
            # The character service passes multiple references (face, body, costume)
            # but the costume/view info is already encoded in the text prompt.
            if input_images:
                gen_kwargs["image"] = input_images[0]

            logger.info(f"Generating with Qwen-Image-Edit: {request.prompt[:80]}...")
            logger.info(f"Input images: {len(input_images)} (using first)")
            logger.info(f"Native: {native_w}x{native_h}, target: {target_w}x{target_h}")

            # Generate at proven native resolution
            with torch.inference_mode():
                output = self.pipe(**gen_kwargs)
                image = output.images[0]

            generation_time = time.time() - start_time
            logger.info(f"Generated in {generation_time:.2f}s: {image.size}")

            # Crop + resize to requested aspect ratio
            image = self._crop_and_resize(image, target_w, target_h)
            logger.info(f"Post-processed to: {image.size}")

            return GenerationResult.success_result(
                image=image,
                message=f"Generated with Qwen-Image-Edit in {generation_time:.2f}s",
                generation_time=generation_time
            )

        except Exception as e:
            logger.error(f"Qwen-Image-Edit generation failed: {e}", exc_info=True)
            return GenerationResult.error_result(f"Qwen-Image-Edit error: {str(e)}")

    @staticmethod
    def _crop_and_resize(image: Image.Image, target_w: int, target_h: int) -> Image.Image:
        """Crop image to target aspect ratio, then resize to target dimensions.

        Centers the crop on the image so equal amounts are trimmed from
        each side.  Uses LANCZOS for high-quality downscaling.
        """
        src_w, src_h = image.size
        target_ratio = target_w / target_h
        src_ratio = src_w / src_h

        if abs(target_ratio - src_ratio) < 0.01:
            # Already the right aspect ratio, just resize
            return image.resize((target_w, target_h), Image.LANCZOS)

        if target_ratio < src_ratio:
            # Target is taller/narrower than source β†’ crop sides
            crop_w = int(src_h * target_ratio)
            offset = (src_w - crop_w) // 2
            image = image.crop((offset, 0, offset + crop_w, src_h))
        else:
            # Target is wider than source β†’ crop top/bottom
            crop_h = int(src_w / target_ratio)
            offset = (src_h - crop_h) // 2
            image = image.crop((0, offset, src_w, offset + crop_h))

        return image.resize((target_w, target_h), Image.LANCZOS)

    def _get_dimensions(self, aspect_ratio: str) -> tuple:
        """Get pixel dimensions for aspect ratio."""
        ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio
        return self.ASPECT_RATIOS.get(ratio, (1024, 1024))

    def is_healthy(self) -> bool:
        """Check if model is loaded and ready."""
        return self._loaded and self.pipe is not None

    @classmethod
    def get_dimensions(cls, aspect_ratio: str) -> tuple:
        """Get pixel dimensions for aspect ratio."""
        ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio
        return cls.ASPECT_RATIOS.get(ratio, (1024, 1024))