File size: 35,301 Bytes
65bd55b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
"""
LLaVA Architecture with Integrated Mask Prediction for Image Editing

This module contains:
- LlavaMetaModel: Base model with vision tower, diffusion components, and mask prediction
- LlavaMetaForCausalLM: Mixin for causal LM with multimodal support
- MaskPredictor: Predicts edit regions from LLM hidden states
- BF16SafeLayerNorm: Numerically stable LayerNorm for BF16 training

Key Innovation: MaskPredictor enables mask-free inference by learning to predict
edit regions from LLM understanding, eliminating the need for external segmentation.
"""

from abc import ABC, abstractmethod
from typing import Optional, Tuple, List
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers import FlowMatchEulerDiscreteScheduler, DPMSolverMultistepScheduler
from diffusers.models.normalization import RMSNorm

from .mobile_block import MobileConditioningProjector
from .multimodal_llava_encoder.builder import build_vision_tower
from .multimodal_llava_projector.builder import build_vision_projector
from .multimodal_projector.builder import build_down_projector
from .multimodal_decoder.builder import build_vae, build_sana
from blip3o.constants import (
    DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, 
    DEFAULT_IMAGE_PATCH_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX
)


# ============================================================
# BF16-Safe LayerNorm
# ============================================================

class BF16SafeLayerNorm(nn.Module):
    """
    LayerNorm that's safe for BF16 training.
    Performs normalization in float32 for numerical stability.
    """
    def __init__(self, hidden_size: int, eps: float = 1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.bias = nn.Parameter(torch.zeros(hidden_size))
        self.eps = eps
        self.hidden_size = hidden_size

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        input_dtype = x.dtype
        x = x.float()
        mean = x.mean(-1, keepdim=True)
        variance = (x - mean).pow(2).mean(-1, keepdim=True)
        x = (x - mean) / torch.sqrt(variance + self.eps)
        x = self.weight.float() * x + self.bias.float()
        return x.to(input_dtype)

    def reset_parameters(self):
        nn.init.ones_(self.weight)
        nn.init.zeros_(self.bias)


# ============================================================
# Mask Predictor - Enables Mask-Free Inference
# ============================================================

class MaskPredictor(nn.Module):
    """
    Predicts edit mask from LLM hidden states.
    
    This is the KEY component that enables mask-free inference.
    During training: Supervised by SAM-generated masks
    During inference: Predicts mask directly from LLM understanding
    
    Architecture:
    1. Attention pooling to focus on instruction-relevant tokens
    2. Project to spatial features
    3. Decode to mask
    """

    def __init__(self, hidden_size: int, latent_channels: int, latent_size: int = 32):
        super().__init__()
        self.latent_size = latent_size
        self.hidden_size = hidden_size

        # Attention pooling to focus on instruction-relevant tokens
        self.attention_pool = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 4),
            nn.Tanh(),
            nn.Linear(hidden_size // 4, 1),
        )

        # Layer norm for stability
        self.input_norm = BF16SafeLayerNorm(hidden_size)

        # Project pooled features to spatial representation
        intermediate_size = hidden_size // 2
        spatial_dim = latent_size * latent_size * 64

        self.hidden_proj = nn.Sequential(
            nn.Linear(hidden_size, intermediate_size),
            nn.LayerNorm(intermediate_size),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(intermediate_size, intermediate_size),
            nn.LayerNorm(intermediate_size),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(intermediate_size, spatial_dim),
        )

        # Decode to mask with sufficient capacity
        self.mask_decoder = nn.Sequential(
            nn.Conv2d(64, 256, 3, padding=1),
            nn.GroupNorm(32, 256),
            nn.GELU(),
            nn.Conv2d(256, 128, 3, padding=1),
            nn.GroupNorm(16, 128),
            nn.GELU(),
            nn.Conv2d(128, 64, 3, padding=1),
            nn.GroupNorm(8, 64),
            nn.GELU(),
            nn.Conv2d(64, 1, 1),
        )

        self._init_weights()

    def _init_weights(self):
        """Initialize weights for stable training."""
        # Initialize attention pooling
        for module in self.attention_pool:
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight, gain=0.1)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

        # Initialize LayerNorm
        self.input_norm.reset_parameters()

        # Initialize projection layers
        for module in self.hidden_proj:
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight, gain=0.1)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.LayerNorm):
                nn.init.ones_(module.weight)
                nn.init.zeros_(module.bias)

        # Initialize conv layers
        for module in self.mask_decoder:
            if isinstance(module, nn.Conv2d):
                nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.GroupNorm):
                nn.init.ones_(module.weight)
                nn.init.zeros_(module.bias)

        # Initialize final layer with small weights for stable start
        for module in reversed(list(self.mask_decoder)):
            if isinstance(module, nn.Conv2d):
                nn.init.normal_(module.weight, mean=0.0, std=0.01)
                nn.init.zeros_(module.bias)
                break

    def forward(self, hidden_states: torch.Tensor, return_logits: bool = False) -> torch.Tensor:
        """
        Predict edit mask from LLM hidden states.

        Args:
            hidden_states: [B, seq_len, hidden_size] from LLM
            return_logits: If True, return logits instead of probabilities

        Returns:
            mask: [B, 1, H, W] predicted edit mask
        """
        batch_size = hidden_states.shape[0]
        device = hidden_states.device

        # Check for NaN/Inf in input
        if torch.isnan(hidden_states).any() or torch.isinf(hidden_states).any():
            if return_logits:
                return torch.zeros(batch_size, 1, self.latent_size, self.latent_size,
                                   device=device, dtype=torch.float32, requires_grad=True)
            return torch.full((batch_size, 1, self.latent_size, self.latent_size), 0.5,
                              device=device, dtype=torch.float32, requires_grad=True)

        # Normalize hidden states
        hidden_states = self.input_norm(hidden_states)

        # Get dtype from first layer
        target_dtype = self.attention_pool[0].weight.dtype
        hidden_states = hidden_states.to(target_dtype)

        # Attention pooling: learn which tokens are important
        attn_weights = self.attention_pool(hidden_states)
        attn_weights = F.softmax(attn_weights, dim=1)

        # Weighted sum of hidden states
        pooled = (hidden_states * attn_weights).sum(dim=1)

        # Project to spatial features
        spatial = self.hidden_proj(pooled)
        spatial = spatial.view(-1, 64, self.latent_size, self.latent_size)

        # Decode to mask logits
        mask_logits = self.mask_decoder(spatial)

        if return_logits:
            return mask_logits.float()

        return torch.sigmoid(mask_logits.float())


