File size: 9,990 Bytes
b701455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
"""SD1.5 model adapter for LightDiffusion-Next.

Provides a clean interface to the SD1.5 model that inherits from
AbstractModel and wraps the existing infrastructure.
"""

import logging
from typing import TYPE_CHECKING, Any, Callable, Optional

import torch

from src.Core.AbstractModel import AbstractModel, ModelCapabilities

if TYPE_CHECKING:
    from src.Core.Context import Context


class SD15Model(AbstractModel):
    """SD1.5 model implementation.
    
    Wraps the existing SD15 model loading and inference code
    with the clean AbstractModel interface.
    """
    
    def __init__(self, model_path: str = None):
        """Initialize the SD15 model adapter.
        
        Args:
            model_path: Path to the model checkpoint (safetensors/pt)
        """
        super().__init__(model_path)
        self._clip_skip = -2
    
    def _create_capabilities(self) -> ModelCapabilities:
        """Create capabilities for SD1.5 models."""
        return ModelCapabilities(
            min_resolution=256,
            max_resolution=2048,
            preferred_resolution=512,
            requires_resolution_multiple=64,
            supports_hires_fix=True,
            supports_img2img=True,
            supports_inpainting=True,
            supports_controlnet=True,
            supports_stable_fast=True,
            supports_deepcache=True,
            supports_tome=True,
            uses_dual_clip=False,
            requires_size_conditioning=False,
        )
    
    def load(self, model_path: str = None) -> "SD15Model":
        """Load the SD1.5 model from disk.
        
        Args:
            model_path: Optional override for the model path
            
        Returns:
            Self for method chaining
        """
        logger = logging.getLogger(__name__)
        
        path = model_path or self.model_path
        if path is None:
            # Use default checkpoint
            path = "./include/checkpoints/DreamShaper_8_pruned.safetensors"
        
        # Guard: Don't reload if already loaded with same path
        if self._loaded and self.model_path == path:
            logger.info(f"SD15Model: Already loaded {path}, skipping redundant load")
            return self
            
        self.model_path = path
        
        try:
            from src.FileManaging import Loader
            
            loader = Loader.CheckpointLoaderSimple()
            result = loader.load_checkpoint(ckpt_name=path)
            
            self.model = result[0]
            self.clip = result[1]
            self.vae = result[2]
            self._loaded = True
            
            logger.info(f"SD15Model: loaded {path}")
            
        except Exception as e:
            logger.exception(f"SD15Model: failed to load {path}: {e}")
            raise
        
        return self
    
    def get_model_object(self, name: str) -> Any:
        """Get an attribute from the underlying model."""
        if self.model:
            return self.model.get_model_object(name)
        return None

    def encode_prompt(
        self,
        prompt: str | list[str],
        negative_prompt: str | list[str] = "",
        clip_skip: int = None,
    ) -> tuple[Any, Any]:
        """Encode text prompts into conditioning tensors.
        
        Args:
            prompt: Positive prompt(s) to encode
            negative_prompt: Negative prompt(s) to encode
            clip_skip: Number of CLIP layers to skip (default: -2)
            
        Returns:
            Tuple of (positive_conditioning, negative_conditioning)
        """
        if not self._loaded:
            raise RuntimeError("Model must be loaded before encoding prompts")
        
        clip_skip = clip_skip if clip_skip is not None else self._clip_skip
        
        try:
            from src.clip import Clip
            
            # Apply CLIP skip
            clip_layer = Clip.CLIPSetLastLayer()
            processed_clip = clip_layer.set_last_layer(
                stop_at_clip_layer=clip_skip,
                clip=self.clip,
            )[0]
            
            # Encode prompts
            encoder = Clip.CLIPTextEncode()
            
            positive = encoder.encode(
                text=prompt,
                clip=processed_clip,
            )[0]
            
            negative = encoder.encode(
                text=negative_prompt,
                clip=processed_clip,
            )[0]
            
            return positive, negative
            
        except Exception as e:
            logging.getLogger(__name__).exception(f"Prompt encoding failed: {e}")
            raise
    
    def generate(
        self,
        ctx: "Context",
        positive: Any,
        negative: Any,
        latent_image: Optional[Any] = None,
        start_step: Optional[int] = None,
        last_step: Optional[int] = None,
        disable_noise: bool = False,
        callback: Optional[Callable] = None,
    ) -> dict:
        """Generate latents using the sampler.
        
        Args:
            ctx: Pipeline context with generation parameters
            positive: Positive conditioning
            negative: Negative conditioning
            latent_image: Optional existing latent to continue from
            start_step: Optional step to start sampling from
            last_step: Optional step to stop sampling at
            
        Returns:
            Dictionary with 'samples' key containing generated latents
        """
        if not self._loaded:
            raise RuntimeError("Model must be loaded before generating")
        
        try:
            from src.sample import sampling
            from src.Utilities import Latent
            from src.hidiffusion import msw_msa_attention
            
            # Use provided latent or create empty one
            if latent_image is not None:
                latent = latent_image
            else:
                # Create empty latent
                latent_gen = Latent.EmptyLatentImage()
                latent = latent_gen.generate(
                    width=ctx.generation.width,
                    height=ctx.generation.height,
                    batch_size=ctx.generation.batch,
                )[0]
                
                # Add seeds for deterministic noise
                latent["seeds"] = ctx.seeds[:ctx.generation.batch] if ctx.seeds else [ctx.seed]
            
            # Apply HiDiffusion optimization only for very high resolutions
            if ctx.generation.width > 2048 or ctx.generation.height > 2048:
                try:
                    # Clone model before patching
                    patch_model = self.model.clone()
                    hidiff = msw_msa_attention.ApplyMSWMSAAttentionSimple()
                    optimized_model = hidiff.go(model_type="sd15", model=patch_model)[0]
                except Exception:
                    optimized_model = self.model
            else:
                optimized_model = self.model
            
            # Run sampling
            ksampler = sampling.KSampler()
            result = ksampler.sample(
                seed=ctx.seed,
                steps=ctx.sampling.steps,
                cfg=ctx.sampling.cfg,
                sampler_name=ctx.sampling.sampler,
                scheduler=ctx.sampling.scheduler,
                denoise=ctx.sampling.denoise,
                pipeline=True,
                model=optimized_model,
                positive=positive,
                negative=negative,
                latent_image=latent,
                start_step=start_step,
                last_step=last_step,
                disable_noise=disable_noise,
                callback=callback or ctx.callback,
                enable_multiscale=ctx.sampling.enable_multiscale,
                multiscale_factor=ctx.sampling.multiscale_factor,
                multiscale_fullres_start=ctx.sampling.multiscale_fullres_start,
                multiscale_fullres_end=ctx.sampling.multiscale_fullres_end,
                multiscale_intermittent_fullres=ctx.sampling.multiscale_intermittent_fullres,
                cfg_free_enabled=ctx.sampling.cfg_free_enabled,
                cfg_free_start_percent=ctx.sampling.cfg_free_start_percent,
                batched_cfg=ctx.sampling.batched_cfg,
                dynamic_cfg_rescaling=ctx.sampling.dynamic_cfg_rescaling,
                dynamic_cfg_method=ctx.sampling.dynamic_cfg_method,
                dynamic_cfg_percentile=ctx.sampling.dynamic_cfg_percentile,
                dynamic_cfg_target_scale=ctx.sampling.dynamic_cfg_target_scale,
                adaptive_noise_enabled=ctx.sampling.adaptive_noise_enabled,
                adaptive_noise_method=ctx.sampling.adaptive_noise_method,
            )
            
            return result[0]
            
        except Exception as e:
            logging.getLogger(__name__).exception(f"Generation failed: {e}")
            raise
    
    def decode(self, latents: torch.Tensor) -> torch.Tensor:
        """Decode latents to pixel space.
        
        Args:
            latents: Latent tensor or dict with 'samples' key
            
        Returns:
            Decoded image tensor in [0, 1] range
        """
        if not self._loaded:
            raise RuntimeError("Model must be loaded before decoding")
        
        try:
            from src.AutoEncoders import VariationalAE
            
            decoder = VariationalAE.VAEDecode()
            
            # Handle both raw tensor and dict input
            if isinstance(latents, dict):
                samples = latents
            else:
                samples = {"samples": latents}
            
            result = decoder.decode(
                samples=samples,
                vae=self.vae,
            )
            
            return result[0]
            
        except Exception as e:
            logging.getLogger(__name__).exception(f"Decoding failed: {e}")
            raise