|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import List, Optional, Tuple, Union, Dict, Any |
|
|
import re |
|
|
import json |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
|
|
|
from transformers import ( |
|
|
AutoConfig, |
|
|
AutoModelForCausalLM, |
|
|
AutoTokenizer, |
|
|
Qwen2Config, |
|
|
Qwen2Model, |
|
|
Qwen2ForCausalLM |
|
|
) |
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
from diffusers.training_utils import ( |
|
|
compute_density_for_timestep_sampling, |
|
|
compute_loss_weighting_for_sd3 |
|
|
) |
|
|
from diffusers.utils.torch_utils import randn_tensor |
|
|
|
|
|
from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Qwen3InstructionParser: |
|
|
"""Parses edit instructions using Qwen3 LLM. Used only during training.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_name: str = "Qwen/Qwen3-1.7B", |
|
|
device: str = "cuda", |
|
|
torch_dtype: torch.dtype = torch.float16 |
|
|
): |
|
|
self.device = device |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch_dtype, |
|
|
device_map=device |
|
|
) |
|
|
self.model.eval() |
|
|
self._cache: Dict[str, Dict] = {} |
|
|
|
|
|
@torch.no_grad() |
|
|
def parse(self, instruction: str) -> Dict[str, Any]: |
|
|
if instruction in self._cache: |
|
|
return self._cache[instruction] |
|
|
|
|
|
prompt = self._build_prompt(instruction) |
|
|
messages = [{"role": "user", "content": prompt}] |
|
|
text = self.tokenizer.apply_chat_template( |
|
|
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False |
|
|
) |
|
|
inputs = self.tokenizer(text, return_tensors="pt").to(self.device) |
|
|
outputs = self.model.generate( |
|
|
**inputs, max_new_tokens=256, temperature=0.1, |
|
|
do_sample=False, pad_token_id=self.tokenizer.eos_token_id |
|
|
) |
|
|
response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) |
|
|
parsed = self._parse_response(response) |
|
|
self._cache[instruction] = parsed |
|
|
return parsed |
|
|
|
|
|
def _build_prompt(self, instruction: str) -> str: |
|
|
return f"""You are an image editing instruction parser. Extract structured information. |
|
|
|
|
|
Respond ONLY with valid JSON: |
|
|
{{"operation": "<type>", "source_object": "<object or null>", "target_object": "<object or null>", "location": "<location or null>", "attributes": "<attributes or null>"}} |
|
|
|
|
|
Operation types: remove, replace, add, extract, style, adjust, compose, action, other |
|
|
|
|
|
Examples: |
|
|
"Remove the red car" -> {{"operation": "remove", "source_object": "red car", "target_object": null, "location": null, "attributes": null}} |
|
|
"Replace the dog with a cat" -> {{"operation": "replace", "source_object": "dog", "target_object": "cat", "location": null, "attributes": null}} |
|
|
"Make the dress blue" -> {{"operation": "adjust", "source_object": "dress", "target_object": null, "location": null, "attributes": "blue"}} |
|
|
|
|
|
Input: "{instruction}" |
|
|
Output:""" |
|
|
|
|
|
def _parse_response(self, response: str) -> Dict[str, Any]: |
|
|
default = {"operation": "other", "source_object": None, "target_object": None, "location": None, |
|
|
"attributes": None} |
|
|
try: |
|
|
parsed = json.loads(response.strip()) |
|
|
except json.JSONDecodeError: |
|
|
match = re.search(r'\{[^{}]*\}', response, re.DOTALL) |
|
|
if match: |
|
|
try: |
|
|
parsed = json.loads(match.group()) |
|
|
except: |
|
|
return default |
|
|
else: |
|
|
return default |
|
|
for key in default: |
|
|
if key not in parsed: |
|
|
parsed[key] = default[key] |
|
|
valid_ops = ["remove", "replace", "add", "extract", "style", "adjust", "compose", "action", "other"] |
|
|
if parsed["operation"] not in valid_ops: |
|
|
parsed["operation"] = "other" |
|
|
return parsed |
|
|
|
|
|
|
|
|
class SAM3MaskGenerator: |
|
|
""" |
|
|
Generates segmentation masks using SAM3. |
|
|
SAM3 natively supports text prompts - no Grounding DINO needed! |
|
|
""" |
|
|
|
|
|
def __init__(self, device: str = "cuda"): |
|
|
self.device = device |
|
|
self._model = None |
|
|
self._processor = None |
|
|
|
|
|
def _load_model(self): |
|
|
"""Lazy load SAM3 model.""" |
|
|
if self._model is None: |
|
|
from sam3.model_builder import build_sam3_image_model |
|
|
from sam3.model.sam3_image_processor import Sam3Processor |
|
|
|
|
|
print("Loading SAM3...") |
|
|
self._model = build_sam3_image_model() |
|
|
self._processor = Sam3Processor(self._model) |
|
|
print("SAM3 loaded!") |
|
|
|
|
|
def _prepare_image(self, image): |
|
|
"""Convert various image formats to PIL Image.""" |
|
|
from PIL import Image as PILImage |
|
|
|
|
|
if isinstance(image, PILImage.Image): |
|
|
return image.convert("RGB") |
|
|
elif isinstance(image, torch.Tensor): |
|
|
if image.dim() == 4: |
|
|
image = image[0] |
|
|
|
|
|
if image.dtype in (torch.bfloat16, torch.float16): |
|
|
image = image.float() |
|
|
if image.shape[0] in [1, 3]: |
|
|
image_np = image.permute(1, 2, 0).cpu().numpy() |
|
|
else: |
|
|
image_np = image.cpu().numpy() |
|
|
if image_np.max() <= 1.0: |
|
|
image_np = (image_np * 255).astype(np.uint8) |
|
|
else: |
|
|
image_np = image_np.astype(np.uint8) |
|
|
return PILImage.fromarray(image_np).convert("RGB") |
|
|
elif isinstance(image, np.ndarray): |
|
|
if image.max() <= 1.0: |
|
|
image = (image * 255).astype(np.uint8) |
|
|
return PILImage.fromarray(image).convert("RGB") |
|
|
else: |
|
|
return PILImage.fromarray(np.array(image)).convert("RGB") |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate_mask( |
|
|
self, |
|
|
image, |
|
|
parsed: Dict, |
|
|
detect_all: bool = False |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Generate segmentation mask using SAM3 with text prompt. |
|
|
|
|
|
Args: |
|
|
image: Input image (PIL, tensor, or numpy) |
|
|
parsed: Parsed instruction dict with 'source_object', 'operation', etc. |
|
|
detect_all: Whether to return all instances |
|
|
|
|
|
Returns: |
|
|
mask: [1, H, W] binary mask tensor |
|
|
""" |
|
|
self._load_model() |
|
|
|
|
|
image_pil = self._prepare_image(image) |
|
|
W, H = image_pil.size |
|
|
|
|
|
|
|
|
text_prompt = self._build_text_prompt(parsed) |
|
|
|
|
|
if not text_prompt: |
|
|
if parsed.get("operation") == "style": |
|
|
return torch.ones(1, H, W) |
|
|
return torch.zeros(1, H, W) |
|
|
|
|
|
|
|
|
inference_state = self._processor.set_image(image_pil) |
|
|
|
|
|
|
|
|
output = self._processor.set_text_prompt( |
|
|
state=inference_state, |
|
|
prompt=text_prompt |
|
|
) |
|
|
|
|
|
masks = output["masks"] |
|
|
scores = output["scores"] |
|
|
|
|
|
if masks is None or len(masks) == 0: |
|
|
return torch.zeros(1, H, W) |
|
|
|
|
|
|
|
|
if isinstance(masks, np.ndarray): |
|
|
masks = torch.from_numpy(masks) |
|
|
elif isinstance(masks, list): |
|
|
masks = torch.stack([torch.from_numpy(m) if isinstance(m, np.ndarray) else m for m in masks]) |
|
|
|
|
|
if detect_all: |
|
|
|
|
|
combined_mask = masks.float().max(dim=0)[0] |
|
|
return combined_mask.unsqueeze(0) |
|
|
else: |
|
|
|
|
|
if isinstance(scores, (list, np.ndarray)): |
|
|
scores = torch.tensor(scores) |
|
|
best_idx = scores.argmax() |
|
|
return masks[best_idx].unsqueeze(0).float() |
|
|
|
|
|
def _build_text_prompt(self, parsed: Dict) -> str: |
|
|
"""Build SAM3 text prompt from parsed instruction.""" |
|
|
operation = parsed.get("operation", "other") |
|
|
source = parsed.get("source_object") |
|
|
target = parsed.get("target_object") |
|
|
location = parsed.get("location") |
|
|
attributes = parsed.get("attributes") |
|
|
|
|
|
if operation in ["remove", "replace", "extract", "adjust", "action"]: |
|
|
|
|
|
if source: |
|
|
|
|
|
if attributes and operation == "adjust": |
|
|
return source |
|
|
return source |
|
|
elif operation == "add": |
|
|
|
|
|
if source: |
|
|
return source |
|
|
elif location: |
|
|
return location |
|
|
elif operation == "compose": |
|
|
if source: |
|
|
return source |
|
|
elif operation == "style": |
|
|
|
|
|
return "" |
|
|
|
|
|
return source or "" |
|
|
|
|
|
|
|
|
class EditMaskGenerator: |
|
|
""" |
|
|
Complete mask generation pipeline using Qwen3 + SAM3. |
|
|
|
|
|
Simplified from: Qwen3 → Grounding DINO → SAM-2 |
|
|
To: Qwen3 → SAM3 (native text support) |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
qwen_model: str = "Qwen/Qwen3-1.7B", |
|
|
device: str = "cuda", |
|
|
enabled: bool = True |
|
|
): |
|
|
self.device = device |
|
|
self.enabled = enabled |
|
|
|
|
|
if enabled: |
|
|
print("Initializing EditMaskGenerator with SAM3...") |
|
|
self.parser = Qwen3InstructionParser(model_name=qwen_model, device=device) |
|
|
self.segmenter = SAM3MaskGenerator(device=device) |
|
|
print("EditMaskGenerator ready!") |
|
|
else: |
|
|
self.parser = None |
|
|
self.segmenter = None |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate( |
|
|
self, |
|
|
image, |
|
|
instruction: str, |
|
|
return_parsed: bool = False |
|
|
): |
|
|
"""Generate edit mask from image and instruction.""" |
|
|
if not self.enabled: |
|
|
if isinstance(image, torch.Tensor): |
|
|
H, W = image.shape[-2:] |
|
|
else: |
|
|
H, W = np.array(image).shape[:2] |
|
|
mask = torch.zeros(1, H, W) |
|
|
return (mask, {"operation": "other"}) if return_parsed else mask |
|
|
|
|
|
|
|
|
parsed = self.parser.parse(instruction) |
|
|
|
|
|
|
|
|
detect_all = "all" in instruction.lower() |
|
|
mask = self.segmenter.generate_mask(image, parsed, detect_all=detect_all) |
|
|
|
|
|
return (mask, parsed) if return_parsed else mask |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class blip3oFastConfig(Qwen2Config): |
|
|
model_type = "llava_qwen2" |
|
|
|
|
|
def __init__(self, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
self.latent_channels = kwargs.get("latent_channels", 32) |
|
|
|
|
|
|
|
|
self.use_spatial_conditioning = kwargs.get("use_spatial_conditioning", False) |
|
|
self.use_mask_conditioning = kwargs.get("use_mask_conditioning", True) |
|
|
self.use_operation_embedding = kwargs.get("use_operation_embedding", True) |
|
|
self.use_mask_predictor = kwargs.get("use_mask_predictor", True) |
|
|
self.mask_predictor_loss_weight = kwargs.get("mask_predictor_loss_weight", 0.5) |
|
|
|
|
|
|
|
|
self.spatial_drop_prob = kwargs.get("spatial_drop_prob", 0.1) |
|
|
self.mask_drop_prob = kwargs.get("mask_drop_prob", 0.1) |
|
|
|
|
|
|
|
|
self.mask_generator_enabled = kwargs.get("mask_generator_enabled", True) |
|
|
self.qwen_model = kwargs.get("qwen_model", "Qwen/Qwen3-1.7B") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BF16SafeLayerNorm(nn.Module): |
|
|
"""LayerNorm that works correctly with BF16 and PEFT.""" |
|
|
|
|
|
def __init__(self, hidden_size: int, eps: float = 1e-6): |
|
|
super().__init__() |
|
|
|
|
|
self.weight = nn.Parameter(torch.ones(hidden_size, dtype=torch.float32)) |
|
|
self.bias = nn.Parameter(torch.zeros(hidden_size, dtype=torch.float32)) |
|
|
self.eps = eps |
|
|
self.hidden_size = hidden_size |
|
|
|
|
|
|
|
|
self.reset_parameters() |
|
|
|
|
|
def reset_parameters(self): |
|
|
"""Ensure weights are properly initialized.""" |
|
|
nn.init.ones_(self.weight) |
|
|
nn.init.zeros_(self.bias) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
input_dtype = x.dtype |
|
|
x_f32 = x.float() |
|
|
|
|
|
|
|
|
mean = x_f32.mean(dim=-1, keepdim=True) |
|
|
var = x_f32.var(dim=-1, keepdim=True, unbiased=False) |
|
|
x_norm = (x_f32 - mean) / torch.sqrt(var + self.eps) |
|
|
|
|
|
|
|
|
output = x_norm * self.weight.float() + self.bias.float() |
|
|
|
|
|
|
|
|
return output.to(input_dtype) |
|
|
|
|
|
|
|
|
class MaskPredictor(nn.Module): |
|
|
""" |
|
|
Predicts edit mask from LLM hidden states. |
|
|
This is the KEY component that enables mask-free inference. |
|
|
|
|
|
The mask predictor learns to identify WHICH object needs to be edited |
|
|
based on the instruction (e.g., "remove the white dog") and the image |
|
|
understanding encoded in the LLM hidden states. |
|
|
|
|
|
Architecture: |
|
|
1. Extract instruction-relevant features using attention pooling |
|
|
2. Project to spatial features |
|
|
3. Decode to mask |
|
|
""" |
|
|
|
|
|
def __init__(self, hidden_size: int, latent_channels: int, latent_size: int = 32): |
|
|
super().__init__() |
|
|
|
|
|
self.latent_size = latent_size |
|
|
self.hidden_size = hidden_size |
|
|
|
|
|
|
|
|
|
|
|
self.attention_pool = nn.Sequential( |
|
|
nn.Linear(hidden_size, hidden_size // 4), |
|
|
nn.Tanh(), |
|
|
nn.Linear(hidden_size // 4, 1), |
|
|
) |
|
|
|
|
|
|
|
|
self.input_norm = BF16SafeLayerNorm(hidden_size) |
|
|
|
|
|
|
|
|
intermediate_size = hidden_size // 2 |
|
|
spatial_dim = latent_size * latent_size * 64 |
|
|
|
|
|
self.hidden_proj = nn.Sequential( |
|
|
nn.Linear(hidden_size, intermediate_size), |
|
|
nn.LayerNorm(intermediate_size), |
|
|
nn.GELU(), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(intermediate_size, intermediate_size), |
|
|
nn.LayerNorm(intermediate_size), |
|
|
nn.GELU(), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(intermediate_size, spatial_dim), |
|
|
) |
|
|
|
|
|
|
|
|
self.mask_decoder = nn.Sequential( |
|
|
nn.Conv2d(64, 256, 3, padding=1), |
|
|
nn.GroupNorm(32, 256), |
|
|
nn.GELU(), |
|
|
nn.Conv2d(256, 128, 3, padding=1), |
|
|
nn.GroupNorm(16, 128), |
|
|
nn.GELU(), |
|
|
nn.Conv2d(128, 64, 3, padding=1), |
|
|
nn.GroupNorm(8, 64), |
|
|
nn.GELU(), |
|
|
nn.Conv2d(64, 1, 1), |
|
|
) |
|
|
|
|
|
self._init_weights() |
|
|
|
|
|
def _init_weights(self): |
|
|
"""Initialize weights for stable training.""" |
|
|
|
|
|
for module in self.attention_pool: |
|
|
if isinstance(module, nn.Linear): |
|
|
nn.init.xavier_uniform_(module.weight, gain=0.1) |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
|
|
|
|
|
|
if hasattr(self, 'input_norm'): |
|
|
self.input_norm.reset_parameters() |
|
|
|
|
|
|
|
|
for module in self.hidden_proj: |
|
|
if isinstance(module, nn.Linear): |
|
|
nn.init.xavier_uniform_(module.weight, gain=0.1) |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
elif isinstance(module, nn.LayerNorm): |
|
|
nn.init.ones_(module.weight) |
|
|
nn.init.zeros_(module.bias) |
|
|
|
|
|
|
|
|
for module in self.mask_decoder: |
|
|
if isinstance(module, nn.Conv2d): |
|
|
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
elif isinstance(module, nn.GroupNorm): |
|
|
nn.init.ones_(module.weight) |
|
|
nn.init.zeros_(module.bias) |
|
|
|
|
|
|
|
|
for module in reversed(list(self.mask_decoder)): |
|
|
if isinstance(module, nn.Conv2d): |
|
|
nn.init.normal_(module.weight, mean=0.0, std=0.01) |
|
|
nn.init.zeros_(module.bias) |
|
|
break |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor, return_logits: bool = False) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
hidden_states: [B, seq_len, hidden_size] from LLM |
|
|
return_logits: If True, return logits instead of probabilities |
|
|
|
|
|
Returns: |
|
|
mask: [B, 1, H, W] predicted edit mask |
|
|
""" |
|
|
batch_size = hidden_states.shape[0] |
|
|
device = hidden_states.device |
|
|
|
|
|
|
|
|
if torch.isnan(hidden_states).any() or torch.isinf(hidden_states).any(): |
|
|
print("WARNING: NaN/Inf in hidden_states input to MaskPredictor") |
|
|
if return_logits: |
|
|
return torch.zeros(batch_size, 1, self.latent_size, self.latent_size, |
|
|
device=device, dtype=torch.float32, requires_grad=True) |
|
|
return torch.full((batch_size, 1, self.latent_size, self.latent_size), 0.5, |
|
|
device=device, dtype=torch.float32, requires_grad=True) |
|
|
|
|
|
|
|
|
hidden_states = self.input_norm(hidden_states) |
|
|
|
|
|
if torch.isnan(hidden_states).any(): |
|
|
print("WARNING: NaN after input_norm in MaskPredictor") |
|
|
if return_logits: |
|
|
return torch.zeros(batch_size, 1, self.latent_size, self.latent_size, |
|
|
device=device, dtype=torch.float32, requires_grad=True) |
|
|
return torch.full((batch_size, 1, self.latent_size, self.latent_size), 0.5, |
|
|
device=device, dtype=torch.float32, requires_grad=True) |
|
|
|
|
|
|
|
|
target_dtype = self.attention_pool[0].weight.dtype |
|
|
hidden_states = hidden_states.to(target_dtype) |
|
|
|
|
|
|
|
|
|
|
|
attn_weights = self.attention_pool(hidden_states) |
|
|
attn_weights = F.softmax(attn_weights, dim=1) |
|
|
|
|
|
|
|
|
|
|
|
pooled = (hidden_states * attn_weights).sum(dim=1) |
|
|
|
|
|
|
|
|
spatial = self.hidden_proj(pooled) |
|
|
spatial = spatial.view(-1, 64, self.latent_size, self.latent_size) |
|
|
|
|
|
|
|
|
mask_logits = self.mask_decoder(spatial) |
|
|
|
|
|
if return_logits: |
|
|
return mask_logits.float() |
|
|
|
|
|
|
|
|
mask = torch.sigmoid(mask_logits.float()) |
|
|
return mask |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class blip3oFastModel(LlavaMetaModel, Qwen2Model): |
|
|
config_class = blip3oFastConfig |
|
|
|
|
|
def __init__(self, config: Qwen2Config): |
|
|
super(blip3oFastModel, self).__init__(config) |
|
|
|
|
|
|
|
|
class blip3oFastForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM): |
|
|
config_class = blip3oFastConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super(Qwen2ForCausalLM, self).__init__(config) |
|
|
|
|
|
self.model = blip3oFastModel(config) |
|
|
self.vocab_size = config.vocab_size |
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
latent_channels = getattr(config, 'latent_channels', 32) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if getattr(config, 'use_spatial_conditioning', True): |
|
|
self.spatial_ref_encoder = nn.Sequential( |
|
|
nn.Conv2d(latent_channels, 320, 3, padding=1), |
|
|
nn.GroupNorm(32, 320), |
|
|
nn.SiLU(), |
|
|
nn.Conv2d(320, 320, 3, padding=1), |
|
|
nn.GroupNorm(32, 320), |
|
|
nn.SiLU(), |
|
|
nn.Conv2d(320, latent_channels, 3, padding=1), |
|
|
) |
|
|
self.spatial_weight = nn.Parameter(torch.tensor(0.0)) |
|
|
else: |
|
|
self.spatial_ref_encoder = None |
|
|
self.spatial_weight = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if getattr(config, 'use_mask_conditioning', True): |
|
|
self.mask_encoder = nn.Sequential( |
|
|
nn.Conv2d(1, 64, 3, padding=1), |
|
|
nn.GroupNorm(8, 64), |
|
|
nn.SiLU(), |
|
|
nn.Conv2d(64, 128, 3, padding=1), |
|
|
nn.GroupNorm(16, 128), |
|
|
nn.SiLU(), |
|
|
nn.Conv2d(128, latent_channels, 3, padding=1), |
|
|
) |
|
|
self.mask_weight = nn.Parameter(torch.tensor(0.0)) |
|
|
else: |
|
|
self.mask_encoder = None |
|
|
self.mask_weight = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if getattr(config, 'use_mask_predictor', True): |
|
|
self.mask_predictor = MaskPredictor( |
|
|
hidden_size=config.hidden_size, |
|
|
latent_channels=latent_channels, |
|
|
latent_size=32 |
|
|
) |
|
|
else: |
|
|
self.mask_predictor = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if getattr(config, 'use_operation_embedding', True): |
|
|
self.operation_types = ["remove", "replace", "add", "extract", "style", "adjust", "compose", "action", |
|
|
"other"] |
|
|
self.operation_embedding = nn.Embedding(len(self.operation_types), latent_channels) |
|
|
else: |
|
|
self.operation_types = None |
|
|
self.operation_embedding = None |
|
|
|
|
|
|
|
|
self._mask_generator = None |
|
|
self._mask_generator_initialized = False |
|
|
|
|
|
self._init_conditioning_layers() |
|
|
self.post_init() |
|
|
|
|
|
def _init_conditioning_layers(self): |
|
|
"""Initialize conditioning layers. Called during __init__ and can be called after loading.""" |
|
|
if self.spatial_ref_encoder is not None: |
|
|
for module in self.spatial_ref_encoder: |
|
|
if isinstance(module, nn.Conv2d): |
|
|
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
elif isinstance(module, nn.GroupNorm): |
|
|
nn.init.ones_(module.weight) |
|
|
nn.init.zeros_(module.bias) |
|
|
|
|
|
nn.init.zeros_(self.spatial_ref_encoder[-1].weight) |
|
|
nn.init.zeros_(self.spatial_ref_encoder[-1].bias) |
|
|
|
|
|
if self.mask_encoder is not None: |
|
|
for module in self.mask_encoder: |
|
|
if isinstance(module, nn.Conv2d): |
|
|
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
elif isinstance(module, nn.GroupNorm): |
|
|
nn.init.ones_(module.weight) |
|
|
nn.init.zeros_(module.bias) |
|
|
|
|
|
nn.init.zeros_(self.mask_encoder[-1].weight) |
|
|
nn.init.zeros_(self.mask_encoder[-1].bias) |
|
|
|
|
|
def reinitialize_new_modules(self): |
|
|
""" |
|
|
Reinitialize modules that were added after the base model. |
|
|
Call this after loading a pretrained model to fix uninitialized weights. |
|
|
""" |
|
|
print("Reinitializing new modules (mask_predictor, mask_encoder, spatial_ref_encoder)...") |
|
|
|
|
|
|
|
|
if self.mask_predictor is not None: |
|
|
self.mask_predictor._init_weights() |
|
|
print(" - mask_predictor reinitialized") |
|
|
|
|
|
|
|
|
self._init_conditioning_layers() |
|
|
print(" - conditioning layers reinitialized") |
|
|
|
|
|
|
|
|
if self.operation_embedding is not None: |
|
|
nn.init.normal_(self.operation_embedding.weight, mean=0.0, std=0.02) |
|
|
print(" - operation_embedding reinitialized") |
|
|
|
|
|
|
|
|
if self.spatial_weight is not None: |
|
|
nn.init.zeros_(self.spatial_weight) |
|
|
print(" - spatial_weight reinitialized to 0") |
|
|
if self.mask_weight is not None: |
|
|
nn.init.zeros_(self.mask_weight) |
|
|
print(" - mask_weight reinitialized to 0") |
|
|
|
|
|
print("Reinitialization complete!") |
|
|
|
|
|
@property |
|
|
def mask_generator(self) -> EditMaskGenerator: |
|
|
"""Lazy init mask generator (training only).""" |
|
|
if not self._mask_generator_initialized: |
|
|
enabled = getattr(self.config, 'mask_generator_enabled', True) and self.training |
|
|
if enabled: |
|
|
|
|
|
self._mask_generator = EditMaskGenerator( |
|
|
qwen_model=getattr(self.config, 'qwen_model', "Qwen/Qwen3-1.7B"), |
|
|
device=str(self.device), |
|
|
enabled=True |
|
|
) |
|
|
else: |
|
|
self._mask_generator = EditMaskGenerator(enabled=False) |
|
|
self._mask_generator_initialized = True |
|
|
return self._mask_generator |
|
|
|
|
|
def get_model(self): |
|
|
return self.model |
|
|
|
|
|
def mask_drop(self, latents: torch.Tensor, drop_prob: float = 0.1) -> torch.Tensor: |
|
|
if drop_prob <= 0 or not self.training: |
|
|
return latents |
|
|
mask = torch.bernoulli(torch.full((latents.shape[0],), drop_prob, device=latents.device, dtype=latents.dtype)) |
|
|
while len(mask.shape) < len(latents.shape): |
|
|
mask = mask.unsqueeze(-1) |
|
|
return latents * (1 - mask) |
|
|
|
|
|
def get_operation_index(self, operation: str) -> int: |
|
|
if self.operation_types is None: |
|
|
return 0 |
|
|
return self.operation_types.index( |
|
|
operation) if operation in self.operation_types else self.operation_types.index("other") |
|
|
|
|
|
def _normalize_mask(self, mask, H, W, device): |
|
|
""" |
|
|
Always return a single mask: [1, H, W] |
|
|
""" |
|
|
if mask is None: |
|
|
return torch.zeros(1, H, W, device=device) |
|
|
|
|
|
|
|
|
if not isinstance(mask, torch.Tensor): |
|
|
mask = torch.from_numpy(mask) |
|
|
|
|
|
mask = mask.to(device) |
|
|
|
|
|
|
|
|
if mask.dim() == 4: |
|
|
mask = mask[:, 0] |
|
|
|
|
|
mask = mask.max(dim=0, keepdim=True)[0] |
|
|
|
|
|
elif mask.dim() == 3: |
|
|
pass |
|
|
|
|
|
elif mask.dim() == 2: |
|
|
mask = mask.unsqueeze(0) |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unexpected mask shape: {mask.shape}") |
|
|
|
|
|
return mask |
|
|
|
|
|
def _generate_masks_on_fly(self, und_images: torch.Tensor, instructions: List[str]) -> Tuple[ |
|
|
torch.Tensor, List[str]]: |
|
|
"""Generate GT masks using Qwen3 + Grounded SAM-2 (training only).""" |
|
|
masks, operations = [], [] |
|
|
B, _, H, W = und_images.shape |
|
|
for i in range(und_images.shape[0]): |
|
|
try: |
|
|
mask, parsed = self.mask_generator.generate(und_images[i], instructions[i], return_parsed=True) |
|
|
mask = self._normalize_mask(mask, H=H, W=W, device=und_images.device) |
|
|
masks.append(mask) |
|
|
operations.append(parsed.get("operation", "other")) |
|
|
except Exception as e: |
|
|
print(f"Mask generation failed: {e}") |
|
|
masks.append(torch.zeros(1, H, W, device=und_images.device)) |
|
|
operations.append("other") |
|
|
return torch.stack(masks).to(und_images.device), operations |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
gen_image: Optional[torch.FloatTensor] = None, |
|
|
und_image: Optional[torch.FloatTensor] = None, |
|
|
edit_mask: Optional[torch.FloatTensor] = None, |
|
|
operations: Optional[List[str]] = None, |
|
|
instructions: Optional[List[str]] = None, |
|
|
categories: Optional[List[str]] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None |
|
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
|
output_hidden_states = True |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if inputs_embeds is None: |
|
|
(input_ids, position_ids, attention_mask, past_key_values, |
|
|
inputs_embeds, labels, latents) = self.prepare_inputs_labels_for_multimodal( |
|
|
input_ids, position_ids, attention_mask, past_key_values, labels, gen_image, und_image) |
|
|
|
|
|
|
|
|
output = super().forward( |
|
|
input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, |
|
|
past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, |
|
|
use_cache=use_cache, output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, return_dict=return_dict |
|
|
) |
|
|
|
|
|
ce_loss = output.loss |
|
|
hidden_states = output.hidden_states |
|
|
logits = output.logits |
|
|
img_hidden_states = hidden_states |
|
|
|
|
|
assert latents is not None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if edit_mask is None and instructions is not None and self.training: |
|
|
if getattr(self.config, 'mask_generator_enabled', True): |
|
|
edit_mask, operations = self._generate_masks_on_fly(und_image, instructions) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mask_pred_loss = torch.tensor(0.0, device=latents.device) |
|
|
predicted_mask = None |
|
|
mask_logits = None |
|
|
gt_mask_resized = None |
|
|
|
|
|
if self.mask_predictor is not None: |
|
|
|
|
|
last_hidden = hidden_states[-1] |
|
|
|
|
|
|
|
|
mask_logits = self.mask_predictor(last_hidden, return_logits=True) |
|
|
|
|
|
|
|
|
mask_logits = F.interpolate( |
|
|
mask_logits.float(), |
|
|
size=(latents.shape[2], latents.shape[3]), |
|
|
mode='bilinear', |
|
|
align_corners=False |
|
|
) |
|
|
|
|
|
|
|
|
predicted_mask = torch.sigmoid(mask_logits) |
|
|
|
|
|
|
|
|
if edit_mask is not None and self.training: |
|
|
gt_mask_resized = F.interpolate( |
|
|
edit_mask.float().to(latents.device), |
|
|
size=(latents.shape[2], latents.shape[3]), |
|
|
mode='nearest' |
|
|
) |
|
|
|
|
|
|
|
|
if not torch.isnan(mask_logits).any() and not torch.isnan(gt_mask_resized).any(): |
|
|
|
|
|
mask_pred_loss = F.binary_cross_entropy_with_logits( |
|
|
mask_logits, |
|
|
gt_mask_resized, |
|
|
reduction='mean' |
|
|
) |
|
|
else: |
|
|
print("WARNING: NaN in mask_logits or gt_mask, skipping mask_pred_loss") |
|
|
mask_pred_loss = torch.tensor(0.0, device=latents.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
noise = torch.randn_like(latents) |
|
|
weighting_scheme = "uniform" |
|
|
u = compute_density_for_timestep_sampling( |
|
|
weighting_scheme=weighting_scheme, batch_size=latents.shape[0], |
|
|
logit_mean=0.0, logit_std=1.0, mode_scale=1.29 |
|
|
) |
|
|
indices = (u * self.get_model().noise_scheduler.config.num_train_timesteps).long() |
|
|
timesteps = self.get_model().noise_scheduler.timesteps[indices].to(device=latents.device) |
|
|
sigmas = self.get_sigmas(timesteps, latents.device, n_dim=latents.ndim, dtype=latents.dtype) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.spatial_ref_encoder is not None: |
|
|
vae = self.get_model().get_sana_vae() |
|
|
ref_latents = vae.encode(und_image.to(vae.device)).latent * vae.config.scaling_factor |
|
|
ref_latents = ref_latents.to(latents.device) |
|
|
spatial_cond = self.spatial_ref_encoder(ref_latents) |
|
|
spatial_cond = self.mask_drop(spatial_cond, getattr(self.config, 'spatial_drop_prob', 0.1)) |
|
|
else: |
|
|
spatial_cond = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.mask_encoder is not None and edit_mask is not None: |
|
|
mask_latent = F.interpolate( |
|
|
edit_mask.float().to(latents.device), |
|
|
size=(latents.shape[2], latents.shape[3]), |
|
|
mode='nearest' |
|
|
) |
|
|
mask_latent = mask_latent.clamp(0.0, 1.0) |
|
|
|
|
|
|
|
|
mask_cond = mask_latent |
|
|
for layer in self.mask_encoder: |
|
|
if isinstance(layer, nn.Conv2d): |
|
|
mask_cond = F.conv2d(mask_cond, layer.weight.float(), |
|
|
layer.bias.float() if layer.bias is not None else None, |
|
|
layer.stride, layer.padding) |
|
|
elif isinstance(layer, nn.GroupNorm): |
|
|
mask_cond = F.group_norm(mask_cond, layer.num_groups, |
|
|
layer.weight.float(), layer.bias.float(), layer.eps) |
|
|
else: |
|
|
mask_cond = layer(mask_cond) |
|
|
|
|
|
|
|
|
mask_cond = mask_cond.to(latents.dtype) |
|
|
mask_cond = self.mask_drop(mask_cond, getattr(self.config, 'mask_drop_prob', 0.1)) |
|
|
else: |
|
|
mask_cond = 0 |
|
|
mask_latent = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.operation_embedding is not None and operations is not None: |
|
|
op_indices = torch.tensor([self.get_operation_index(op) for op in operations], device=latents.device) |
|
|
op_embed = self.operation_embedding(op_indices)[:, :, None, None] |
|
|
op_cond = op_embed * mask_latent if mask_latent is not None else op_embed.expand(-1, -1, latents.shape[2], |
|
|
latents.shape[3]) |
|
|
else: |
|
|
op_cond = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
noisy_latents = (1.0 - sigmas) * latents + sigmas * noise |
|
|
combined_input = noisy_latents |
|
|
|
|
|
if self.mask_weight is not None and isinstance(mask_cond, torch.Tensor): |
|
|
combined_input = combined_input + self.mask_weight * mask_cond |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fused_features = self.get_model().diffusion_connector(img_hidden_states) |
|
|
|
|
|
diffusion_pred = self.get_model().dit( |
|
|
hidden_states=combined_input, timestep=timesteps, |
|
|
encoder_hidden_states=fused_features, encoder_attention_mask=attention_mask |
|
|
).sample |
|
|
|
|
|
target = latents - noise |
|
|
|
|
|
weighting = compute_loss_weighting_for_sd3(weighting_scheme=weighting_scheme, sigmas=sigmas) |
|
|
diff_loss = torch.mean( |
|
|
(weighting.float() * (diffusion_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1 |
|
|
).mean() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mask_pred_weight = getattr(self.config, 'mask_predictor_loss_weight', 0.5) |
|
|
total_loss = diff_loss + 0.2 * ce_loss + mask_pred_weight * mask_pred_loss |
|
|
|
|
|
|
|
|
if self.training: |
|
|
print(f"Loss - diff: {diff_loss.item():.4f}, ce: {ce_loss.item():.4f}, mask_pred: {mask_pred_loss.item() if isinstance(mask_pred_loss, torch.Tensor) else 0:.4f}") |
|
|
|
|
|
return CausalLMOutputWithPast( |
|
|
loss=total_loss, logits=logits, past_key_values=output.past_key_values, |
|
|
hidden_states=output.hidden_states, attentions=output.attentions |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def generate_edited_image( |
|
|
self, |
|
|
und_image: torch.Tensor, |
|
|
input_ids: torch.Tensor, |
|
|
attention_mask: torch.Tensor, |
|
|
num_inference_steps: int = 50, |
|
|
guidance_scale: float = 7.5, |
|
|
spatial_guidance_scale: float = 1.0, |
|
|
mask_guidance_scale: float = 1.0, |
|
|
generator: Optional[torch.Generator] = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Lightweight inference - uses learned mask predictor instead of SAM-2. |
|
|
|
|
|
Args: |
|
|
und_image: Input image tensor [B, C, H, W] |
|
|
input_ids: Tokenized prompt [B, seq_len] |
|
|
attention_mask: Attention mask [B, seq_len] |
|
|
num_inference_steps: Denoising steps |
|
|
guidance_scale: CFG scale for text |
|
|
spatial_guidance_scale: Scale for spatial conditioning |
|
|
mask_guidance_scale: Scale for predicted mask conditioning |
|
|
generator: Random generator for reproducibility |
|
|
|
|
|
Returns: |
|
|
Edited image latents [B, C, H, W] |
|
|
""" |
|
|
|
|
|
device = und_image.device |
|
|
dtype = und_image.dtype |
|
|
batch_size = und_image.shape[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
(input_ids_mm, position_ids, attention_mask_mm, _, |
|
|
inputs_embeds, _, _) = self.prepare_inputs_labels_for_multimodal( |
|
|
input_ids, None, attention_mask, None, None, None, und_image |
|
|
) |
|
|
|
|
|
output = Qwen2ForCausalLM.forward( |
|
|
self, |
|
|
input_ids=input_ids_mm, |
|
|
attention_mask=attention_mask_mm, |
|
|
position_ids=position_ids, |
|
|
inputs_embeds=inputs_embeds, |
|
|
output_hidden_states=True, |
|
|
return_dict=True |
|
|
) |
|
|
|
|
|
hidden_states = output.hidden_states |
|
|
img_hidden_states = hidden_states |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.mask_predictor is not None: |
|
|
last_hidden = hidden_states[-1] |
|
|
predicted_mask = self.mask_predictor(last_hidden) |
|
|
else: |
|
|
predicted_mask = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vae = self.get_model().get_sana_vae() |
|
|
ref_latents = vae.encode(und_image.to(vae.device)).latent * vae.config.scaling_factor |
|
|
ref_latents = ref_latents.to(device) |
|
|
|
|
|
latent_h, latent_w = ref_latents.shape[2], ref_latents.shape[3] |
|
|
latent_channels = ref_latents.shape[1] |
|
|
|
|
|
|
|
|
if predicted_mask is not None: |
|
|
predicted_mask = F.interpolate( |
|
|
predicted_mask, size=(latent_h, latent_w), mode='bilinear', align_corners=False |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.spatial_ref_encoder is not None: |
|
|
spatial_cond = self.spatial_ref_encoder(ref_latents) |
|
|
else: |
|
|
spatial_cond = torch.zeros_like(ref_latents) |
|
|
|
|
|
|
|
|
if self.mask_encoder is not None and predicted_mask is not None: |
|
|
mask_cond = self.mask_encoder(predicted_mask.to(dtype=self.mask_encoder[0].weight.dtype)) |
|
|
else: |
|
|
mask_cond = torch.zeros_like(ref_latents) |
|
|
|
|
|
|
|
|
fused_features = self.get_model().diffusion_connector(img_hidden_states) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if guidance_scale > 1.0: |
|
|
|
|
|
spatial_cond_uncond = torch.zeros_like(spatial_cond) |
|
|
mask_cond_uncond = torch.zeros_like(mask_cond) |
|
|
fused_features_uncond = torch.zeros_like(fused_features) |
|
|
|
|
|
|
|
|
spatial_cond_cfg = torch.cat([spatial_cond_uncond, spatial_cond]) |
|
|
mask_cond_cfg = torch.cat([mask_cond_uncond, mask_cond]) |
|
|
fused_features_cfg = torch.cat([fused_features_uncond, fused_features]) |
|
|
else: |
|
|
spatial_cond_cfg = spatial_cond |
|
|
mask_cond_cfg = mask_cond |
|
|
fused_features_cfg = fused_features |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
latents = randn_tensor( |
|
|
(batch_size, latent_channels, latent_h, latent_w), |
|
|
generator=generator, device=device, dtype=dtype |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scheduler = self.get_model().noise_scheduler |
|
|
scheduler.set_timesteps(num_inference_steps, device=device) |
|
|
timesteps = scheduler.timesteps |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for t in timesteps: |
|
|
|
|
|
if guidance_scale > 1.0: |
|
|
latent_model_input = torch.cat([latents] * 2) |
|
|
t_input = torch.cat([t.unsqueeze(0)] * 2 * batch_size) |
|
|
else: |
|
|
latent_model_input = latents |
|
|
t_input = t.unsqueeze(0).expand(batch_size) |
|
|
|
|
|
|
|
|
combined_input = latent_model_input |
|
|
if self.spatial_weight is not None: |
|
|
combined_input = combined_input + spatial_guidance_scale * self.spatial_weight * spatial_cond_cfg |
|
|
if self.mask_weight is not None: |
|
|
combined_input = combined_input + mask_guidance_scale * self.mask_weight * mask_cond_cfg |
|
|
|
|
|
|
|
|
noise_pred = self.get_model().dit( |
|
|
hidden_states=combined_input, |
|
|
timestep=t_input, |
|
|
encoder_hidden_states=fused_features_cfg, |
|
|
).sample |
|
|
|
|
|
|
|
|
if guidance_scale > 1.0: |
|
|
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) |
|
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) |
|
|
|
|
|
|
|
|
latents = scheduler.step(noise_pred, t, latents).prev_sample |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
latents = latents / vae.config.scaling_factor |
|
|
image = vae.decode(latents.to(vae.device)).sample |
|
|
|
|
|
return image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AutoConfig.register("llava_qwen2", blip3oFastConfig) |
|
|
AutoModelForCausalLM.register(blip3oFastConfig, blip3oFastForCausalLM) |
|
|
|