# ============================================================
# Diffusion Connector
# ============================================================

class DiffusionConnector(nn.Module):
    def __init__(self, input_dim=896, hidden_dim=1024, output_dim=2304, eps=1e-5):
        super().__init__()
        self.linear1 = nn.Linear(input_dim, hidden_dim)
        self.act = nn.GELU(approximate="tanh")
        self.linear2 = nn.Linear(hidden_dim, output_dim)
        self.norm = RMSNorm(output_dim, eps=eps, elementwise_affine=True)

        nn.init.xavier_uniform_(self.linear1.weight)
        nn.init.zeros_(self.linear1.bias)
        nn.init.xavier_uniform_(self.linear2.weight)
        nn.init.zeros_(self.linear2.bias)
        with torch.no_grad():
            self.norm.weight.fill_(math.sqrt(5.5))

    def forward(self, x):
        x = self.linear1(x)
        x = self.act(x)
        x = self.linear2(x)
        x = self.norm(x)
        return x


# ============================================================
# Mask Encoder - Encodes masks for diffusion conditioning
# ============================================================

class MaskEncoder(nn.Module):
    """Encodes binary mask into latent conditioning for diffusion."""
    
    def __init__(self, latent_channels: int = 32):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1),
            nn.GroupNorm(8, 64),
            nn.SiLU(),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.GroupNorm(16, 128),
            nn.SiLU(),
            nn.Conv2d(128, latent_channels, 3, padding=1),
        )
        self._init_weights()

    def _init_weights(self):
        for module in self.encoder:
            if isinstance(module, nn.Conv2d):
                nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.GroupNorm):
                nn.init.ones_(module.weight)
                nn.init.zeros_(module.bias)
        # Last layer: small random weights, NOT zeros!
        nn.init.normal_(self.encoder[-1].weight, mean=0.0, std=0.01)
        nn.init.zeros_(self.encoder[-1].bias)

    def forward(self, mask: torch.Tensor) -> torch.Tensor:
        return self.encoder(mask.to(torch.bfloat16))


