Fahad-S commited on
Commit
9f4b773
·
verified ·
1 Parent(s): df4b253

Upload code/blip3o_fast.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. code/blip3o_fast.py +622 -0
code/blip3o_fast.py ADDED
@@ -0,0 +1,622 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BLIP3o Fast - Unified Image Understanding and Generation with Mask Prediction
3
+
4
+ This module provides:
5
+ - Training: Diffusion-based image editing with mask supervision from SAM
6
+ - Inference: Lightweight mask-free editing using learned MaskPredictor
7
+
8
+ Key Components (from llava_arch.py):
9
+ - MaskPredictor: Learns to predict edit regions from LLM hidden states
10
+ - MaskEncoder: Encodes masks for diffusion conditioning
11
+ - mask_weight/spatial_weight: Learnable conditioning scales (SAVED with model!)
12
+
13
+ Training Flow:
14
+ 1. LLM processes image + instruction → hidden states
15
+ 2. MaskPredictor predicts edit mask (supervised by SAM)
16
+ 3. Diffusion generates edited image with mask conditioning
17
+
18
+ Inference Flow:
19
+ 1. LLM processes image + instruction → hidden states
20
+ 2. MaskPredictor predicts edit mask (NO SAM needed!)
21
+ 3. Diffusion generates edited image
22
+ """
23
+
24
+ from typing import List, Optional, Tuple, Union, Dict, Any
25
+ import json
26
+ import re
27
+
28
+ import torch
29
+ import torch.nn as nn
30
+ import torch.nn.functional as F
31
+
32
+ from transformers import (
33
+ AutoConfig,
34
+ AutoModelForCausalLM,
35
+ AutoTokenizer,
36
+ Qwen2Config,
37
+ Qwen2Model,
38
+ Qwen2ForCausalLM
39
+ )
40
+ from transformers.modeling_outputs import CausalLMOutputWithPast
41
+ from diffusers.training_utils import (
42
+ compute_density_for_timestep_sampling,
43
+ compute_loss_weighting_for_sd3
44
+ )
45
+ from diffusers.utils.torch_utils import randn_tensor
46
+
47
+ from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
48
+
49
+
50
+ # ============================================================
51
+ # TRAINING ONLY: Qwen3 Client for Instruction Parsing
52
+ # ============================================================
53
+
54
+ class Qwen3InstructionParser:
55
+ """Parses edit instructions using Qwen3 LLM. Used only during training."""
56
+
57
+ def __init__(
58
+ self,
59
+ model_name: str = "Qwen/Qwen3-1.7B",
60
+ device: str = "cuda",
61
+ torch_dtype: torch.dtype = torch.float16
62
+ ):
63
+ self.device = device
64
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
65
+ self.model = AutoModelForCausalLM.from_pretrained(
66
+ model_name,
67
+ torch_dtype=torch_dtype,
68
+ device_map=device
69
+ )
70
+ self.model.eval()
71
+ self._cache: Dict[str, Dict] = {}
72
+
73
+ @torch.no_grad()
74
+ def parse(self, instruction: str) -> Dict[str, Any]:
75
+ if instruction in self._cache:
76
+ return self._cache[instruction]
77
+
78
+ prompt = self._build_prompt(instruction)
79
+ messages = [{"role": "user", "content": prompt}]
80
+ text = self.tokenizer.apply_chat_template(
81
+ messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
82
+ )
83
+ inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
84
+ outputs = self.model.generate(
85
+ **inputs, max_new_tokens=256, temperature=0.1,
86
+ do_sample=False, pad_token_id=self.tokenizer.eos_token_id
87
+ )
88
+ response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
89
+ parsed = self._parse_response(response)
90
+ self._cache[instruction] = parsed
91
+ return parsed
92
+
93
+ def _build_prompt(self, instruction: str) -> str:
94
+ return f"""You are an image editing instruction parser. Extract structured information.
95
+
96
+ Respond ONLY with valid JSON:
97
+ {{"operation": "<type>", "source_object": "<object or null>", "target_object": "<object or null>", "location": "<location or null>", "attributes": "<attributes or null>"}}
98
+
99
+ Operation types: remove, replace, add, extract, style, adjust, compose, action, other
100
+
101
+ Examples:
102
+ "Remove the red car" -> {{"operation": "remove", "source_object": "red car", "target_object": null, "location": null, "attributes": null}}
103
+ "Replace the dog with a cat" -> {{"operation": "replace", "source_object": "dog", "target_object": "cat", "location": null, "attributes": null}}
104
+ "Make the dress blue" -> {{"operation": "adjust", "source_object": "dress", "target_object": null, "location": null, "attributes": "blue"}}
105
+
106
+ Input: "{instruction}"
107
+ Output:"""
108
+
109
+ def _parse_response(self, response: str) -> Dict[str, Any]:
110
+ default = {"operation": "other", "source_object": None, "target_object": None, "location": None, "attributes": None}
111
+ try:
112
+ parsed = json.loads(response.strip())
113
+ except json.JSONDecodeError:
114
+ match = re.search(r'\{[^{}]*\}', response, re.DOTALL)
115
+ if match:
116
+ try:
117
+ parsed = json.loads(match.group())
118
+ except:
119
+ return default
120
+ else:
121
+ return default
122
+ return {**default, **parsed}
123
+
124
+
125
+ # ============================================================
126
+ # TRAINING ONLY: Edit Mask Generator (SAM + Qwen3)
127
+ # ============================================================
128
+
129
+ class EditMaskGenerator:
130
+ """Generates ground truth edit masks using Qwen3 + SAM. Training only."""
131
+
132
+ def __init__(
133
+ self,
134
+ qwen_model: str = "Qwen/Qwen3-1.7B",
135
+ sam_model: str = "facebook/sam2.1-hiera-large",
136
+ device: str = "cuda",
137
+ enabled: bool = True
138
+ ):
139
+ self.enabled = enabled
140
+ self.device = device
141
+ if not enabled:
142
+ return
143
+
144
+ self.parser = Qwen3InstructionParser(model_name=qwen_model, device=device)
145
+
146
+ try:
147
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
148
+ self.sam = SAM2ImagePredictor.from_pretrained(sam_model, device=device)
149
+ except ImportError:
150
+ print("WARNING: SAM2 not installed. Mask generation disabled.")
151
+ self.enabled = False
152
+
153
+ def generate(self, image: torch.Tensor, instruction: str, return_parsed: bool = False):
154
+ """Generate edit mask from image and instruction."""
155
+ if not self.enabled:
156
+ H, W = image.shape[-2:]
157
+ mask = torch.zeros(1, H, W, device=self.device)
158
+ return (mask, {"operation": "other"}) if return_parsed else mask
159
+
160
+ # Parse instruction
161
+ parsed = self.parser.parse(instruction)
162
+ source_object = parsed.get("source_object")
163
+
164
+ if not source_object:
165
+ H, W = image.shape[-2:]
166
+ mask = torch.zeros(1, H, W, device=self.device)
167
+ return (mask, parsed) if return_parsed else mask
168
+
169
+ # Convert image for SAM
170
+ if image.dim() == 3:
171
+ image = image.unsqueeze(0)
172
+ image_np = ((image[0].permute(1, 2, 0).cpu().numpy() + 1) * 127.5).astype("uint8")
173
+
174
+ # Generate mask with SAM
175
+ with torch.inference_mode():
176
+ self.sam.set_image(image_np)
177
+
178
+ # Use text prompt if available, otherwise center point
179
+ H, W = image_np.shape[:2]
180
+ point_coords = [[W // 2, H // 2]]
181
+ point_labels = [1]
182
+
183
+ masks, scores, _ = self.sam.predict(
184
+ point_coords=point_coords,
185
+ point_labels=point_labels,
186
+ multimask_output=True
187
+ )
188
+
189
+ # Use best mask
190
+ best_idx = scores.argmax()
191
+ mask = torch.from_numpy(masks[best_idx]).float().unsqueeze(0).to(self.device)
192
+
193
+ return (mask, parsed) if return_parsed else mask
194
+
195
+
196
+ # ============================================================
197
+ # Configuration
198
+ # ============================================================
199
+
200
+ class blip3oFastConfig(Qwen2Config):
201
+ model_type = "llava_qwen2"
202
+
203
+ def __init__(
204
+ self,
205
+ use_mask_predictor: bool = True,
206
+ use_mask_conditioning: bool = True,
207
+ use_spatial_conditioning: bool = False,
208
+ use_operation_embedding: bool = False,
209
+ mask_predictor_loss_weight: float = 0.5,
210
+ latent_channels: int = 32,
211
+ latent_size: int = 32,
212
+ num_operation_types: int = 10,
213
+ **kwargs
214
+ ):
215
+ super().__init__(**kwargs)
216
+ self.use_mask_predictor = use_mask_predictor
217
+ self.use_mask_conditioning = use_mask_conditioning
218
+ self.use_spatial_conditioning = use_spatial_conditioning
219
+ self.use_operation_embedding = use_operation_embedding
220
+ self.mask_predictor_loss_weight = mask_predictor_loss_weight
221
+ self.latent_channels = latent_channels
222
+ self.latent_size = latent_size
223
+ self.num_operation_types = num_operation_types
224
+
225
+
226
+ # ============================================================
227
+ # Base Model
228
+ # ============================================================
229
+
230
+ class blip3oFastModel(LlavaMetaModel, Qwen2Model):
231
+ config_class = blip3oFastConfig
232
+
233
+ def __init__(self, config: Qwen2Config):
234
+ super(blip3oFastModel, self).__init__(config)
235
+
236
+
237
+ # ============================================================
238
+ # Main Model for Training
239
+ # ============================================================
240
+
241
+ class blip3oFastForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
242
+ """
243
+ BLIP3o Fast model for training.
244
+
245
+ All mask-related components (mask_predictor, mask_encoder, mask_weight, etc.)
246
+ are defined in LlavaMetaModel and accessed via properties in LlavaMetaForCausalLM.
247
+ This ensures they are saved/loaded with the model.
248
+ """
249
+
250
+ config_class = blip3oFastConfig
251
+
252
+ def __init__(self, config):
253
+ super(blip3oFastForCausalLM, self).__init__(config)
254
+ config.model_type = "llava_qwen2"
255
+
256
+ self.model = blip3oFastModel(config)
257
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
258
+
259
+ # Operation types for edit classification
260
+ self.operation_types = ["remove", "replace", "add", "extract", "style",
261
+ "adjust", "compose", "action", "inpaint", "other"]
262
+
263
+ # Mask generator (training only, lazy init)
264
+ self._mask_generator = None
265
+ self._mask_generator_initialized = False
266
+
267
+ self.post_init()
268
+
269
+ def get_model(self):
270
+ return self.model
271
+
272
+ # ============================================================
273
+ # Mask Generator (Training Only)
274
+ # ============================================================
275
+
276
+ @property
277
+ def mask_generator(self) -> EditMaskGenerator:
278
+ """Lazy init mask generator (training only)."""
279
+ if not self._mask_generator_initialized:
280
+ enabled = getattr(self.config, 'mask_generator_enabled', True) and self.training
281
+ if enabled:
282
+ self._mask_generator = EditMaskGenerator(
283
+ qwen_model=getattr(self.config, 'qwen_model', "Qwen/Qwen3-1.7B"),
284
+ device=str(self.device),
285
+ enabled=True
286
+ )
287
+ else:
288
+ self._mask_generator = EditMaskGenerator(enabled=False)
289
+ self._mask_generator_initialized = True
290
+ return self._mask_generator
291
+
292
+ def get_operation_index(self, operation: str) -> int:
293
+ if self.operation_types is None:
294
+ return 0
295
+ return self.operation_types.index(operation) if operation in self.operation_types else self.operation_types.index("other")
296
+
297
+ def _normalize_mask(self, mask, H, W, device):
298
+ """Normalize mask to [1, H, W] format."""
299
+ if mask is None:
300
+ return torch.zeros(1, H, W, device=device)
301
+
302
+ if not isinstance(mask, torch.Tensor):
303
+ mask = torch.from_numpy(mask)
304
+
305
+ mask = mask.to(device)
306
+
307
+ if mask.dim() == 4:
308
+ mask = mask[:, 0]
309
+ mask = mask.max(dim=0, keepdim=True)[0]
310
+ elif mask.dim() == 3:
311
+ pass
312
+ elif mask.dim() == 2:
313
+ mask = mask.unsqueeze(0)
314
+ else:
315
+ raise ValueError(f"Unexpected mask shape: {mask.shape}")
316
+
317
+ return mask
318
+
319
+ def _generate_masks_on_fly(self, und_images: torch.Tensor, instructions: List[str]) -> Tuple[torch.Tensor, List[str]]:
320
+ """Generate GT masks using Qwen3 + SAM (training only)."""
321
+ masks, operations = [], []
322
+ B, _, H, W = und_images.shape
323
+ for i in range(und_images.shape[0]):
324
+ try:
325
+ mask, parsed = self.mask_generator.generate(und_images[i], instructions[i], return_parsed=True)
326
+ mask = self._normalize_mask(mask, H=H, W=W, device=und_images.device)
327
+ masks.append(mask)
328
+ operations.append(parsed.get("operation", "other"))
329
+ except Exception as e:
330
+ print(f"Mask generation failed: {e}")
331
+ masks.append(torch.zeros(1, H, W, device=und_images.device))
332
+ operations.append("other")
333
+ return torch.stack(masks).to(und_images.device), operations
334
+
335
+ # ============================================================
336
+ # TRAINING FORWARD
337
+ # ============================================================
338
+
339
+ def forward(
340
+ self,
341
+ input_ids: torch.LongTensor = None,
342
+ attention_mask: Optional[torch.Tensor] = None,
343
+ position_ids: Optional[torch.LongTensor] = None,
344
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
345
+ inputs_embeds: Optional[torch.FloatTensor] = None,
346
+ labels: Optional[torch.LongTensor] = None,
347
+ use_cache: Optional[bool] = None,
348
+ output_attentions: Optional[bool] = None,
349
+ output_hidden_states: Optional[bool] = None,
350
+ gen_image: Optional[torch.FloatTensor] = None,
351
+ und_image: Optional[torch.FloatTensor] = None,
352
+ edit_mask: Optional[torch.FloatTensor] = None,
353
+ operations: Optional[List[str]] = None,
354
+ instructions: Optional[List[str]] = None,
355
+ categories: Optional[List[str]] = None,
356
+ return_dict: Optional[bool] = None,
357
+ cache_position: Optional[torch.LongTensor] = None
358
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
359
+
360
+ output_hidden_states = True
361
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
362
+
363
+ # Prepare multimodal inputs
364
+ if inputs_embeds is None:
365
+ (input_ids, position_ids, attention_mask, past_key_values,
366
+ inputs_embeds, labels, latents) = self.prepare_inputs_labels_for_multimodal(
367
+ input_ids, position_ids, attention_mask, past_key_values,
368
+ labels, gen_image, und_image
369
+ )
370
+ else:
371
+ latents = None
372
+
373
+ # LLM forward
374
+ output = Qwen2ForCausalLM.forward(
375
+ self,
376
+ input_ids=input_ids,
377
+ attention_mask=attention_mask,
378
+ position_ids=position_ids,
379
+ past_key_values=past_key_values,
380
+ inputs_embeds=inputs_embeds,
381
+ use_cache=use_cache,
382
+ output_attentions=output_attentions,
383
+ output_hidden_states=True,
384
+ return_dict=True
385
+ )
386
+ logits = output.logits
387
+ img_hidden_states = output.hidden_states
388
+
389
+ # CE Loss
390
+ if labels is not None:
391
+ shift_logits = logits[..., :-1, :].contiguous()
392
+ shift_labels = labels[..., 1:].contiguous()
393
+ ce_loss = F.cross_entropy(
394
+ shift_logits.view(-1, self.config.vocab_size),
395
+ shift_labels.view(-1),
396
+ ignore_index=-100
397
+ )
398
+ else:
399
+ ce_loss = torch.tensor(0.0, device=logits.device)
400
+
401
+ # If no generation image, return CE loss only
402
+ if latents is None:
403
+ return CausalLMOutputWithPast(
404
+ loss=ce_loss, logits=logits, past_key_values=output.past_key_values,
405
+ hidden_states=output.hidden_states, attentions=output.attentions
406
+ )
407
+
408
+ # ============================================================
409
+ # Generate masks if not provided (training)
410
+ # ============================================================
411
+ if edit_mask is None and instructions is not None and self.training:
412
+ edit_mask, operations = self._generate_masks_on_fly(und_image, instructions)
413
+
414
+ # ============================================================
415
+ # Mask Predictor Loss
416
+ # ============================================================
417
+ mask_pred_loss = torch.tensor(0.0, device=latents.device)
418
+
419
+ if self.mask_predictor is not None:
420
+ last_hidden = img_hidden_states[-1]
421
+ mask_logits = self.mask_predictor(last_hidden, return_logits=True)
422
+
423
+ if edit_mask is not None and self.training:
424
+ gt_mask_resized = F.interpolate(
425
+ edit_mask.float().to(latents.device),
426
+ size=(latents.shape[2], latents.shape[3]),
427
+ mode='nearest'
428
+ )
429
+
430
+ if not torch.isnan(mask_logits).any() and not torch.isnan(gt_mask_resized).any():
431
+ mask_pred_loss = F.binary_cross_entropy_with_logits(
432
+ mask_logits,
433
+ gt_mask_resized,
434
+ reduction='mean'
435
+ )
436
+
437
+ # ============================================================
438
+ # Diffusion Training
439
+ # ============================================================
440
+ noise = torch.randn_like(latents)
441
+ weighting_scheme = "uniform"
442
+ u = compute_density_for_timestep_sampling(
443
+ weighting_scheme=weighting_scheme, batch_size=latents.shape[0],
444
+ logit_mean=0.0, logit_std=1.0, mode_scale=1.29
445
+ )
446
+ indices = (u * self.get_model().noise_scheduler.config.num_train_timesteps).long()
447
+ timesteps = self.get_model().noise_scheduler.timesteps[indices].to(device=latents.device)
448
+ sigmas = self.get_sigmas(timesteps, latents.device, n_dim=latents.ndim, dtype=latents.dtype)
449
+
450
+ # Mask conditioning
451
+ mask_cond = 0
452
+ if self.mask_encoder is not None and edit_mask is not None:
453
+ mask_latent = F.interpolate(
454
+ edit_mask.float().to(latents.device),
455
+ size=(latents.shape[2], latents.shape[3]),
456
+ mode='nearest'
457
+ ).clamp(0.0, 1.0)
458
+ mask_cond = self.mask_encoder(mask_latent)
459
+ mask_cond = self.mask_drop(mask_cond, getattr(self.config, 'mask_drop_prob', 0.1))
460
+
461
+ # Noisy latents with conditioning
462
+ noisy_latents = (1.0 - sigmas) * latents + sigmas * noise
463
+ combined_input = noisy_latents
464
+
465
+ if self.mask_weight is not None and isinstance(mask_cond, torch.Tensor):
466
+ combined_input = combined_input + self.mask_weight * mask_cond
467
+
468
+ # DiT forward
469
+ fused_features = self.get_model().diffusion_connector(img_hidden_states)
470
+
471
+ diffusion_pred = self.get_model().dit(
472
+ hidden_states=combined_input, timestep=timesteps,
473
+ encoder_hidden_states=fused_features, encoder_attention_mask=attention_mask
474
+ ).sample
475
+
476
+ # Diffusion loss (v-prediction)
477
+ target = latents - noise
478
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=weighting_scheme, sigmas=sigmas)
479
+ diff_loss = torch.mean(
480
+ (weighting.float() * (diffusion_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1
481
+ ).mean()
482
+
483
+ # Total loss
484
+ mask_pred_weight = getattr(self.config, 'mask_predictor_loss_weight', 0.5)
485
+ total_loss = diff_loss + 0.2 * ce_loss + mask_pred_weight * mask_pred_loss
486
+
487
+ if self.training:
488
+ print(f"Loss - diff: {diff_loss.item():.4f}, ce: {ce_loss.item():.4f}, mask_pred: {mask_pred_loss.item():.4f}")
489
+
490
+ return CausalLMOutputWithPast(
491
+ loss=total_loss, logits=logits, past_key_values=output.past_key_values,
492
+ hidden_states=output.hidden_states, attentions=output.attentions
493
+ )
494
+
495
+ # ============================================================
496
+ # INFERENCE
497
+ # ============================================================
498
+
499
+ @torch.no_grad()
500
+ def generate_edited_image(
501
+ self,
502
+ und_image: torch.Tensor,
503
+ input_ids: torch.Tensor,
504
+ attention_mask: torch.Tensor,
505
+ num_inference_steps: int = 50,
506
+ guidance_scale: float = 7.5,
507
+ mask_guidance_scale: float = 1.0,
508
+ generator: Optional[torch.Generator] = None,
509
+ ) -> torch.Tensor:
510
+ """
511
+ Generate edited image using learned mask predictor.
512
+ NO external segmentation model needed!
513
+ """
514
+ device = und_image.device
515
+ dtype = und_image.dtype
516
+ batch_size = und_image.shape[0]
517
+
518
+ # Get LLM hidden states
519
+ (input_ids_mm, position_ids, attention_mask_mm, _,
520
+ inputs_embeds, _, _) = self.prepare_inputs_labels_for_multimodal(
521
+ input_ids, None, attention_mask, None, None, None, und_image
522
+ )
523
+
524
+ output = Qwen2ForCausalLM.forward(
525
+ self,
526
+ input_ids=input_ids_mm,
527
+ attention_mask=attention_mask_mm,
528
+ position_ids=position_ids,
529
+ inputs_embeds=inputs_embeds,
530
+ output_hidden_states=True,
531
+ return_dict=True
532
+ )
533
+ hidden_states = output.hidden_states
534
+
535
+ # Predict mask using trained MaskPredictor
536
+ predicted_mask = None
537
+ if self.mask_predictor is not None:
538
+ last_hidden = hidden_states[-1]
539
+ predicted_mask = self.mask_predictor(last_hidden)
540
+
541
+ # Encode reference image
542
+ vae = self.get_model().get_sana_vae()
543
+ ref_latents = vae.encode(und_image.to(vae.device)).latent * vae.config.scaling_factor
544
+ ref_latents = ref_latents.to(device)
545
+
546
+ latent_h, latent_w = ref_latents.shape[2], ref_latents.shape[3]
547
+ latent_channels = ref_latents.shape[1]
548
+
549
+ # Resize predicted mask
550
+ if predicted_mask is not None:
551
+ predicted_mask = F.interpolate(
552
+ predicted_mask, size=(latent_h, latent_w), mode='bilinear', align_corners=False
553
+ )
554
+
555
+ # Mask conditioning
556
+ mask_cond = torch.zeros_like(ref_latents)
557
+ if self.mask_encoder is not None and predicted_mask is not None:
558
+ mask_cond = self.mask_encoder(predicted_mask.to(dtype))
559
+
560
+ # LLM conditioning
561
+ fused_features = self.get_model().diffusion_connector(hidden_states)
562
+
563
+ # Prepare CFG
564
+ if guidance_scale > 1.0:
565
+ mask_cond_cfg = torch.cat([torch.zeros_like(mask_cond), mask_cond])
566
+ fused_features_cfg = torch.cat([torch.zeros_like(fused_features), fused_features])
567
+ else:
568
+ mask_cond_cfg = mask_cond
569
+ fused_features_cfg = fused_features
570
+
571
+ # Initialize latents
572
+ latents = randn_tensor(
573
+ (batch_size, latent_channels, latent_h, latent_w),
574
+ generator=generator, device=device, dtype=dtype
575
+ )
576
+
577
+ # Denoising loop
578
+ scheduler = self.get_model().noise_scheduler
579
+ scheduler.set_timesteps(num_inference_steps, device=device)
580
+ timesteps = scheduler.timesteps
581
+
582
+ for t in timesteps:
583
+ if guidance_scale > 1.0:
584
+ latent_model_input = torch.cat([latents] * 2)
585
+ t_input = torch.cat([t.unsqueeze(0)] * 2 * batch_size)
586
+ else:
587
+ latent_model_input = latents
588
+ t_input = t.unsqueeze(0).expand(batch_size)
589
+
590
+ # Add mask conditioning
591
+ combined_input = latent_model_input
592
+ if self.mask_weight is not None:
593
+ combined_input = combined_input + mask_guidance_scale * self.mask_weight * mask_cond_cfg
594
+
595
+ # DiT forward
596
+ noise_pred = self.get_model().dit(
597
+ hidden_states=combined_input,
598
+ timestep=t_input,
599
+ encoder_hidden_states=fused_features_cfg,
600
+ ).sample
601
+
602
+ # CFG
603
+ if guidance_scale > 1.0:
604
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
605
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
606
+
607
+ # Scheduler step
608
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
609
+
610
+ # Decode
611
+ latents = latents / vae.config.scaling_factor
612
+ image = vae.decode(latents.to(vae.device)).sample
613
+
614
+ return image
615
+
616
+
617
+ # ============================================================
618
+ # Register Model
619
+ # ============================================================
620
+
621
+ AutoConfig.register("llava_qwen2", blip3oFastConfig)
622
+ AutoModelForCausalLM.register(blip3oFastConfig, blip3oFastForCausalLM)