| | """ |
| | BLIP3o Fast - Unified Image Understanding and Generation with Mask Prediction |
| | |
| | This module provides: |
| | - Training: Diffusion-based image editing with mask supervision from SAM |
| | - Inference: Lightweight mask-free editing using learned MaskPredictor |
| | |
| | Key Components (from llava_arch.py): |
| | - MaskPredictor: Learns to predict edit regions from LLM hidden states |
| | - MaskEncoder: Encodes masks for diffusion conditioning |
| | - mask_weight/spatial_weight: Learnable conditioning scales (SAVED with model!) |
| | |
| | Training Flow: |
| | 1. LLM processes image + instruction → hidden states |
| | 2. MaskPredictor predicts edit mask (supervised by SAM) |
| | 3. Diffusion generates edited image with mask conditioning |
| | |
| | Inference Flow: |
| | 1. LLM processes image + instruction → hidden states |
| | 2. MaskPredictor predicts edit mask (NO SAM needed!) |
| | 3. Diffusion generates edited image |
| | """ |
| |
|
| | from typing import List, Optional, Tuple, Union, Dict, Any |
| | import json |
| | import re |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from transformers import ( |
| | AutoConfig, |
| | AutoModelForCausalLM, |
| | AutoTokenizer, |
| | Qwen2Config, |
| | Qwen2Model, |
| | Qwen2ForCausalLM |
| | ) |
| | from transformers.modeling_outputs import CausalLMOutputWithPast |
| | from diffusers.training_utils import ( |
| | compute_density_for_timestep_sampling, |
| | compute_loss_weighting_for_sd3 |
| | ) |
| | from diffusers.utils.torch_utils import randn_tensor |
| |
|
| | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class Qwen3InstructionParser: |
| | """Parses edit instructions using Qwen3 LLM. Used only during training.""" |
| |
|
| | def __init__( |
| | self, |
| | model_name: str = "Qwen/Qwen3-1.7B", |
| | device: str = "cuda", |
| | torch_dtype: torch.dtype = torch.float16 |
| | ): |
| | self.device = device |
| | self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| | self.model = AutoModelForCausalLM.from_pretrained( |
| | model_name, |
| | torch_dtype=torch_dtype, |
| | device_map=device |
| | ) |
| | self.model.eval() |
| | self._cache: Dict[str, Dict] = {} |
| |
|
| | @torch.no_grad() |
| | def parse(self, instruction: str) -> Dict[str, Any]: |
| | if instruction in self._cache: |
| | return self._cache[instruction] |
| |
|
| | prompt = self._build_prompt(instruction) |
| | messages = [{"role": "user", "content": prompt}] |
| | text = self.tokenizer.apply_chat_template( |
| | messages, tokenize=False, add_generation_prompt=True, enable_thinking=False |
| | ) |
| | inputs = self.tokenizer(text, return_tensors="pt").to(self.device) |
| | outputs = self.model.generate( |
| | **inputs, max_new_tokens=256, temperature=0.1, |
| | do_sample=False, pad_token_id=self.tokenizer.eos_token_id |
| | ) |
| | response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) |
| | parsed = self._parse_response(response) |
| | self._cache[instruction] = parsed |
| | return parsed |
| |
|
| | def _build_prompt(self, instruction: str) -> str: |
| | return f"""You are an image editing instruction parser. Extract structured information. |
| | |
| | Respond ONLY with valid JSON: |
| | {{"operation": "<type>", "source_object": "<object or null>", "target_object": "<object or null>", "location": "<location or null>", "attributes": "<attributes or null>"}} |
| | |
| | Operation types: remove, replace, add, extract, style, adjust, compose, action, other |
| | |
| | Examples: |
| | "Remove the red car" -> {{"operation": "remove", "source_object": "red car", "target_object": null, "location": null, "attributes": null}} |
| | "Replace the dog with a cat" -> {{"operation": "replace", "source_object": "dog", "target_object": "cat", "location": null, "attributes": null}} |
| | "Make the dress blue" -> {{"operation": "adjust", "source_object": "dress", "target_object": null, "location": null, "attributes": "blue"}} |
| | |
| | Input: "{instruction}" |
| | Output:""" |
| |
|
| | def _parse_response(self, response: str) -> Dict[str, Any]: |
| | default = {"operation": "other", "source_object": None, "target_object": None, "location": None, "attributes": None} |
| | try: |
| | parsed = json.loads(response.strip()) |
| | except json.JSONDecodeError: |
| | match = re.search(r'\{[^{}]*\}', response, re.DOTALL) |
| | if match: |
| | try: |
| | parsed = json.loads(match.group()) |
| | except: |
| | return default |
| | else: |
| | return default |
| | return {**default, **parsed} |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class EditMaskGenerator: |
| | """Generates ground truth edit masks using Qwen3 + SAM. Training only.""" |
| |
|
| | def __init__( |
| | self, |
| | qwen_model: str = "Qwen/Qwen3-1.7B", |
| | sam_model: str = "facebook/sam2.1-hiera-large", |
| | device: str = "cuda", |
| | enabled: bool = True |
| | ): |
| | self.enabled = enabled |
| | self.device = device |
| | if not enabled: |
| | return |
| |
|
| | self.parser = Qwen3InstructionParser(model_name=qwen_model, device=device) |
| |
|
| | try: |
| | from sam2.sam2_image_predictor import SAM2ImagePredictor |
| | self.sam = SAM2ImagePredictor.from_pretrained(sam_model, device=device) |
| | except ImportError: |
| | print("WARNING: SAM2 not installed. Mask generation disabled.") |
| | self.enabled = False |
| |
|
| | def generate(self, image: torch.Tensor, instruction: str, return_parsed: bool = False): |
| | """Generate edit mask from image and instruction.""" |
| | if not self.enabled: |
| | H, W = image.shape[-2:] |
| | mask = torch.zeros(1, H, W, device=self.device) |
| | return (mask, {"operation": "other"}) if return_parsed else mask |
| |
|
| | |
| | parsed = self.parser.parse(instruction) |
| | source_object = parsed.get("source_object") |
| |
|
| | if not source_object: |
| | H, W = image.shape[-2:] |
| | mask = torch.zeros(1, H, W, device=self.device) |
| | return (mask, parsed) if return_parsed else mask |
| |
|
| | |
| | if image.dim() == 3: |
| | image = image.unsqueeze(0) |
| | image_np = ((image[0].permute(1, 2, 0).cpu().numpy() + 1) * 127.5).astype("uint8") |
| |
|
| | |
| | with torch.inference_mode(): |
| | self.sam.set_image(image_np) |
| | |
| | |
| | H, W = image_np.shape[:2] |
| | point_coords = [[W // 2, H // 2]] |
| | point_labels = [1] |
| | |
| | masks, scores, _ = self.sam.predict( |
| | point_coords=point_coords, |
| | point_labels=point_labels, |
| | multimask_output=True |
| | ) |
| |
|
| | |
| | best_idx = scores.argmax() |
| | mask = torch.from_numpy(masks[best_idx]).float().unsqueeze(0).to(self.device) |
| |
|
| | return (mask, parsed) if return_parsed else mask |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class blip3oFastConfig(Qwen2Config): |
| | model_type = "llava_qwen2" |
| |
|
| | def __init__( |
| | self, |
| | use_mask_predictor: bool = True, |
| | use_mask_conditioning: bool = True, |
| | use_spatial_conditioning: bool = False, |
| | use_operation_embedding: bool = False, |
| | mask_predictor_loss_weight: float = 0.5, |
| | latent_channels: int = 32, |
| | latent_size: int = 32, |
| | num_operation_types: int = 10, |
| | **kwargs |
| | ): |
| | super().__init__(**kwargs) |
| | self.use_mask_predictor = use_mask_predictor |
| | self.use_mask_conditioning = use_mask_conditioning |
| | self.use_spatial_conditioning = use_spatial_conditioning |
| | self.use_operation_embedding = use_operation_embedding |
| | self.mask_predictor_loss_weight = mask_predictor_loss_weight |
| | self.latent_channels = latent_channels |
| | self.latent_size = latent_size |
| | self.num_operation_types = num_operation_types |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class blip3oFastModel(LlavaMetaModel, Qwen2Model): |
| | config_class = blip3oFastConfig |
| |
|
| | def __init__(self, config: Qwen2Config): |
| | super(blip3oFastModel, self).__init__(config) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class blip3oFastForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM): |
| | """ |
| | BLIP3o Fast model for training. |
| | |
| | All mask-related components (mask_predictor, mask_encoder, mask_weight, etc.) |
| | are defined in LlavaMetaModel and accessed via properties in LlavaMetaForCausalLM. |
| | This ensures they are saved/loaded with the model. |
| | """ |
| | |
| | config_class = blip3oFastConfig |
| |
|
| | def __init__(self, config): |
| | super(blip3oFastForCausalLM, self).__init__(config) |
| | config.model_type = "llava_qwen2" |
| |
|
| | self.model = blip3oFastModel(config) |
| | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| |
|
| | |
| | self.operation_types = ["remove", "replace", "add", "extract", "style", |
| | "adjust", "compose", "action", "inpaint", "other"] |
| |
|
| | |
| | self._mask_generator = None |
| | self._mask_generator_initialized = False |
| |
|
| | self.post_init() |
| |
|
| | def get_model(self): |
| | return self.model |
| |
|
| | |
| | |
| | |
| |
|
| | @property |
| | def mask_generator(self) -> EditMaskGenerator: |
| | """Lazy init mask generator (training only).""" |
| | if not self._mask_generator_initialized: |
| | enabled = getattr(self.config, 'mask_generator_enabled', True) and self.training |
| | if enabled: |
| | self._mask_generator = EditMaskGenerator( |
| | qwen_model=getattr(self.config, 'qwen_model', "Qwen/Qwen3-1.7B"), |
| | device=str(self.device), |
| | enabled=True |
| | ) |
| | else: |
| | self._mask_generator = EditMaskGenerator(enabled=False) |
| | self._mask_generator_initialized = True |
| | return self._mask_generator |
| |
|
| | def get_operation_index(self, operation: str) -> int: |
| | if self.operation_types is None: |
| | return 0 |
| | return self.operation_types.index(operation) if operation in self.operation_types else self.operation_types.index("other") |
| |
|
| | def _normalize_mask(self, mask, H, W, device): |
| | """Normalize mask to [1, H, W] format.""" |
| | if mask is None: |
| | return torch.zeros(1, H, W, device=device) |
| |
|
| | if not isinstance(mask, torch.Tensor): |
| | mask = torch.from_numpy(mask) |
| |
|
| | mask = mask.to(device) |
| |
|
| | if mask.dim() == 4: |
| | mask = mask[:, 0] |
| | mask = mask.max(dim=0, keepdim=True)[0] |
| | elif mask.dim() == 3: |
| | pass |
| | elif mask.dim() == 2: |
| | mask = mask.unsqueeze(0) |
| | else: |
| | raise ValueError(f"Unexpected mask shape: {mask.shape}") |
| |
|
| | return mask |
| |
|
| | def _generate_masks_on_fly(self, und_images: torch.Tensor, instructions: List[str]) -> Tuple[torch.Tensor, List[str]]: |
| | """Generate GT masks using Qwen3 + SAM (training only).""" |
| | masks, operations = [], [] |
| | B, _, H, W = und_images.shape |
| | for i in range(und_images.shape[0]): |
| | try: |
| | mask, parsed = self.mask_generator.generate(und_images[i], instructions[i], return_parsed=True) |
| | mask = self._normalize_mask(mask, H=H, W=W, device=und_images.device) |
| | masks.append(mask) |
| | operations.append(parsed.get("operation", "other")) |
| | except Exception as e: |
| | print(f"Mask generation failed: {e}") |
| | masks.append(torch.zeros(1, H, W, device=und_images.device)) |
| | operations.append("other") |
| | return torch.stack(masks).to(und_images.device), operations |
| |
|
| | |
| | |
| | |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[List[torch.FloatTensor]] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | gen_image: Optional[torch.FloatTensor] = None, |
| | und_image: Optional[torch.FloatTensor] = None, |
| | edit_mask: Optional[torch.FloatTensor] = None, |
| | operations: Optional[List[str]] = None, |
| | instructions: Optional[List[str]] = None, |
| | categories: Optional[List[str]] = None, |
| | return_dict: Optional[bool] = None, |
| | cache_position: Optional[torch.LongTensor] = None |
| | ) -> Union[Tuple, CausalLMOutputWithPast]: |
| |
|
| | output_hidden_states = True |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | |
| | if inputs_embeds is None: |
| | (input_ids, position_ids, attention_mask, past_key_values, |
| | inputs_embeds, labels, latents) = self.prepare_inputs_labels_for_multimodal( |
| | input_ids, position_ids, attention_mask, past_key_values, |
| | labels, gen_image, und_image |
| | ) |
| | else: |
| | latents = None |
| |
|
| | |
| | output = Qwen2ForCausalLM.forward( |
| | self, |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | past_key_values=past_key_values, |
| | inputs_embeds=inputs_embeds, |
| | use_cache=use_cache, |
| | output_attentions=output_attentions, |
| | output_hidden_states=True, |
| | return_dict=True |
| | ) |
| | logits = output.logits |
| | img_hidden_states = output.hidden_states |
| |
|
| | |
| | if labels is not None: |
| | shift_logits = logits[..., :-1, :].contiguous() |
| | shift_labels = labels[..., 1:].contiguous() |
| | ce_loss = F.cross_entropy( |
| | shift_logits.view(-1, self.config.vocab_size), |
| | shift_labels.view(-1), |
| | ignore_index=-100 |
| | ) |
| | else: |
| | ce_loss = torch.tensor(0.0, device=logits.device) |
| |
|
| | |
| | if latents is None: |
| | return CausalLMOutputWithPast( |
| | loss=ce_loss, logits=logits, past_key_values=output.past_key_values, |
| | hidden_states=output.hidden_states, attentions=output.attentions |
| | ) |
| |
|
| | |
| | |
| | |
| | if edit_mask is None and instructions is not None and self.training: |
| | edit_mask, operations = self._generate_masks_on_fly(und_image, instructions) |
| |
|
| | |
| | |
| | |
| | mask_pred_loss = torch.tensor(0.0, device=latents.device) |
| |
|
| | if self.mask_predictor is not None: |
| | last_hidden = img_hidden_states[-1] |
| | mask_logits = self.mask_predictor(last_hidden, return_logits=True) |
| |
|
| | if edit_mask is not None and self.training: |
| | gt_mask_resized = F.interpolate( |
| | edit_mask.float().to(latents.device), |
| | size=(latents.shape[2], latents.shape[3]), |
| | mode='nearest' |
| | ) |
| |
|
| | if not torch.isnan(mask_logits).any() and not torch.isnan(gt_mask_resized).any(): |
| | mask_pred_loss = F.binary_cross_entropy_with_logits( |
| | mask_logits, |
| | gt_mask_resized, |
| | reduction='mean' |
| | ) |
| |
|
| | |
| | |
| | |
| | noise = torch.randn_like(latents) |
| | weighting_scheme = "uniform" |
| | u = compute_density_for_timestep_sampling( |
| | weighting_scheme=weighting_scheme, batch_size=latents.shape[0], |
| | logit_mean=0.0, logit_std=1.0, mode_scale=1.29 |
| | ) |
| | indices = (u * self.get_model().noise_scheduler.config.num_train_timesteps).long() |
| | timesteps = self.get_model().noise_scheduler.timesteps[indices].to(device=latents.device) |
| | sigmas = self.get_sigmas(timesteps, latents.device, n_dim=latents.ndim, dtype=latents.dtype) |
| |
|
| | |
| | mask_cond = 0 |
| | if self.mask_encoder is not None and edit_mask is not None: |
| | mask_latent = F.interpolate( |
| | edit_mask.float().to(latents.device), |
| | size=(latents.shape[2], latents.shape[3]), |
| | mode='nearest' |
| | ).clamp(0.0, 1.0) |
| | mask_cond = self.mask_encoder(mask_latent) |
| | mask_cond = self.mask_drop(mask_cond, getattr(self.config, 'mask_drop_prob', 0.1)) |
| |
|
| | |
| | noisy_latents = (1.0 - sigmas) * latents + sigmas * noise |
| | combined_input = noisy_latents |
| |
|
| | if self.mask_weight is not None and isinstance(mask_cond, torch.Tensor): |
| | combined_input = combined_input + self.mask_weight * mask_cond |
| |
|
| | |
| | fused_features = self.get_model().diffusion_connector(img_hidden_states) |
| | |
| | diffusion_pred = self.get_model().dit( |
| | hidden_states=combined_input, timestep=timesteps, |
| | encoder_hidden_states=fused_features, encoder_attention_mask=attention_mask |
| | ).sample |
| |
|
| | |
| | target = latents - noise |
| | weighting = compute_loss_weighting_for_sd3(weighting_scheme=weighting_scheme, sigmas=sigmas) |
| | diff_loss = torch.mean( |
| | (weighting.float() * (diffusion_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1 |
| | ).mean() |
| |
|
| | |
| | mask_pred_weight = getattr(self.config, 'mask_predictor_loss_weight', 0.5) |
| | total_loss = diff_loss + 0.2 * ce_loss + mask_pred_weight * mask_pred_loss |
| |
|
| | if self.training: |
| | print(f"Loss - diff: {diff_loss.item():.4f}, ce: {ce_loss.item():.4f}, mask_pred: {mask_pred_loss.item():.4f}") |
| |
|
| | return CausalLMOutputWithPast( |
| | loss=total_loss, logits=logits, past_key_values=output.past_key_values, |
| | hidden_states=output.hidden_states, attentions=output.attentions |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | @torch.no_grad() |
| | def generate_edited_image( |
| | self, |
| | und_image: torch.Tensor, |
| | input_ids: torch.Tensor, |
| | attention_mask: torch.Tensor, |
| | num_inference_steps: int = 50, |
| | guidance_scale: float = 7.5, |
| | mask_guidance_scale: float = 1.0, |
| | generator: Optional[torch.Generator] = None, |
| | ) -> torch.Tensor: |
| | """ |
| | Generate edited image using learned mask predictor. |
| | NO external segmentation model needed! |
| | """ |
| | device = und_image.device |
| | dtype = und_image.dtype |
| | batch_size = und_image.shape[0] |
| |
|
| | |
| | (input_ids_mm, position_ids, attention_mask_mm, _, |
| | inputs_embeds, _, _) = self.prepare_inputs_labels_for_multimodal( |
| | input_ids, None, attention_mask, None, None, None, und_image |
| | ) |
| |
|
| | output = Qwen2ForCausalLM.forward( |
| | self, |
| | input_ids=input_ids_mm, |
| | attention_mask=attention_mask_mm, |
| | position_ids=position_ids, |
| | inputs_embeds=inputs_embeds, |
| | output_hidden_states=True, |
| | return_dict=True |
| | ) |
| | hidden_states = output.hidden_states |
| |
|
| | |
| | predicted_mask = None |
| | if self.mask_predictor is not None: |
| | last_hidden = hidden_states[-1] |
| | predicted_mask = self.mask_predictor(last_hidden) |
| |
|
| | |
| | vae = self.get_model().get_sana_vae() |
| | ref_latents = vae.encode(und_image.to(vae.device)).latent * vae.config.scaling_factor |
| | ref_latents = ref_latents.to(device) |
| |
|
| | latent_h, latent_w = ref_latents.shape[2], ref_latents.shape[3] |
| | latent_channels = ref_latents.shape[1] |
| |
|
| | |
| | if predicted_mask is not None: |
| | predicted_mask = F.interpolate( |
| | predicted_mask, size=(latent_h, latent_w), mode='bilinear', align_corners=False |
| | ) |
| |
|
| | |
| | mask_cond = torch.zeros_like(ref_latents) |
| | if self.mask_encoder is not None and predicted_mask is not None: |
| | mask_cond = self.mask_encoder(predicted_mask.to(dtype)) |
| |
|
| | |
| | fused_features = self.get_model().diffusion_connector(hidden_states) |
| |
|
| | |
| | if guidance_scale > 1.0: |
| | mask_cond_cfg = torch.cat([torch.zeros_like(mask_cond), mask_cond]) |
| | fused_features_cfg = torch.cat([torch.zeros_like(fused_features), fused_features]) |
| | else: |
| | mask_cond_cfg = mask_cond |
| | fused_features_cfg = fused_features |
| |
|
| | |
| | latents = randn_tensor( |
| | (batch_size, latent_channels, latent_h, latent_w), |
| | generator=generator, device=device, dtype=dtype |
| | ) |
| |
|
| | |
| | scheduler = self.get_model().noise_scheduler |
| | scheduler.set_timesteps(num_inference_steps, device=device) |
| | timesteps = scheduler.timesteps |
| |
|
| | for t in timesteps: |
| | if guidance_scale > 1.0: |
| | latent_model_input = torch.cat([latents] * 2) |
| | t_input = torch.cat([t.unsqueeze(0)] * 2 * batch_size) |
| | else: |
| | latent_model_input = latents |
| | t_input = t.unsqueeze(0).expand(batch_size) |
| |
|
| | |
| | combined_input = latent_model_input |
| | if self.mask_weight is not None: |
| | combined_input = combined_input + mask_guidance_scale * self.mask_weight * mask_cond_cfg |
| |
|
| | |
| | noise_pred = self.get_model().dit( |
| | hidden_states=combined_input, |
| | timestep=t_input, |
| | encoder_hidden_states=fused_features_cfg, |
| | ).sample |
| |
|
| | |
| | if guidance_scale > 1.0: |
| | noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) |
| | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) |
| |
|
| | |
| | latents = scheduler.step(noise_pred, t, latents).prev_sample |
| |
|
| | |
| | latents = latents / vae.config.scaling_factor |
| | image = vae.decode(latents.to(vae.device)).sample |
| |
|
| | return image |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | AutoConfig.register("llava_qwen2", blip3oFastConfig) |
| | AutoModelForCausalLM.register(blip3oFastConfig, blip3oFastForCausalLM) |
| |
|