# ============================================================
# Spatial Reference Encoder
# ============================================================

class SpatialRefEncoder(nn.Module):
    """Encodes reference image latents for spatial conditioning."""
    
    def __init__(self, latent_channels: int = 32):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(latent_channels, 64, 3, padding=1),
            nn.GroupNorm(8, 64),
            nn.SiLU(),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.GroupNorm(16, 128),
            nn.SiLU(),
            nn.Conv2d(128, latent_channels, 3, padding=1),
        )
        self._init_weights()

    def _init_weights(self):
        for module in self.encoder:
            if isinstance(module, nn.Conv2d):
                nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.GroupNorm):
                nn.init.ones_(module.weight)
                nn.init.zeros_(module.bias)
        # Last layer: small random weights
        nn.init.normal_(self.encoder[-1].weight, mean=0.0, std=0.01)
        nn.init.zeros_(self.encoder[-1].bias)

    def forward(self, latents: torch.Tensor) -> torch.Tensor:
        return self.encoder(latents)


# ============================================================
# LlavaMetaModel - Base Model with All Components
# ============================================================

class LlavaMetaModel:
    """
    Base model containing:
    - Vision tower for image understanding
    - DiT for diffusion generation
    - VAE for latent encoding/decoding
    - MaskPredictor for edit region prediction
    - MaskEncoder for mask conditioning
    - Conditioning weights (mask_weight, spatial_weight)
    """

    def __init__(self, config):
        super(LlavaMetaModel, self).__init__(config)

        # Vision components
        if hasattr(config, "mm_vision_tower"):
            self.vision_tower = build_vision_tower(config, delay_load=True)
            self.mm_projector = build_vision_projector(config)

        # Diffusion components
        if hasattr(config, "diffusion_name_or_path"):
            self.dit = build_sana(config)
            self.vae = build_vae(config)
            
            # Diffusion connector
            self.diffusion_connector = MobileConditioningProjector(
                input_dim=896, 
                hidden_dim=512, 
                output_dim=2304, 
                num_layers=config.vlm_num_layers
            )

            # Noise scheduler
            if getattr(config, 'is_train', False):
                print("Using FlowMatchEulerDiscreteScheduler for training")
                self.noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
                    config.diffusion_name_or_path, subfolder="scheduler"
                )
            else:
                print("Using DPMSolverMultistepScheduler for inference")
                self.noise_scheduler = DPMSolverMultistepScheduler.from_pretrained(
                    config.diffusion_name_or_path, subfolder="scheduler"
                )

        # Get latent config
        latent_channels = getattr(config, 'latent_channels', 32)
        latent_size = getattr(config, 'latent_size', 32)

        # ============================================================
        # Mask Prediction Components (for image editing)
        # ============================================================
        
        # Mask predictor: predicts edit region from LLM hidden states
        if getattr(config, 'use_mask_predictor', True):
            self.mask_predictor = MaskPredictor(
                hidden_size=config.hidden_size,
                latent_channels=latent_channels,
                latent_size=latent_size
            )
        else:
            self.mask_predictor = None

        # Mask encoder: encodes mask for diffusion conditioning
        if getattr(config, 'use_mask_conditioning', True):
            self.mask_encoder = MaskEncoder(latent_channels=latent_channels)
            # CRITICAL: This is inside self (LlavaMetaModel), so it gets saved!
            self.mask_weight = nn.Parameter(torch.tensor(1.0))
        else:
            self.mask_encoder = None
            self.mask_weight = None

        # Spatial reference encoder
        if getattr(config, 'use_spatial_conditioning', False):
            self.spatial_ref_encoder = SpatialRefEncoder(latent_channels=latent_channels)
            self.spatial_weight = nn.Parameter(torch.tensor(0.5))
        else:
            self.spatial_ref_encoder = None
            self.spatial_weight = None

        # Operation embedding for edit type
        if getattr(config, 'use_operation_embedding', False):
            num_operations = getattr(config, 'num_operation_types', 10)
            self.operation_embedding = nn.Embedding(num_operations, latent_channels)
        else:
            self.operation_embedding = None

    def get_vision_tower(self):
        vision_tower = getattr(self, 'vision_tower', None)
        if type(vision_tower) is list:
            vision_tower = vision_tower[0]
        return vision_tower
    
    def get_sana(self):
        dit = getattr(self, 'dit', None)
        if type(dit) is list:
            dit = dit[0]
        if dit is not None:
            dit.to(self.device)
        return dit

    def get_sana_vae(self):
        vae = getattr(self, 'vae', None)
        if type(vae) is list:
            vae = vae[0]
        if vae is not None:
            vae.to(self.device)
        return vae

    def reinitialize_mask_components(self):
        """
        Reinitialize mask-related components.
        Call after loading pretrained weights if these components weren't in the original model.
        """
        print("Reinitializing mask components...")
        
        if self.mask_predictor is not None:
            self.mask_predictor._init_weights()
            print("  ✓ mask_predictor reinitialized")

        if self.mask_encoder is not None:
            self.mask_encoder._init_weights()
            print("  ✓ mask_encoder reinitialized")

        if self.spatial_ref_encoder is not None:
            self.spatial_ref_encoder._init_weights()
            print("  ✓ spatial_ref_encoder reinitialized")

        if self.mask_weight is not None:
            nn.init.ones_(self.mask_weight)
            print("  ✓ mask_weight set to 1.0")

        if self.spatial_weight is not None:
            nn.init.constant_(self.spatial_weight, 0.5)
            print("  ✓ spatial_weight set to 0.5")

        #if self.operation_embedding is not None:
        #    nn.init.normal_(self.operation_embedding.weight, mean=0.0, std=0.02)
        #    print("  ✓ operation_embedding reinitialized")

        print("Reinitialization complete!")

    def initialize_vision_modules(self, model_args, fsdp=None):
        """Initialize vision and diffusion modules."""
        mm_vision_select_layer = model_args.mm_vision_select_layer
        mm_vision_select_feature = model_args.mm_vision_select_feature
        mm_patch_merge_type = model_args.mm_patch_merge_type

        # Initialize DiT
        if self.get_sana() is None:
            dit = build_sana(model_args)
            if hasattr(model_args, "is_train"):
                if model_args.is_train:
                    print("FLOW MATCHING !!")
                    self.noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_args.diffusion_name_or_path, subfolder="scheduler")
                else:
                    print("DPM SOLVER !!")
                    self.noise_scheduler = DPMSolverMultistepScheduler.from_pretrained(model_args.diffusion_name_or_path, subfolder="scheduler")

            if fsdp is not None and len(fsdp) > 0:
                self.dit = [dit]
            else:
                self.dit = dit
        else:
            if fsdp is not None and len(fsdp) > 0:
                dit = self.dit[0]
            else:
                dit = self.dit
        for p in dit.parameters():
            p.requires_grad = False
                
        if self.get_sana_vae() is None:
            vae = build_vae(model_args)

            if fsdp is not None and len(fsdp) > 0:
                self.vae = [vae]
            else:
                self.vae = vae
        else:
            if fsdp is not None and len(fsdp) > 0:
                vae = self.vae[0]
            else:
                vae = self.vae
        for p in vae.parameters():
            p.requires_grad = False
    

        if self.get_vision_tower() is None:
            print("=" * 20, "Building vision tower", "=" * 20)
            vision_tower = build_vision_tower(model_args)
            

            if fsdp is not None and len(fsdp) > 0:
                self.vision_tower = [vision_tower]
            else:
                self.vision_tower = vision_tower
        else:
            if fsdp is not None and len(fsdp) > 0:
                vision_tower = self.vision_tower[0]
            else:
                vision_tower = self.vision_tower
            vision_tower.load_model()
        
        
        if getattr(self, 'diffusion_connector', None) is None:
            #self.diffusion_connector = DiffusionConnector(input_dim=self.config.hidden_size,hidden_dim=1024,output_dim=2304)
            self.diffusion_connector = MobileConditioningProjector(input_dim=896, hidden_dim=512, output_dim=2304, num_layers=model_args.vlm_num_layers)


            '''
            norm = RMSNorm(2304, eps=1e-5, elementwise_affine=True)
            with torch.no_grad():
                norm.weight.fill_(math.sqrt(5.5))
            self.diffusion_connector = nn.Sequential(
                nn.Linear(self.config.hidden_size, 1024),
                nn.GELU(approximate="tanh"),
                nn.Linear(1024, 2304),
                norm,
            )
            '''
        else:
            for p in self.diffusion_connector.parameters():
                p.requires_grad = True
        
        # freeze all parameters in dit except for caption_projection
        for name, param in self.dit.named_parameters():
            if "caption" in name:
                param.requires_grad = True
            else:
                param.requires_grad = False
        

        for p in dit.parameters():
            p.requires_grad = True
        for p in vision_tower.parameters():
            p.requires_grad = False
           # vision_tower().eval()

        self.config.use_mm_proj = True
        self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
        self.config.mm_vision_select_layer = mm_vision_select_layer
        self.config.mm_vision_select_feature = mm_vision_select_feature
        self.config.mm_patch_merge_type = mm_patch_merge_type
        self.config.diffusion_name_or_path = model_args.diffusion_name_or_path
        self.config.is_train = False #model_args.is_train

        if getattr(self, 'down_projector', None) is None:
            self.down_projector = build_down_projector(self.config)
        else:
            # In case it is frozen by LoRA
            for p in self.down_projector.parameters():
                p.requires_grad = True

        




