| |
| |
| |
|
|
| from typing import List, Optional, Tuple, Union, Dict, Any |
| import re |
| import json |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
|
|
| from transformers import ( |
| AutoConfig, |
| AutoModelForCausalLM, |
| AutoTokenizer, |
| Qwen2Config, |
| Qwen2Model, |
| Qwen2ForCausalLM |
| ) |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| from diffusers.training_utils import ( |
| compute_density_for_timestep_sampling, |
| compute_loss_weighting_for_sd3 |
| ) |
| from diffusers.utils.torch_utils import randn_tensor |
|
|
| from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM |
|
|
|
|
| |
| |
| |
|
|
| class Qwen3InstructionParser: |
| """Parses edit instructions using Qwen3 LLM. Used only during training.""" |
|
|
| def __init__( |
| self, |
| model_name: str = "Qwen/Qwen3-1.7B", |
| device: str = "cuda", |
| torch_dtype: torch.dtype = torch.float16 |
| ): |
| self.device = device |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| self.model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| torch_dtype=torch_dtype, |
| device_map=device |
| ) |
| self.model.eval() |
| self._cache: Dict[str, Dict] = {} |
|
|
| @torch.no_grad() |
| def parse(self, instruction: str) -> Dict[str, Any]: |
| if instruction in self._cache: |
| return self._cache[instruction] |
|
|
| prompt = self._build_prompt(instruction) |
| messages = [{"role": "user", "content": prompt}] |
| text = self.tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True, enable_thinking=False |
| ) |
| inputs = self.tokenizer(text, return_tensors="pt").to(self.device) |
| outputs = self.model.generate( |
| **inputs, max_new_tokens=256, temperature=0.1, |
| do_sample=False, pad_token_id=self.tokenizer.eos_token_id |
| ) |
| response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) |
| parsed = self._parse_response(response) |
| self._cache[instruction] = parsed |
| return parsed |
|
|
| def _build_prompt(self, instruction: str) -> str: |
| return f"""You are an image editing instruction parser. Extract structured information. |
| |
| Respond ONLY with valid JSON: |
| {{"operation": "<type>", "source_object": "<object or null>", "target_object": "<object or null>", "location": "<location or null>", "attributes": "<attributes or null>"}} |
| |
| Operation types: remove, replace, add, extract, style, adjust, compose, action, other |
| |
| Examples: |
| "Remove the red car" -> {{"operation": "remove", "source_object": "red car", "target_object": null, "location": null, "attributes": null}} |
| "Replace the dog with a cat" -> {{"operation": "replace", "source_object": "dog", "target_object": "cat", "location": null, "attributes": null}} |
| "Make the dress blue" -> {{"operation": "adjust", "source_object": "dress", "target_object": null, "location": null, "attributes": "blue"}} |
| |
| Input: "{instruction}" |
| Output:""" |
|
|
| def _parse_response(self, response: str) -> Dict[str, Any]: |
| default = {"operation": "other", "source_object": None, "target_object": None, "location": None, |
| "attributes": None} |
| try: |
| parsed = json.loads(response.strip()) |
| except json.JSONDecodeError: |
| match = re.search(r'\{[^{}]*\}', response, re.DOTALL) |
| if match: |
| try: |
| parsed = json.loads(match.group()) |
| except: |
| return default |
| else: |
| return default |
| for key in default: |
| if key not in parsed: |
| parsed[key] = default[key] |
| valid_ops = ["remove", "replace", "add", "extract", "style", "adjust", "compose", "action", "other"] |
| if parsed["operation"] not in valid_ops: |
| parsed["operation"] = "other" |
| return parsed |
|
|
|
|
| class SAM3MaskGenerator: |
| """ |
| Generates segmentation masks using SAM3. |
| SAM3 natively supports text prompts - no Grounding DINO needed! |
| """ |
|
|
| def __init__(self, device: str = "cuda"): |
| self.device = device |
| self._model = None |
| self._processor = None |
|
|
| def _load_model(self): |
| """Lazy load SAM3 model.""" |
| if self._model is None: |
| from sam3.model_builder import build_sam3_image_model |
| from sam3.model.sam3_image_processor import Sam3Processor |
|
|
| print("Loading SAM3...") |
| self._model = build_sam3_image_model() |
| self._processor = Sam3Processor(self._model) |
| print("SAM3 loaded!") |
|
|
| def _prepare_image(self, image): |
| """Convert various image formats to PIL Image.""" |
| from PIL import Image as PILImage |
|
|
| if isinstance(image, PILImage.Image): |
| return image.convert("RGB") |
| elif isinstance(image, torch.Tensor): |
| if image.dim() == 4: |
| image = image[0] |
|
|
| if image.dtype in (torch.bfloat16, torch.float16): |
| image = image.float() |
| if image.shape[0] in [1, 3]: |
| image_np = image.permute(1, 2, 0).cpu().numpy() |
| else: |
| image_np = image.cpu().numpy() |
| if image_np.max() <= 1.0: |
| image_np = (image_np * 255).astype(np.uint8) |
| else: |
| image_np = image_np.astype(np.uint8) |
| return PILImage.fromarray(image_np).convert("RGB") |
| elif isinstance(image, np.ndarray): |
| if image.max() <= 1.0: |
| image = (image * 255).astype(np.uint8) |
| return PILImage.fromarray(image).convert("RGB") |
| else: |
| return PILImage.fromarray(np.array(image)).convert("RGB") |
|
|
| @torch.no_grad() |
| def generate_mask( |
| self, |
| image, |
| parsed: Dict, |
| detect_all: bool = False |
| ) -> torch.Tensor: |
| """ |
| Generate segmentation mask using SAM3 with text prompt. |
| |
| Args: |
| image: Input image (PIL, tensor, or numpy) |
| parsed: Parsed instruction dict with 'source_object', 'operation', etc. |
| detect_all: Whether to return all instances |
| |
| Returns: |
| mask: [1, H, W] binary mask tensor |
| """ |
| self._load_model() |
| |
| image_pil = self._prepare_image(image) |
| W, H = image_pil.size |
|
|
| |
| text_prompt = self._build_text_prompt(parsed) |
|
|
| if not text_prompt: |
| if parsed.get("operation") == "style": |
| return torch.ones(1, H, W) |
| return torch.zeros(1, H, W) |
|
|
| |
| inference_state = self._processor.set_image(image_pil) |
|
|
| |
| output = self._processor.set_text_prompt( |
| state=inference_state, |
| prompt=text_prompt |
| ) |
|
|
| masks = output["masks"] |
| scores = output["scores"] |
|
|
| if masks is None or len(masks) == 0: |
| return torch.zeros(1, H, W) |
|
|
| |
| if isinstance(masks, np.ndarray): |
| masks = torch.from_numpy(masks) |
| elif isinstance(masks, list): |
| masks = torch.stack([torch.from_numpy(m) if isinstance(m, np.ndarray) else m for m in masks]) |
|
|
| if detect_all: |
| |
| combined_mask = masks.float().max(dim=0)[0] |
| return combined_mask.unsqueeze(0) |
| else: |
| |
| if isinstance(scores, (list, np.ndarray)): |
| scores = torch.tensor(scores) |
| best_idx = scores.argmax() |
| return masks[best_idx].unsqueeze(0).float() |
|
|
| def _build_text_prompt(self, parsed: Dict) -> str: |
| """Build SAM3 text prompt from parsed instruction.""" |
| operation = parsed.get("operation", "other") |
| source = parsed.get("source_object") |
| target = parsed.get("target_object") |
| location = parsed.get("location") |
| attributes = parsed.get("attributes") |
|
|
| if operation in ["remove", "replace", "extract", "adjust", "action"]: |
| |
| if source: |
| |
| if attributes and operation == "adjust": |
| return source |
| return source |
| elif operation == "add": |
| |
| if source: |
| return source |
| elif location: |
| return location |
| elif operation == "compose": |
| if source: |
| return source |
| elif operation == "style": |
| |
| return "" |
|
|
| return source or "" |
|
|
|
|
| class EditMaskGenerator: |
| """ |
| Complete mask generation pipeline using Qwen3 + SAM3. |
| |
| Simplified from: Qwen3 → Grounding DINO → SAM-2 |
| To: Qwen3 → SAM3 (native text support) |
| """ |
|
|
| def __init__( |
| self, |
| qwen_model: str = "Qwen/Qwen3-1.7B", |
| device: str = "cuda", |
| enabled: bool = True |
| ): |
| self.device = device |
| self.enabled = enabled |
|
|
| if enabled: |
| print("Initializing EditMaskGenerator with SAM3...") |
| self.parser = Qwen3InstructionParser(model_name=qwen_model, device=device) |
| self.segmenter = SAM3MaskGenerator(device=device) |
| print("EditMaskGenerator ready!") |
| else: |
| self.parser = None |
| self.segmenter = None |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| image, |
| instruction: str, |
| return_parsed: bool = False |
| ): |
| """Generate edit mask from image and instruction.""" |
| if not self.enabled: |
| if isinstance(image, torch.Tensor): |
| H, W = image.shape[-2:] |
| else: |
| H, W = np.array(image).shape[:2] |
| mask = torch.zeros(1, H, W) |
| return (mask, {"operation": "other"}) if return_parsed else mask |
|
|
| |
| parsed = self.parser.parse(instruction) |
|
|
| |
| detect_all = "all" in instruction.lower() |
| mask = self.segmenter.generate_mask(image, parsed, detect_all=detect_all) |
|
|
| return (mask, parsed) if return_parsed else mask |
|
|
|
|
| |
| |
| |
| class blip3oFastConfig(Qwen2Config): |
| model_type = "llava_qwen2" |
|
|
| def __init__(self, **kwargs): |
| super().__init__(**kwargs) |
|
|
| self.latent_channels = kwargs.get("latent_channels", 32) |
|
|
| |
| self.use_spatial_conditioning = kwargs.get("use_spatial_conditioning", False) |
| self.use_mask_conditioning = kwargs.get("use_mask_conditioning", True) |
| self.use_operation_embedding = kwargs.get("use_operation_embedding", True) |
| self.use_mask_predictor = kwargs.get("use_mask_predictor", True) |
| self.mask_predictor_loss_weight = kwargs.get("mask_predictor_loss_weight", 0.5) |
|
|
| |
| self.spatial_drop_prob = kwargs.get("spatial_drop_prob", 0.1) |
| self.mask_drop_prob = kwargs.get("mask_drop_prob", 0.1) |
|
|
| |
| self.mask_generator_enabled = kwargs.get("mask_generator_enabled", True) |
| self.qwen_model = kwargs.get("qwen_model", "Qwen/Qwen3-1.7B") |
|
|
|
|
| |
| |
| |
|
|
| class BF16SafeLayerNorm(nn.Module): |
| """LayerNorm that works correctly with BF16 and PEFT.""" |
| |
| def __init__(self, hidden_size: int, eps: float = 1e-6): |
| super().__init__() |
| |
| self.weight = nn.Parameter(torch.ones(hidden_size, dtype=torch.float32)) |
| self.bias = nn.Parameter(torch.zeros(hidden_size, dtype=torch.float32)) |
| self.eps = eps |
| self.hidden_size = hidden_size |
| |
| |
| self.reset_parameters() |
| |
| def reset_parameters(self): |
| """Ensure weights are properly initialized.""" |
| nn.init.ones_(self.weight) |
| nn.init.zeros_(self.bias) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| input_dtype = x.dtype |
| x_f32 = x.float() |
| |
| |
| mean = x_f32.mean(dim=-1, keepdim=True) |
| var = x_f32.var(dim=-1, keepdim=True, unbiased=False) |
| x_norm = (x_f32 - mean) / torch.sqrt(var + self.eps) |
| |
| |
| output = x_norm * self.weight.float() + self.bias.float() |
| |
| |
| return output.to(input_dtype) |
|
|
|
|
| class MaskPredictor(nn.Module): |
| """ |
| Predicts edit mask from LLM hidden states. |
| This is the KEY component that enables mask-free inference. |
| |
| The mask predictor learns to identify WHICH object needs to be edited |
| based on the instruction (e.g., "remove the white dog") and the image |
| understanding encoded in the LLM hidden states. |
| |
| Architecture: |
| 1. Extract instruction-relevant features using attention pooling |
| 2. Project to spatial features |
| 3. Decode to mask |
| """ |
|
|
| def __init__(self, hidden_size: int, latent_channels: int, latent_size: int = 32): |
| super().__init__() |
|
|
| self.latent_size = latent_size |
| self.hidden_size = hidden_size |
|
|
| |
| |
| self.attention_pool = nn.Sequential( |
| nn.Linear(hidden_size, hidden_size // 4), |
| nn.Tanh(), |
| nn.Linear(hidden_size // 4, 1), |
| ) |
|
|
| |
| self.input_norm = BF16SafeLayerNorm(hidden_size) |
|
|
| |
| intermediate_size = hidden_size // 2 |
| spatial_dim = latent_size * latent_size * 64 |
| |
| self.hidden_proj = nn.Sequential( |
| nn.Linear(hidden_size, intermediate_size), |
| nn.LayerNorm(intermediate_size), |
| nn.GELU(), |
| nn.Dropout(0.1), |
| nn.Linear(intermediate_size, intermediate_size), |
| nn.LayerNorm(intermediate_size), |
| nn.GELU(), |
| nn.Dropout(0.1), |
| nn.Linear(intermediate_size, spatial_dim), |
| ) |
|
|
| |
| self.mask_decoder = nn.Sequential( |
| nn.Conv2d(64, 256, 3, padding=1), |
| nn.GroupNorm(32, 256), |
| nn.GELU(), |
| nn.Conv2d(256, 128, 3, padding=1), |
| nn.GroupNorm(16, 128), |
| nn.GELU(), |
| nn.Conv2d(128, 64, 3, padding=1), |
| nn.GroupNorm(8, 64), |
| nn.GELU(), |
| nn.Conv2d(64, 1, 1), |
| ) |
|
|
| self._init_weights() |
|
|
| def _init_weights(self): |
| """Initialize weights for stable training.""" |
| |
| for module in self.attention_pool: |
| if isinstance(module, nn.Linear): |
| nn.init.xavier_uniform_(module.weight, gain=0.1) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| |
| |
| if hasattr(self, 'input_norm'): |
| self.input_norm.reset_parameters() |
| |
| |
| for module in self.hidden_proj: |
| if isinstance(module, nn.Linear): |
| nn.init.xavier_uniform_(module.weight, gain=0.1) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.LayerNorm): |
| nn.init.ones_(module.weight) |
| nn.init.zeros_(module.bias) |
| |
| |
| for module in self.mask_decoder: |
| if isinstance(module, nn.Conv2d): |
| nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.GroupNorm): |
| nn.init.ones_(module.weight) |
| nn.init.zeros_(module.bias) |
| |
| |
| for module in reversed(list(self.mask_decoder)): |
| if isinstance(module, nn.Conv2d): |
| nn.init.normal_(module.weight, mean=0.0, std=0.01) |
| nn.init.zeros_(module.bias) |
| break |
|
|
| def forward(self, hidden_states: torch.Tensor, return_logits: bool = False) -> torch.Tensor: |
| """ |
| Args: |
| hidden_states: [B, seq_len, hidden_size] from LLM |
| return_logits: If True, return logits instead of probabilities |
| |
| Returns: |
| mask: [B, 1, H, W] predicted edit mask |
| """ |
| batch_size = hidden_states.shape[0] |
| device = hidden_states.device |
| |
| |
| if torch.isnan(hidden_states).any() or torch.isinf(hidden_states).any(): |
| print("WARNING: NaN/Inf in hidden_states input to MaskPredictor") |
| if return_logits: |
| return torch.zeros(batch_size, 1, self.latent_size, self.latent_size, |
| device=device, dtype=torch.float32, requires_grad=True) |
| return torch.full((batch_size, 1, self.latent_size, self.latent_size), 0.5, |
| device=device, dtype=torch.float32, requires_grad=True) |
| |
| |
| hidden_states = self.input_norm(hidden_states) |
| |
| if torch.isnan(hidden_states).any(): |
| print("WARNING: NaN after input_norm in MaskPredictor") |
| if return_logits: |
| return torch.zeros(batch_size, 1, self.latent_size, self.latent_size, |
| device=device, dtype=torch.float32, requires_grad=True) |
| return torch.full((batch_size, 1, self.latent_size, self.latent_size), 0.5, |
| device=device, dtype=torch.float32, requires_grad=True) |
|
|
| |
| target_dtype = self.attention_pool[0].weight.dtype |
| hidden_states = hidden_states.to(target_dtype) |
| |
| |
| |
| attn_weights = self.attention_pool(hidden_states) |
| attn_weights = F.softmax(attn_weights, dim=1) |
| |
| |
| |
| pooled = (hidden_states * attn_weights).sum(dim=1) |
|
|
| |
| spatial = self.hidden_proj(pooled) |
| spatial = spatial.view(-1, 64, self.latent_size, self.latent_size) |
|
|
| |
| mask_logits = self.mask_decoder(spatial) |
|
|
| if return_logits: |
| return mask_logits.float() |
|
|
| |
| mask = torch.sigmoid(mask_logits.float()) |
| return mask |
|
|
|
|
| |
| |
| |
|
|
| class blip3oFastModel(LlavaMetaModel, Qwen2Model): |
| config_class = blip3oFastConfig |
|
|
| def __init__(self, config: Qwen2Config): |
| super(blip3oFastModel, self).__init__(config) |
|
|
|
|
| class blip3oFastForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM): |
| config_class = blip3oFastConfig |
|
|
| def __init__(self, config): |
| super(Qwen2ForCausalLM, self).__init__(config) |
|
|
| self.model = blip3oFastModel(config) |
| self.vocab_size = config.vocab_size |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| latent_channels = getattr(config, 'latent_channels', 32) |
|
|
| |
| |
| |
| if getattr(config, 'use_spatial_conditioning', True): |
| self.spatial_ref_encoder = nn.Sequential( |
| nn.Conv2d(latent_channels, 320, 3, padding=1), |
| nn.GroupNorm(32, 320), |
| nn.SiLU(), |
| nn.Conv2d(320, 320, 3, padding=1), |
| nn.GroupNorm(32, 320), |
| nn.SiLU(), |
| nn.Conv2d(320, latent_channels, 3, padding=1), |
| ) |
| self.spatial_weight = nn.Parameter(torch.tensor(0.0)) |
| else: |
| self.spatial_ref_encoder = None |
| self.spatial_weight = None |
|
|
| |
| |
| |
| if getattr(config, 'use_mask_conditioning', True): |
| self.mask_encoder = nn.Sequential( |
| nn.Conv2d(1, 64, 3, padding=1), |
| nn.GroupNorm(8, 64), |
| nn.SiLU(), |
| nn.Conv2d(64, 128, 3, padding=1), |
| nn.GroupNorm(16, 128), |
| nn.SiLU(), |
| nn.Conv2d(128, latent_channels, 3, padding=1), |
| ) |
| self.mask_weight = nn.Parameter(torch.tensor(0.0)) |
| else: |
| self.mask_encoder = None |
| self.mask_weight = None |
|
|
| |
| |
| |
| if getattr(config, 'use_mask_predictor', True): |
| self.mask_predictor = MaskPredictor( |
| hidden_size=config.hidden_size, |
| latent_channels=latent_channels, |
| latent_size=32 |
| ) |
| else: |
| self.mask_predictor = None |
|
|
| |
| |
| |
| if getattr(config, 'use_operation_embedding', True): |
| self.operation_types = ["remove", "replace", "add", "extract", "style", "adjust", "compose", "action", |
| "other"] |
| self.operation_embedding = nn.Embedding(len(self.operation_types), latent_channels) |
| else: |
| self.operation_types = None |
| self.operation_embedding = None |
|
|
| |
| self._mask_generator = None |
| self._mask_generator_initialized = False |
|
|
| self._init_conditioning_layers() |
| self.post_init() |
|
|
| def _init_conditioning_layers(self): |
| """Initialize conditioning layers. Called during __init__ and can be called after loading.""" |
| if self.spatial_ref_encoder is not None: |
| for module in self.spatial_ref_encoder: |
| if isinstance(module, nn.Conv2d): |
| nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.GroupNorm): |
| nn.init.ones_(module.weight) |
| nn.init.zeros_(module.bias) |
| |
| nn.init.zeros_(self.spatial_ref_encoder[-1].weight) |
| nn.init.zeros_(self.spatial_ref_encoder[-1].bias) |
| |
| if self.mask_encoder is not None: |
| for module in self.mask_encoder: |
| if isinstance(module, nn.Conv2d): |
| nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.GroupNorm): |
| nn.init.ones_(module.weight) |
| nn.init.zeros_(module.bias) |
| |
| nn.init.zeros_(self.mask_encoder[-1].weight) |
| nn.init.zeros_(self.mask_encoder[-1].bias) |
| |
| def reinitialize_new_modules(self): |
| """ |
| Reinitialize modules that were added after the base model. |
| Call this after loading a pretrained model to fix uninitialized weights. |
| """ |
| print("Reinitializing new modules (mask_predictor, mask_encoder, spatial_ref_encoder)...") |
| |
| |
| if self.mask_predictor is not None: |
| self.mask_predictor._init_weights() |
| print(" - mask_predictor reinitialized") |
| |
| |
| self._init_conditioning_layers() |
| print(" - conditioning layers reinitialized") |
| |
| |
| if self.operation_embedding is not None: |
| nn.init.normal_(self.operation_embedding.weight, mean=0.0, std=0.02) |
| print(" - operation_embedding reinitialized") |
| |
| |
| if self.spatial_weight is not None: |
| nn.init.zeros_(self.spatial_weight) |
| print(" - spatial_weight reinitialized to 0") |
| if self.mask_weight is not None: |
| nn.init.zeros_(self.mask_weight) |
| print(" - mask_weight reinitialized to 0") |
| |
| print("Reinitialization complete!") |
|
|
| @property |
| def mask_generator(self) -> EditMaskGenerator: |
| """Lazy init mask generator (training only).""" |
| if not self._mask_generator_initialized: |
| enabled = getattr(self.config, 'mask_generator_enabled', True) and self.training |
| if enabled: |
| |
| self._mask_generator = EditMaskGenerator( |
| qwen_model=getattr(self.config, 'qwen_model', "Qwen/Qwen3-1.7B"), |
| device=str(self.device), |
| enabled=True |
| ) |
| else: |
| self._mask_generator = EditMaskGenerator(enabled=False) |
| self._mask_generator_initialized = True |
| return self._mask_generator |
|
|
| def get_model(self): |
| return self.model |
|
|
| def mask_drop(self, latents: torch.Tensor, drop_prob: float = 0.1) -> torch.Tensor: |
| if drop_prob <= 0 or not self.training: |
| return latents |
| mask = torch.bernoulli(torch.full((latents.shape[0],), drop_prob, device=latents.device, dtype=latents.dtype)) |
| while len(mask.shape) < len(latents.shape): |
| mask = mask.unsqueeze(-1) |
| return latents * (1 - mask) |
|
|
| def get_operation_index(self, operation: str) -> int: |
| if self.operation_types is None: |
| return 0 |
| return self.operation_types.index( |
| operation) if operation in self.operation_types else self.operation_types.index("other") |
|
|
| def _normalize_mask(self, mask, H, W, device): |
| """ |
| Always return a single mask: [1, H, W] |
| """ |
| if mask is None: |
| return torch.zeros(1, H, W, device=device) |
|
|
| |
| if not isinstance(mask, torch.Tensor): |
| mask = torch.from_numpy(mask) |
|
|
| mask = mask.to(device) |
|
|
| |
| if mask.dim() == 4: |
| mask = mask[:, 0] |
| |
| mask = mask.max(dim=0, keepdim=True)[0] |
|
|
| elif mask.dim() == 3: |
| pass |
|
|
| elif mask.dim() == 2: |
| mask = mask.unsqueeze(0) |
|
|
| else: |
| raise ValueError(f"Unexpected mask shape: {mask.shape}") |
|
|
| return mask |
|
|
| def _generate_masks_on_fly(self, und_images: torch.Tensor, instructions: List[str]) -> Tuple[ |
| torch.Tensor, List[str]]: |
| """Generate GT masks using Qwen3 + Grounded SAM-2 (training only).""" |
| masks, operations = [], [] |
| B, _, H, W = und_images.shape |
| for i in range(und_images.shape[0]): |
| try: |
| mask, parsed = self.mask_generator.generate(und_images[i], instructions[i], return_parsed=True) |
| mask = self._normalize_mask(mask, H=H, W=W, device=und_images.device) |
| masks.append(mask) |
| operations.append(parsed.get("operation", "other")) |
| except Exception as e: |
| print(f"Mask generation failed: {e}") |
| masks.append(torch.zeros(1, H, W, device=und_images.device)) |
| operations.append("other") |
| return torch.stack(masks).to(und_images.device), operations |
|
|
| |
| |
| |
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| gen_image: Optional[torch.FloatTensor] = None, |
| und_image: Optional[torch.FloatTensor] = None, |
| edit_mask: Optional[torch.FloatTensor] = None, |
| operations: Optional[List[str]] = None, |
| instructions: Optional[List[str]] = None, |
| categories: Optional[List[str]] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None |
| ) -> Union[Tuple, CausalLMOutputWithPast]: |
|
|
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = True |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if inputs_embeds is None: |
| (input_ids, position_ids, attention_mask, past_key_values, |
| inputs_embeds, labels, latents) = self.prepare_inputs_labels_for_multimodal( |
| input_ids, position_ids, attention_mask, past_key_values, labels, gen_image, und_image) |
|
|
| |
| output = super().forward( |
| input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, |
| past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, |
| use_cache=use_cache, output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, return_dict=return_dict |
| ) |
|
|
| ce_loss = output.loss |
| hidden_states = output.hidden_states |
| logits = output.logits |
| img_hidden_states = hidden_states |
|
|
| assert latents is not None |
|
|
| |
| |
| |
| if edit_mask is None and instructions is not None and self.training: |
| if getattr(self.config, 'mask_generator_enabled', True): |
| edit_mask, operations = self._generate_masks_on_fly(und_image, instructions) |
|
|
| |
| |
| |
| mask_pred_loss = torch.tensor(0.0, device=latents.device) |
| predicted_mask = None |
| mask_logits = None |
| gt_mask_resized = None |
|
|
| if self.mask_predictor is not None: |
| |
| last_hidden = hidden_states[-1] |
|
|
| |
| mask_logits = self.mask_predictor(last_hidden, return_logits=True) |
|
|
| |
| mask_logits = F.interpolate( |
| mask_logits.float(), |
| size=(latents.shape[2], latents.shape[3]), |
| mode='bilinear', |
| align_corners=False |
| ) |
|
|
| |
| predicted_mask = torch.sigmoid(mask_logits) |
|
|
| |
| if edit_mask is not None and self.training: |
| gt_mask_resized = F.interpolate( |
| edit_mask.float().to(latents.device), |
| size=(latents.shape[2], latents.shape[3]), |
| mode='nearest' |
| ) |
|
|
| |
| if not torch.isnan(mask_logits).any() and not torch.isnan(gt_mask_resized).any(): |
| |
| mask_pred_loss = F.binary_cross_entropy_with_logits( |
| mask_logits, |
| gt_mask_resized, |
| reduction='mean' |
| ) |
| else: |
| print("WARNING: NaN in mask_logits or gt_mask, skipping mask_pred_loss") |
| mask_pred_loss = torch.tensor(0.0, device=latents.device) |
|
|
| |
| |
| |
| noise = torch.randn_like(latents) |
| weighting_scheme = "uniform" |
| u = compute_density_for_timestep_sampling( |
| weighting_scheme=weighting_scheme, batch_size=latents.shape[0], |
| logit_mean=0.0, logit_std=1.0, mode_scale=1.29 |
| ) |
| indices = (u * self.get_model().noise_scheduler.config.num_train_timesteps).long() |
| timesteps = self.get_model().noise_scheduler.timesteps[indices].to(device=latents.device) |
| sigmas = self.get_sigmas(timesteps, latents.device, n_dim=latents.ndim, dtype=latents.dtype) |
|
|
| |
| |
| |
| if self.spatial_ref_encoder is not None: |
| vae = self.get_model().get_sana_vae() |
| ref_latents = vae.encode(und_image.to(vae.device)).latent * vae.config.scaling_factor |
| ref_latents = ref_latents.to(latents.device) |
| spatial_cond = self.spatial_ref_encoder(ref_latents) |
| spatial_cond = self.mask_drop(spatial_cond, getattr(self.config, 'spatial_drop_prob', 0.1)) |
| else: |
| spatial_cond = 0 |
|
|
| |
| |
| |
| if self.mask_encoder is not None and edit_mask is not None: |
| mask_latent = F.interpolate( |
| edit_mask.float().to(latents.device), |
| size=(latents.shape[2], latents.shape[3]), |
| mode='nearest' |
| ) |
| mask_latent = mask_latent.clamp(0.0, 1.0) |
| |
| |
| mask_cond = mask_latent |
| for layer in self.mask_encoder: |
| if isinstance(layer, nn.Conv2d): |
| mask_cond = F.conv2d(mask_cond, layer.weight.float(), |
| layer.bias.float() if layer.bias is not None else None, |
| layer.stride, layer.padding) |
| elif isinstance(layer, nn.GroupNorm): |
| mask_cond = F.group_norm(mask_cond, layer.num_groups, |
| layer.weight.float(), layer.bias.float(), layer.eps) |
| else: |
| mask_cond = layer(mask_cond) |
| |
| |
| mask_cond = mask_cond.to(latents.dtype) |
| mask_cond = self.mask_drop(mask_cond, getattr(self.config, 'mask_drop_prob', 0.1)) |
| else: |
| mask_cond = 0 |
| mask_latent = None |
|
|
| |
| |
| |
| if self.operation_embedding is not None and operations is not None: |
| op_indices = torch.tensor([self.get_operation_index(op) for op in operations], device=latents.device) |
| op_embed = self.operation_embedding(op_indices)[:, :, None, None] |
| op_cond = op_embed * mask_latent if mask_latent is not None else op_embed.expand(-1, -1, latents.shape[2], |
| latents.shape[3]) |
| else: |
| op_cond = 0 |
|
|
| |
| |
| |
| noisy_latents = (1.0 - sigmas) * latents + sigmas * noise |
| combined_input = noisy_latents |
|
|
| if self.mask_weight is not None and isinstance(mask_cond, torch.Tensor): |
| combined_input = combined_input + self.mask_weight * mask_cond |
|
|
| |
| |
| |
| fused_features = self.get_model().diffusion_connector(img_hidden_states) |
| |
| diffusion_pred = self.get_model().dit( |
| hidden_states=combined_input, timestep=timesteps, |
| encoder_hidden_states=fused_features, encoder_attention_mask=attention_mask |
| ).sample |
|
|
| target = latents - noise |
|
|
| weighting = compute_loss_weighting_for_sd3(weighting_scheme=weighting_scheme, sigmas=sigmas) |
| diff_loss = torch.mean( |
| (weighting.float() * (diffusion_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1 |
| ).mean() |
|
|
| |
| |
| |
| mask_pred_weight = getattr(self.config, 'mask_predictor_loss_weight', 0.5) |
| total_loss = diff_loss + 0.2 * ce_loss + mask_pred_weight * mask_pred_loss |
|
|
| |
| if self.training: |
| print(f"Loss - diff: {diff_loss.item():.4f}, ce: {ce_loss.item():.4f}, mask_pred: {mask_pred_loss.item() if isinstance(mask_pred_loss, torch.Tensor) else 0:.4f}") |
|
|
| return CausalLMOutputWithPast( |
| loss=total_loss, logits=logits, past_key_values=output.past_key_values, |
| hidden_states=output.hidden_states, attentions=output.attentions |
| ) |
|
|
| |
| |
| |
| @torch.no_grad() |
| def generate_edited_image( |
| self, |
| und_image: torch.Tensor, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| num_inference_steps: int = 50, |
| guidance_scale: float = 7.5, |
| spatial_guidance_scale: float = 1.0, |
| mask_guidance_scale: float = 1.0, |
| generator: Optional[torch.Generator] = None, |
| ) -> torch.Tensor: |
| """ |
| Lightweight inference - uses learned mask predictor instead of SAM-2. |
| |
| Args: |
| und_image: Input image tensor [B, C, H, W] |
| input_ids: Tokenized prompt [B, seq_len] |
| attention_mask: Attention mask [B, seq_len] |
| num_inference_steps: Denoising steps |
| guidance_scale: CFG scale for text |
| spatial_guidance_scale: Scale for spatial conditioning |
| mask_guidance_scale: Scale for predicted mask conditioning |
| generator: Random generator for reproducibility |
| |
| Returns: |
| Edited image latents [B, C, H, W] |
| """ |
|
|
| device = und_image.device |
| dtype = und_image.dtype |
| batch_size = und_image.shape[0] |
|
|
| |
| |
| |
| (input_ids_mm, position_ids, attention_mask_mm, _, |
| inputs_embeds, _, _) = self.prepare_inputs_labels_for_multimodal( |
| input_ids, None, attention_mask, None, None, None, und_image |
| ) |
|
|
| output = Qwen2ForCausalLM.forward( |
| self, |
| input_ids=input_ids_mm, |
| attention_mask=attention_mask_mm, |
| position_ids=position_ids, |
| inputs_embeds=inputs_embeds, |
| output_hidden_states=True, |
| return_dict=True |
| ) |
|
|
| hidden_states = output.hidden_states |
| img_hidden_states = hidden_states |
|
|
| |
| |
| |
| if self.mask_predictor is not None: |
| last_hidden = hidden_states[-1] |
| predicted_mask = self.mask_predictor(last_hidden) |
| else: |
| predicted_mask = None |
|
|
| |
| |
| |
| vae = self.get_model().get_sana_vae() |
| ref_latents = vae.encode(und_image.to(vae.device)).latent * vae.config.scaling_factor |
| ref_latents = ref_latents.to(device) |
|
|
| latent_h, latent_w = ref_latents.shape[2], ref_latents.shape[3] |
| latent_channels = ref_latents.shape[1] |
|
|
| |
| if predicted_mask is not None: |
| predicted_mask = F.interpolate( |
| predicted_mask, size=(latent_h, latent_w), mode='bilinear', align_corners=False |
| ) |
|
|
| |
| |
| |
| |
| if self.spatial_ref_encoder is not None: |
| spatial_cond = self.spatial_ref_encoder(ref_latents) |
| else: |
| spatial_cond = torch.zeros_like(ref_latents) |
|
|
| |
| if self.mask_encoder is not None and predicted_mask is not None: |
| mask_cond = self.mask_encoder(predicted_mask.to(dtype=self.mask_encoder[0].weight.dtype)) |
| else: |
| mask_cond = torch.zeros_like(ref_latents) |
|
|
| |
| fused_features = self.get_model().diffusion_connector(img_hidden_states) |
|
|
| |
| |
| |
| if guidance_scale > 1.0: |
| |
| spatial_cond_uncond = torch.zeros_like(spatial_cond) |
| mask_cond_uncond = torch.zeros_like(mask_cond) |
| fused_features_uncond = torch.zeros_like(fused_features) |
|
|
| |
| spatial_cond_cfg = torch.cat([spatial_cond_uncond, spatial_cond]) |
| mask_cond_cfg = torch.cat([mask_cond_uncond, mask_cond]) |
| fused_features_cfg = torch.cat([fused_features_uncond, fused_features]) |
| else: |
| spatial_cond_cfg = spatial_cond |
| mask_cond_cfg = mask_cond |
| fused_features_cfg = fused_features |
|
|
| |
| |
| |
| latents = randn_tensor( |
| (batch_size, latent_channels, latent_h, latent_w), |
| generator=generator, device=device, dtype=dtype |
| ) |
|
|
| |
| |
| |
| scheduler = self.get_model().noise_scheduler |
| scheduler.set_timesteps(num_inference_steps, device=device) |
| timesteps = scheduler.timesteps |
|
|
| |
| |
| |
| for t in timesteps: |
| |
| if guidance_scale > 1.0: |
| latent_model_input = torch.cat([latents] * 2) |
| t_input = torch.cat([t.unsqueeze(0)] * 2 * batch_size) |
| else: |
| latent_model_input = latents |
| t_input = t.unsqueeze(0).expand(batch_size) |
|
|
| |
| combined_input = latent_model_input |
| if self.spatial_weight is not None: |
| combined_input = combined_input + spatial_guidance_scale * self.spatial_weight * spatial_cond_cfg |
| if self.mask_weight is not None: |
| combined_input = combined_input + mask_guidance_scale * self.mask_weight * mask_cond_cfg |
|
|
| |
| noise_pred = self.get_model().dit( |
| hidden_states=combined_input, |
| timestep=t_input, |
| encoder_hidden_states=fused_features_cfg, |
| ).sample |
|
|
| |
| if guidance_scale > 1.0: |
| noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) |
|
|
| |
| latents = scheduler.step(noise_pred, t, latents).prev_sample |
|
|
| |
| |
| |
| latents = latents / vae.config.scaling_factor |
| image = vae.decode(latents.to(vae.device)).sample |
|
|
| return image |
|
|
|
|
| |
| |
| |
|
|
| AutoConfig.register("llava_qwen2", blip3oFastConfig) |
| AutoModelForCausalLM.register(blip3oFastConfig, blip3oFastForCausalLM) |
|
|