Fahad-S commited on
Commit
63a6381
·
verified ·
1 Parent(s): 9f4b773

Upload code/blip3o_fast_v0.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. code/blip3o_fast_v0.py +1160 -0
code/blip3o_fast_v0.py ADDED
@@ -0,0 +1,1160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # blip3o_fast.py
2
+ # Training: Qwen3 + Grounding DINO + SAM-2 for mask supervision
3
+ # Inference: Lightweight - no external components needed
4
+
5
+ from typing import List, Optional, Tuple, Union, Dict, Any
6
+ import re
7
+ import json
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import numpy as np
13
+
14
+ from transformers import (
15
+ AutoConfig,
16
+ AutoModelForCausalLM,
17
+ AutoTokenizer,
18
+ Qwen2Config,
19
+ Qwen2Model,
20
+ Qwen2ForCausalLM
21
+ )
22
+ from transformers.modeling_outputs import CausalLMOutputWithPast
23
+ from diffusers.training_utils import (
24
+ compute_density_for_timestep_sampling,
25
+ compute_loss_weighting_for_sd3
26
+ )
27
+ from diffusers.utils.torch_utils import randn_tensor
28
+
29
+ from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
30
+
31
+
32
+ # ============================================================
33
+ # TRAINING ONLY: Qwen3 Client for Instruction Parsing
34
+ # ============================================================
35
+
36
+ class Qwen3InstructionParser:
37
+ """Parses edit instructions using Qwen3 LLM. Used only during training."""
38
+
39
+ def __init__(
40
+ self,
41
+ model_name: str = "Qwen/Qwen3-1.7B",
42
+ device: str = "cuda",
43
+ torch_dtype: torch.dtype = torch.float16
44
+ ):
45
+ self.device = device
46
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
47
+ self.model = AutoModelForCausalLM.from_pretrained(
48
+ model_name,
49
+ torch_dtype=torch_dtype,
50
+ device_map=device
51
+ )
52
+ self.model.eval()
53
+ self._cache: Dict[str, Dict] = {}
54
+
55
+ @torch.no_grad()
56
+ def parse(self, instruction: str) -> Dict[str, Any]:
57
+ if instruction in self._cache:
58
+ return self._cache[instruction]
59
+
60
+ prompt = self._build_prompt(instruction)
61
+ messages = [{"role": "user", "content": prompt}]
62
+ text = self.tokenizer.apply_chat_template(
63
+ messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
64
+ )
65
+ inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
66
+ outputs = self.model.generate(
67
+ **inputs, max_new_tokens=256, temperature=0.1,
68
+ do_sample=False, pad_token_id=self.tokenizer.eos_token_id
69
+ )
70
+ response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
71
+ parsed = self._parse_response(response)
72
+ self._cache[instruction] = parsed
73
+ return parsed
74
+
75
+ def _build_prompt(self, instruction: str) -> str:
76
+ return f"""You are an image editing instruction parser. Extract structured information.
77
+
78
+ Respond ONLY with valid JSON:
79
+ {{"operation": "<type>", "source_object": "<object or null>", "target_object": "<object or null>", "location": "<location or null>", "attributes": "<attributes or null>"}}
80
+
81
+ Operation types: remove, replace, add, extract, style, adjust, compose, action, other
82
+
83
+ Examples:
84
+ "Remove the red car" -> {{"operation": "remove", "source_object": "red car", "target_object": null, "location": null, "attributes": null}}
85
+ "Replace the dog with a cat" -> {{"operation": "replace", "source_object": "dog", "target_object": "cat", "location": null, "attributes": null}}
86
+ "Make the dress blue" -> {{"operation": "adjust", "source_object": "dress", "target_object": null, "location": null, "attributes": "blue"}}
87
+
88
+ Input: "{instruction}"
89
+ Output:"""
90
+
91
+ def _parse_response(self, response: str) -> Dict[str, Any]:
92
+ default = {"operation": "other", "source_object": None, "target_object": None, "location": None,
93
+ "attributes": None}
94
+ try:
95
+ parsed = json.loads(response.strip())
96
+ except json.JSONDecodeError:
97
+ match = re.search(r'\{[^{}]*\}', response, re.DOTALL)
98
+ if match:
99
+ try:
100
+ parsed = json.loads(match.group())
101
+ except:
102
+ return default
103
+ else:
104
+ return default
105
+ for key in default:
106
+ if key not in parsed:
107
+ parsed[key] = default[key]
108
+ valid_ops = ["remove", "replace", "add", "extract", "style", "adjust", "compose", "action", "other"]
109
+ if parsed["operation"] not in valid_ops:
110
+ parsed["operation"] = "other"
111
+ return parsed
112
+
113
+
114
+ class SAM3MaskGenerator:
115
+ """
116
+ Generates segmentation masks using SAM3.
117
+ SAM3 natively supports text prompts - no Grounding DINO needed!
118
+ """
119
+
120
+ def __init__(self, device: str = "cuda"):
121
+ self.device = device
122
+ self._model = None
123
+ self._processor = None
124
+
125
+ def _load_model(self):
126
+ """Lazy load SAM3 model."""
127
+ if self._model is None:
128
+ from sam3.model_builder import build_sam3_image_model
129
+ from sam3.model.sam3_image_processor import Sam3Processor
130
+
131
+ print("Loading SAM3...")
132
+ self._model = build_sam3_image_model()
133
+ self._processor = Sam3Processor(self._model)
134
+ print("SAM3 loaded!")
135
+
136
+ def _prepare_image(self, image):
137
+ """Convert various image formats to PIL Image."""
138
+ from PIL import Image as PILImage
139
+
140
+ if isinstance(image, PILImage.Image):
141
+ return image.convert("RGB")
142
+ elif isinstance(image, torch.Tensor):
143
+ if image.dim() == 4:
144
+ image = image[0]
145
+
146
+ if image.dtype in (torch.bfloat16, torch.float16):
147
+ image = image.float()
148
+ if image.shape[0] in [1, 3]:
149
+ image_np = image.permute(1, 2, 0).cpu().numpy()
150
+ else:
151
+ image_np = image.cpu().numpy()
152
+ if image_np.max() <= 1.0:
153
+ image_np = (image_np * 255).astype(np.uint8)
154
+ else:
155
+ image_np = image_np.astype(np.uint8)
156
+ return PILImage.fromarray(image_np).convert("RGB")
157
+ elif isinstance(image, np.ndarray):
158
+ if image.max() <= 1.0:
159
+ image = (image * 255).astype(np.uint8)
160
+ return PILImage.fromarray(image).convert("RGB")
161
+ else:
162
+ return PILImage.fromarray(np.array(image)).convert("RGB")
163
+
164
+ @torch.no_grad()
165
+ def generate_mask(
166
+ self,
167
+ image,
168
+ parsed: Dict,
169
+ detect_all: bool = False
170
+ ) -> torch.Tensor:
171
+ """
172
+ Generate segmentation mask using SAM3 with text prompt.
173
+
174
+ Args:
175
+ image: Input image (PIL, tensor, or numpy)
176
+ parsed: Parsed instruction dict with 'source_object', 'operation', etc.
177
+ detect_all: Whether to return all instances
178
+
179
+ Returns:
180
+ mask: [1, H, W] binary mask tensor
181
+ """
182
+ self._load_model()
183
+ # Convert image to PIL
184
+ image_pil = self._prepare_image(image)
185
+ W, H = image_pil.size
186
+
187
+ # Build text prompt from parsed instruction
188
+ text_prompt = self._build_text_prompt(parsed)
189
+
190
+ if not text_prompt:
191
+ if parsed.get("operation") == "style":
192
+ return torch.ones(1, H, W)
193
+ return torch.zeros(1, H, W)
194
+
195
+ # Set image in SAM3
196
+ inference_state = self._processor.set_image(image_pil)
197
+
198
+ # Get segmentation with text prompt
199
+ output = self._processor.set_text_prompt(
200
+ state=inference_state,
201
+ prompt=text_prompt
202
+ )
203
+
204
+ masks = output["masks"] # List of masks
205
+ scores = output["scores"] # Confidence scores
206
+
207
+ if masks is None or len(masks) == 0:
208
+ return torch.zeros(1, H, W)
209
+
210
+ # Convert masks to tensor
211
+ if isinstance(masks, np.ndarray):
212
+ masks = torch.from_numpy(masks)
213
+ elif isinstance(masks, list):
214
+ masks = torch.stack([torch.from_numpy(m) if isinstance(m, np.ndarray) else m for m in masks])
215
+
216
+ if detect_all:
217
+ # Combine all masks
218
+ combined_mask = masks.float().max(dim=0)[0]
219
+ return combined_mask.unsqueeze(0)
220
+ else:
221
+ # Return highest scoring mask
222
+ if isinstance(scores, (list, np.ndarray)):
223
+ scores = torch.tensor(scores)
224
+ best_idx = scores.argmax()
225
+ return masks[best_idx].unsqueeze(0).float()
226
+
227
+ def _build_text_prompt(self, parsed: Dict) -> str:
228
+ """Build SAM3 text prompt from parsed instruction."""
229
+ operation = parsed.get("operation", "other")
230
+ source = parsed.get("source_object")
231
+ target = parsed.get("target_object")
232
+ location = parsed.get("location")
233
+ attributes = parsed.get("attributes")
234
+
235
+ if operation in ["remove", "replace", "extract", "adjust", "action"]:
236
+ # Need to find the source object
237
+ if source:
238
+ # Add attributes if available
239
+ if attributes and operation == "adjust":
240
+ return source # e.g., "dress" for "make the dress blue"
241
+ return source
242
+ elif operation == "add":
243
+ # For add, find where to add (the context object)
244
+ if source:
245
+ return source # e.g., "woman" for "put sunglasses on the woman"
246
+ elif location:
247
+ return location
248
+ elif operation == "compose":
249
+ if source:
250
+ return source
251
+ elif operation == "style":
252
+ # Style affects whole image, return empty
253
+ return ""
254
+
255
+ return source or ""
256
+
257
+
258
+ class EditMaskGenerator:
259
+ """
260
+ Complete mask generation pipeline using Qwen3 + SAM3.
261
+
262
+ Simplified from: Qwen3 → Grounding DINO → SAM-2
263
+ To: Qwen3 → SAM3 (native text support)
264
+ """
265
+
266
+ def __init__(
267
+ self,
268
+ qwen_model: str = "Qwen/Qwen3-1.7B",
269
+ device: str = "cuda",
270
+ enabled: bool = True
271
+ ):
272
+ self.device = device
273
+ self.enabled = enabled
274
+
275
+ if enabled:
276
+ print("Initializing EditMaskGenerator with SAM3...")
277
+ self.parser = Qwen3InstructionParser(model_name=qwen_model, device=device)
278
+ self.segmenter = SAM3MaskGenerator(device=device)
279
+ print("EditMaskGenerator ready!")
280
+ else:
281
+ self.parser = None
282
+ self.segmenter = None
283
+
284
+ @torch.no_grad()
285
+ def generate(
286
+ self,
287
+ image,
288
+ instruction: str,
289
+ return_parsed: bool = False
290
+ ):
291
+ """Generate edit mask from image and instruction."""
292
+ if not self.enabled:
293
+ if isinstance(image, torch.Tensor):
294
+ H, W = image.shape[-2:]
295
+ else:
296
+ H, W = np.array(image).shape[:2]
297
+ mask = torch.zeros(1, H, W)
298
+ return (mask, {"operation": "other"}) if return_parsed else mask
299
+
300
+ # Step 1: Parse instruction with Qwen3
301
+ parsed = self.parser.parse(instruction)
302
+
303
+ # Step 2: Generate mask with SAM3 (native text prompt!)
304
+ detect_all = "all" in instruction.lower()
305
+ mask = self.segmenter.generate_mask(image, parsed, detect_all=detect_all)
306
+
307
+ return (mask, parsed) if return_parsed else mask
308
+
309
+
310
+ # ============================================================
311
+ # Model Configuration
312
+ # ============================================================
313
+ class blip3oFastConfig(Qwen2Config):
314
+ model_type = "llava_qwen2"
315
+
316
+ def __init__(self, **kwargs):
317
+ super().__init__(**kwargs)
318
+
319
+ self.latent_channels = kwargs.get("latent_channels", 32)
320
+
321
+ # Conditioning
322
+ self.use_spatial_conditioning = kwargs.get("use_spatial_conditioning", False)
323
+ self.use_mask_conditioning = kwargs.get("use_mask_conditioning", True)
324
+ self.use_operation_embedding = kwargs.get("use_operation_embedding", True)
325
+ self.use_mask_predictor = kwargs.get("use_mask_predictor", True)
326
+ self.mask_predictor_loss_weight = kwargs.get("mask_predictor_loss_weight", 0.5)
327
+
328
+ # Dropout
329
+ self.spatial_drop_prob = kwargs.get("spatial_drop_prob", 0.1)
330
+ self.mask_drop_prob = kwargs.get("mask_drop_prob", 0.1)
331
+
332
+ # Mask generator config (SIMPLIFIED - no Grounding DINO!)
333
+ self.mask_generator_enabled = kwargs.get("mask_generator_enabled", True)
334
+ self.qwen_model = kwargs.get("qwen_model", "Qwen/Qwen3-1.7B")
335
+
336
+
337
+ # ============================================================
338
+ # Mask Predictor: Learns to predict edit regions from LLM hidden states
339
+ # ============================================================
340
+
341
+ class BF16SafeLayerNorm(nn.Module):
342
+ """LayerNorm that works correctly with BF16 and PEFT."""
343
+
344
+ def __init__(self, hidden_size: int, eps: float = 1e-6):
345
+ super().__init__()
346
+ # Explicitly initialize with proper values
347
+ self.weight = nn.Parameter(torch.ones(hidden_size, dtype=torch.float32))
348
+ self.bias = nn.Parameter(torch.zeros(hidden_size, dtype=torch.float32))
349
+ self.eps = eps
350
+ self.hidden_size = hidden_size
351
+
352
+ # Force initialization
353
+ self.reset_parameters()
354
+
355
+ def reset_parameters(self):
356
+ """Ensure weights are properly initialized."""
357
+ nn.init.ones_(self.weight)
358
+ nn.init.zeros_(self.bias)
359
+
360
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
361
+ # Always compute normalization in float32 for stability
362
+ input_dtype = x.dtype
363
+ x_f32 = x.float()
364
+
365
+ # Manual layer norm computation
366
+ mean = x_f32.mean(dim=-1, keepdim=True)
367
+ var = x_f32.var(dim=-1, keepdim=True, unbiased=False)
368
+ x_norm = (x_f32 - mean) / torch.sqrt(var + self.eps)
369
+
370
+ # Apply weight and bias in float32
371
+ output = x_norm * self.weight.float() + self.bias.float()
372
+
373
+ # Convert back to original dtype
374
+ return output.to(input_dtype)
375
+
376
+
377
+ class MaskPredictor(nn.Module):
378
+ """
379
+ Predicts edit mask from LLM hidden states.
380
+ This is the KEY component that enables mask-free inference.
381
+
382
+ The mask predictor learns to identify WHICH object needs to be edited
383
+ based on the instruction (e.g., "remove the white dog") and the image
384
+ understanding encoded in the LLM hidden states.
385
+
386
+ Architecture:
387
+ 1. Extract instruction-relevant features using attention pooling
388
+ 2. Project to spatial features
389
+ 3. Decode to mask
390
+ """
391
+
392
+ def __init__(self, hidden_size: int, latent_channels: int, latent_size: int = 32):
393
+ super().__init__()
394
+
395
+ self.latent_size = latent_size
396
+ self.hidden_size = hidden_size
397
+
398
+ # Attention pooling to focus on instruction-relevant tokens
399
+ # Instead of simple mean pooling, learn which tokens are important
400
+ self.attention_pool = nn.Sequential(
401
+ nn.Linear(hidden_size, hidden_size // 4),
402
+ nn.Tanh(),
403
+ nn.Linear(hidden_size // 4, 1),
404
+ )
405
+
406
+ # Layer norm for stability
407
+ self.input_norm = BF16SafeLayerNorm(hidden_size)
408
+
409
+ # Project pooled features to spatial representation
410
+ intermediate_size = hidden_size // 2
411
+ spatial_dim = latent_size * latent_size * 64
412
+
413
+ self.hidden_proj = nn.Sequential(
414
+ nn.Linear(hidden_size, intermediate_size),
415
+ nn.LayerNorm(intermediate_size),
416
+ nn.GELU(),
417
+ nn.Dropout(0.1),
418
+ nn.Linear(intermediate_size, intermediate_size),
419
+ nn.LayerNorm(intermediate_size),
420
+ nn.GELU(),
421
+ nn.Dropout(0.1),
422
+ nn.Linear(intermediate_size, spatial_dim),
423
+ )
424
+
425
+ # Upsample to mask with more capacity
426
+ self.mask_decoder = nn.Sequential(
427
+ nn.Conv2d(64, 256, 3, padding=1),
428
+ nn.GroupNorm(32, 256),
429
+ nn.GELU(),
430
+ nn.Conv2d(256, 128, 3, padding=1),
431
+ nn.GroupNorm(16, 128),
432
+ nn.GELU(),
433
+ nn.Conv2d(128, 64, 3, padding=1),
434
+ nn.GroupNorm(8, 64),
435
+ nn.GELU(),
436
+ nn.Conv2d(64, 1, 1),
437
+ )
438
+
439
+ self._init_weights()
440
+
441
+ def _init_weights(self):
442
+ """Initialize weights for stable training."""
443
+ # Initialize attention pooling
444
+ for module in self.attention_pool:
445
+ if isinstance(module, nn.Linear):
446
+ nn.init.xavier_uniform_(module.weight, gain=0.1)
447
+ if module.bias is not None:
448
+ nn.init.zeros_(module.bias)
449
+
450
+ # Initialize LayerNorm
451
+ if hasattr(self, 'input_norm'):
452
+ self.input_norm.reset_parameters()
453
+
454
+ # Initialize projection layers
455
+ for module in self.hidden_proj:
456
+ if isinstance(module, nn.Linear):
457
+ nn.init.xavier_uniform_(module.weight, gain=0.1)
458
+ if module.bias is not None:
459
+ nn.init.zeros_(module.bias)
460
+ elif isinstance(module, nn.LayerNorm):
461
+ nn.init.ones_(module.weight)
462
+ nn.init.zeros_(module.bias)
463
+
464
+ # Initialize conv layers
465
+ for module in self.mask_decoder:
466
+ if isinstance(module, nn.Conv2d):
467
+ nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
468
+ if module.bias is not None:
469
+ nn.init.zeros_(module.bias)
470
+ elif isinstance(module, nn.GroupNorm):
471
+ nn.init.ones_(module.weight)
472
+ nn.init.zeros_(module.bias)
473
+
474
+ # Initialize final layer with small weights for stable start
475
+ for module in reversed(list(self.mask_decoder)):
476
+ if isinstance(module, nn.Conv2d):
477
+ nn.init.normal_(module.weight, mean=0.0, std=0.01)
478
+ nn.init.zeros_(module.bias)
479
+ break
480
+
481
+ def forward(self, hidden_states: torch.Tensor, return_logits: bool = False) -> torch.Tensor:
482
+ """
483
+ Args:
484
+ hidden_states: [B, seq_len, hidden_size] from LLM
485
+ return_logits: If True, return logits instead of probabilities
486
+
487
+ Returns:
488
+ mask: [B, 1, H, W] predicted edit mask
489
+ """
490
+ batch_size = hidden_states.shape[0]
491
+ device = hidden_states.device
492
+
493
+ # Check for NaN/Inf in input
494
+ if torch.isnan(hidden_states).any() or torch.isinf(hidden_states).any():
495
+ print("WARNING: NaN/Inf in hidden_states input to MaskPredictor")
496
+ if return_logits:
497
+ return torch.zeros(batch_size, 1, self.latent_size, self.latent_size,
498
+ device=device, dtype=torch.float32, requires_grad=True)
499
+ return torch.full((batch_size, 1, self.latent_size, self.latent_size), 0.5,
500
+ device=device, dtype=torch.float32, requires_grad=True)
501
+
502
+ # Normalize hidden states
503
+ hidden_states = self.input_norm(hidden_states)
504
+
505
+ if torch.isnan(hidden_states).any():
506
+ print("WARNING: NaN after input_norm in MaskPredictor")
507
+ if return_logits:
508
+ return torch.zeros(batch_size, 1, self.latent_size, self.latent_size,
509
+ device=device, dtype=torch.float32, requires_grad=True)
510
+ return torch.full((batch_size, 1, self.latent_size, self.latent_size), 0.5,
511
+ device=device, dtype=torch.float32, requires_grad=True)
512
+
513
+ # Get dtype from first layer
514
+ target_dtype = self.attention_pool[0].weight.dtype
515
+ hidden_states = hidden_states.to(target_dtype)
516
+
517
+ # Attention pooling: learn which tokens are important for mask prediction
518
+ # [B, seq_len, hidden_size] -> [B, seq_len, 1]
519
+ attn_weights = self.attention_pool(hidden_states)
520
+ attn_weights = F.softmax(attn_weights, dim=1) # [B, seq_len, 1]
521
+
522
+ # Weighted sum of hidden states
523
+ # [B, seq_len, hidden_size] * [B, seq_len, 1] -> [B, hidden_size]
524
+ pooled = (hidden_states * attn_weights).sum(dim=1)
525
+
526
+ # Project to spatial features
527
+ spatial = self.hidden_proj(pooled) # [B, spatial_dim]
528
+ spatial = spatial.view(-1, 64, self.latent_size, self.latent_size) # [B, 64, H, W]
529
+
530
+ # Decode to mask logits
531
+ mask_logits = self.mask_decoder(spatial) # [B, 1, H, W]
532
+
533
+ if return_logits:
534
+ return mask_logits.float()
535
+
536
+ # Apply sigmoid to get probabilities
537
+ mask = torch.sigmoid(mask_logits.float())
538
+ return mask
539
+
540
+
541
+ # ============================================================
542
+ # Main Model
543
+ # ============================================================
544
+
545
+ class blip3oFastModel(LlavaMetaModel, Qwen2Model):
546
+ config_class = blip3oFastConfig
547
+
548
+ def __init__(self, config: Qwen2Config):
549
+ super(blip3oFastModel, self).__init__(config)
550
+
551
+
552
+ class blip3oFastForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
553
+ config_class = blip3oFastConfig
554
+
555
+ def __init__(self, config):
556
+ super(Qwen2ForCausalLM, self).__init__(config)
557
+
558
+ self.model = blip3oFastModel(config)
559
+ self.vocab_size = config.vocab_size
560
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
561
+
562
+ latent_channels = getattr(config, 'latent_channels', 32)
563
+
564
+ # ============================================================
565
+ # Spatial Reference Encoder
566
+ # ============================================================
567
+ if getattr(config, 'use_spatial_conditioning', True):
568
+ self.spatial_ref_encoder = nn.Sequential(
569
+ nn.Conv2d(latent_channels, 320, 3, padding=1),
570
+ nn.GroupNorm(32, 320),
571
+ nn.SiLU(),
572
+ nn.Conv2d(320, 320, 3, padding=1),
573
+ nn.GroupNorm(32, 320),
574
+ nn.SiLU(),
575
+ nn.Conv2d(320, latent_channels, 3, padding=1),
576
+ )
577
+ self.spatial_weight = nn.Parameter(torch.tensor(0.0))
578
+ else:
579
+ self.spatial_ref_encoder = None
580
+ self.spatial_weight = None
581
+
582
+ # ============================================================
583
+ # Mask Encoder (encodes mask into conditioning)
584
+ # ============================================================
585
+ if getattr(config, 'use_mask_conditioning', True):
586
+ self.mask_encoder = nn.Sequential(
587
+ nn.Conv2d(1, 64, 3, padding=1),
588
+ nn.GroupNorm(8, 64),
589
+ nn.SiLU(),
590
+ nn.Conv2d(64, 128, 3, padding=1),
591
+ nn.GroupNorm(16, 128),
592
+ nn.SiLU(),
593
+ nn.Conv2d(128, latent_channels, 3, padding=1),
594
+ )
595
+ self.mask_weight = nn.Parameter(torch.tensor(0.0))
596
+ else:
597
+ self.mask_encoder = None
598
+ self.mask_weight = None
599
+
600
+ # ============================================================
601
+ # Mask Predictor (CRITICAL: enables mask-free inference)
602
+ # ============================================================
603
+ if getattr(config, 'use_mask_predictor', True):
604
+ self.mask_predictor = MaskPredictor(
605
+ hidden_size=config.hidden_size,
606
+ latent_channels=latent_channels,
607
+ latent_size=32 # Adjust based on your latent resolution
608
+ )
609
+ else:
610
+ self.mask_predictor = None
611
+
612
+ # ============================================================
613
+ # Operation Embedding
614
+ # ============================================================
615
+ if getattr(config, 'use_operation_embedding', True):
616
+ self.operation_types = ["remove", "replace", "add", "extract", "style", "adjust", "compose", "action",
617
+ "other"]
618
+ self.operation_embedding = nn.Embedding(len(self.operation_types), latent_channels)
619
+ else:
620
+ self.operation_types = None
621
+ self.operation_embedding = None
622
+
623
+ # Mask generator (training only, lazy init)
624
+ self._mask_generator = None
625
+ self._mask_generator_initialized = False
626
+
627
+ self._init_conditioning_layers()
628
+ self.post_init()
629
+
630
+ def _init_conditioning_layers(self):
631
+ """Initialize conditioning layers. Called during __init__ and can be called after loading."""
632
+ if self.spatial_ref_encoder is not None:
633
+ for module in self.spatial_ref_encoder:
634
+ if isinstance(module, nn.Conv2d):
635
+ nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
636
+ if module.bias is not None:
637
+ nn.init.zeros_(module.bias)
638
+ elif isinstance(module, nn.GroupNorm):
639
+ nn.init.ones_(module.weight)
640
+ nn.init.zeros_(module.bias)
641
+ # Zero-init the last layer
642
+ nn.init.zeros_(self.spatial_ref_encoder[-1].weight)
643
+ nn.init.zeros_(self.spatial_ref_encoder[-1].bias)
644
+
645
+ if self.mask_encoder is not None:
646
+ for module in self.mask_encoder:
647
+ if isinstance(module, nn.Conv2d):
648
+ nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
649
+ if module.bias is not None:
650
+ nn.init.zeros_(module.bias)
651
+ elif isinstance(module, nn.GroupNorm):
652
+ nn.init.ones_(module.weight)
653
+ nn.init.zeros_(module.bias)
654
+ # Zero-init the last layer
655
+ nn.init.zeros_(self.mask_encoder[-1].weight)
656
+ nn.init.zeros_(self.mask_encoder[-1].bias)
657
+
658
+ def reinitialize_new_modules(self):
659
+ """
660
+ Reinitialize modules that were added after the base model.
661
+ Call this after loading a pretrained model to fix uninitialized weights.
662
+ """
663
+ print("Reinitializing new modules (mask_predictor, mask_encoder, spatial_ref_encoder)...")
664
+
665
+ # Reinitialize mask_predictor
666
+ if self.mask_predictor is not None:
667
+ self.mask_predictor._init_weights()
668
+ print(" - mask_predictor reinitialized")
669
+
670
+ # Reinitialize conditioning layers
671
+ self._init_conditioning_layers()
672
+ print(" - conditioning layers reinitialized")
673
+
674
+ # Reinitialize operation embedding
675
+ if self.operation_embedding is not None:
676
+ nn.init.normal_(self.operation_embedding.weight, mean=0.0, std=0.02)
677
+ print(" - operation_embedding reinitialized")
678
+
679
+ # Reinitialize scalar weights
680
+ if self.spatial_weight is not None:
681
+ nn.init.zeros_(self.spatial_weight)
682
+ print(" - spatial_weight reinitialized to 0")
683
+ if self.mask_weight is not None:
684
+ nn.init.zeros_(self.mask_weight)
685
+ print(" - mask_weight reinitialized to 0")
686
+
687
+ print("Reinitialization complete!")
688
+
689
+ @property
690
+ def mask_generator(self) -> EditMaskGenerator:
691
+ """Lazy init mask generator (training only)."""
692
+ if not self._mask_generator_initialized:
693
+ enabled = getattr(self.config, 'mask_generator_enabled', True) and self.training
694
+ if enabled:
695
+ # SIMPLIFIED: Only Qwen3 + SAM3 needed now!
696
+ self._mask_generator = EditMaskGenerator(
697
+ qwen_model=getattr(self.config, 'qwen_model', "Qwen/Qwen3-1.7B"),
698
+ device=str(self.device),
699
+ enabled=True
700
+ )
701
+ else:
702
+ self._mask_generator = EditMaskGenerator(enabled=False)
703
+ self._mask_generator_initialized = True
704
+ return self._mask_generator
705
+
706
+ def get_model(self):
707
+ return self.model
708
+
709
+ def mask_drop(self, latents: torch.Tensor, drop_prob: float = 0.1) -> torch.Tensor:
710
+ if drop_prob <= 0 or not self.training:
711
+ return latents
712
+ mask = torch.bernoulli(torch.full((latents.shape[0],), drop_prob, device=latents.device, dtype=latents.dtype))
713
+ while len(mask.shape) < len(latents.shape):
714
+ mask = mask.unsqueeze(-1)
715
+ return latents * (1 - mask)
716
+
717
+ def get_operation_index(self, operation: str) -> int:
718
+ if self.operation_types is None:
719
+ return 0
720
+ return self.operation_types.index(
721
+ operation) if operation in self.operation_types else self.operation_types.index("other")
722
+
723
+ def _normalize_mask(self, mask, H, W, device):
724
+ """
725
+ Always return a single mask: [1, H, W]
726
+ """
727
+ if mask is None:
728
+ return torch.zeros(1, H, W, device=device)
729
+
730
+ # Convert numpy → torch if needed
731
+ if not isinstance(mask, torch.Tensor):
732
+ mask = torch.from_numpy(mask)
733
+
734
+ mask = mask.to(device)
735
+
736
+ # Remove batch dim if present
737
+ if mask.dim() == 4: # [N, 1, H, W]
738
+ mask = mask[:, 0] # [N, H, W]
739
+ # Reduction: union of all objects
740
+ mask = mask.max(dim=0, keepdim=True)[0]
741
+
742
+ elif mask.dim() == 3: # [1, H, W]
743
+ pass
744
+
745
+ elif mask.dim() == 2: # [H, W]
746
+ mask = mask.unsqueeze(0)
747
+
748
+ else:
749
+ raise ValueError(f"Unexpected mask shape: {mask.shape}")
750
+
751
+ return mask
752
+
753
+ def _generate_masks_on_fly(self, und_images: torch.Tensor, instructions: List[str]) -> Tuple[
754
+ torch.Tensor, List[str]]:
755
+ """Generate GT masks using Qwen3 + Grounded SAM-2 (training only)."""
756
+ masks, operations = [], []
757
+ B, _, H, W = und_images.shape
758
+ for i in range(und_images.shape[0]):
759
+ try:
760
+ mask, parsed = self.mask_generator.generate(und_images[i], instructions[i], return_parsed=True)
761
+ mask = self._normalize_mask(mask, H=H, W=W, device=und_images.device)
762
+ masks.append(mask)
763
+ operations.append(parsed.get("operation", "other"))
764
+ except Exception as e:
765
+ print(f"Mask generation failed: {e}")
766
+ masks.append(torch.zeros(1, H, W, device=und_images.device))
767
+ operations.append("other")
768
+ return torch.stack(masks).to(und_images.device), operations
769
+
770
+ # ============================================================
771
+ # TRAINING FORWARD
772
+ # ============================================================
773
+ def forward(
774
+ self,
775
+ input_ids: torch.LongTensor = None,
776
+ attention_mask: Optional[torch.Tensor] = None,
777
+ position_ids: Optional[torch.LongTensor] = None,
778
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
779
+ inputs_embeds: Optional[torch.FloatTensor] = None,
780
+ labels: Optional[torch.LongTensor] = None,
781
+ use_cache: Optional[bool] = None,
782
+ output_attentions: Optional[bool] = None,
783
+ output_hidden_states: Optional[bool] = None,
784
+ gen_image: Optional[torch.FloatTensor] = None,
785
+ und_image: Optional[torch.FloatTensor] = None,
786
+ edit_mask: Optional[torch.FloatTensor] = None,
787
+ operations: Optional[List[str]] = None,
788
+ instructions: Optional[List[str]] = None,
789
+ categories: Optional[List[str]] = None,
790
+ return_dict: Optional[bool] = None,
791
+ cache_position: Optional[torch.LongTensor] = None
792
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
793
+
794
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
795
+ output_hidden_states = True
796
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
797
+
798
+ if inputs_embeds is None:
799
+ (input_ids, position_ids, attention_mask, past_key_values,
800
+ inputs_embeds, labels, latents) = self.prepare_inputs_labels_for_multimodal(
801
+ input_ids, position_ids, attention_mask, past_key_values, labels, gen_image, und_image)
802
+
803
+ # LLM Forward
804
+ output = super().forward(
805
+ input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids,
806
+ past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels,
807
+ use_cache=use_cache, output_attentions=output_attentions,
808
+ output_hidden_states=output_hidden_states, return_dict=return_dict
809
+ )
810
+
811
+ ce_loss = output.loss
812
+ hidden_states = output.hidden_states
813
+ logits = output.logits
814
+ img_hidden_states = hidden_states
815
+
816
+ assert latents is not None
817
+
818
+ # ============================================================
819
+ # Generate GT Masks (Training Only)
820
+ # ============================================================
821
+ if edit_mask is None and instructions is not None and self.training:
822
+ if getattr(self.config, 'mask_generator_enabled', True):
823
+ edit_mask, operations = self._generate_masks_on_fly(und_image, instructions)
824
+
825
+ # ============================================================
826
+ # Predict Mask from LLM Hidden States (for inference capability)
827
+ # ============================================================
828
+ mask_pred_loss = torch.tensor(0.0, device=latents.device)
829
+ predicted_mask = None
830
+ mask_logits = None
831
+ gt_mask_resized = None
832
+
833
+ if self.mask_predictor is not None:
834
+ # Get last layer hidden states
835
+ last_hidden = hidden_states[-1] # [B, seq_len, hidden_size]
836
+
837
+ # Get mask logits (for stable BCE loss computation)
838
+ mask_logits = self.mask_predictor(last_hidden, return_logits=True) # [B, 1, H, W]
839
+
840
+ # Resize to latent size
841
+ mask_logits = F.interpolate(
842
+ mask_logits.float(),
843
+ size=(latents.shape[2], latents.shape[3]),
844
+ mode='bilinear',
845
+ align_corners=False
846
+ )
847
+
848
+ # Get probabilities for conditioning
849
+ predicted_mask = torch.sigmoid(mask_logits)
850
+
851
+ # Supervision loss (train predictor to match GT mask)
852
+ if edit_mask is not None and self.training:
853
+ gt_mask_resized = F.interpolate(
854
+ edit_mask.float().to(latents.device),
855
+ size=(latents.shape[2], latents.shape[3]),
856
+ mode='nearest'
857
+ )
858
+
859
+ # Check for NaN before loss computation
860
+ if not torch.isnan(mask_logits).any() and not torch.isnan(gt_mask_resized).any():
861
+ # Standard BCE loss
862
+ mask_pred_loss = F.binary_cross_entropy_with_logits(
863
+ mask_logits,
864
+ gt_mask_resized,
865
+ reduction='mean'
866
+ )
867
+ else:
868
+ print("WARNING: NaN in mask_logits or gt_mask, skipping mask_pred_loss")
869
+ mask_pred_loss = torch.tensor(0.0, device=latents.device)
870
+
871
+ # ============================================================
872
+ # Diffusion Setup
873
+ # ============================================================
874
+ noise = torch.randn_like(latents)
875
+ weighting_scheme = "uniform"
876
+ u = compute_density_for_timestep_sampling(
877
+ weighting_scheme=weighting_scheme, batch_size=latents.shape[0],
878
+ logit_mean=0.0, logit_std=1.0, mode_scale=1.29
879
+ )
880
+ indices = (u * self.get_model().noise_scheduler.config.num_train_timesteps).long()
881
+ timesteps = self.get_model().noise_scheduler.timesteps[indices].to(device=latents.device)
882
+ sigmas = self.get_sigmas(timesteps, latents.device, n_dim=latents.ndim, dtype=latents.dtype)
883
+
884
+ # ============================================================
885
+ # Spatial Conditioning
886
+ # ============================================================
887
+ if self.spatial_ref_encoder is not None:
888
+ vae = self.get_model().get_sana_vae()
889
+ ref_latents = vae.encode(und_image.to(vae.device)).latent * vae.config.scaling_factor
890
+ ref_latents = ref_latents.to(latents.device)
891
+ spatial_cond = self.spatial_ref_encoder(ref_latents)
892
+ spatial_cond = self.mask_drop(spatial_cond, getattr(self.config, 'spatial_drop_prob', 0.1))
893
+ else:
894
+ spatial_cond = 0
895
+
896
+ # ============================================================
897
+ # Mask Conditioning (use GT mask during training)
898
+ # ============================================================
899
+ if self.mask_encoder is not None and edit_mask is not None:
900
+ mask_latent = F.interpolate(
901
+ edit_mask.float().to(latents.device),
902
+ size=(latents.shape[2], latents.shape[3]),
903
+ mode='nearest'
904
+ )
905
+ mask_latent = mask_latent.clamp(0.0, 1.0)
906
+
907
+ # Do mask encoding in float32 to avoid BF16 issues
908
+ mask_cond = mask_latent
909
+ for layer in self.mask_encoder:
910
+ if isinstance(layer, nn.Conv2d):
911
+ mask_cond = F.conv2d(mask_cond, layer.weight.float(),
912
+ layer.bias.float() if layer.bias is not None else None,
913
+ layer.stride, layer.padding)
914
+ elif isinstance(layer, nn.GroupNorm):
915
+ mask_cond = F.group_norm(mask_cond, layer.num_groups,
916
+ layer.weight.float(), layer.bias.float(), layer.eps)
917
+ else:
918
+ mask_cond = layer(mask_cond)
919
+
920
+ # Convert to model dtype and apply dropout
921
+ mask_cond = mask_cond.to(latents.dtype)
922
+ mask_cond = self.mask_drop(mask_cond, getattr(self.config, 'mask_drop_prob', 0.1))
923
+ else:
924
+ mask_cond = 0
925
+ mask_latent = None
926
+
927
+ # ============================================================
928
+ # Operation Embedding
929
+ # ============================================================
930
+ if self.operation_embedding is not None and operations is not None:
931
+ op_indices = torch.tensor([self.get_operation_index(op) for op in operations], device=latents.device)
932
+ op_embed = self.operation_embedding(op_indices)[:, :, None, None]
933
+ op_cond = op_embed * mask_latent if mask_latent is not None else op_embed.expand(-1, -1, latents.shape[2],
934
+ latents.shape[3])
935
+ else:
936
+ op_cond = 0
937
+
938
+ # ============================================================
939
+ # Combine Conditioning
940
+ # ============================================================
941
+ noisy_latents = (1.0 - sigmas) * latents + sigmas * noise
942
+ combined_input = noisy_latents
943
+
944
+ if self.mask_weight is not None and isinstance(mask_cond, torch.Tensor):
945
+ combined_input = combined_input + self.mask_weight * mask_cond
946
+
947
+ # ============================================================
948
+ # DiT Forward
949
+ # ============================================================
950
+ fused_features = self.get_model().diffusion_connector(img_hidden_states)
951
+
952
+ diffusion_pred = self.get_model().dit(
953
+ hidden_states=combined_input, timestep=timesteps,
954
+ encoder_hidden_states=fused_features, encoder_attention_mask=attention_mask
955
+ ).sample
956
+
957
+ target = latents - noise
958
+
959
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=weighting_scheme, sigmas=sigmas)
960
+ diff_loss = torch.mean(
961
+ (weighting.float() * (diffusion_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1
962
+ ).mean()
963
+
964
+ # ============================================================
965
+ # Total Loss
966
+ # ============================================================
967
+ mask_pred_weight = getattr(self.config, 'mask_predictor_loss_weight', 0.5)
968
+ total_loss = diff_loss + 0.2 * ce_loss + mask_pred_weight * mask_pred_loss
969
+
970
+ # Logging
971
+ if self.training:
972
+ print(f"Loss - diff: {diff_loss.item():.4f}, ce: {ce_loss.item():.4f}, mask_pred: {mask_pred_loss.item() if isinstance(mask_pred_loss, torch.Tensor) else 0:.4f}")
973
+
974
+ return CausalLMOutputWithPast(
975
+ loss=total_loss, logits=logits, past_key_values=output.past_key_values,
976
+ hidden_states=output.hidden_states, attentions=output.attentions
977
+ )
978
+
979
+ # ============================================================
980
+ # INFERENCE: Lightweight - No Qwen3/SAM-2 needed!
981
+ # ============================================================
982
+ @torch.no_grad()
983
+ def generate_edited_image(
984
+ self,
985
+ und_image: torch.Tensor,
986
+ input_ids: torch.Tensor,
987
+ attention_mask: torch.Tensor,
988
+ num_inference_steps: int = 50,
989
+ guidance_scale: float = 7.5,
990
+ spatial_guidance_scale: float = 1.0,
991
+ mask_guidance_scale: float = 1.0,
992
+ generator: Optional[torch.Generator] = None,
993
+ ) -> torch.Tensor:
994
+ """
995
+ Lightweight inference - uses learned mask predictor instead of SAM-2.
996
+
997
+ Args:
998
+ und_image: Input image tensor [B, C, H, W]
999
+ input_ids: Tokenized prompt [B, seq_len]
1000
+ attention_mask: Attention mask [B, seq_len]
1001
+ num_inference_steps: Denoising steps
1002
+ guidance_scale: CFG scale for text
1003
+ spatial_guidance_scale: Scale for spatial conditioning
1004
+ mask_guidance_scale: Scale for predicted mask conditioning
1005
+ generator: Random generator for reproducibility
1006
+
1007
+ Returns:
1008
+ Edited image latents [B, C, H, W]
1009
+ """
1010
+
1011
+ device = und_image.device
1012
+ dtype = und_image.dtype
1013
+ batch_size = und_image.shape[0]
1014
+
1015
+ # ============================================================
1016
+ # 1. Get LLM Hidden States
1017
+ # ============================================================
1018
+ (input_ids_mm, position_ids, attention_mask_mm, _,
1019
+ inputs_embeds, _, _) = self.prepare_inputs_labels_for_multimodal(
1020
+ input_ids, None, attention_mask, None, None, None, und_image
1021
+ )
1022
+
1023
+ output = Qwen2ForCausalLM.forward(
1024
+ self,
1025
+ input_ids=input_ids_mm,
1026
+ attention_mask=attention_mask_mm,
1027
+ position_ids=position_ids,
1028
+ inputs_embeds=inputs_embeds,
1029
+ output_hidden_states=True,
1030
+ return_dict=True
1031
+ )
1032
+
1033
+ hidden_states = output.hidden_states
1034
+ img_hidden_states = hidden_states
1035
+
1036
+ # ============================================================
1037
+ # 2. Predict Edit Mask (NO SAM-2 needed!)
1038
+ # ============================================================
1039
+ if self.mask_predictor is not None:
1040
+ last_hidden = hidden_states[-1]
1041
+ predicted_mask = self.mask_predictor(last_hidden) # [B, 1, H, W]
1042
+ else:
1043
+ predicted_mask = None
1044
+
1045
+ # ============================================================
1046
+ # 3. Encode Reference Image
1047
+ # ============================================================
1048
+ vae = self.get_model().get_sana_vae()
1049
+ ref_latents = vae.encode(und_image.to(vae.device)).latent * vae.config.scaling_factor
1050
+ ref_latents = ref_latents.to(device)
1051
+
1052
+ latent_h, latent_w = ref_latents.shape[2], ref_latents.shape[3]
1053
+ latent_channels = ref_latents.shape[1]
1054
+
1055
+ # Resize predicted mask to latent size
1056
+ if predicted_mask is not None:
1057
+ predicted_mask = F.interpolate(
1058
+ predicted_mask, size=(latent_h, latent_w), mode='bilinear', align_corners=False
1059
+ )
1060
+
1061
+ # ============================================================
1062
+ # 4. Prepare Conditioning
1063
+ # ============================================================
1064
+ # Spatial conditioning
1065
+ if self.spatial_ref_encoder is not None:
1066
+ spatial_cond = self.spatial_ref_encoder(ref_latents)
1067
+ else:
1068
+ spatial_cond = torch.zeros_like(ref_latents)
1069
+
1070
+ # Mask conditioning
1071
+ if self.mask_encoder is not None and predicted_mask is not None:
1072
+ mask_cond = self.mask_encoder(predicted_mask.to(dtype=self.mask_encoder[0].weight.dtype))
1073
+ else:
1074
+ mask_cond = torch.zeros_like(ref_latents)
1075
+
1076
+ # Semantic conditioning from LLM
1077
+ fused_features = self.get_model().diffusion_connector(img_hidden_states)
1078
+
1079
+ # ============================================================
1080
+ # 5. Prepare for CFG
1081
+ # ============================================================
1082
+ if guidance_scale > 1.0:
1083
+ # Unconditional: zero out conditioning
1084
+ spatial_cond_uncond = torch.zeros_like(spatial_cond)
1085
+ mask_cond_uncond = torch.zeros_like(mask_cond)
1086
+ fused_features_uncond = torch.zeros_like(fused_features)
1087
+
1088
+ # Stack [uncond, cond]
1089
+ spatial_cond_cfg = torch.cat([spatial_cond_uncond, spatial_cond])
1090
+ mask_cond_cfg = torch.cat([mask_cond_uncond, mask_cond])
1091
+ fused_features_cfg = torch.cat([fused_features_uncond, fused_features])
1092
+ else:
1093
+ spatial_cond_cfg = spatial_cond
1094
+ mask_cond_cfg = mask_cond
1095
+ fused_features_cfg = fused_features
1096
+
1097
+ # ============================================================
1098
+ # 6. Initialize Latents
1099
+ # ============================================================
1100
+ latents = randn_tensor(
1101
+ (batch_size, latent_channels, latent_h, latent_w),
1102
+ generator=generator, device=device, dtype=dtype
1103
+ )
1104
+
1105
+ # ============================================================
1106
+ # 7. Setup Scheduler
1107
+ # ============================================================
1108
+ scheduler = self.get_model().noise_scheduler
1109
+ scheduler.set_timesteps(num_inference_steps, device=device)
1110
+ timesteps = scheduler.timesteps
1111
+
1112
+ # ============================================================
1113
+ # 8. Denoising Loop
1114
+ # ============================================================
1115
+ for t in timesteps:
1116
+ # Expand for CFG
1117
+ if guidance_scale > 1.0:
1118
+ latent_model_input = torch.cat([latents] * 2)
1119
+ t_input = torch.cat([t.unsqueeze(0)] * 2 * batch_size)
1120
+ else:
1121
+ latent_model_input = latents
1122
+ t_input = t.unsqueeze(0).expand(batch_size)
1123
+
1124
+ # Add conditioning
1125
+ combined_input = latent_model_input
1126
+ if self.spatial_weight is not None:
1127
+ combined_input = combined_input + spatial_guidance_scale * self.spatial_weight * spatial_cond_cfg
1128
+ if self.mask_weight is not None:
1129
+ combined_input = combined_input + mask_guidance_scale * self.mask_weight * mask_cond_cfg
1130
+
1131
+ # DiT forward
1132
+ noise_pred = self.get_model().dit(
1133
+ hidden_states=combined_input,
1134
+ timestep=t_input,
1135
+ encoder_hidden_states=fused_features_cfg,
1136
+ ).sample
1137
+
1138
+ # CFG
1139
+ if guidance_scale > 1.0:
1140
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
1141
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
1142
+
1143
+ # Scheduler step
1144
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
1145
+
1146
+ # ============================================================
1147
+ # 9. Decode Latents
1148
+ # ============================================================
1149
+ latents = latents / vae.config.scaling_factor
1150
+ image = vae.decode(latents.to(vae.device)).sample
1151
+
1152
+ return image
1153
+
1154
+
1155
+ # ============================================================
1156
+ # Register Model
1157
+ # ============================================================
1158
+
1159
+ AutoConfig.register("llava_qwen2", blip3oFastConfig)
1160
+ AutoModelForCausalLM.register(blip3oFastConfig, blip3oFastForCausalLM)