|
|
""" |
|
|
BLIP3o Fast - Unified Image Understanding and Generation with Mask Prediction |
|
|
|
|
|
This module provides: |
|
|
- Training: Diffusion-based image editing with mask supervision from SAM |
|
|
- Inference: Lightweight mask-free editing using learned MaskPredictor |
|
|
|
|
|
Key Components (from llava_arch.py): |
|
|
- MaskPredictor: Learns to predict edit regions from LLM hidden states |
|
|
- MaskEncoder: Encodes masks for diffusion conditioning |
|
|
- mask_weight/spatial_weight: Learnable conditioning scales (SAVED with model!) |
|
|
|
|
|
Training Flow: |
|
|
1. LLM processes image + instruction → hidden states |
|
|
2. MaskPredictor predicts edit mask (supervised by SAM) |
|
|
3. Diffusion generates edited image with mask conditioning |
|
|
|
|
|
Inference Flow: |
|
|
1. LLM processes image + instruction → hidden states |
|
|
2. MaskPredictor predicts edit mask (NO SAM needed!) |
|
|
3. Diffusion generates edited image |
|
|
""" |
|
|
|
|
|
from typing import List, Optional, Tuple, Union, Dict, Any |
|
|
import json |
|
|
import re |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from transformers import ( |
|
|
AutoConfig, |
|
|
AutoModelForCausalLM, |
|
|
AutoTokenizer, |
|
|
Qwen2Config, |
|
|
Qwen2Model, |
|
|
Qwen2ForCausalLM |
|
|
) |
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
from diffusers.training_utils import ( |
|
|
compute_density_for_timestep_sampling, |
|
|
compute_loss_weighting_for_sd3 |
|
|
) |
|
|
from diffusers.utils.torch_utils import randn_tensor |
|
|
|
|
|
from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Qwen3InstructionParser: |
|
|
"""Parses edit instructions using Qwen3 LLM. Used only during training.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_name: str = "Qwen/Qwen3-1.7B", |
|
|
device: str = "cuda", |
|
|
torch_dtype: torch.dtype = torch.float16 |
|
|
): |
|
|
self.device = device |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch_dtype, |
|
|
device_map=device |
|
|
) |
|
|
self.model.eval() |
|
|
self._cache: Dict[str, Dict] = {} |
|
|
|
|
|
@torch.no_grad() |
|
|
def parse(self, instruction: str) -> Dict[str, Any]: |
|
|
if instruction in self._cache: |
|
|
return self._cache[instruction] |
|
|
|
|
|
prompt = self._build_prompt(instruction) |
|
|
messages = [{"role": "user", "content": prompt}] |
|
|
text = self.tokenizer.apply_chat_template( |
|
|
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False |
|
|
) |
|
|
inputs = self.tokenizer(text, return_tensors="pt").to(self.device) |
|
|
outputs = self.model.generate( |
|
|
**inputs, max_new_tokens=256, temperature=0.1, |
|
|
do_sample=False, pad_token_id=self.tokenizer.eos_token_id |
|
|
) |
|
|
response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) |
|
|
parsed = self._parse_response(response) |
|
|
self._cache[instruction] = parsed |
|
|
return parsed |
|
|
|
|
|
def _build_prompt(self, instruction: str) -> str: |
|
|
return f"""You are an image editing instruction parser. Extract structured information. |
|
|
|
|
|
Respond ONLY with valid JSON: |
|
|
{{"operation": "<type>", "source_object": "<object or null>", "target_object": "<object or null>", "location": "<location or null>", "attributes": "<attributes or null>"}} |
|
|
|
|
|
Operation types: remove, replace, add, extract, style, adjust, compose, action, other |
|
|
|
|
|
Examples: |
|
|
"Remove the red car" -> {{"operation": "remove", "source_object": "red car", "target_object": null, "location": null, "attributes": null}} |
|
|
"Replace the dog with a cat" -> {{"operation": "replace", "source_object": "dog", "target_object": "cat", "location": null, "attributes": null}} |
|
|
"Make the dress blue" -> {{"operation": "adjust", "source_object": "dress", "target_object": null, "location": null, "attributes": "blue"}} |
|
|
|
|
|
Input: "{instruction}" |
|
|
Output:""" |
|
|
|
|
|
def _parse_response(self, response: str) -> Dict[str, Any]: |
|
|
default = {"operation": "other", "source_object": None, "target_object": None, "location": None, "attributes": None} |
|
|
try: |
|
|
parsed = json.loads(response.strip()) |
|
|
except json.JSONDecodeError: |
|
|
match = re.search(r'\{[^{}]*\}', response, re.DOTALL) |
|
|
if match: |
|
|
try: |
|
|
parsed = json.loads(match.group()) |
|
|
except: |
|
|
return default |
|
|
else: |
|
|
return default |
|
|
return {**default, **parsed} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EditMaskGenerator: |
|
|
"""Generates ground truth edit masks using Qwen3 + SAM. Training only.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
qwen_model: str = "Qwen/Qwen3-1.7B", |
|
|
sam_model: str = "facebook/sam2.1-hiera-large", |
|
|
device: str = "cuda", |
|
|
enabled: bool = True |
|
|
): |
|
|
self.enabled = enabled |
|
|
self.device = device |
|
|
if not enabled: |
|
|
return |
|
|
|
|
|
self.parser = Qwen3InstructionParser(model_name=qwen_model, device=device) |
|
|
|
|
|
try: |
|
|
from sam2.sam2_image_predictor import SAM2ImagePredictor |
|
|
self.sam = SAM2ImagePredictor.from_pretrained(sam_model, device=device) |
|
|
except ImportError: |
|
|
print("WARNING: SAM2 not installed. Mask generation disabled.") |
|
|
self.enabled = False |
|
|
|
|
|
def generate(self, image: torch.Tensor, instruction: str, return_parsed: bool = False): |
|
|
"""Generate edit mask from image and instruction.""" |
|
|
if not self.enabled: |
|
|
H, W = image.shape[-2:] |
|
|
mask = torch.zeros(1, H, W, device=self.device) |
|
|
return (mask, {"operation": "other"}) if return_parsed else mask |
|
|
|
|
|
|
|
|
parsed = self.parser.parse(instruction) |
|
|
source_object = parsed.get("source_object") |
|
|
|
|
|
if not source_object: |
|
|
H, W = image.shape[-2:] |
|
|
mask = torch.zeros(1, H, W, device=self.device) |
|
|
return (mask, parsed) if return_parsed else mask |
|
|
|
|
|
|
|
|
if image.dim() == 3: |
|
|
image = image.unsqueeze(0) |
|
|
image_np = ((image[0].permute(1, 2, 0).cpu().numpy() + 1) * 127.5).astype("uint8") |
|
|
|
|
|
|
|
|
with torch.inference_mode(): |
|
|
self.sam.set_image(image_np) |
|
|
|
|
|
|
|
|
H, W = image_np.shape[:2] |
|
|
point_coords = [[W // 2, H // 2]] |
|
|
point_labels = [1] |
|
|
|
|
|
masks, scores, _ = self.sam.predict( |
|
|
point_coords=point_coords, |
|
|
point_labels=point_labels, |
|
|
multimask_output=True |
|
|
) |
|
|
|
|
|
|
|
|
best_idx = scores.argmax() |
|
|
mask = torch.from_numpy(masks[best_idx]).float().unsqueeze(0).to(self.device) |
|
|
|
|
|
return (mask, parsed) if return_parsed else mask |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class blip3oFastConfig(Qwen2Config): |
|
|
model_type = "llava_qwen2" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
use_mask_predictor: bool = True, |
|
|
use_mask_conditioning: bool = True, |
|
|
use_spatial_conditioning: bool = False, |
|
|
use_operation_embedding: bool = False, |
|
|
mask_predictor_loss_weight: float = 0.5, |
|
|
latent_channels: int = 32, |
|
|
latent_size: int = 32, |
|
|
num_operation_types: int = 10, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.use_mask_predictor = use_mask_predictor |
|
|
self.use_mask_conditioning = use_mask_conditioning |
|
|
self.use_spatial_conditioning = use_spatial_conditioning |
|
|
self.use_operation_embedding = use_operation_embedding |
|
|
self.mask_predictor_loss_weight = mask_predictor_loss_weight |
|
|
self.latent_channels = latent_channels |
|
|
self.latent_size = latent_size |
|
|
self.num_operation_types = num_operation_types |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class blip3oFastModel(LlavaMetaModel, Qwen2Model): |
|
|
config_class = blip3oFastConfig |
|
|
|
|
|
def __init__(self, config: Qwen2Config): |
|
|
super(blip3oFastModel, self).__init__(config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class blip3oFastForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM): |
|
|
""" |
|
|
BLIP3o Fast model for training. |
|
|
|
|
|
All mask-related components (mask_predictor, mask_encoder, mask_weight, etc.) |
|
|
are defined in LlavaMetaModel and accessed via properties in LlavaMetaForCausalLM. |
|
|
This ensures they are saved/loaded with the model. |
|
|
""" |
|
|
|
|
|
config_class = blip3oFastConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super(blip3oFastForCausalLM, self).__init__(config) |
|
|
config.model_type = "llava_qwen2" |
|
|
|
|
|
self.model = blip3oFastModel(config) |
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
|
|
|
self.operation_types = ["remove", "replace", "add", "extract", "style", |
|
|
"adjust", "compose", "action", "inpaint", "other"] |
|
|
|
|
|
|
|
|
self._mask_generator = None |
|
|
self._mask_generator_initialized = False |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def get_model(self): |
|
|
return self.model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
def mask_generator(self) -> EditMaskGenerator: |
|
|
"""Lazy init mask generator (training only).""" |
|
|
if not self._mask_generator_initialized: |
|
|
enabled = getattr(self.config, 'mask_generator_enabled', True) and self.training |
|
|
if enabled: |
|
|
self._mask_generator = EditMaskGenerator( |
|
|
qwen_model=getattr(self.config, 'qwen_model', "Qwen/Qwen3-1.7B"), |
|
|
device=str(self.device), |
|
|
enabled=True |
|
|
) |
|
|
else: |
|
|
self._mask_generator = EditMaskGenerator(enabled=False) |
|
|
self._mask_generator_initialized = True |
|
|
return self._mask_generator |
|
|
|
|
|
def get_operation_index(self, operation: str) -> int: |
|
|
if self.operation_types is None: |
|
|
return 0 |
|
|
return self.operation_types.index(operation) if operation in self.operation_types else self.operation_types.index("other") |
|
|
|
|
|
def _normalize_mask(self, mask, H, W, device): |
|
|
"""Normalize mask to [1, H, W] format.""" |
|
|
if mask is None: |
|
|
return torch.zeros(1, H, W, device=device) |
|
|
|
|
|
if not isinstance(mask, torch.Tensor): |
|
|
mask = torch.from_numpy(mask) |
|
|
|
|
|
mask = mask.to(device) |
|
|
|
|
|
if mask.dim() == 4: |
|
|
mask = mask[:, 0] |
|
|
mask = mask.max(dim=0, keepdim=True)[0] |
|
|
elif mask.dim() == 3: |
|
|
pass |
|
|
elif mask.dim() == 2: |
|
|
mask = mask.unsqueeze(0) |
|
|
else: |
|
|
raise ValueError(f"Unexpected mask shape: {mask.shape}") |
|
|
|
|
|
return mask |
|
|
|
|
|
def _generate_masks_on_fly(self, und_images: torch.Tensor, instructions: List[str]) -> Tuple[torch.Tensor, List[str]]: |
|
|
"""Generate GT masks using Qwen3 + SAM (training only).""" |
|
|
masks, operations = [], [] |
|
|
B, _, H, W = und_images.shape |
|
|
for i in range(und_images.shape[0]): |
|
|
try: |
|
|
mask, parsed = self.mask_generator.generate(und_images[i], instructions[i], return_parsed=True) |
|
|
mask = self._normalize_mask(mask, H=H, W=W, device=und_images.device) |
|
|
masks.append(mask) |
|
|
operations.append(parsed.get("operation", "other")) |
|
|
except Exception as e: |
|
|
print(f"Mask generation failed: {e}") |
|
|
masks.append(torch.zeros(1, H, W, device=und_images.device)) |
|
|
operations.append("other") |
|
|
return torch.stack(masks).to(und_images.device), operations |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
gen_image: Optional[torch.FloatTensor] = None, |
|
|
und_image: Optional[torch.FloatTensor] = None, |
|
|
edit_mask: Optional[torch.FloatTensor] = None, |
|
|
operations: Optional[List[str]] = None, |
|
|
instructions: Optional[List[str]] = None, |
|
|
categories: Optional[List[str]] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None |
|
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
|
|
|
|
output_hidden_states = True |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
|
|
|
if inputs_embeds is None: |
|
|
(input_ids, position_ids, attention_mask, past_key_values, |
|
|
inputs_embeds, labels, latents) = self.prepare_inputs_labels_for_multimodal( |
|
|
input_ids, position_ids, attention_mask, past_key_values, |
|
|
labels, gen_image, und_image |
|
|
) |
|
|
else: |
|
|
latents = None |
|
|
|
|
|
|
|
|
output = Qwen2ForCausalLM.forward( |
|
|
self, |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=True, |
|
|
return_dict=True |
|
|
) |
|
|
logits = output.logits |
|
|
img_hidden_states = output.hidden_states |
|
|
|
|
|
|
|
|
if labels is not None: |
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
ce_loss = F.cross_entropy( |
|
|
shift_logits.view(-1, self.config.vocab_size), |
|
|
shift_labels.view(-1), |
|
|
ignore_index=-100 |
|
|
) |
|
|
else: |
|
|
ce_loss = torch.tensor(0.0, device=logits.device) |
|
|
|
|
|
|
|
|
if latents is None: |
|
|
return CausalLMOutputWithPast( |
|
|
loss=ce_loss, logits=logits, past_key_values=output.past_key_values, |
|
|
hidden_states=output.hidden_states, attentions=output.attentions |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if edit_mask is None and instructions is not None and self.training: |
|
|
edit_mask, operations = self._generate_masks_on_fly(und_image, instructions) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mask_pred_loss = torch.tensor(0.0, device=latents.device) |
|
|
|
|
|
if self.mask_predictor is not None: |
|
|
last_hidden = img_hidden_states[-1] |
|
|
mask_logits = self.mask_predictor(last_hidden, return_logits=True) |
|
|
|
|
|
if edit_mask is not None and self.training: |
|
|
gt_mask_resized = F.interpolate( |
|
|
edit_mask.float().to(latents.device), |
|
|
size=(latents.shape[2], latents.shape[3]), |
|
|
mode='nearest' |
|
|
) |
|
|
|
|
|
if not torch.isnan(mask_logits).any() and not torch.isnan(gt_mask_resized).any(): |
|
|
mask_pred_loss = F.binary_cross_entropy_with_logits( |
|
|
mask_logits, |
|
|
gt_mask_resized, |
|
|
reduction='mean' |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
noise = torch.randn_like(latents) |
|
|
weighting_scheme = "uniform" |
|
|
u = compute_density_for_timestep_sampling( |
|
|
weighting_scheme=weighting_scheme, batch_size=latents.shape[0], |
|
|
logit_mean=0.0, logit_std=1.0, mode_scale=1.29 |
|
|
) |
|
|
indices = (u * self.get_model().noise_scheduler.config.num_train_timesteps).long() |
|
|
timesteps = self.get_model().noise_scheduler.timesteps[indices].to(device=latents.device) |
|
|
sigmas = self.get_sigmas(timesteps, latents.device, n_dim=latents.ndim, dtype=latents.dtype) |
|
|
|
|
|
|
|
|
mask_cond = 0 |
|
|
if self.mask_encoder is not None and edit_mask is not None: |
|
|
mask_latent = F.interpolate( |
|
|
edit_mask.float().to(latents.device), |
|
|
size=(latents.shape[2], latents.shape[3]), |
|
|
mode='nearest' |
|
|
).clamp(0.0, 1.0) |
|
|
mask_cond = self.mask_encoder(mask_latent) |
|
|
mask_cond = self.mask_drop(mask_cond, getattr(self.config, 'mask_drop_prob', 0.1)) |
|
|
|
|
|
|
|
|
noisy_latents = (1.0 - sigmas) * latents + sigmas * noise |
|
|
combined_input = noisy_latents |
|
|
|
|
|
if self.mask_weight is not None and isinstance(mask_cond, torch.Tensor): |
|
|
combined_input = combined_input + self.mask_weight * mask_cond |
|
|
|
|
|
|
|
|
fused_features = self.get_model().diffusion_connector(img_hidden_states) |
|
|
|
|
|
diffusion_pred = self.get_model().dit( |
|
|
hidden_states=combined_input, timestep=timesteps, |
|
|
encoder_hidden_states=fused_features, encoder_attention_mask=attention_mask |
|
|
).sample |
|
|
|
|
|
|
|
|
target = latents - noise |
|
|
weighting = compute_loss_weighting_for_sd3(weighting_scheme=weighting_scheme, sigmas=sigmas) |
|
|
diff_loss = torch.mean( |
|
|
(weighting.float() * (diffusion_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1 |
|
|
).mean() |
|
|
|
|
|
|
|
|
mask_pred_weight = getattr(self.config, 'mask_predictor_loss_weight', 0.5) |
|
|
total_loss = diff_loss + 0.2 * ce_loss + mask_pred_weight * mask_pred_loss |
|
|
|
|
|
if self.training: |
|
|
print(f"Loss - diff: {diff_loss.item():.4f}, ce: {ce_loss.item():.4f}, mask_pred: {mask_pred_loss.item():.4f}") |
|
|
|
|
|
return CausalLMOutputWithPast( |
|
|
loss=total_loss, logits=logits, past_key_values=output.past_key_values, |
|
|
hidden_states=output.hidden_states, attentions=output.attentions |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def generate_edited_image( |
|
|
self, |
|
|
und_image: torch.Tensor, |
|
|
input_ids: torch.Tensor, |
|
|
attention_mask: torch.Tensor, |
|
|
num_inference_steps: int = 50, |
|
|
guidance_scale: float = 7.5, |
|
|
mask_guidance_scale: float = 1.0, |
|
|
generator: Optional[torch.Generator] = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Generate edited image using learned mask predictor. |
|
|
NO external segmentation model needed! |
|
|
""" |
|
|
device = und_image.device |
|
|
dtype = und_image.dtype |
|
|
batch_size = und_image.shape[0] |
|
|
|
|
|
|
|
|
(input_ids_mm, position_ids, attention_mask_mm, _, |
|
|
inputs_embeds, _, _) = self.prepare_inputs_labels_for_multimodal( |
|
|
input_ids, None, attention_mask, None, None, None, und_image |
|
|
) |
|
|
|
|
|
output = Qwen2ForCausalLM.forward( |
|
|
self, |
|
|
input_ids=input_ids_mm, |
|
|
attention_mask=attention_mask_mm, |
|
|
position_ids=position_ids, |
|
|
inputs_embeds=inputs_embeds, |
|
|
output_hidden_states=True, |
|
|
return_dict=True |
|
|
) |
|
|
hidden_states = output.hidden_states |
|
|
|
|
|
|
|
|
predicted_mask = None |
|
|
if self.mask_predictor is not None: |
|
|
last_hidden = hidden_states[-1] |
|
|
predicted_mask = self.mask_predictor(last_hidden) |
|
|
|
|
|
|
|
|
vae = self.get_model().get_sana_vae() |
|
|
ref_latents = vae.encode(und_image.to(vae.device)).latent * vae.config.scaling_factor |
|
|
ref_latents = ref_latents.to(device) |
|
|
|
|
|
latent_h, latent_w = ref_latents.shape[2], ref_latents.shape[3] |
|
|
latent_channels = ref_latents.shape[1] |
|
|
|
|
|
|
|
|
if predicted_mask is not None: |
|
|
predicted_mask = F.interpolate( |
|
|
predicted_mask, size=(latent_h, latent_w), mode='bilinear', align_corners=False |
|
|
) |
|
|
|
|
|
|
|
|
mask_cond = torch.zeros_like(ref_latents) |
|
|
if self.mask_encoder is not None and predicted_mask is not None: |
|
|
mask_cond = self.mask_encoder(predicted_mask.to(dtype)) |
|
|
|
|
|
|
|
|
fused_features = self.get_model().diffusion_connector(hidden_states) |
|
|
|
|
|
|
|
|
if guidance_scale > 1.0: |
|
|
mask_cond_cfg = torch.cat([torch.zeros_like(mask_cond), mask_cond]) |
|
|
fused_features_cfg = torch.cat([torch.zeros_like(fused_features), fused_features]) |
|
|
else: |
|
|
mask_cond_cfg = mask_cond |
|
|
fused_features_cfg = fused_features |
|
|
|
|
|
|
|
|
latents = randn_tensor( |
|
|
(batch_size, latent_channels, latent_h, latent_w), |
|
|
generator=generator, device=device, dtype=dtype |
|
|
) |
|
|
|
|
|
|
|
|
scheduler = self.get_model().noise_scheduler |
|
|
scheduler.set_timesteps(num_inference_steps, device=device) |
|
|
timesteps = scheduler.timesteps |
|
|
|
|
|
for t in timesteps: |
|
|
if guidance_scale > 1.0: |
|
|
latent_model_input = torch.cat([latents] * 2) |
|
|
t_input = torch.cat([t.unsqueeze(0)] * 2 * batch_size) |
|
|
else: |
|
|
latent_model_input = latents |
|
|
t_input = t.unsqueeze(0).expand(batch_size) |
|
|
|
|
|
|
|
|
combined_input = latent_model_input |
|
|
if self.mask_weight is not None: |
|
|
combined_input = combined_input + mask_guidance_scale * self.mask_weight * mask_cond_cfg |
|
|
|
|
|
|
|
|
noise_pred = self.get_model().dit( |
|
|
hidden_states=combined_input, |
|
|
timestep=t_input, |
|
|
encoder_hidden_states=fused_features_cfg, |
|
|
).sample |
|
|
|
|
|
|
|
|
if guidance_scale > 1.0: |
|
|
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) |
|
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) |
|
|
|
|
|
|
|
|
latents = scheduler.step(noise_pred, t, latents).prev_sample |
|
|
|
|
|
|
|
|
latents = latents / vae.config.scaling_factor |
|
|
image = vae.decode(latents.to(vae.device)).sample |
|
|
|
|
|
return image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AutoConfig.register("llava_qwen2", blip3oFastConfig) |
|
|
AutoModelForCausalLM.register(blip3oFastConfig, blip3oFastForCausalLM) |
|
|
|