FastVLM_SANA / code /blip3o_fast.py
Fahad-S's picture
Upload code/blip3o_fast.py with huggingface_hub
9f4b773 verified
"""
BLIP3o Fast - Unified Image Understanding and Generation with Mask Prediction
This module provides:
- Training: Diffusion-based image editing with mask supervision from SAM
- Inference: Lightweight mask-free editing using learned MaskPredictor
Key Components (from llava_arch.py):
- MaskPredictor: Learns to predict edit regions from LLM hidden states
- MaskEncoder: Encodes masks for diffusion conditioning
- mask_weight/spatial_weight: Learnable conditioning scales (SAVED with model!)
Training Flow:
1. LLM processes image + instruction → hidden states
2. MaskPredictor predicts edit mask (supervised by SAM)
3. Diffusion generates edited image with mask conditioning
Inference Flow:
1. LLM processes image + instruction → hidden states
2. MaskPredictor predicts edit mask (NO SAM needed!)
3. Diffusion generates edited image
"""
from typing import List, Optional, Tuple, Union, Dict, Any
import json
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
Qwen2Config,
Qwen2Model,
Qwen2ForCausalLM
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from diffusers.training_utils import (
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3
)
from diffusers.utils.torch_utils import randn_tensor
from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
# ============================================================
# TRAINING ONLY: Qwen3 Client for Instruction Parsing
# ============================================================
class Qwen3InstructionParser:
"""Parses edit instructions using Qwen3 LLM. Used only during training."""
def __init__(
self,
model_name: str = "Qwen/Qwen3-1.7B",
device: str = "cuda",
torch_dtype: torch.dtype = torch.float16
):
self.device = device
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch_dtype,
device_map=device
)
self.model.eval()
self._cache: Dict[str, Dict] = {}
@torch.no_grad()
def parse(self, instruction: str) -> Dict[str, Any]:
if instruction in self._cache:
return self._cache[instruction]
prompt = self._build_prompt(instruction)
messages = [{"role": "user", "content": prompt}]
text = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
)
inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
outputs = self.model.generate(
**inputs, max_new_tokens=256, temperature=0.1,
do_sample=False, pad_token_id=self.tokenizer.eos_token_id
)
response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
parsed = self._parse_response(response)
self._cache[instruction] = parsed
return parsed
def _build_prompt(self, instruction: str) -> str:
return f"""You are an image editing instruction parser. Extract structured information.
Respond ONLY with valid JSON:
{{"operation": "<type>", "source_object": "<object or null>", "target_object": "<object or null>", "location": "<location or null>", "attributes": "<attributes or null>"}}
Operation types: remove, replace, add, extract, style, adjust, compose, action, other
Examples:
"Remove the red car" -> {{"operation": "remove", "source_object": "red car", "target_object": null, "location": null, "attributes": null}}
"Replace the dog with a cat" -> {{"operation": "replace", "source_object": "dog", "target_object": "cat", "location": null, "attributes": null}}
"Make the dress blue" -> {{"operation": "adjust", "source_object": "dress", "target_object": null, "location": null, "attributes": "blue"}}
Input: "{instruction}"
Output:"""
def _parse_response(self, response: str) -> Dict[str, Any]:
default = {"operation": "other", "source_object": None, "target_object": None, "location": None, "attributes": None}
try:
parsed = json.loads(response.strip())
except json.JSONDecodeError:
match = re.search(r'\{[^{}]*\}', response, re.DOTALL)
if match:
try:
parsed = json.loads(match.group())
except:
return default
else:
return default
return {**default, **parsed}
# ============================================================
# TRAINING ONLY: Edit Mask Generator (SAM + Qwen3)
# ============================================================
class EditMaskGenerator:
"""Generates ground truth edit masks using Qwen3 + SAM. Training only."""
def __init__(
self,
qwen_model: str = "Qwen/Qwen3-1.7B",
sam_model: str = "facebook/sam2.1-hiera-large",
device: str = "cuda",
enabled: bool = True
):
self.enabled = enabled
self.device = device
if not enabled:
return
self.parser = Qwen3InstructionParser(model_name=qwen_model, device=device)
try:
from sam2.sam2_image_predictor import SAM2ImagePredictor
self.sam = SAM2ImagePredictor.from_pretrained(sam_model, device=device)
except ImportError:
print("WARNING: SAM2 not installed. Mask generation disabled.")
self.enabled = False
def generate(self, image: torch.Tensor, instruction: str, return_parsed: bool = False):
"""Generate edit mask from image and instruction."""
if not self.enabled:
H, W = image.shape[-2:]
mask = torch.zeros(1, H, W, device=self.device)
return (mask, {"operation": "other"}) if return_parsed else mask
# Parse instruction
parsed = self.parser.parse(instruction)
source_object = parsed.get("source_object")
if not source_object:
H, W = image.shape[-2:]
mask = torch.zeros(1, H, W, device=self.device)
return (mask, parsed) if return_parsed else mask
# Convert image for SAM
if image.dim() == 3:
image = image.unsqueeze(0)
image_np = ((image[0].permute(1, 2, 0).cpu().numpy() + 1) * 127.5).astype("uint8")
# Generate mask with SAM
with torch.inference_mode():
self.sam.set_image(image_np)
# Use text prompt if available, otherwise center point
H, W = image_np.shape[:2]
point_coords = [[W // 2, H // 2]]
point_labels = [1]
masks, scores, _ = self.sam.predict(
point_coords=point_coords,
point_labels=point_labels,
multimask_output=True
)
# Use best mask
best_idx = scores.argmax()
mask = torch.from_numpy(masks[best_idx]).float().unsqueeze(0).to(self.device)
return (mask, parsed) if return_parsed else mask
# ============================================================
# Configuration
# ============================================================
class blip3oFastConfig(Qwen2Config):
model_type = "llava_qwen2"
def __init__(
self,
use_mask_predictor: bool = True,
use_mask_conditioning: bool = True,
use_spatial_conditioning: bool = False,
use_operation_embedding: bool = False,
mask_predictor_loss_weight: float = 0.5,
latent_channels: int = 32,
latent_size: int = 32,
num_operation_types: int = 10,
**kwargs
):
super().__init__(**kwargs)
self.use_mask_predictor = use_mask_predictor
self.use_mask_conditioning = use_mask_conditioning
self.use_spatial_conditioning = use_spatial_conditioning
self.use_operation_embedding = use_operation_embedding
self.mask_predictor_loss_weight = mask_predictor_loss_weight
self.latent_channels = latent_channels
self.latent_size = latent_size
self.num_operation_types = num_operation_types
# ============================================================
# Base Model
# ============================================================
class blip3oFastModel(LlavaMetaModel, Qwen2Model):
config_class = blip3oFastConfig
def __init__(self, config: Qwen2Config):
super(blip3oFastModel, self).__init__(config)
# ============================================================
# Main Model for Training
# ============================================================
class blip3oFastForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
"""
BLIP3o Fast model for training.
All mask-related components (mask_predictor, mask_encoder, mask_weight, etc.)
are defined in LlavaMetaModel and accessed via properties in LlavaMetaForCausalLM.
This ensures they are saved/loaded with the model.
"""
config_class = blip3oFastConfig
def __init__(self, config):
super(blip3oFastForCausalLM, self).__init__(config)
config.model_type = "llava_qwen2"
self.model = blip3oFastModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Operation types for edit classification
self.operation_types = ["remove", "replace", "add", "extract", "style",
"adjust", "compose", "action", "inpaint", "other"]
# Mask generator (training only, lazy init)
self._mask_generator = None
self._mask_generator_initialized = False
self.post_init()
def get_model(self):
return self.model
# ============================================================
# Mask Generator (Training Only)
# ============================================================
@property
def mask_generator(self) -> EditMaskGenerator:
"""Lazy init mask generator (training only)."""
if not self._mask_generator_initialized:
enabled = getattr(self.config, 'mask_generator_enabled', True) and self.training
if enabled:
self._mask_generator = EditMaskGenerator(
qwen_model=getattr(self.config, 'qwen_model', "Qwen/Qwen3-1.7B"),
device=str(self.device),
enabled=True
)
else:
self._mask_generator = EditMaskGenerator(enabled=False)
self._mask_generator_initialized = True
return self._mask_generator
def get_operation_index(self, operation: str) -> int:
if self.operation_types is None:
return 0
return self.operation_types.index(operation) if operation in self.operation_types else self.operation_types.index("other")
def _normalize_mask(self, mask, H, W, device):
"""Normalize mask to [1, H, W] format."""
if mask is None:
return torch.zeros(1, H, W, device=device)
if not isinstance(mask, torch.Tensor):
mask = torch.from_numpy(mask)
mask = mask.to(device)
if mask.dim() == 4:
mask = mask[:, 0]
mask = mask.max(dim=0, keepdim=True)[0]
elif mask.dim() == 3:
pass
elif mask.dim() == 2:
mask = mask.unsqueeze(0)
else:
raise ValueError(f"Unexpected mask shape: {mask.shape}")
return mask
def _generate_masks_on_fly(self, und_images: torch.Tensor, instructions: List[str]) -> Tuple[torch.Tensor, List[str]]:
"""Generate GT masks using Qwen3 + SAM (training only)."""
masks, operations = [], []
B, _, H, W = und_images.shape
for i in range(und_images.shape[0]):
try:
mask, parsed = self.mask_generator.generate(und_images[i], instructions[i], return_parsed=True)
mask = self._normalize_mask(mask, H=H, W=W, device=und_images.device)
masks.append(mask)
operations.append(parsed.get("operation", "other"))
except Exception as e:
print(f"Mask generation failed: {e}")
masks.append(torch.zeros(1, H, W, device=und_images.device))
operations.append("other")
return torch.stack(masks).to(und_images.device), operations
# ============================================================
# TRAINING FORWARD
# ============================================================
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
gen_image: Optional[torch.FloatTensor] = None,
und_image: Optional[torch.FloatTensor] = None,
edit_mask: Optional[torch.FloatTensor] = None,
operations: Optional[List[str]] = None,
instructions: Optional[List[str]] = None,
categories: Optional[List[str]] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None
) -> Union[Tuple, CausalLMOutputWithPast]:
output_hidden_states = True
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Prepare multimodal inputs
if inputs_embeds is None:
(input_ids, position_ids, attention_mask, past_key_values,
inputs_embeds, labels, latents) = self.prepare_inputs_labels_for_multimodal(
input_ids, position_ids, attention_mask, past_key_values,
labels, gen_image, und_image
)
else:
latents = None
# LLM forward
output = Qwen2ForCausalLM.forward(
self,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=True,
return_dict=True
)
logits = output.logits
img_hidden_states = output.hidden_states
# CE Loss
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
ce_loss = F.cross_entropy(
shift_logits.view(-1, self.config.vocab_size),
shift_labels.view(-1),
ignore_index=-100
)
else:
ce_loss = torch.tensor(0.0, device=logits.device)
# If no generation image, return CE loss only
if latents is None:
return CausalLMOutputWithPast(
loss=ce_loss, logits=logits, past_key_values=output.past_key_values,
hidden_states=output.hidden_states, attentions=output.attentions
)
# ============================================================
# Generate masks if not provided (training)
# ============================================================
if edit_mask is None and instructions is not None and self.training:
edit_mask, operations = self._generate_masks_on_fly(und_image, instructions)
# ============================================================
# Mask Predictor Loss
# ============================================================
mask_pred_loss = torch.tensor(0.0, device=latents.device)
if self.mask_predictor is not None:
last_hidden = img_hidden_states[-1]
mask_logits = self.mask_predictor(last_hidden, return_logits=True)
if edit_mask is not None and self.training:
gt_mask_resized = F.interpolate(
edit_mask.float().to(latents.device),
size=(latents.shape[2], latents.shape[3]),
mode='nearest'
)
if not torch.isnan(mask_logits).any() and not torch.isnan(gt_mask_resized).any():
mask_pred_loss = F.binary_cross_entropy_with_logits(
mask_logits,
gt_mask_resized,
reduction='mean'
)
# ============================================================
# Diffusion Training
# ============================================================
noise = torch.randn_like(latents)
weighting_scheme = "uniform"
u = compute_density_for_timestep_sampling(
weighting_scheme=weighting_scheme, batch_size=latents.shape[0],
logit_mean=0.0, logit_std=1.0, mode_scale=1.29
)
indices = (u * self.get_model().noise_scheduler.config.num_train_timesteps).long()
timesteps = self.get_model().noise_scheduler.timesteps[indices].to(device=latents.device)
sigmas = self.get_sigmas(timesteps, latents.device, n_dim=latents.ndim, dtype=latents.dtype)
# Mask conditioning
mask_cond = 0
if self.mask_encoder is not None and edit_mask is not None:
mask_latent = F.interpolate(
edit_mask.float().to(latents.device),
size=(latents.shape[2], latents.shape[3]),
mode='nearest'
).clamp(0.0, 1.0)
mask_cond = self.mask_encoder(mask_latent)
mask_cond = self.mask_drop(mask_cond, getattr(self.config, 'mask_drop_prob', 0.1))
# Noisy latents with conditioning
noisy_latents = (1.0 - sigmas) * latents + sigmas * noise
combined_input = noisy_latents
if self.mask_weight is not None and isinstance(mask_cond, torch.Tensor):
combined_input = combined_input + self.mask_weight * mask_cond
# DiT forward
fused_features = self.get_model().diffusion_connector(img_hidden_states)
diffusion_pred = self.get_model().dit(
hidden_states=combined_input, timestep=timesteps,
encoder_hidden_states=fused_features, encoder_attention_mask=attention_mask
).sample
# Diffusion loss (v-prediction)
target = latents - noise
weighting = compute_loss_weighting_for_sd3(weighting_scheme=weighting_scheme, sigmas=sigmas)
diff_loss = torch.mean(
(weighting.float() * (diffusion_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1
).mean()
# Total loss
mask_pred_weight = getattr(self.config, 'mask_predictor_loss_weight', 0.5)
total_loss = diff_loss + 0.2 * ce_loss + mask_pred_weight * mask_pred_loss
if self.training:
print(f"Loss - diff: {diff_loss.item():.4f}, ce: {ce_loss.item():.4f}, mask_pred: {mask_pred_loss.item():.4f}")
return CausalLMOutputWithPast(
loss=total_loss, logits=logits, past_key_values=output.past_key_values,
hidden_states=output.hidden_states, attentions=output.attentions
)
# ============================================================
# INFERENCE
# ============================================================
@torch.no_grad()
def generate_edited_image(
self,
und_image: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
mask_guidance_scale: float = 1.0,
generator: Optional[torch.Generator] = None,
) -> torch.Tensor:
"""
Generate edited image using learned mask predictor.
NO external segmentation model needed!
"""
device = und_image.device
dtype = und_image.dtype
batch_size = und_image.shape[0]
# Get LLM hidden states
(input_ids_mm, position_ids, attention_mask_mm, _,
inputs_embeds, _, _) = self.prepare_inputs_labels_for_multimodal(
input_ids, None, attention_mask, None, None, None, und_image
)
output = Qwen2ForCausalLM.forward(
self,
input_ids=input_ids_mm,
attention_mask=attention_mask_mm,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_hidden_states=True,
return_dict=True
)
hidden_states = output.hidden_states
# Predict mask using trained MaskPredictor
predicted_mask = None
if self.mask_predictor is not None:
last_hidden = hidden_states[-1]
predicted_mask = self.mask_predictor(last_hidden)
# Encode reference image
vae = self.get_model().get_sana_vae()
ref_latents = vae.encode(und_image.to(vae.device)).latent * vae.config.scaling_factor
ref_latents = ref_latents.to(device)
latent_h, latent_w = ref_latents.shape[2], ref_latents.shape[3]
latent_channels = ref_latents.shape[1]
# Resize predicted mask
if predicted_mask is not None:
predicted_mask = F.interpolate(
predicted_mask, size=(latent_h, latent_w), mode='bilinear', align_corners=False
)
# Mask conditioning
mask_cond = torch.zeros_like(ref_latents)
if self.mask_encoder is not None and predicted_mask is not None:
mask_cond = self.mask_encoder(predicted_mask.to(dtype))
# LLM conditioning
fused_features = self.get_model().diffusion_connector(hidden_states)
# Prepare CFG
if guidance_scale > 1.0:
mask_cond_cfg = torch.cat([torch.zeros_like(mask_cond), mask_cond])
fused_features_cfg = torch.cat([torch.zeros_like(fused_features), fused_features])
else:
mask_cond_cfg = mask_cond
fused_features_cfg = fused_features
# Initialize latents
latents = randn_tensor(
(batch_size, latent_channels, latent_h, latent_w),
generator=generator, device=device, dtype=dtype
)
# Denoising loop
scheduler = self.get_model().noise_scheduler
scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = scheduler.timesteps
for t in timesteps:
if guidance_scale > 1.0:
latent_model_input = torch.cat([latents] * 2)
t_input = torch.cat([t.unsqueeze(0)] * 2 * batch_size)
else:
latent_model_input = latents
t_input = t.unsqueeze(0).expand(batch_size)
# Add mask conditioning
combined_input = latent_model_input
if self.mask_weight is not None:
combined_input = combined_input + mask_guidance_scale * self.mask_weight * mask_cond_cfg
# DiT forward
noise_pred = self.get_model().dit(
hidden_states=combined_input,
timestep=t_input,
encoder_hidden_states=fused_features_cfg,
).sample
# CFG
if guidance_scale > 1.0:
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
# Scheduler step
latents = scheduler.step(noise_pred, t, latents).prev_sample
# Decode
latents = latents / vae.config.scaling_factor
image = vae.decode(latents.to(vae.device)).sample
return image
# ============================================================
# Register Model
# ============================================================
AutoConfig.register("llava_qwen2", blip3oFastConfig)
AutoModelForCausalLM.register(blip3oFastConfig, blip3oFastForCausalLM)