File size: 10,288 Bytes
da23dfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
"""
LongCat-Image-Edit Client
=========================

Client for Meituan's LongCat-Image-Edit model.
Supports instruction-following image editing with bilingual (Chinese-English) support.

This is a SOTA open-source image editing model with excellent:
- Global editing, local editing, text modification
- Reference-guided editing
- Consistency preservation (layout, texture, color tone, identity)
- Multi-turn editing capabilities
"""

import logging
import time
from typing import Optional, List
from PIL import Image

import torch

from .models import GenerationRequest, GenerationResult


logger = logging.getLogger(__name__)


class LongCatEditClient:
    """
    Client for LongCat-Image-Edit model from Meituan.

    Features:
    - Instruction-following image editing
    - Bilingual support (Chinese-English)
    - Excellent consistency preservation
    - Multi-turn editing

    Requires ~18GB VRAM with CPU offload.
    """

    MODEL_ID = "meituan-longcat/LongCat-Image-Edit"

    # Aspect ratio to dimensions mapping
    ASPECT_RATIOS = {
        "1:1": (1024, 1024),
        "16:9": (1344, 768),
        "9:16": (768, 1344),
        "21:9": (1536, 640),    # Cinematic ultra-wide
        "3:2": (1248, 832),
        "2:3": (832, 1248),
        "3:4": (896, 1152),
        "4:3": (1152, 896),
        "4:5": (896, 1120),
        "5:4": (1120, 896),
    }

    # Default generation settings
    DEFAULT_STEPS = 50
    DEFAULT_GUIDANCE = 4.5

    def __init__(
        self,
        device: str = "cuda",
        dtype: torch.dtype = torch.bfloat16,
        enable_cpu_offload: bool = True,
    ):
        """
        Initialize LongCat-Image-Edit client.

        Args:
            device: Device to use (cuda or cpu)
            dtype: Data type for model weights (bfloat16 recommended)
            enable_cpu_offload: Enable CPU offload to save VRAM (~18GB required)
        """
        self.device = device
        self.dtype = dtype
        self.enable_cpu_offload = enable_cpu_offload
        self.pipe = None
        self._loaded = False

        logger.info(f"LongCatEditClient initialized (cpu_offload: {enable_cpu_offload})")

    def load_model(self) -> bool:
        """Load the model into memory."""
        if self._loaded:
            return True

        try:
            logger.info(f"Loading LongCat-Image-Edit from {self.MODEL_ID}...")

            start_time = time.time()

            # Import LongCat pipeline
            # Requires latest diffusers: pip install git+https://github.com/huggingface/diffusers
            from diffusers import LongCatImageEditPipeline

            self.pipe = LongCatImageEditPipeline.from_pretrained(
                self.MODEL_ID,
                torch_dtype=self.dtype,
            )

            # Apply memory optimization
            if self.enable_cpu_offload:
                self.pipe.enable_model_cpu_offload()
                logger.info("CPU offload enabled (~18GB VRAM)")
            else:
                self.pipe.to(self.device, self.dtype)
                logger.info(f"Model moved to {self.device} (high VRAM mode)")

            load_time = time.time() - start_time
            logger.info(f"LongCat-Image-Edit loaded in {load_time:.1f}s")

            self._loaded = True
            return True

        except Exception as e:
            logger.error(f"Failed to load LongCat-Image-Edit: {e}", exc_info=True)
            return False

    def unload_model(self):
        """Unload model from memory."""
        if self.pipe is not None:
            del self.pipe
            self.pipe = None

        self._loaded = False

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        logger.info("LongCat-Image-Edit unloaded")

    def generate(
        self,
        request: GenerationRequest,
        num_inference_steps: int = None,
        guidance_scale: float = None
    ) -> GenerationResult:
        """
        Edit image using LongCat-Image-Edit.

        Args:
            request: GenerationRequest object with:
                - prompt: The editing instruction (e.g., "Change the background to a forest")
                - input_images: List with the source image to edit
                - aspect_ratio: Output aspect ratio
            num_inference_steps: Number of denoising steps (default: 50)
            guidance_scale: Classifier-free guidance scale (default: 4.5)

        Returns:
            GenerationResult object
        """
        if not self._loaded:
            if not self.load_model():
                return GenerationResult.error_result("Failed to load LongCat-Image-Edit model")

        # Use defaults if not specified
        if num_inference_steps is None:
            num_inference_steps = self.DEFAULT_STEPS
        if guidance_scale is None:
            guidance_scale = self.DEFAULT_GUIDANCE

        try:
            start_time = time.time()

            # Get input image
            if not request.has_input_images:
                return GenerationResult.error_result("LongCat-Image-Edit requires an input image to edit")

            input_image = None
            for img in request.input_images:
                if img is not None:
                    input_image = img
                    break

            if input_image is None:
                return GenerationResult.error_result("No valid input image provided")

            # Get dimensions from aspect ratio
            width, height = self._get_dimensions(request.aspect_ratio)

            # Resize input image to target dimensions
            input_image = input_image.convert('RGB')
            input_image = input_image.resize((width, height), Image.Resampling.LANCZOS)

            logger.info(f"Editing with LongCat: steps={num_inference_steps}, guidance={guidance_scale}")
            logger.info(f"Edit instruction: {request.prompt[:100]}...")

            # Build generation kwargs
            gen_kwargs = {
                "image": input_image,
                "prompt": request.prompt,
                "negative_prompt": request.negative_prompt or "",
                "guidance_scale": guidance_scale,
                "num_inference_steps": num_inference_steps,
                "num_images_per_prompt": 1,
                "generator": torch.Generator("cpu").manual_seed(42),
            }

            # Generate
            with torch.inference_mode():
                output = self.pipe(**gen_kwargs)
                image = output.images[0]

            generation_time = time.time() - start_time
            logger.info(f"Edited in {generation_time:.2f}s: {image.size}")

            return GenerationResult.success_result(
                image=image,
                message=f"Edited with LongCat-Image-Edit in {generation_time:.2f}s",
                generation_time=generation_time
            )

        except Exception as e:
            logger.error(f"LongCat-Image-Edit generation failed: {e}", exc_info=True)
            return GenerationResult.error_result(f"LongCat-Image-Edit error: {str(e)}")

    def edit_with_instruction(
        self,
        source_image: Image.Image,
        instruction: str,
        negative_prompt: str = "",
        num_inference_steps: int = None,
        guidance_scale: float = None,
        seed: int = 42
    ) -> GenerationResult:
        """
        Simplified method for instruction-based image editing.

        Args:
            source_image: The image to edit
            instruction: Natural language editing instruction
                Examples:
                - "Change the background to a sunset beach"
                - "Make the person wear a red dress"
                - "Add snow to the scene"
                - "Change the cat to a dog"
            negative_prompt: What to avoid in the output
            num_inference_steps: Denoising steps (default: 50)
            guidance_scale: CFG scale (default: 4.5)
            seed: Random seed for reproducibility

        Returns:
            GenerationResult with the edited image
        """
        if not self._loaded:
            if not self.load_model():
                return GenerationResult.error_result("Failed to load LongCat-Image-Edit model")

        if num_inference_steps is None:
            num_inference_steps = self.DEFAULT_STEPS
        if guidance_scale is None:
            guidance_scale = self.DEFAULT_GUIDANCE

        try:
            start_time = time.time()

            # Ensure RGB
            source_image = source_image.convert('RGB')

            logger.info(f"Editing image with instruction: {instruction[:100]}...")

            with torch.inference_mode():
                output = self.pipe(
                    image=source_image,
                    prompt=instruction,
                    negative_prompt=negative_prompt,
                    guidance_scale=guidance_scale,
                    num_inference_steps=num_inference_steps,
                    num_images_per_prompt=1,
                    generator=torch.Generator("cpu").manual_seed(seed),
                )
                image = output.images[0]

            generation_time = time.time() - start_time
            logger.info(f"Edit completed in {generation_time:.2f}s")

            return GenerationResult.success_result(
                image=image,
                message=f"Edited with instruction in {generation_time:.2f}s",
                generation_time=generation_time
            )

        except Exception as e:
            logger.error(f"Instruction-based edit failed: {e}", exc_info=True)
            return GenerationResult.error_result(f"Edit error: {str(e)}")

    def _get_dimensions(self, aspect_ratio: str) -> tuple:
        """Get pixel dimensions for aspect ratio."""
        ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio
        return self.ASPECT_RATIOS.get(ratio, (1024, 1024))

    def is_healthy(self) -> bool:
        """Check if model is loaded and ready."""
        return self._loaded and self.pipe is not None

    @classmethod
    def get_dimensions(cls, aspect_ratio: str) -> tuple:
        """Get pixel dimensions for aspect ratio."""
        ratio = aspect_ratio.split()[0] if " " in aspect_ratio else aspect_ratio
        return cls.ASPECT_RATIOS.get(ratio, (1024, 1024))