Fahad-S commited on
Commit
65bd55b
·
verified ·
1 Parent(s): 63a6381

Upload code/llava_arch.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. code/llava_arch.py +875 -0
code/llava_arch.py ADDED
@@ -0,0 +1,875 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLaVA Architecture with Integrated Mask Prediction for Image Editing
3
+
4
+ This module contains:
5
+ - LlavaMetaModel: Base model with vision tower, diffusion components, and mask prediction
6
+ - LlavaMetaForCausalLM: Mixin for causal LM with multimodal support
7
+ - MaskPredictor: Predicts edit regions from LLM hidden states
8
+ - BF16SafeLayerNorm: Numerically stable LayerNorm for BF16 training
9
+
10
+ Key Innovation: MaskPredictor enables mask-free inference by learning to predict
11
+ edit regions from LLM understanding, eliminating the need for external segmentation.
12
+ """
13
+
14
+ from abc import ABC, abstractmethod
15
+ from typing import Optional, Tuple, List
16
+ import math
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from diffusers import FlowMatchEulerDiscreteScheduler, DPMSolverMultistepScheduler
22
+ from diffusers.models.normalization import RMSNorm
23
+
24
+ from .mobile_block import MobileConditioningProjector
25
+ from .multimodal_llava_encoder.builder import build_vision_tower
26
+ from .multimodal_llava_projector.builder import build_vision_projector
27
+ from .multimodal_projector.builder import build_down_projector
28
+ from .multimodal_decoder.builder import build_vae, build_sana
29
+ from blip3o.constants import (
30
+ DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN,
31
+ DEFAULT_IMAGE_PATCH_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX
32
+ )
33
+
34
+
35
+ # ============================================================
36
+ # BF16-Safe LayerNorm
37
+ # ============================================================
38
+
39
+ class BF16SafeLayerNorm(nn.Module):
40
+ """
41
+ LayerNorm that's safe for BF16 training.
42
+ Performs normalization in float32 for numerical stability.
43
+ """
44
+ def __init__(self, hidden_size: int, eps: float = 1e-6):
45
+ super().__init__()
46
+ self.weight = nn.Parameter(torch.ones(hidden_size))
47
+ self.bias = nn.Parameter(torch.zeros(hidden_size))
48
+ self.eps = eps
49
+ self.hidden_size = hidden_size
50
+
51
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
52
+ input_dtype = x.dtype
53
+ x = x.float()
54
+ mean = x.mean(-1, keepdim=True)
55
+ variance = (x - mean).pow(2).mean(-1, keepdim=True)
56
+ x = (x - mean) / torch.sqrt(variance + self.eps)
57
+ x = self.weight.float() * x + self.bias.float()
58
+ return x.to(input_dtype)
59
+
60
+ def reset_parameters(self):
61
+ nn.init.ones_(self.weight)
62
+ nn.init.zeros_(self.bias)
63
+
64
+
65
+ # ============================================================
66
+ # Mask Predictor - Enables Mask-Free Inference
67
+ # ============================================================
68
+
69
+ class MaskPredictor(nn.Module):
70
+ """
71
+ Predicts edit mask from LLM hidden states.
72
+
73
+ This is the KEY component that enables mask-free inference.
74
+ During training: Supervised by SAM-generated masks
75
+ During inference: Predicts mask directly from LLM understanding
76
+
77
+ Architecture:
78
+ 1. Attention pooling to focus on instruction-relevant tokens
79
+ 2. Project to spatial features
80
+ 3. Decode to mask
81
+ """
82
+
83
+ def __init__(self, hidden_size: int, latent_channels: int, latent_size: int = 32):
84
+ super().__init__()
85
+ self.latent_size = latent_size
86
+ self.hidden_size = hidden_size
87
+
88
+ # Attention pooling to focus on instruction-relevant tokens
89
+ self.attention_pool = nn.Sequential(
90
+ nn.Linear(hidden_size, hidden_size // 4),
91
+ nn.Tanh(),
92
+ nn.Linear(hidden_size // 4, 1),
93
+ )
94
+
95
+ # Layer norm for stability
96
+ self.input_norm = BF16SafeLayerNorm(hidden_size)
97
+
98
+ # Project pooled features to spatial representation
99
+ intermediate_size = hidden_size // 2
100
+ spatial_dim = latent_size * latent_size * 64
101
+
102
+ self.hidden_proj = nn.Sequential(
103
+ nn.Linear(hidden_size, intermediate_size),
104
+ nn.LayerNorm(intermediate_size),
105
+ nn.GELU(),
106
+ nn.Dropout(0.1),
107
+ nn.Linear(intermediate_size, intermediate_size),
108
+ nn.LayerNorm(intermediate_size),
109
+ nn.GELU(),
110
+ nn.Dropout(0.1),
111
+ nn.Linear(intermediate_size, spatial_dim),
112
+ )
113
+
114
+ # Decode to mask with sufficient capacity
115
+ self.mask_decoder = nn.Sequential(
116
+ nn.Conv2d(64, 256, 3, padding=1),
117
+ nn.GroupNorm(32, 256),
118
+ nn.GELU(),
119
+ nn.Conv2d(256, 128, 3, padding=1),
120
+ nn.GroupNorm(16, 128),
121
+ nn.GELU(),
122
+ nn.Conv2d(128, 64, 3, padding=1),
123
+ nn.GroupNorm(8, 64),
124
+ nn.GELU(),
125
+ nn.Conv2d(64, 1, 1),
126
+ )
127
+
128
+ self._init_weights()
129
+
130
+ def _init_weights(self):
131
+ """Initialize weights for stable training."""
132
+ # Initialize attention pooling
133
+ for module in self.attention_pool:
134
+ if isinstance(module, nn.Linear):
135
+ nn.init.xavier_uniform_(module.weight, gain=0.1)
136
+ if module.bias is not None:
137
+ nn.init.zeros_(module.bias)
138
+
139
+ # Initialize LayerNorm
140
+ self.input_norm.reset_parameters()
141
+
142
+ # Initialize projection layers
143
+ for module in self.hidden_proj:
144
+ if isinstance(module, nn.Linear):
145
+ nn.init.xavier_uniform_(module.weight, gain=0.1)
146
+ if module.bias is not None:
147
+ nn.init.zeros_(module.bias)
148
+ elif isinstance(module, nn.LayerNorm):
149
+ nn.init.ones_(module.weight)
150
+ nn.init.zeros_(module.bias)
151
+
152
+ # Initialize conv layers
153
+ for module in self.mask_decoder:
154
+ if isinstance(module, nn.Conv2d):
155
+ nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
156
+ if module.bias is not None:
157
+ nn.init.zeros_(module.bias)
158
+ elif isinstance(module, nn.GroupNorm):
159
+ nn.init.ones_(module.weight)
160
+ nn.init.zeros_(module.bias)
161
+
162
+ # Initialize final layer with small weights for stable start
163
+ for module in reversed(list(self.mask_decoder)):
164
+ if isinstance(module, nn.Conv2d):
165
+ nn.init.normal_(module.weight, mean=0.0, std=0.01)
166
+ nn.init.zeros_(module.bias)
167
+ break
168
+
169
+ def forward(self, hidden_states: torch.Tensor, return_logits: bool = False) -> torch.Tensor:
170
+ """
171
+ Predict edit mask from LLM hidden states.
172
+
173
+ Args:
174
+ hidden_states: [B, seq_len, hidden_size] from LLM
175
+ return_logits: If True, return logits instead of probabilities
176
+
177
+ Returns:
178
+ mask: [B, 1, H, W] predicted edit mask
179
+ """
180
+ batch_size = hidden_states.shape[0]
181
+ device = hidden_states.device
182
+
183
+ # Check for NaN/Inf in input
184
+ if torch.isnan(hidden_states).any() or torch.isinf(hidden_states).any():
185
+ if return_logits:
186
+ return torch.zeros(batch_size, 1, self.latent_size, self.latent_size,
187
+ device=device, dtype=torch.float32, requires_grad=True)
188
+ return torch.full((batch_size, 1, self.latent_size, self.latent_size), 0.5,
189
+ device=device, dtype=torch.float32, requires_grad=True)
190
+
191
+ # Normalize hidden states
192
+ hidden_states = self.input_norm(hidden_states)
193
+
194
+ # Get dtype from first layer
195
+ target_dtype = self.attention_pool[0].weight.dtype
196
+ hidden_states = hidden_states.to(target_dtype)
197
+
198
+ # Attention pooling: learn which tokens are important
199
+ attn_weights = self.attention_pool(hidden_states)
200
+ attn_weights = F.softmax(attn_weights, dim=1)
201
+
202
+ # Weighted sum of hidden states
203
+ pooled = (hidden_states * attn_weights).sum(dim=1)
204
+
205
+ # Project to spatial features
206
+ spatial = self.hidden_proj(pooled)
207
+ spatial = spatial.view(-1, 64, self.latent_size, self.latent_size)
208
+
209
+ # Decode to mask logits
210
+ mask_logits = self.mask_decoder(spatial)
211
+
212
+ if return_logits:
213
+ return mask_logits.float()
214
+
215
+ return torch.sigmoid(mask_logits.float())
216
+
217
+
218
+ # ============================================================
219
+ # Diffusion Connector
220
+ # ============================================================
221
+
222
+ class DiffusionConnector(nn.Module):
223
+ def __init__(self, input_dim=896, hidden_dim=1024, output_dim=2304, eps=1e-5):
224
+ super().__init__()
225
+ self.linear1 = nn.Linear(input_dim, hidden_dim)
226
+ self.act = nn.GELU(approximate="tanh")
227
+ self.linear2 = nn.Linear(hidden_dim, output_dim)
228
+ self.norm = RMSNorm(output_dim, eps=eps, elementwise_affine=True)
229
+
230
+ nn.init.xavier_uniform_(self.linear1.weight)
231
+ nn.init.zeros_(self.linear1.bias)
232
+ nn.init.xavier_uniform_(self.linear2.weight)
233
+ nn.init.zeros_(self.linear2.bias)
234
+ with torch.no_grad():
235
+ self.norm.weight.fill_(math.sqrt(5.5))
236
+
237
+ def forward(self, x):
238
+ x = self.linear1(x)
239
+ x = self.act(x)
240
+ x = self.linear2(x)
241
+ x = self.norm(x)
242
+ return x
243
+
244
+
245
+ # ============================================================
246
+ # Mask Encoder - Encodes masks for diffusion conditioning
247
+ # ============================================================
248
+
249
+ class MaskEncoder(nn.Module):
250
+ """Encodes binary mask into latent conditioning for diffusion."""
251
+
252
+ def __init__(self, latent_channels: int = 32):
253
+ super().__init__()
254
+ self.encoder = nn.Sequential(
255
+ nn.Conv2d(1, 64, 3, padding=1),
256
+ nn.GroupNorm(8, 64),
257
+ nn.SiLU(),
258
+ nn.Conv2d(64, 128, 3, padding=1),
259
+ nn.GroupNorm(16, 128),
260
+ nn.SiLU(),
261
+ nn.Conv2d(128, latent_channels, 3, padding=1),
262
+ )
263
+ self._init_weights()
264
+
265
+ def _init_weights(self):
266
+ for module in self.encoder:
267
+ if isinstance(module, nn.Conv2d):
268
+ nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
269
+ if module.bias is not None:
270
+ nn.init.zeros_(module.bias)
271
+ elif isinstance(module, nn.GroupNorm):
272
+ nn.init.ones_(module.weight)
273
+ nn.init.zeros_(module.bias)
274
+ # Last layer: small random weights, NOT zeros!
275
+ nn.init.normal_(self.encoder[-1].weight, mean=0.0, std=0.01)
276
+ nn.init.zeros_(self.encoder[-1].bias)
277
+
278
+ def forward(self, mask: torch.Tensor) -> torch.Tensor:
279
+ return self.encoder(mask.to(torch.bfloat16))
280
+
281
+
282
+ # ============================================================
283
+ # Spatial Reference Encoder
284
+ # ============================================================
285
+
286
+ class SpatialRefEncoder(nn.Module):
287
+ """Encodes reference image latents for spatial conditioning."""
288
+
289
+ def __init__(self, latent_channels: int = 32):
290
+ super().__init__()
291
+ self.encoder = nn.Sequential(
292
+ nn.Conv2d(latent_channels, 64, 3, padding=1),
293
+ nn.GroupNorm(8, 64),
294
+ nn.SiLU(),
295
+ nn.Conv2d(64, 128, 3, padding=1),
296
+ nn.GroupNorm(16, 128),
297
+ nn.SiLU(),
298
+ nn.Conv2d(128, latent_channels, 3, padding=1),
299
+ )
300
+ self._init_weights()
301
+
302
+ def _init_weights(self):
303
+ for module in self.encoder:
304
+ if isinstance(module, nn.Conv2d):
305
+ nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
306
+ if module.bias is not None:
307
+ nn.init.zeros_(module.bias)
308
+ elif isinstance(module, nn.GroupNorm):
309
+ nn.init.ones_(module.weight)
310
+ nn.init.zeros_(module.bias)
311
+ # Last layer: small random weights
312
+ nn.init.normal_(self.encoder[-1].weight, mean=0.0, std=0.01)
313
+ nn.init.zeros_(self.encoder[-1].bias)
314
+
315
+ def forward(self, latents: torch.Tensor) -> torch.Tensor:
316
+ return self.encoder(latents)
317
+
318
+
319
+ # ============================================================
320
+ # LlavaMetaModel - Base Model with All Components
321
+ # ============================================================
322
+
323
+ class LlavaMetaModel:
324
+ """
325
+ Base model containing:
326
+ - Vision tower for image understanding
327
+ - DiT for diffusion generation
328
+ - VAE for latent encoding/decoding
329
+ - MaskPredictor for edit region prediction
330
+ - MaskEncoder for mask conditioning
331
+ - Conditioning weights (mask_weight, spatial_weight)
332
+ """
333
+
334
+ def __init__(self, config):
335
+ super(LlavaMetaModel, self).__init__(config)
336
+
337
+ # Vision components
338
+ if hasattr(config, "mm_vision_tower"):
339
+ self.vision_tower = build_vision_tower(config, delay_load=True)
340
+ self.mm_projector = build_vision_projector(config)
341
+
342
+ # Diffusion components
343
+ if hasattr(config, "diffusion_name_or_path"):
344
+ self.dit = build_sana(config)
345
+ self.vae = build_vae(config)
346
+
347
+ # Diffusion connector
348
+ self.diffusion_connector = MobileConditioningProjector(
349
+ input_dim=896,
350
+ hidden_dim=512,
351
+ output_dim=2304,
352
+ num_layers=config.vlm_num_layers
353
+ )
354
+
355
+ # Noise scheduler
356
+ if getattr(config, 'is_train', False):
357
+ print("Using FlowMatchEulerDiscreteScheduler for training")
358
+ self.noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
359
+ config.diffusion_name_or_path, subfolder="scheduler"
360
+ )
361
+ else:
362
+ print("Using DPMSolverMultistepScheduler for inference")
363
+ self.noise_scheduler = DPMSolverMultistepScheduler.from_pretrained(
364
+ config.diffusion_name_or_path, subfolder="scheduler"
365
+ )
366
+
367
+ # Get latent config
368
+ latent_channels = getattr(config, 'latent_channels', 32)
369
+ latent_size = getattr(config, 'latent_size', 32)
370
+
371
+ # ============================================================
372
+ # Mask Prediction Components (for image editing)
373
+ # ============================================================
374
+
375
+ # Mask predictor: predicts edit region from LLM hidden states
376
+ if getattr(config, 'use_mask_predictor', True):
377
+ self.mask_predictor = MaskPredictor(
378
+ hidden_size=config.hidden_size,
379
+ latent_channels=latent_channels,
380
+ latent_size=latent_size
381
+ )
382
+ else:
383
+ self.mask_predictor = None
384
+
385
+ # Mask encoder: encodes mask for diffusion conditioning
386
+ if getattr(config, 'use_mask_conditioning', True):
387
+ self.mask_encoder = MaskEncoder(latent_channels=latent_channels)
388
+ # CRITICAL: This is inside self (LlavaMetaModel), so it gets saved!
389
+ self.mask_weight = nn.Parameter(torch.tensor(1.0))
390
+ else:
391
+ self.mask_encoder = None
392
+ self.mask_weight = None
393
+
394
+ # Spatial reference encoder
395
+ if getattr(config, 'use_spatial_conditioning', False):
396
+ self.spatial_ref_encoder = SpatialRefEncoder(latent_channels=latent_channels)
397
+ self.spatial_weight = nn.Parameter(torch.tensor(0.5))
398
+ else:
399
+ self.spatial_ref_encoder = None
400
+ self.spatial_weight = None
401
+
402
+ # Operation embedding for edit type
403
+ if getattr(config, 'use_operation_embedding', False):
404
+ num_operations = getattr(config, 'num_operation_types', 10)
405
+ self.operation_embedding = nn.Embedding(num_operations, latent_channels)
406
+ else:
407
+ self.operation_embedding = None
408
+
409
+ def get_vision_tower(self):
410
+ vision_tower = getattr(self, 'vision_tower', None)
411
+ if type(vision_tower) is list:
412
+ vision_tower = vision_tower[0]
413
+ return vision_tower
414
+
415
+ def get_sana(self):
416
+ dit = getattr(self, 'dit', None)
417
+ if type(dit) is list:
418
+ dit = dit[0]
419
+ if dit is not None:
420
+ dit.to(self.device)
421
+ return dit
422
+
423
+ def get_sana_vae(self):
424
+ vae = getattr(self, 'vae', None)
425
+ if type(vae) is list:
426
+ vae = vae[0]
427
+ if vae is not None:
428
+ vae.to(self.device)
429
+ return vae
430
+
431
+ def reinitialize_mask_components(self):
432
+ """
433
+ Reinitialize mask-related components.
434
+ Call after loading pretrained weights if these components weren't in the original model.
435
+ """
436
+ print("Reinitializing mask components...")
437
+
438
+ if self.mask_predictor is not None:
439
+ self.mask_predictor._init_weights()
440
+ print(" ✓ mask_predictor reinitialized")
441
+
442
+ if self.mask_encoder is not None:
443
+ self.mask_encoder._init_weights()
444
+ print(" ✓ mask_encoder reinitialized")
445
+
446
+ if self.spatial_ref_encoder is not None:
447
+ self.spatial_ref_encoder._init_weights()
448
+ print(" ✓ spatial_ref_encoder reinitialized")
449
+
450
+ if self.mask_weight is not None:
451
+ nn.init.ones_(self.mask_weight)
452
+ print(" ✓ mask_weight set to 1.0")
453
+
454
+ if self.spatial_weight is not None:
455
+ nn.init.constant_(self.spatial_weight, 0.5)
456
+ print(" ✓ spatial_weight set to 0.5")
457
+
458
+ #if self.operation_embedding is not None:
459
+ # nn.init.normal_(self.operation_embedding.weight, mean=0.0, std=0.02)
460
+ # print(" ✓ operation_embedding reinitialized")
461
+
462
+ print("Reinitialization complete!")
463
+
464
+ def initialize_vision_modules(self, model_args, fsdp=None):
465
+ """Initialize vision and diffusion modules."""
466
+ mm_vision_select_layer = model_args.mm_vision_select_layer
467
+ mm_vision_select_feature = model_args.mm_vision_select_feature
468
+ mm_patch_merge_type = model_args.mm_patch_merge_type
469
+
470
+ # Initialize DiT
471
+ if self.get_sana() is None:
472
+ dit = build_sana(model_args)
473
+ if hasattr(model_args, "is_train"):
474
+ if model_args.is_train:
475
+ print("FLOW MATCHING !!")
476
+ self.noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_args.diffusion_name_or_path, subfolder="scheduler")
477
+ else:
478
+ print("DPM SOLVER !!")
479
+ self.noise_scheduler = DPMSolverMultistepScheduler.from_pretrained(model_args.diffusion_name_or_path, subfolder="scheduler")
480
+
481
+ if fsdp is not None and len(fsdp) > 0:
482
+ self.dit = [dit]
483
+ else:
484
+ self.dit = dit
485
+ else:
486
+ if fsdp is not None and len(fsdp) > 0:
487
+ dit = self.dit[0]
488
+ else:
489
+ dit = self.dit
490
+ for p in dit.parameters():
491
+ p.requires_grad = False
492
+
493
+ if self.get_sana_vae() is None:
494
+ vae = build_vae(model_args)
495
+
496
+ if fsdp is not None and len(fsdp) > 0:
497
+ self.vae = [vae]
498
+ else:
499
+ self.vae = vae
500
+ else:
501
+ if fsdp is not None and len(fsdp) > 0:
502
+ vae = self.vae[0]
503
+ else:
504
+ vae = self.vae
505
+ for p in vae.parameters():
506
+ p.requires_grad = False
507
+
508
+
509
+ if self.get_vision_tower() is None:
510
+ print("=" * 20, "Building vision tower", "=" * 20)
511
+ vision_tower = build_vision_tower(model_args)
512
+
513
+
514
+ if fsdp is not None and len(fsdp) > 0:
515
+ self.vision_tower = [vision_tower]
516
+ else:
517
+ self.vision_tower = vision_tower
518
+ else:
519
+ if fsdp is not None and len(fsdp) > 0:
520
+ vision_tower = self.vision_tower[0]
521
+ else:
522
+ vision_tower = self.vision_tower
523
+ vision_tower.load_model()
524
+
525
+
526
+ if getattr(self, 'diffusion_connector', None) is None:
527
+ #self.diffusion_connector = DiffusionConnector(input_dim=self.config.hidden_size,hidden_dim=1024,output_dim=2304)
528
+ self.diffusion_connector = MobileConditioningProjector(input_dim=896, hidden_dim=512, output_dim=2304, num_layers=model_args.vlm_num_layers)
529
+
530
+
531
+ '''
532
+ norm = RMSNorm(2304, eps=1e-5, elementwise_affine=True)
533
+ with torch.no_grad():
534
+ norm.weight.fill_(math.sqrt(5.5))
535
+ self.diffusion_connector = nn.Sequential(
536
+ nn.Linear(self.config.hidden_size, 1024),
537
+ nn.GELU(approximate="tanh"),
538
+ nn.Linear(1024, 2304),
539
+ norm,
540
+ )
541
+ '''
542
+ else:
543
+ for p in self.diffusion_connector.parameters():
544
+ p.requires_grad = True
545
+
546
+ # freeze all parameters in dit except for caption_projection
547
+ for name, param in self.dit.named_parameters():
548
+ if "caption" in name:
549
+ param.requires_grad = True
550
+ else:
551
+ param.requires_grad = False
552
+
553
+
554
+ for p in dit.parameters():
555
+ p.requires_grad = True
556
+ for p in vision_tower.parameters():
557
+ p.requires_grad = False
558
+ # vision_tower().eval()
559
+
560
+ self.config.use_mm_proj = True
561
+ self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
562
+ self.config.mm_vision_select_layer = mm_vision_select_layer
563
+ self.config.mm_vision_select_feature = mm_vision_select_feature
564
+ self.config.mm_patch_merge_type = mm_patch_merge_type
565
+ self.config.diffusion_name_or_path = model_args.diffusion_name_or_path
566
+ self.config.is_train = False #model_args.is_train
567
+
568
+ if getattr(self, 'down_projector', None) is None:
569
+ self.down_projector = build_down_projector(self.config)
570
+ else:
571
+ # In case it is frozen by LoRA
572
+ for p in self.down_projector.parameters():
573
+ p.requires_grad = True
574
+
575
+
576
+
577
+
578
+
579
+
580
+ def unpad_image(tensor, original_size):
581
+ """
582
+ Unpads a PyTorch tensor of a padded and resized image.
583
+
584
+ Args:
585
+ tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
586
+ original_size (tuple): The original size of PIL image (width, height).
587
+
588
+ Returns:
589
+ torch.Tensor: The unpadded image tensor.
590
+ """
591
+ original_width, original_height = original_size
592
+ current_height, current_width = tensor.shape[1:]
593
+
594
+ original_aspect_ratio = original_width / original_height
595
+ current_aspect_ratio = current_width / current_height
596
+
597
+ if original_aspect_ratio > current_aspect_ratio:
598
+ scale_factor = current_width / original_width
599
+ new_height = int(original_height * scale_factor)
600
+ padding = (current_height - new_height) // 2
601
+ unpadded_tensor = tensor[:, padding:current_height - padding, :]
602
+ else:
603
+ scale_factor = current_height / original_height
604
+ new_width = int(original_width * scale_factor)
605
+ padding = (current_width - new_width) // 2
606
+ unpadded_tensor = tensor[:, :, padding:current_width - padding]
607
+
608
+ return unpadded_tensor
609
+
610
+
611
+ class LlavaMetaForCausalLM(ABC):
612
+
613
+ @abstractmethod
614
+ def get_model(self):
615
+ pass
616
+
617
+ def get_vision_tower(self):
618
+ return self.get_model().get_vision_tower()
619
+
620
+ def visual(self, pixel_values: torch.Tensor) -> torch.Tensor:
621
+ image_features = self.get_model().get_vision_tower()(pixel_values)
622
+ image_features = self.get_model().mm_projector(image_features)
623
+ return image_features
624
+
625
+
626
+ def get_mm_projector(self):
627
+ return self.get_model().mm_projector
628
+
629
+
630
+ def get_sigmas(self, timesteps, device, n_dim=4, dtype=torch.float32):
631
+ sigmas = self.get_model().noise_scheduler.sigmas.to(device=device, dtype=dtype)
632
+ schedule_timesteps = self.get_model().noise_scheduler.timesteps.to(device=device)
633
+ timesteps = timesteps.to(device)
634
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
635
+
636
+ sigma = sigmas[step_indices].flatten()
637
+ while len(sigma.shape) < n_dim:
638
+ sigma = sigma.unsqueeze(-1)
639
+ return sigma
640
+
641
+ def mask_drop(self, latents, drop_prob=0.1):
642
+ if drop_prob <= 0:
643
+ return latents
644
+ mask = torch.bernoulli(torch.zeros(latents.shape[0], device=latents.device, dtype=latents.dtype) + drop_prob)
645
+ while len(mask.shape) < len(latents.shape):
646
+ mask = mask.unsqueeze(-1)
647
+ mask = 1 - mask # need to flip 0 <-> 1
648
+ return latents * mask
649
+
650
+ # ============================================================
651
+ # Convenience Properties for Mask Components
652
+ # ============================================================
653
+
654
+ @property
655
+ def mask_predictor(self):
656
+ return getattr(self.get_model(), 'mask_predictor', None)
657
+
658
+ @property
659
+ def mask_encoder(self):
660
+ return getattr(self.get_model(), 'mask_encoder', None)
661
+
662
+ @property
663
+ def mask_weight(self):
664
+ return getattr(self.get_model(), 'mask_weight', None)
665
+
666
+ @property
667
+ def spatial_weight(self):
668
+ return getattr(self.get_model(), 'spatial_weight', None)
669
+
670
+ @property
671
+ def spatial_ref_encoder(self):
672
+ return getattr(self.get_model(), 'spatial_ref_encoder', None)
673
+
674
+ @property
675
+ def operation_embedding(self):
676
+ return getattr(self.get_model(), 'operation_embedding', None)
677
+
678
+ # ============================================================
679
+ # Multimodal Input Preparation
680
+ # ============================================================
681
+
682
+ def prepare_inputs_labels_for_multimodal(
683
+ self, input_ids, position_ids, attention_mask, past_key_values, labels,
684
+ gen_images=None, und_images=None
685
+ ):
686
+ if (gen_images is None and und_images is None) or input_ids.shape[1] == 1 or self.get_vision_tower() is None:
687
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels, None, None, None
688
+ if gen_images is not None:
689
+ vae = self.get_model().get_sana_vae()
690
+ vae_device = vae.device
691
+ prompt_image_embeds = vae.encode(gen_images.to(vae_device)).latent if gen_images is not None else None
692
+ prompt_image_embeds = prompt_image_embeds * vae.config.scaling_factor if prompt_image_embeds is not None else None
693
+ target_image_embeds = torch.clone(prompt_image_embeds).detach()
694
+ else:
695
+ target_image_embeds = None
696
+
697
+
698
+ images = und_images
699
+ if type(images) is list or images.ndim == 5:
700
+ if type(images) is list:
701
+ images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
702
+ concat_images = torch.cat([image for image in images], dim=0)
703
+ image_features = self.visual(concat_images)
704
+ split_sizes = [image.shape[0] for image in images]
705
+ image_features = torch.split(image_features, split_sizes, dim=0)
706
+ image_features = [x.flatten(0, 1) for x in image_features]
707
+ else:
708
+ image_features = self.visual(images) # [B, image_tokens, hidden_size]
709
+
710
+
711
+ # Let's just add dummy tensors if they do not exist,
712
+ # it is a headache to deal with None all the time.
713
+ # But it is not ideal, and if you have a better idea,
714
+ # please open an issue / submit a PR, thanks.
715
+ _labels = labels
716
+ _position_ids = position_ids
717
+ _attention_mask = attention_mask
718
+ if attention_mask is None:
719
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
720
+ else:
721
+ attention_mask = attention_mask.bool()
722
+ if position_ids is None:
723
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
724
+ if labels is None:
725
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
726
+
727
+ # remove the padding using attention_mask -- FIXME
728
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
729
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
730
+
731
+ new_input_embeds = []
732
+ new_labels = []
733
+ new_input_ids = []
734
+ cur_image_idx = 0
735
+ for batch_idx, cur_input_ids in enumerate(input_ids):
736
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
737
+ if num_images == 0:
738
+ cur_image_features = image_features[cur_image_idx]
739
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
740
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
741
+ new_input_embeds.append(cur_input_embeds)
742
+ new_labels.append(labels[batch_idx])
743
+ cur_image_idx += 1
744
+ continue
745
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
746
+ cur_input_ids_noim = []
747
+ cur_labels = labels[batch_idx]
748
+ cur_labels_noim = []
749
+ for i in range(len(image_token_indices) - 1):
750
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
751
+ cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
752
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
753
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
754
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
755
+ cur_new_input_embeds = []
756
+ cur_new_labels = []
757
+ cur_new_input_ids = []
758
+
759
+ for i in range(num_images + 1):
760
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
761
+ cur_new_labels.append(cur_labels_noim[i])
762
+ cur_new_input_ids.append(cur_input_ids_noim[i])
763
+ if i < num_images:
764
+ if cur_image_idx < image_features.shape[0]:
765
+ cur_image_features = image_features[cur_image_idx]
766
+ else:
767
+ cur_image_features = image_features[-1]
768
+ cur_image_idx += 1
769
+ cur_new_input_embeds.append(cur_image_features)
770
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
771
+ cur_new_input_ids.append(torch.full((cur_image_features.shape[0],), IMAGE_TOKEN_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
772
+
773
+ cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
774
+
775
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
776
+ cur_new_labels = torch.cat(cur_new_labels, dim=0)
777
+ cur_new_input_ids = torch.cat(cur_new_input_ids, dim=0)
778
+
779
+ new_input_embeds.append(cur_new_input_embeds)
780
+ new_labels.append(cur_new_labels)
781
+ new_input_ids.append(cur_new_input_ids)
782
+
783
+ # Combine them
784
+ max_len = max(x.shape[0] for x in new_input_embeds)
785
+ batch_size = len(new_input_embeds)
786
+
787
+ new_input_embeds_padded = []
788
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
789
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
790
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
791
+ new_input_ids_padded = torch.full((batch_size, max_len), -300, dtype=new_input_ids[0].dtype, device=new_input_ids[0].device) if len(new_input_ids) > 0 else None
792
+
793
+
794
+ for i, (cur_new_embed, cur_new_labels, cur_new_input_ids) in enumerate(zip(new_input_embeds, new_labels, new_input_ids)):
795
+ cur_len = cur_new_embed.shape[0]
796
+ new_input_embeds_padded.append(torch.cat((
797
+ cur_new_embed,
798
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
799
+ ), dim=0))
800
+ if cur_len > 0:
801
+ new_labels_padded[i, :cur_len] = cur_new_labels
802
+ attention_mask[i, :cur_len] = True
803
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
804
+ new_input_ids_padded[i, :cur_len] = cur_new_input_ids
805
+
806
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
807
+
808
+ if _labels is None:
809
+ new_labels = None
810
+ else:
811
+ new_labels = new_labels_padded
812
+
813
+ if _attention_mask is None:
814
+ attention_mask = None
815
+ else:
816
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
817
+
818
+ if _position_ids is None:
819
+ position_ids = None
820
+
821
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels, target_image_embeds
822
+
823
+
824
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
825
+ if model_args.mm_use_im_patch_token:
826
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
827
+ self.resize_token_embeddings(len(tokenizer))
828
+
829
+ if model_args.mm_use_im_start_end:
830
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
831
+ self.resize_token_embeddings(len(tokenizer))
832
+
833
+ if num_new_tokens > 0:
834
+ input_embeddings = self.get_input_embeddings().weight.data
835
+ output_embeddings = self.get_output_embeddings().weight.data
836
+
837
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
838
+ dim=0, keepdim=True)
839
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
840
+ dim=0, keepdim=True)
841
+
842
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
843
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
844
+
845
+ if model_args.tune_mm_mlp_adapter:
846
+ for p in self.get_input_embeddings().parameters():
847
+ p.requires_grad = True
848
+ for p in self.get_output_embeddings().parameters():
849
+ p.requires_grad = False
850
+
851
+ if model_args.pretrain_mm_mlp_adapter:
852
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
853
+ embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
854
+ assert num_new_tokens == 2
855
+ if input_embeddings.shape == embed_tokens_weight.shape:
856
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
857
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
858
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
859
+ else:
860
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
861
+ elif model_args.mm_use_im_patch_token:
862
+ if model_args.tune_mm_mlp_adapter:
863
+ for p in self.get_input_embeddings().parameters():
864
+ p.requires_grad = False
865
+ for p in self.get_output_embeddings().parameters():
866
+ p.requires_grad = False
867
+
868
+
869
+
870
+
871
+
872
+
873
+
874
+
875
+