XinxuanLu's picture
Initial demo
becf13a verified
import torch
import math
import numpy as np
import torch.nn as nn
from einops import rearrange
from transformers.cache_utils import DynamicCache
from src.builder import BUILDER
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence
from xtuner.utils import IMAGE_TOKEN_INDEX
from .viewpoint_mlp import ViewpointTokenMLP
def build_mlp(hidden_size, projector_dim, z_dim):
return nn.Sequential(
nn.Linear(hidden_size, projector_dim),
nn.SiLU(),
nn.Linear(projector_dim, z_dim),)
def mask_by_order(mask_len, order, bsz, seq_len):
masking = torch.zeros(bsz, seq_len, device=order.device)
masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()],
src=torch.ones(bsz, seq_len, device=order.device)).bool()
return masking
class Harmon(nn.Module):
def __init__(self,
vae,
vae_scale,
llm,
mar,
tokenizer,
prompt_template,
# Viewpoint conditioning
use_viewpoint_tokens=False,
num_view_tokens=8,
viewpoint_mlp_config=None):
super().__init__()
# VAE
self.vae = BUILDER.build(vae)
self.vae.requires_grad_(False)
self.vae_scale = vae_scale
# LLM
self.llm = BUILDER.build(llm)
self.tokenizer = BUILDER.build(tokenizer)
self.prompt_template = prompt_template
# MAR
self.mar = BUILDER.build(mar)
# projection layers
self.proj_in = build_mlp(hidden_size=self.mar.encoder_embed_dim,
projector_dim=self.llm.config.hidden_size,
z_dim=self.llm.config.hidden_size)
self.proj_out = build_mlp(hidden_size=self.llm.config.hidden_size,
projector_dim=self.llm.config.hidden_size,
z_dim=self.mar.encoder_embed_dim)
# Viewpoint token system
self.use_viewpoint_tokens = use_viewpoint_tokens
self.num_view_tokens = num_view_tokens if use_viewpoint_tokens else 0
self.viewpoint_mlp_config = viewpoint_mlp_config
# Note: Actual viewpoint token initialization is deferred to _init_viewpoint_tokens()
# This allows loading pretrained checkpoints before resizing embeddings
if use_viewpoint_tokens:
self.view_tokens = [f"<view_token_{i}>" for i in range(num_view_tokens)]
self.view_token_ids = [] # Will be populated in _init_viewpoint_tokens()
else:
self.view_tokens = []
self.view_token_ids = []
def _init_viewpoint_tokens(self):
"""
Initialize viewpoint tokens after loading pretrained checkpoint.
This method should be called after load_state_dict() in HarmonDev.
Note: New view tokens will have randomly initialized embeddings since they
don't exist in the pretrained checkpoint. These will be trained from scratch.
"""
if not self.use_viewpoint_tokens:
return
print(f"\n[VIEWPOINT INIT DEBUG]", flush=True)
print(f" Tokenizer before adding view tokens:", flush=True)
print(f" Type: {type(self.tokenizer)}", flush=True)
print(f" Vocab size: {len(self.tokenizer)}", flush=True)
print(f" LLM embedding size: {self.llm.get_input_embeddings().weight.shape[0]}", flush=True)
# Check for mismatch (pre-existing in Harmon, not a bug introduced by viewpoint tokens)
tokenizer_vocab_size = len(self.tokenizer)
embedding_size = self.llm.get_input_embeddings().weight.shape[0]
if tokenizer_vocab_size != embedding_size:
print(f" NOTE: Vocab size mismatch ({tokenizer_vocab_size} vs {embedding_size})", flush=True)
print(f" This is expected if using a different HF tokenizer version than checkpoint.", flush=True)
print(f" View tokens will use IDs starting from {tokenizer_vocab_size}", flush=True)
# Check if tokens already exist (e.g., from resumed checkpoint)
existing_tokens = [
token for token in self.view_tokens
if token in self.tokenizer.get_vocab()
]
if existing_tokens:
print(f"\n [Viewpoint] Found {len(existing_tokens)} existing view tokens in tokenizer", flush=True)
# Tokens already exist, just create ID mapping
self.view_token_ids = [
self.tokenizer.convert_tokens_to_ids(token)
for token in self.view_tokens
]
print(f" View token IDs: {self.view_token_ids}", flush=True)
else:
print(f"\n [Viewpoint] Adding {len(self.view_tokens)} new view tokens...", flush=True)
# Add special viewpoint tokens to tokenizer
num_added = self.tokenizer.add_tokens(self.view_tokens, special_tokens=True)
print(f" [Viewpoint] Added {num_added} new view tokens to tokenizer", flush=True)
# Resize embeddings: use LLM's configured vocab_size + new tokens
# This ensures compatibility with pretrained checkpoints
old_vocab_size = self.llm.config.vocab_size
target_vocab_size = old_vocab_size + num_added
self.llm.resize_token_embeddings(target_vocab_size)
print(f" [Viewpoint] Resized LLM embeddings: {old_vocab_size} -> {target_vocab_size}", flush=True)
print(f" [Viewpoint] New tokens randomly initialized", flush=True)
# Create view token ID mapping
self.view_token_ids = [
self.tokenizer.convert_tokens_to_ids(token)
for token in self.view_tokens
]
print(f" View token IDs: {self.view_token_ids}", flush=True)
print(f"\n Tokenizer after adding view tokens:", flush=True)
print(f" Vocab size: {len(self.tokenizer)}", flush=True)
print(f" LLM embedding size: {self.llm.get_input_embeddings().weight.shape[0]}", flush=True)
print(f"[END VIEWPOINT INIT DEBUG]\n", flush=True)
# Initialize viewpoint MLP
if self.viewpoint_mlp_config is None:
self.viewpoint_mlp_config = dict(
hidden_dim=1024,
output_dim=self.llm.config.hidden_size,
num_layers=5,
num_view_tokens=self.num_view_tokens,
num_params=2, # Will be set by viewpoint_param_type ('spherical': 2, 'rotation_translation': 9)
num_freqs=16,
viewpoint_param_type='spherical', # Default to spherical mode
)
self.viewpoint_mlp = ViewpointTokenMLP(**self.viewpoint_mlp_config)
# Move MLP to same device and dtype as model
self.viewpoint_mlp = self.viewpoint_mlp.to(device=self.device, dtype=self.dtype)
print(f" Viewpoint MLP moved to device={self.device}, dtype={self.dtype}", flush=True)
@property
def llm_model(self):
if hasattr(self, '_llm_base_model'):
return self._llm_base_model
return self.llm.model
@property
def device(self):
return self.llm.device
@property
def dtype(self):
return self.llm.dtype
@property
def gen_seq_len(self):
return self.mar.seq_len
@property
def token_embed_dim(self):
return self.vae.embed_dim * (self.mar.patch_size ** 2)
def inject_viewpoint_embeddings(self, input_ids, viewpoint_params, inputs_embeds, valid_mask=None, num_objects=None):
"""
Replace viewpoint token IDs with learned embeddings.
Args:
input_ids: (batch, seq_len) token IDs
viewpoint_params: (batch, num_params) or (batch, num_objects * num_params) viewpoint parameters
Spherical mode: [azimuth, elevation] or [az1, el1, az2, el2] for multi-object
Rotation_translation mode: [rot_9d, trans_3d] or flattened for multi-object
inputs_embeds: (batch, seq_len, hidden_dim) input embeddings to modify
valid_mask: (batch, num_params) or (batch, num_objects * num_params) boolean mask
If None, assumes all parameters are valid
num_objects: (batch,) tensor or None
Number of objects per sample (1 or 2). If None, assumes single object.
Returns:
inputs_embeds: (batch, seq_len, hidden_dim) with viewpoint tokens injected
"""
if not self.use_viewpoint_tokens:
raise ValueError("Viewpoint tokens are not enabled in this model")
batch_size = input_ids.shape[0]
# Generate viewpoint token embeddings (should be in self.dtype from MLP)
# For multi-object: returns (batch, num_objects * num_view_tokens, hidden_dim)
# For single-object: returns (batch, num_view_tokens, hidden_dim)
view_embeddings = self.viewpoint_mlp(viewpoint_params, valid_mask, num_objects=num_objects)
# Determine if multi-object mode
is_multi_object = num_objects is not None and num_objects.max() > 1
# Find positions of viewpoint tokens and replace
for token_idx, token_id in enumerate(self.view_token_ids):
mask = (input_ids == token_id) # (batch, seq_len)
if mask.any():
# For each sample in batch, replace the token embedding
for batch_idx in range(batch_size):
if mask[batch_idx].any():
# Find positions where this token appears
token_positions = mask[batch_idx].nonzero(as_tuple=True)[0]
if is_multi_object:
# Multi-object mode: each token can appear multiple times (once per object)
expected_occurrences = num_objects[batch_idx].item()
if len(token_positions) != expected_occurrences:
raise ValueError(
f"View token {self.view_tokens[token_idx]} appears "
f"{len(token_positions)} times in sequence (expected {expected_occurrences}). "
f"For multi-object mode, each token should appear once per object."
)
# Replace each occurrence with the correct embedding
for occurrence_idx, pos in enumerate(token_positions):
# Calculate embedding index: object_idx * num_view_tokens + token_idx
embedding_idx = occurrence_idx * self.num_view_tokens + token_idx
inputs_embeds[batch_idx, pos] = view_embeddings[batch_idx, embedding_idx, :]
else:
# Single-object mode: each view token should appear exactly once
if len(token_positions) != 1:
raise ValueError(
f"View token {self.view_tokens[token_idx]} appears "
f"{len(token_positions)} times in sequence (expected 1). "
f"This may indicate the token accidentally appeared in the caption."
)
# Replace with the corresponding view embedding
inputs_embeds[batch_idx, token_positions] = view_embeddings[batch_idx, token_idx, :]
return inputs_embeds
@torch.no_grad()
def encode(self, x):
posterior = self.vae.encode(x)
z = posterior.mode().mul_(self.vae_scale)
z = rearrange(z, 'b c (m p) (n q) -> b m n (c p q)',
p=self.mar.patch_size, q=self.mar.patch_size)
return z
@torch.no_grad()
def decode(self, z):
z /= self.vae_scale
z = rearrange(z, 'b m n (c p q) -> b c (m p) (n q)',
p=self.mar.patch_size, q=self.mar.patch_size)
x = self.vae.decode(z)
return x
def prepare_forward_input(self,
x,
inputs_embeds=None,
input_ids=None,
attention_mask=None,
past_key_values=None):
b, l, _ = x.shape
attention_mask = attention_mask.to(device=self.device, dtype=torch.bool)
attention_mask = torch.cat([
attention_mask, attention_mask.new_ones(b, l)
], dim=1)
position_ids = torch.cumsum(attention_mask, dim=1) - 1
position_ids[position_ids < 0] = 0
# prepare context
if past_key_values is not None:
inputs_embeds = x
position_ids = position_ids[:, -l:]
else:
if inputs_embeds is None:
input_ids = input_ids.to(self.device)
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
inputs_embeds = torch.cat([inputs_embeds, x], dim=1)
return dict(inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values)
def extract_visual_feature(self, x, mask=None, detach=False):
b, m, n, _ = x.shape
x = x.view(b, m*n, -1)
# x: b mn c
if mask is None:
mask = torch.zeros_like(x[..., 0])
null_embeds = self.mar.fake_latent.expand(x.shape[0], -1)
x_enc = self.mar.forward_mae_encoder(x, mask, null_embeds, image_shape=(m, n))
z_enc = self.proj_in(x_enc)
# Move buffers to the end of the image sequence
z_enc = torch.cat([
z_enc[:, self.mar.buffer_size:],
z_enc[:, :self.mar.buffer_size]], dim=1)
if detach:
x_enc = x_enc.detach()
z_enc = z_enc.detach()
return x_enc, z_enc
def forward_mae_encoder(self, x, mask, detach=False, inputs_embeds=None, **context):
b, m, n, _ = x.shape
x_enc, z_enc = self.extract_visual_feature(x, mask=mask, detach=detach)
inputs = self.prepare_forward_input(x=z_enc, inputs_embeds=inputs_embeds, **context)
output = self.llm_model(**inputs, return_dict=True)
z_llm = output.last_hidden_state[:, -z_enc.shape[1]:]
# move buffers back to the start of the image sequence
z_llm = torch.cat([
z_llm[:, -self.mar.buffer_size:],
z_llm[:, :-self.mar.buffer_size]], dim=1)
# residual learning
x_enc = x_enc + self.proj_out(z_llm)
return x_enc
@staticmethod
def curtail_cache(past_key_values, cur_len):
for past_key_values_ in past_key_values:
keys, values = past_key_values_
keys.data = keys.data[:, :, :cur_len]
values.data = values.data[:, :, :cur_len]
@torch.no_grad()
def prepare_text_conditions(self, prompt, cfg_prompt='Generate an image.'):
all_prompts = [self.prompt_template['INSTRUCTION'].format(input=prompt),
self.prompt_template['INSTRUCTION'].format(input=cfg_prompt)]
input_ids = [self.tokenizer.encode(p, add_special_tokens=True, return_tensors='pt')[0]
for p in all_prompts]
valid_lens = [len(input_ids_) for input_ids_ in input_ids]
input_ids = pad_sequence(input_ids, batch_first=True,
padding_value=self.tokenizer.eos_token_id)
attention_mask = torch.zeros_like(input_ids).bool()
for i in range(len(input_ids)):
attention_mask[i, :valid_lens[i]] = True
return dict(input_ids=input_ids.to(self.device),
attention_mask=attention_mask.to(self.device))
@torch.no_grad()
def sample(self,
input_ids=None, inputs_embeds=None,
attention_mask=None, num_iter=64, cfg=1.0, cfg_schedule="constant", temperature=1.0,
progress=False, mask=None, past_key_values=None, image_shape=None, x_con=None, **kwargs):
if inputs_embeds is None and input_ids is not None:
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
bsz = attention_mask.shape[0]
if cfg != 1.0:
assert bsz % 2 == 0
if image_shape is None:
m = n = int(self.gen_seq_len ** 0.5)
else:
m, n = image_shape
if mask is None:
mask = torch.ones(bsz, m*n, device=self.device, dtype=self.dtype)
else:
mask = mask.view(bsz, m*n)
tokens = torch.zeros(bsz, m*n, self.token_embed_dim,
device=self.device, dtype=self.dtype)
orders = self.mar.sample_orders(bsz, seq_len=m*n)
if cfg != 1.0:
orders[bsz//2:] = orders[:bsz//2]
indices = list(range(num_iter))
if progress:
indices = tqdm(indices)
# past key values can be prepared outside (usually in multi-turn editing)
if past_key_values is None:
output = self.llm_model(inputs_embeds=inputs_embeds,
attention_mask=None,
position_ids=None,
past_key_values=DynamicCache.from_legacy_cache(),
return_dict=True,
use_cache=True)
past_key_values = output.past_key_values
# generate latents
for step in indices:
cur_tokens = tokens.clone()
x_enc = self.forward_mae_encoder(tokens.view(bsz, m, n, -1),
mask.to(self.dtype),
past_key_values=past_key_values,
# inputs_embeds=inputs_embeds,
attention_mask=attention_mask)
self.curtail_cache(past_key_values, inputs_embeds.shape[1])
z = self.mar.forward_mae_decoder(x_enc, mask.to(self.dtype), image_shape=(m, n), x_con=x_con)
# mask ratio for the next round, following MaskGIT and MAGE.
mask_ratio = np.cos(math.pi / 2. * (step + 1) / num_iter)
mask_len = torch.Tensor([np.floor(m*n * mask_ratio)]).to(self.device)
# masks out at least one for the next iteration
mask_len = torch.maximum(torch.Tensor([1]).to(self.device),
torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len))
# get masking for next iteration and locations to be predicted in this iteration
mask_next = mask_by_order(mask_len[0], orders, bsz, m*n).to(self.device)
if cfg != 1.0:
mask_next[bsz//2:] = mask_next[:bsz//2]
if step >= num_iter - 1:
mask_to_pred = mask[:bsz].bool()
else:
mask_to_pred = torch.logical_xor(mask[:bsz].bool(), mask_next.bool())
mask = mask_next
# if not cfg == 1.0:
# mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0)
# sample token latents for this step
z = z[mask_to_pred.nonzero(as_tuple=True)]
# cfg schedule follow Muse
if cfg_schedule == "linear":
cfg_iter = 1 + (cfg - 1) * (m*n - mask_len[0]) / (m*n)
elif cfg_schedule == "constant":
cfg_iter = cfg
else:
raise NotImplementedError
sampled_token_latent = self.mar.diffloss.sample(z, temperature, cfg_iter).to(self.dtype)
# if not cfg == 1.0:
# sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0) # Remove null class samples
# mask_to_pred, _ = mask_to_pred.chunk(2, dim=0)
cur_tokens[mask_to_pred.nonzero(as_tuple=True)] = sampled_token_latent
if cfg != 1.0:
cur_tokens[bsz//2:] = cur_tokens[:bsz//2]
tokens = cur_tokens.clone()
pred = self.decode(tokens.view(bsz, m, n, -1))
if cfg != 1.0:
pred = pred[:bsz//2]
return pred
@torch.no_grad()
def sample_recon(self,
image,
prompt="Describe this image in details.",
temperature=1.0,
num_iter=32,
cfg=3.0,
progress=True):
x = self.encode(image)
b, m, n, _ = x.shape
_, z_enc = self.extract_visual_feature(x, detach=True)
text_prompt = self.prompt_template['INSTRUCTION'].format(input=prompt)
input_ids = self.tokenizer.encode(text_prompt, add_special_tokens=True, return_tensors='pt')[0].to(self.device)
input_embeds = self.llm.get_input_embeddings()(input_ids.unsqueeze(0))
input_embeds = input_embeds.expand(b, -1, -1)
input_embeds = torch.cat([
input_embeds[:,:3,:],
z_enc.reshape(b, -1, input_embeds.shape[-1]),
input_embeds[:,3:,:],
], dim=1)
text_prompt_uncond = self.prompt_template['INSTRUCTION'].format(input="Generate an image.")
input_ids_uncond = self.tokenizer.encode(text_prompt_uncond, add_special_tokens=True, return_tensors='pt')[0].to(self.device)
# padding uncond to the same length
padded_length = input_embeds.shape[1] - input_ids_uncond.shape[0]
input_ids_uncond = torch.cat([input_ids_uncond,
torch.full((input_embeds.shape[1] - input_ids_uncond.shape[0],),
self.tokenizer.pad_token_id, device=self.device)], dim=0)
input_embeds_uncond = self.llm.get_input_embeddings()(input_ids_uncond.unsqueeze(0))
input_embeds_uncond = input_embeds_uncond.expand(b, -1, -1)
input_embeds = torch.cat([input_embeds,
input_embeds_uncond], dim=0)
# padding attention mask
attention_mask = torch.ones((b * 2, input_embeds.shape[1]),
dtype=torch.bool,
device=self.device)
attention_mask[b:, -padded_length:] = False
recon_image = self.sample(
inputs_embeds=input_embeds,
attention_mask=attention_mask,
num_iter=num_iter,
temperature=temperature,
progress=progress,
cfg=cfg,
image_shape=(m, n)
)
return recon_image
@torch.no_grad()
def sample_relpose(self,
src_image,
viewpoint_params,
input_ids,
num_iter=64,
cfg=3.0,
temperature=1.0,
progress=False):
"""
Generate target image from source image + viewpoint parameters.
Args:
src_image: (batch, 3, H, W) source images in [-1, 1]
viewpoint_params: (batch, 2) [azimuth, elevation] in radians
input_ids: (batch, seq_len) tokenized prompt with IMAGE_TOKEN and view tokens
num_iter: Number of sampling iterations
cfg: Classifier-free guidance scale
temperature: Sampling temperature
progress: Whether to show progress bar
Returns:
generated_images: (batch, 3, H, W) in [-1, 1]
"""
# Encode source image
src_latents = self.encode(src_image) # (batch, m, n, c)
batch_size, m, n, _ = src_latents.shape
# Extract visual features from source (no masking, detached)
_, z_src = self.extract_visual_feature(src_latents, detach=True) # (batch, 1088, hidden_dim)
# === STEP 1: Expand IMAGE_TOKEN from 1 to 1088 tokens ===
# Find IMAGE_TOKEN position (should be at position 3)
image_token_pos = (input_ids == IMAGE_TOKEN_INDEX).nonzero(as_tuple=True)[1][0].item()
# Expand IMAGE_TOKEN in input_ids
expanded_input_ids = torch.cat([
input_ids[:, :image_token_pos],
IMAGE_TOKEN_INDEX * torch.ones(batch_size, 1088, dtype=torch.long, device=input_ids.device),
input_ids[:, image_token_pos + 1:],
], dim=1)
# === STEP 2: Build inputs_embeds (same pattern as training) ===
# Create zero embeddings
inputs_embeds = z_src.new_zeros(batch_size, expanded_input_ids.shape[1], self.llm.config.hidden_size)
# Fill IMAGE_TOKEN positions with source visual features
inputs_embeds[expanded_input_ids == IMAGE_TOKEN_INDEX] = z_src.flatten(0, 1)
# Fill non-IMAGE positions with text embeddings
inputs_embeds[expanded_input_ids != IMAGE_TOKEN_INDEX] = self.llm.get_input_embeddings()(
expanded_input_ids[expanded_input_ids != IMAGE_TOKEN_INDEX]
)
# === STEP 3: Inject viewpoint embeddings ===
inputs_embeds = self.inject_viewpoint_embeddings(
expanded_input_ids,
viewpoint_params,
inputs_embeds
)
# Create unconditional embeddings for CFG
if cfg != 1.0:
# Unconditional: "Generate an image."
uncond_prompt = self.prompt_template['INSTRUCTION'].format(input="Generate an image.")
uncond_ids = self.tokenizer.encode(uncond_prompt, add_special_tokens=True, return_tensors='pt')[0].to(self.device)
# Pad to same length as conditional
padded_length = inputs_embeds.shape[1] - uncond_ids.shape[0]
uncond_ids_padded = torch.cat([
uncond_ids,
torch.full((padded_length,), self.tokenizer.pad_token_id, device=self.device)
], dim=0)
uncond_embeds = self.llm.get_input_embeddings()(uncond_ids_padded.unsqueeze(0))
uncond_embeds = uncond_embeds.expand(batch_size, -1, -1).to(dtype=inputs_embeds.dtype)
# Concatenate conditional and unconditional
inputs_embeds = torch.cat([inputs_embeds, uncond_embeds], dim=0)
# Create attention masks
attention_mask_cond = torch.ones(
(batch_size, inputs_embeds.shape[1] // 2),
dtype=torch.bool,
device=self.device
)
attention_mask_uncond = torch.ones(
(batch_size, inputs_embeds.shape[1] // 2),
dtype=torch.bool,
device=self.device
)
attention_mask_uncond[:, -padded_length:] = False
attention_mask = torch.cat([attention_mask_cond, attention_mask_uncond], dim=0)
else:
attention_mask = torch.ones(
(batch_size, inputs_embeds.shape[1]),
dtype=torch.bool,
device=self.device
)
# Generate target image
generated = self.sample(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
num_iter=num_iter,
temperature=temperature,
progress=progress,
cfg=cfg,
image_shape=(m, n)
)
return generated