import torch import math import numpy as np import torch.nn as nn from einops import rearrange from transformers.cache_utils import DynamicCache from src.builder import BUILDER from tqdm import tqdm from torch.nn.utils.rnn import pad_sequence from xtuner.utils import IMAGE_TOKEN_INDEX from .viewpoint_mlp import ViewpointTokenMLP def build_mlp(hidden_size, projector_dim, z_dim): return nn.Sequential( nn.Linear(hidden_size, projector_dim), nn.SiLU(), nn.Linear(projector_dim, z_dim),) def mask_by_order(mask_len, order, bsz, seq_len): masking = torch.zeros(bsz, seq_len, device=order.device) masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()], src=torch.ones(bsz, seq_len, device=order.device)).bool() return masking class Harmon(nn.Module): def __init__(self, vae, vae_scale, llm, mar, tokenizer, prompt_template, # Viewpoint conditioning use_viewpoint_tokens=False, num_view_tokens=8, viewpoint_mlp_config=None): super().__init__() # VAE self.vae = BUILDER.build(vae) self.vae.requires_grad_(False) self.vae_scale = vae_scale # LLM self.llm = BUILDER.build(llm) self.tokenizer = BUILDER.build(tokenizer) self.prompt_template = prompt_template # MAR self.mar = BUILDER.build(mar) # projection layers self.proj_in = build_mlp(hidden_size=self.mar.encoder_embed_dim, projector_dim=self.llm.config.hidden_size, z_dim=self.llm.config.hidden_size) self.proj_out = build_mlp(hidden_size=self.llm.config.hidden_size, projector_dim=self.llm.config.hidden_size, z_dim=self.mar.encoder_embed_dim) # Viewpoint token system self.use_viewpoint_tokens = use_viewpoint_tokens self.num_view_tokens = num_view_tokens if use_viewpoint_tokens else 0 self.viewpoint_mlp_config = viewpoint_mlp_config # Note: Actual viewpoint token initialization is deferred to _init_viewpoint_tokens() # This allows loading pretrained checkpoints before resizing embeddings if use_viewpoint_tokens: self.view_tokens = [f"" for i in range(num_view_tokens)] self.view_token_ids = [] # Will be populated in _init_viewpoint_tokens() else: self.view_tokens = [] self.view_token_ids = [] def _init_viewpoint_tokens(self): """ Initialize viewpoint tokens after loading pretrained checkpoint. This method should be called after load_state_dict() in HarmonDev. Note: New view tokens will have randomly initialized embeddings since they don't exist in the pretrained checkpoint. These will be trained from scratch. """ if not self.use_viewpoint_tokens: return print(f"\n[VIEWPOINT INIT DEBUG]", flush=True) print(f" Tokenizer before adding view tokens:", flush=True) print(f" Type: {type(self.tokenizer)}", flush=True) print(f" Vocab size: {len(self.tokenizer)}", flush=True) print(f" LLM embedding size: {self.llm.get_input_embeddings().weight.shape[0]}", flush=True) # Check for mismatch (pre-existing in Harmon, not a bug introduced by viewpoint tokens) tokenizer_vocab_size = len(self.tokenizer) embedding_size = self.llm.get_input_embeddings().weight.shape[0] if tokenizer_vocab_size != embedding_size: print(f" NOTE: Vocab size mismatch ({tokenizer_vocab_size} vs {embedding_size})", flush=True) print(f" This is expected if using a different HF tokenizer version than checkpoint.", flush=True) print(f" View tokens will use IDs starting from {tokenizer_vocab_size}", flush=True) # Check if tokens already exist (e.g., from resumed checkpoint) existing_tokens = [ token for token in self.view_tokens if token in self.tokenizer.get_vocab() ] if existing_tokens: print(f"\n [Viewpoint] Found {len(existing_tokens)} existing view tokens in tokenizer", flush=True) # Tokens already exist, just create ID mapping self.view_token_ids = [ self.tokenizer.convert_tokens_to_ids(token) for token in self.view_tokens ] print(f" View token IDs: {self.view_token_ids}", flush=True) else: print(f"\n [Viewpoint] Adding {len(self.view_tokens)} new view tokens...", flush=True) # Add special viewpoint tokens to tokenizer num_added = self.tokenizer.add_tokens(self.view_tokens, special_tokens=True) print(f" [Viewpoint] Added {num_added} new view tokens to tokenizer", flush=True) # Resize embeddings: use LLM's configured vocab_size + new tokens # This ensures compatibility with pretrained checkpoints old_vocab_size = self.llm.config.vocab_size target_vocab_size = old_vocab_size + num_added self.llm.resize_token_embeddings(target_vocab_size) print(f" [Viewpoint] Resized LLM embeddings: {old_vocab_size} -> {target_vocab_size}", flush=True) print(f" [Viewpoint] New tokens randomly initialized", flush=True) # Create view token ID mapping self.view_token_ids = [ self.tokenizer.convert_tokens_to_ids(token) for token in self.view_tokens ] print(f" View token IDs: {self.view_token_ids}", flush=True) print(f"\n Tokenizer after adding view tokens:", flush=True) print(f" Vocab size: {len(self.tokenizer)}", flush=True) print(f" LLM embedding size: {self.llm.get_input_embeddings().weight.shape[0]}", flush=True) print(f"[END VIEWPOINT INIT DEBUG]\n", flush=True) # Initialize viewpoint MLP if self.viewpoint_mlp_config is None: self.viewpoint_mlp_config = dict( hidden_dim=1024, output_dim=self.llm.config.hidden_size, num_layers=5, num_view_tokens=self.num_view_tokens, num_params=2, # Will be set by viewpoint_param_type ('spherical': 2, 'rotation_translation': 9) num_freqs=16, viewpoint_param_type='spherical', # Default to spherical mode ) self.viewpoint_mlp = ViewpointTokenMLP(**self.viewpoint_mlp_config) # Move MLP to same device and dtype as model self.viewpoint_mlp = self.viewpoint_mlp.to(device=self.device, dtype=self.dtype) print(f" Viewpoint MLP moved to device={self.device}, dtype={self.dtype}", flush=True) @property def llm_model(self): if hasattr(self, '_llm_base_model'): return self._llm_base_model return self.llm.model @property def device(self): return self.llm.device @property def dtype(self): return self.llm.dtype @property def gen_seq_len(self): return self.mar.seq_len @property def token_embed_dim(self): return self.vae.embed_dim * (self.mar.patch_size ** 2) def inject_viewpoint_embeddings(self, input_ids, viewpoint_params, inputs_embeds, valid_mask=None, num_objects=None): """ Replace viewpoint token IDs with learned embeddings. Args: input_ids: (batch, seq_len) token IDs viewpoint_params: (batch, num_params) or (batch, num_objects * num_params) viewpoint parameters Spherical mode: [azimuth, elevation] or [az1, el1, az2, el2] for multi-object Rotation_translation mode: [rot_9d, trans_3d] or flattened for multi-object inputs_embeds: (batch, seq_len, hidden_dim) input embeddings to modify valid_mask: (batch, num_params) or (batch, num_objects * num_params) boolean mask If None, assumes all parameters are valid num_objects: (batch,) tensor or None Number of objects per sample (1 or 2). If None, assumes single object. Returns: inputs_embeds: (batch, seq_len, hidden_dim) with viewpoint tokens injected """ if not self.use_viewpoint_tokens: raise ValueError("Viewpoint tokens are not enabled in this model") batch_size = input_ids.shape[0] # Generate viewpoint token embeddings (should be in self.dtype from MLP) # For multi-object: returns (batch, num_objects * num_view_tokens, hidden_dim) # For single-object: returns (batch, num_view_tokens, hidden_dim) view_embeddings = self.viewpoint_mlp(viewpoint_params, valid_mask, num_objects=num_objects) # Determine if multi-object mode is_multi_object = num_objects is not None and num_objects.max() > 1 # Find positions of viewpoint tokens and replace for token_idx, token_id in enumerate(self.view_token_ids): mask = (input_ids == token_id) # (batch, seq_len) if mask.any(): # For each sample in batch, replace the token embedding for batch_idx in range(batch_size): if mask[batch_idx].any(): # Find positions where this token appears token_positions = mask[batch_idx].nonzero(as_tuple=True)[0] if is_multi_object: # Multi-object mode: each token can appear multiple times (once per object) expected_occurrences = num_objects[batch_idx].item() if len(token_positions) != expected_occurrences: raise ValueError( f"View token {self.view_tokens[token_idx]} appears " f"{len(token_positions)} times in sequence (expected {expected_occurrences}). " f"For multi-object mode, each token should appear once per object." ) # Replace each occurrence with the correct embedding for occurrence_idx, pos in enumerate(token_positions): # Calculate embedding index: object_idx * num_view_tokens + token_idx embedding_idx = occurrence_idx * self.num_view_tokens + token_idx inputs_embeds[batch_idx, pos] = view_embeddings[batch_idx, embedding_idx, :] else: # Single-object mode: each view token should appear exactly once if len(token_positions) != 1: raise ValueError( f"View token {self.view_tokens[token_idx]} appears " f"{len(token_positions)} times in sequence (expected 1). " f"This may indicate the token accidentally appeared in the caption." ) # Replace with the corresponding view embedding inputs_embeds[batch_idx, token_positions] = view_embeddings[batch_idx, token_idx, :] return inputs_embeds @torch.no_grad() def encode(self, x): posterior = self.vae.encode(x) z = posterior.mode().mul_(self.vae_scale) z = rearrange(z, 'b c (m p) (n q) -> b m n (c p q)', p=self.mar.patch_size, q=self.mar.patch_size) return z @torch.no_grad() def decode(self, z): z /= self.vae_scale z = rearrange(z, 'b m n (c p q) -> b c (m p) (n q)', p=self.mar.patch_size, q=self.mar.patch_size) x = self.vae.decode(z) return x def prepare_forward_input(self, x, inputs_embeds=None, input_ids=None, attention_mask=None, past_key_values=None): b, l, _ = x.shape attention_mask = attention_mask.to(device=self.device, dtype=torch.bool) attention_mask = torch.cat([ attention_mask, attention_mask.new_ones(b, l) ], dim=1) position_ids = torch.cumsum(attention_mask, dim=1) - 1 position_ids[position_ids < 0] = 0 # prepare context if past_key_values is not None: inputs_embeds = x position_ids = position_ids[:, -l:] else: if inputs_embeds is None: input_ids = input_ids.to(self.device) inputs_embeds = self.llm.get_input_embeddings()(input_ids) inputs_embeds = torch.cat([inputs_embeds, x], dim=1) return dict(inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values) def extract_visual_feature(self, x, mask=None, detach=False): b, m, n, _ = x.shape x = x.view(b, m*n, -1) # x: b mn c if mask is None: mask = torch.zeros_like(x[..., 0]) null_embeds = self.mar.fake_latent.expand(x.shape[0], -1) x_enc = self.mar.forward_mae_encoder(x, mask, null_embeds, image_shape=(m, n)) z_enc = self.proj_in(x_enc) # Move buffers to the end of the image sequence z_enc = torch.cat([ z_enc[:, self.mar.buffer_size:], z_enc[:, :self.mar.buffer_size]], dim=1) if detach: x_enc = x_enc.detach() z_enc = z_enc.detach() return x_enc, z_enc def forward_mae_encoder(self, x, mask, detach=False, inputs_embeds=None, **context): b, m, n, _ = x.shape x_enc, z_enc = self.extract_visual_feature(x, mask=mask, detach=detach) inputs = self.prepare_forward_input(x=z_enc, inputs_embeds=inputs_embeds, **context) output = self.llm_model(**inputs, return_dict=True) z_llm = output.last_hidden_state[:, -z_enc.shape[1]:] # move buffers back to the start of the image sequence z_llm = torch.cat([ z_llm[:, -self.mar.buffer_size:], z_llm[:, :-self.mar.buffer_size]], dim=1) # residual learning x_enc = x_enc + self.proj_out(z_llm) return x_enc @staticmethod def curtail_cache(past_key_values, cur_len): for past_key_values_ in past_key_values: keys, values = past_key_values_ keys.data = keys.data[:, :, :cur_len] values.data = values.data[:, :, :cur_len] @torch.no_grad() def prepare_text_conditions(self, prompt, cfg_prompt='Generate an image.'): all_prompts = [self.prompt_template['INSTRUCTION'].format(input=prompt), self.prompt_template['INSTRUCTION'].format(input=cfg_prompt)] input_ids = [self.tokenizer.encode(p, add_special_tokens=True, return_tensors='pt')[0] for p in all_prompts] valid_lens = [len(input_ids_) for input_ids_ in input_ids] input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.eos_token_id) attention_mask = torch.zeros_like(input_ids).bool() for i in range(len(input_ids)): attention_mask[i, :valid_lens[i]] = True return dict(input_ids=input_ids.to(self.device), attention_mask=attention_mask.to(self.device)) @torch.no_grad() def sample(self, input_ids=None, inputs_embeds=None, attention_mask=None, num_iter=64, cfg=1.0, cfg_schedule="constant", temperature=1.0, progress=False, mask=None, past_key_values=None, image_shape=None, x_con=None, **kwargs): if inputs_embeds is None and input_ids is not None: inputs_embeds = self.llm.get_input_embeddings()(input_ids) bsz = attention_mask.shape[0] if cfg != 1.0: assert bsz % 2 == 0 if image_shape is None: m = n = int(self.gen_seq_len ** 0.5) else: m, n = image_shape if mask is None: mask = torch.ones(bsz, m*n, device=self.device, dtype=self.dtype) else: mask = mask.view(bsz, m*n) tokens = torch.zeros(bsz, m*n, self.token_embed_dim, device=self.device, dtype=self.dtype) orders = self.mar.sample_orders(bsz, seq_len=m*n) if cfg != 1.0: orders[bsz//2:] = orders[:bsz//2] indices = list(range(num_iter)) if progress: indices = tqdm(indices) # past key values can be prepared outside (usually in multi-turn editing) if past_key_values is None: output = self.llm_model(inputs_embeds=inputs_embeds, attention_mask=None, position_ids=None, past_key_values=DynamicCache.from_legacy_cache(), return_dict=True, use_cache=True) past_key_values = output.past_key_values # generate latents for step in indices: cur_tokens = tokens.clone() x_enc = self.forward_mae_encoder(tokens.view(bsz, m, n, -1), mask.to(self.dtype), past_key_values=past_key_values, # inputs_embeds=inputs_embeds, attention_mask=attention_mask) self.curtail_cache(past_key_values, inputs_embeds.shape[1]) z = self.mar.forward_mae_decoder(x_enc, mask.to(self.dtype), image_shape=(m, n), x_con=x_con) # mask ratio for the next round, following MaskGIT and MAGE. mask_ratio = np.cos(math.pi / 2. * (step + 1) / num_iter) mask_len = torch.Tensor([np.floor(m*n * mask_ratio)]).to(self.device) # masks out at least one for the next iteration mask_len = torch.maximum(torch.Tensor([1]).to(self.device), torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len)) # get masking for next iteration and locations to be predicted in this iteration mask_next = mask_by_order(mask_len[0], orders, bsz, m*n).to(self.device) if cfg != 1.0: mask_next[bsz//2:] = mask_next[:bsz//2] if step >= num_iter - 1: mask_to_pred = mask[:bsz].bool() else: mask_to_pred = torch.logical_xor(mask[:bsz].bool(), mask_next.bool()) mask = mask_next # if not cfg == 1.0: # mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0) # sample token latents for this step z = z[mask_to_pred.nonzero(as_tuple=True)] # cfg schedule follow Muse if cfg_schedule == "linear": cfg_iter = 1 + (cfg - 1) * (m*n - mask_len[0]) / (m*n) elif cfg_schedule == "constant": cfg_iter = cfg else: raise NotImplementedError sampled_token_latent = self.mar.diffloss.sample(z, temperature, cfg_iter).to(self.dtype) # if not cfg == 1.0: # sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0) # Remove null class samples # mask_to_pred, _ = mask_to_pred.chunk(2, dim=0) cur_tokens[mask_to_pred.nonzero(as_tuple=True)] = sampled_token_latent if cfg != 1.0: cur_tokens[bsz//2:] = cur_tokens[:bsz//2] tokens = cur_tokens.clone() pred = self.decode(tokens.view(bsz, m, n, -1)) if cfg != 1.0: pred = pred[:bsz//2] return pred @torch.no_grad() def sample_recon(self, image, prompt="Describe this image in details.", temperature=1.0, num_iter=32, cfg=3.0, progress=True): x = self.encode(image) b, m, n, _ = x.shape _, z_enc = self.extract_visual_feature(x, detach=True) text_prompt = self.prompt_template['INSTRUCTION'].format(input=prompt) input_ids = self.tokenizer.encode(text_prompt, add_special_tokens=True, return_tensors='pt')[0].to(self.device) input_embeds = self.llm.get_input_embeddings()(input_ids.unsqueeze(0)) input_embeds = input_embeds.expand(b, -1, -1) input_embeds = torch.cat([ input_embeds[:,:3,:], z_enc.reshape(b, -1, input_embeds.shape[-1]), input_embeds[:,3:,:], ], dim=1) text_prompt_uncond = self.prompt_template['INSTRUCTION'].format(input="Generate an image.") input_ids_uncond = self.tokenizer.encode(text_prompt_uncond, add_special_tokens=True, return_tensors='pt')[0].to(self.device) # padding uncond to the same length padded_length = input_embeds.shape[1] - input_ids_uncond.shape[0] input_ids_uncond = torch.cat([input_ids_uncond, torch.full((input_embeds.shape[1] - input_ids_uncond.shape[0],), self.tokenizer.pad_token_id, device=self.device)], dim=0) input_embeds_uncond = self.llm.get_input_embeddings()(input_ids_uncond.unsqueeze(0)) input_embeds_uncond = input_embeds_uncond.expand(b, -1, -1) input_embeds = torch.cat([input_embeds, input_embeds_uncond], dim=0) # padding attention mask attention_mask = torch.ones((b * 2, input_embeds.shape[1]), dtype=torch.bool, device=self.device) attention_mask[b:, -padded_length:] = False recon_image = self.sample( inputs_embeds=input_embeds, attention_mask=attention_mask, num_iter=num_iter, temperature=temperature, progress=progress, cfg=cfg, image_shape=(m, n) ) return recon_image @torch.no_grad() def sample_relpose(self, src_image, viewpoint_params, input_ids, num_iter=64, cfg=3.0, temperature=1.0, progress=False): """ Generate target image from source image + viewpoint parameters. Args: src_image: (batch, 3, H, W) source images in [-1, 1] viewpoint_params: (batch, 2) [azimuth, elevation] in radians input_ids: (batch, seq_len) tokenized prompt with IMAGE_TOKEN and view tokens num_iter: Number of sampling iterations cfg: Classifier-free guidance scale temperature: Sampling temperature progress: Whether to show progress bar Returns: generated_images: (batch, 3, H, W) in [-1, 1] """ # Encode source image src_latents = self.encode(src_image) # (batch, m, n, c) batch_size, m, n, _ = src_latents.shape # Extract visual features from source (no masking, detached) _, z_src = self.extract_visual_feature(src_latents, detach=True) # (batch, 1088, hidden_dim) # === STEP 1: Expand IMAGE_TOKEN from 1 to 1088 tokens === # Find IMAGE_TOKEN position (should be at position 3) image_token_pos = (input_ids == IMAGE_TOKEN_INDEX).nonzero(as_tuple=True)[1][0].item() # Expand IMAGE_TOKEN in input_ids expanded_input_ids = torch.cat([ input_ids[:, :image_token_pos], IMAGE_TOKEN_INDEX * torch.ones(batch_size, 1088, dtype=torch.long, device=input_ids.device), input_ids[:, image_token_pos + 1:], ], dim=1) # === STEP 2: Build inputs_embeds (same pattern as training) === # Create zero embeddings inputs_embeds = z_src.new_zeros(batch_size, expanded_input_ids.shape[1], self.llm.config.hidden_size) # Fill IMAGE_TOKEN positions with source visual features inputs_embeds[expanded_input_ids == IMAGE_TOKEN_INDEX] = z_src.flatten(0, 1) # Fill non-IMAGE positions with text embeddings inputs_embeds[expanded_input_ids != IMAGE_TOKEN_INDEX] = self.llm.get_input_embeddings()( expanded_input_ids[expanded_input_ids != IMAGE_TOKEN_INDEX] ) # === STEP 3: Inject viewpoint embeddings === inputs_embeds = self.inject_viewpoint_embeddings( expanded_input_ids, viewpoint_params, inputs_embeds ) # Create unconditional embeddings for CFG if cfg != 1.0: # Unconditional: "Generate an image." uncond_prompt = self.prompt_template['INSTRUCTION'].format(input="Generate an image.") uncond_ids = self.tokenizer.encode(uncond_prompt, add_special_tokens=True, return_tensors='pt')[0].to(self.device) # Pad to same length as conditional padded_length = inputs_embeds.shape[1] - uncond_ids.shape[0] uncond_ids_padded = torch.cat([ uncond_ids, torch.full((padded_length,), self.tokenizer.pad_token_id, device=self.device) ], dim=0) uncond_embeds = self.llm.get_input_embeddings()(uncond_ids_padded.unsqueeze(0)) uncond_embeds = uncond_embeds.expand(batch_size, -1, -1).to(dtype=inputs_embeds.dtype) # Concatenate conditional and unconditional inputs_embeds = torch.cat([inputs_embeds, uncond_embeds], dim=0) # Create attention masks attention_mask_cond = torch.ones( (batch_size, inputs_embeds.shape[1] // 2), dtype=torch.bool, device=self.device ) attention_mask_uncond = torch.ones( (batch_size, inputs_embeds.shape[1] // 2), dtype=torch.bool, device=self.device ) attention_mask_uncond[:, -padded_length:] = False attention_mask = torch.cat([attention_mask_cond, attention_mask_uncond], dim=0) else: attention_mask = torch.ones( (batch_size, inputs_embeds.shape[1]), dtype=torch.bool, device=self.device ) # Generate target image generated = self.sample( inputs_embeds=inputs_embeds, attention_mask=attention_mask, num_iter=num_iter, temperature=temperature, progress=progress, cfg=cfg, image_shape=(m, n) ) return generated