def unpad_image(tensor, original_size):
    """
    Unpads a PyTorch tensor of a padded and resized image.

    Args:
    tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
    original_size (tuple): The original size of PIL image (width, height).

    Returns:
    torch.Tensor: The unpadded image tensor.
    """
    original_width, original_height = original_size
    current_height, current_width = tensor.shape[1:]

    original_aspect_ratio = original_width / original_height
    current_aspect_ratio = current_width / current_height

    if original_aspect_ratio > current_aspect_ratio:
        scale_factor = current_width / original_width
        new_height = int(original_height * scale_factor)
        padding = (current_height - new_height) // 2
        unpadded_tensor = tensor[:, padding:current_height - padding, :]
    else:
        scale_factor = current_height / original_height
        new_width = int(original_width * scale_factor)
        padding = (current_width - new_width) // 2
        unpadded_tensor = tensor[:, :, padding:current_width - padding]

    return unpadded_tensor


class LlavaMetaForCausalLM(ABC):

    @abstractmethod
    def get_model(self):
        pass

    def get_vision_tower(self):
        return self.get_model().get_vision_tower()
    
    def visual(self, pixel_values: torch.Tensor) -> torch.Tensor:
        image_features = self.get_model().get_vision_tower()(pixel_values)
        image_features = self.get_model().mm_projector(image_features)
        return image_features


    def get_mm_projector(self):
        return self.get_model().mm_projector


    def get_sigmas(self, timesteps, device, n_dim=4, dtype=torch.float32):
        sigmas = self.get_model().noise_scheduler.sigmas.to(device=device, dtype=dtype)
        schedule_timesteps = self.get_model().noise_scheduler.timesteps.to(device=device)
        timesteps = timesteps.to(device)
        step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]

        sigma = sigmas[step_indices].flatten()
        while len(sigma.shape) < n_dim:
            sigma = sigma.unsqueeze(-1)
        return sigma

    def mask_drop(self, latents, drop_prob=0.1):
        if drop_prob <= 0:
            return latents
        mask = torch.bernoulli(torch.zeros(latents.shape[0], device=latents.device, dtype=latents.dtype) + drop_prob)
        while len(mask.shape) < len(latents.shape):
            mask = mask.unsqueeze(-1)
        mask = 1 - mask  # need to flip 0 <-> 1
        return latents * mask

    # ============================================================
    # Convenience Properties for Mask Components
    # ============================================================
    
    @property
    def mask_predictor(self):
        return getattr(self.get_model(), 'mask_predictor', None)

    @property
    def mask_encoder(self):
        return getattr(self.get_model(), 'mask_encoder', None)

    @property
    def mask_weight(self):
        return getattr(self.get_model(), 'mask_weight', None)

    @property
    def spatial_weight(self):
        return getattr(self.get_model(), 'spatial_weight', None)

    @property
    def spatial_ref_encoder(self):
        return getattr(self.get_model(), 'spatial_ref_encoder', None)

    @property
    def operation_embedding(self):
        return getattr(self.get_model(), 'operation_embedding', None)

    # ============================================================
    # Multimodal Input Preparation
    # ============================================================

    def prepare_inputs_labels_for_multimodal(
        self, input_ids, position_ids, attention_mask, past_key_values, labels,
        gen_images=None, und_images=None
    ):
        if (gen_images is None and und_images is None) or input_ids.shape[1] == 1 or self.get_vision_tower() is None:
            return input_ids, position_ids, attention_mask, past_key_values, None, labels, None, None, None
        if gen_images is not None:
            vae = self.get_model().get_sana_vae()
            vae_device = vae.device
            prompt_image_embeds = vae.encode(gen_images.to(vae_device)).latent if gen_images is not None else None
            prompt_image_embeds = prompt_image_embeds * vae.config.scaling_factor if prompt_image_embeds is not None else None
            target_image_embeds = torch.clone(prompt_image_embeds).detach()
        else:
            target_image_embeds = None
            

        images = und_images
        if type(images) is list or images.ndim == 5:
            if type(images) is list:
                images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
            concat_images = torch.cat([image for image in images], dim=0)
            image_features = self.visual(concat_images)
            split_sizes = [image.shape[0] for image in images]
            image_features = torch.split(image_features, split_sizes, dim=0)
            image_features = [x.flatten(0, 1) for x in image_features]
        else:
            image_features = self.visual(images) # [B, image_tokens, hidden_size]


        # Let's just add dummy tensors if they do not exist,
        # it is a headache to deal with None all the time.
        # But it is not ideal, and if you have a better idea,
        # please open an issue / submit a PR, thanks.
        _labels = labels
        _position_ids = position_ids
        _attention_mask = attention_mask
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
        else:
            attention_mask = attention_mask.bool()
        if position_ids is None:
            position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
        if labels is None:
            labels = torch.full_like(input_ids, IGNORE_INDEX)

        # remove the padding using attention_mask -- FIXME
        input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
        labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]

        new_input_embeds = []
        new_labels = []
        new_input_ids = []
        cur_image_idx = 0
        for batch_idx, cur_input_ids in enumerate(input_ids):
            num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
            if num_images == 0:
                cur_image_features = image_features[cur_image_idx]
                cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
                cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
                new_input_embeds.append(cur_input_embeds)
                new_labels.append(labels[batch_idx])
                cur_image_idx += 1
                continue
            image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
            cur_input_ids_noim = []
            cur_labels = labels[batch_idx]
            cur_labels_noim = []
            for i in range(len(image_token_indices) - 1):
                cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
                cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
            split_sizes = [x.shape[0] for x in cur_labels_noim]
            cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
            cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
            cur_new_input_embeds = []
            cur_new_labels = []
            cur_new_input_ids = []

            for i in range(num_images + 1):
                cur_new_input_embeds.append(cur_input_embeds_no_im[i])
                cur_new_labels.append(cur_labels_noim[i])
                cur_new_input_ids.append(cur_input_ids_noim[i])
                if i < num_images:
                    if cur_image_idx < image_features.shape[0]:
                        cur_image_features = image_features[cur_image_idx]
                    else:
                        cur_image_features = image_features[-1]
                    cur_image_idx += 1
                    cur_new_input_embeds.append(cur_image_features)
                    cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
                    cur_new_input_ids.append(torch.full((cur_image_features.shape[0],), IMAGE_TOKEN_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))

            cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]

            cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
            cur_new_labels = torch.cat(cur_new_labels, dim=0)
            cur_new_input_ids = torch.cat(cur_new_input_ids, dim=0)

            new_input_embeds.append(cur_new_input_embeds)
            new_labels.append(cur_new_labels)
            new_input_ids.append(cur_new_input_ids)

        # Combine them
        max_len = max(x.shape[0] for x in new_input_embeds)
        batch_size = len(new_input_embeds)

        new_input_embeds_padded = []
        new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
        attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
        position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
        new_input_ids_padded = torch.full((batch_size, max_len), -300, dtype=new_input_ids[0].dtype, device=new_input_ids[0].device) if len(new_input_ids) > 0 else None
        

        for i, (cur_new_embed, cur_new_labels, cur_new_input_ids) in enumerate(zip(new_input_embeds, new_labels, new_input_ids)):
            cur_len = cur_new_embed.shape[0]
            new_input_embeds_padded.append(torch.cat((
                cur_new_embed,
                torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
            ), dim=0))
            if cur_len > 0:
                new_labels_padded[i, :cur_len] = cur_new_labels
                attention_mask[i, :cur_len] = True
                position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
                new_input_ids_padded[i, :cur_len] = cur_new_input_ids

        new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)

        if _labels is None:
            new_labels = None
        else:
            new_labels = new_labels_padded

        if _attention_mask is None:
            attention_mask = None
        else:
            attention_mask = attention_mask.to(dtype=_attention_mask.dtype)

        if _position_ids is None:
            position_ids = None

        return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels, target_image_embeds


    def initialize_vision_tokenizer(self, model_args, tokenizer):
        if model_args.mm_use_im_patch_token:
            tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
            self.resize_token_embeddings(len(tokenizer))

        if model_args.mm_use_im_start_end:
            num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
            self.resize_token_embeddings(len(tokenizer))

            if num_new_tokens > 0:
                input_embeddings = self.get_input_embeddings().weight.data
                output_embeddings = self.get_output_embeddings().weight.data

                input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
                    dim=0, keepdim=True)
                output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
                    dim=0, keepdim=True)

                input_embeddings[-num_new_tokens:] = input_embeddings_avg
                output_embeddings[-num_new_tokens:] = output_embeddings_avg

            if model_args.tune_mm_mlp_adapter:
                for p in self.get_input_embeddings().parameters():
                    p.requires_grad = True
                for p in self.get_output_embeddings().parameters():
                    p.requires_grad = False

            if model_args.pretrain_mm_mlp_adapter:
                mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
                embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
                assert num_new_tokens == 2
                if input_embeddings.shape == embed_tokens_weight.shape:
                    input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
                elif embed_tokens_weight.shape[0] == num_new_tokens:
                    input_embeddings[-num_new_tokens:] = embed_tokens_weight
                else:
                    raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
        elif model_args.mm_use_im_patch_token:
            if model_args.tune_mm_mlp_adapter:
                for p in self.get_input_embeddings().parameters():
                    p.requires_grad = False
                for p in self.get_output_embeddings().parameters():
                    p.requires_grad = False