Spaces:
Running on Zero
Running on Zero
| import torch | |
| import math | |
| import numpy as np | |
| import torch.nn as nn | |
| from einops import rearrange | |
| from transformers.cache_utils import DynamicCache | |
| from src.builder import BUILDER | |
| from tqdm import tqdm | |
| from torch.nn.utils.rnn import pad_sequence | |
| from xtuner.utils import IMAGE_TOKEN_INDEX | |
| from .viewpoint_mlp import ViewpointTokenMLP | |
| def build_mlp(hidden_size, projector_dim, z_dim): | |
| return nn.Sequential( | |
| nn.Linear(hidden_size, projector_dim), | |
| nn.SiLU(), | |
| nn.Linear(projector_dim, z_dim),) | |
| def mask_by_order(mask_len, order, bsz, seq_len): | |
| masking = torch.zeros(bsz, seq_len, device=order.device) | |
| masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()], | |
| src=torch.ones(bsz, seq_len, device=order.device)).bool() | |
| return masking | |
| class Harmon(nn.Module): | |
| def __init__(self, | |
| vae, | |
| vae_scale, | |
| llm, | |
| mar, | |
| tokenizer, | |
| prompt_template, | |
| # Viewpoint conditioning | |
| use_viewpoint_tokens=False, | |
| num_view_tokens=8, | |
| viewpoint_mlp_config=None): | |
| super().__init__() | |
| # VAE | |
| self.vae = BUILDER.build(vae) | |
| self.vae.requires_grad_(False) | |
| self.vae_scale = vae_scale | |
| # LLM | |
| self.llm = BUILDER.build(llm) | |
| self.tokenizer = BUILDER.build(tokenizer) | |
| self.prompt_template = prompt_template | |
| # MAR | |
| self.mar = BUILDER.build(mar) | |
| # projection layers | |
| self.proj_in = build_mlp(hidden_size=self.mar.encoder_embed_dim, | |
| projector_dim=self.llm.config.hidden_size, | |
| z_dim=self.llm.config.hidden_size) | |
| self.proj_out = build_mlp(hidden_size=self.llm.config.hidden_size, | |
| projector_dim=self.llm.config.hidden_size, | |
| z_dim=self.mar.encoder_embed_dim) | |
| # Viewpoint token system | |
| self.use_viewpoint_tokens = use_viewpoint_tokens | |
| self.num_view_tokens = num_view_tokens if use_viewpoint_tokens else 0 | |
| self.viewpoint_mlp_config = viewpoint_mlp_config | |
| # Note: Actual viewpoint token initialization is deferred to _init_viewpoint_tokens() | |
| # This allows loading pretrained checkpoints before resizing embeddings | |
| if use_viewpoint_tokens: | |
| self.view_tokens = [f"<view_token_{i}>" for i in range(num_view_tokens)] | |
| self.view_token_ids = [] # Will be populated in _init_viewpoint_tokens() | |
| else: | |
| self.view_tokens = [] | |
| self.view_token_ids = [] | |
| def _init_viewpoint_tokens(self): | |
| """ | |
| Initialize viewpoint tokens after loading pretrained checkpoint. | |
| This method should be called after load_state_dict() in HarmonDev. | |
| Note: New view tokens will have randomly initialized embeddings since they | |
| don't exist in the pretrained checkpoint. These will be trained from scratch. | |
| """ | |
| if not self.use_viewpoint_tokens: | |
| return | |
| print(f"\n[VIEWPOINT INIT DEBUG]", flush=True) | |
| print(f" Tokenizer before adding view tokens:", flush=True) | |
| print(f" Type: {type(self.tokenizer)}", flush=True) | |
| print(f" Vocab size: {len(self.tokenizer)}", flush=True) | |
| print(f" LLM embedding size: {self.llm.get_input_embeddings().weight.shape[0]}", flush=True) | |
| # Check for mismatch (pre-existing in Harmon, not a bug introduced by viewpoint tokens) | |
| tokenizer_vocab_size = len(self.tokenizer) | |
| embedding_size = self.llm.get_input_embeddings().weight.shape[0] | |
| if tokenizer_vocab_size != embedding_size: | |
| print(f" NOTE: Vocab size mismatch ({tokenizer_vocab_size} vs {embedding_size})", flush=True) | |
| print(f" This is expected if using a different HF tokenizer version than checkpoint.", flush=True) | |
| print(f" View tokens will use IDs starting from {tokenizer_vocab_size}", flush=True) | |
| # Check if tokens already exist (e.g., from resumed checkpoint) | |
| existing_tokens = [ | |
| token for token in self.view_tokens | |
| if token in self.tokenizer.get_vocab() | |
| ] | |
| if existing_tokens: | |
| print(f"\n [Viewpoint] Found {len(existing_tokens)} existing view tokens in tokenizer", flush=True) | |
| # Tokens already exist, just create ID mapping | |
| self.view_token_ids = [ | |
| self.tokenizer.convert_tokens_to_ids(token) | |
| for token in self.view_tokens | |
| ] | |
| print(f" View token IDs: {self.view_token_ids}", flush=True) | |
| else: | |
| print(f"\n [Viewpoint] Adding {len(self.view_tokens)} new view tokens...", flush=True) | |
| # Add special viewpoint tokens to tokenizer | |
| num_added = self.tokenizer.add_tokens(self.view_tokens, special_tokens=True) | |
| print(f" [Viewpoint] Added {num_added} new view tokens to tokenizer", flush=True) | |
| # Resize embeddings: use LLM's configured vocab_size + new tokens | |
| # This ensures compatibility with pretrained checkpoints | |
| old_vocab_size = self.llm.config.vocab_size | |
| target_vocab_size = old_vocab_size + num_added | |
| self.llm.resize_token_embeddings(target_vocab_size) | |
| print(f" [Viewpoint] Resized LLM embeddings: {old_vocab_size} -> {target_vocab_size}", flush=True) | |
| print(f" [Viewpoint] New tokens randomly initialized", flush=True) | |
| # Create view token ID mapping | |
| self.view_token_ids = [ | |
| self.tokenizer.convert_tokens_to_ids(token) | |
| for token in self.view_tokens | |
| ] | |
| print(f" View token IDs: {self.view_token_ids}", flush=True) | |
| print(f"\n Tokenizer after adding view tokens:", flush=True) | |
| print(f" Vocab size: {len(self.tokenizer)}", flush=True) | |
| print(f" LLM embedding size: {self.llm.get_input_embeddings().weight.shape[0]}", flush=True) | |
| print(f"[END VIEWPOINT INIT DEBUG]\n", flush=True) | |
| # Initialize viewpoint MLP | |
| if self.viewpoint_mlp_config is None: | |
| self.viewpoint_mlp_config = dict( | |
| hidden_dim=1024, | |
| output_dim=self.llm.config.hidden_size, | |
| num_layers=5, | |
| num_view_tokens=self.num_view_tokens, | |
| num_params=2, # Will be set by viewpoint_param_type ('spherical': 2, 'rotation_translation': 9) | |
| num_freqs=16, | |
| viewpoint_param_type='spherical', # Default to spherical mode | |
| ) | |
| self.viewpoint_mlp = ViewpointTokenMLP(**self.viewpoint_mlp_config) | |
| # Move MLP to same device and dtype as model | |
| self.viewpoint_mlp = self.viewpoint_mlp.to(device=self.device, dtype=self.dtype) | |
| print(f" Viewpoint MLP moved to device={self.device}, dtype={self.dtype}", flush=True) | |
| def llm_model(self): | |
| if hasattr(self, '_llm_base_model'): | |
| return self._llm_base_model | |
| return self.llm.model | |
| def device(self): | |
| return self.llm.device | |
| def dtype(self): | |
| return self.llm.dtype | |
| def gen_seq_len(self): | |
| return self.mar.seq_len | |
| def token_embed_dim(self): | |
| return self.vae.embed_dim * (self.mar.patch_size ** 2) | |
| def inject_viewpoint_embeddings(self, input_ids, viewpoint_params, inputs_embeds, valid_mask=None, num_objects=None): | |
| """ | |
| Replace viewpoint token IDs with learned embeddings. | |
| Args: | |
| input_ids: (batch, seq_len) token IDs | |
| viewpoint_params: (batch, num_params) or (batch, num_objects * num_params) viewpoint parameters | |
| Spherical mode: [azimuth, elevation] or [az1, el1, az2, el2] for multi-object | |
| Rotation_translation mode: [rot_9d, trans_3d] or flattened for multi-object | |
| inputs_embeds: (batch, seq_len, hidden_dim) input embeddings to modify | |
| valid_mask: (batch, num_params) or (batch, num_objects * num_params) boolean mask | |
| If None, assumes all parameters are valid | |
| num_objects: (batch,) tensor or None | |
| Number of objects per sample (1 or 2). If None, assumes single object. | |
| Returns: | |
| inputs_embeds: (batch, seq_len, hidden_dim) with viewpoint tokens injected | |
| """ | |
| if not self.use_viewpoint_tokens: | |
| raise ValueError("Viewpoint tokens are not enabled in this model") | |
| batch_size = input_ids.shape[0] | |
| # Generate viewpoint token embeddings (should be in self.dtype from MLP) | |
| # For multi-object: returns (batch, num_objects * num_view_tokens, hidden_dim) | |
| # For single-object: returns (batch, num_view_tokens, hidden_dim) | |
| view_embeddings = self.viewpoint_mlp(viewpoint_params, valid_mask, num_objects=num_objects) | |
| # Determine if multi-object mode | |
| is_multi_object = num_objects is not None and num_objects.max() > 1 | |
| # Find positions of viewpoint tokens and replace | |
| for token_idx, token_id in enumerate(self.view_token_ids): | |
| mask = (input_ids == token_id) # (batch, seq_len) | |
| if mask.any(): | |
| # For each sample in batch, replace the token embedding | |
| for batch_idx in range(batch_size): | |
| if mask[batch_idx].any(): | |
| # Find positions where this token appears | |
| token_positions = mask[batch_idx].nonzero(as_tuple=True)[0] | |
| if is_multi_object: | |
| # Multi-object mode: each token can appear multiple times (once per object) | |
| expected_occurrences = num_objects[batch_idx].item() | |
| if len(token_positions) != expected_occurrences: | |
| raise ValueError( | |
| f"View token {self.view_tokens[token_idx]} appears " | |
| f"{len(token_positions)} times in sequence (expected {expected_occurrences}). " | |
| f"For multi-object mode, each token should appear once per object." | |
| ) | |
| # Replace each occurrence with the correct embedding | |
| for occurrence_idx, pos in enumerate(token_positions): | |
| # Calculate embedding index: object_idx * num_view_tokens + token_idx | |
| embedding_idx = occurrence_idx * self.num_view_tokens + token_idx | |
| inputs_embeds[batch_idx, pos] = view_embeddings[batch_idx, embedding_idx, :] | |
| else: | |
| # Single-object mode: each view token should appear exactly once | |
| if len(token_positions) != 1: | |
| raise ValueError( | |
| f"View token {self.view_tokens[token_idx]} appears " | |
| f"{len(token_positions)} times in sequence (expected 1). " | |
| f"This may indicate the token accidentally appeared in the caption." | |
| ) | |
| # Replace with the corresponding view embedding | |
| inputs_embeds[batch_idx, token_positions] = view_embeddings[batch_idx, token_idx, :] | |
| return inputs_embeds | |
| def encode(self, x): | |
| posterior = self.vae.encode(x) | |
| z = posterior.mode().mul_(self.vae_scale) | |
| z = rearrange(z, 'b c (m p) (n q) -> b m n (c p q)', | |
| p=self.mar.patch_size, q=self.mar.patch_size) | |
| return z | |
| def decode(self, z): | |
| z /= self.vae_scale | |
| z = rearrange(z, 'b m n (c p q) -> b c (m p) (n q)', | |
| p=self.mar.patch_size, q=self.mar.patch_size) | |
| x = self.vae.decode(z) | |
| return x | |
| def prepare_forward_input(self, | |
| x, | |
| inputs_embeds=None, | |
| input_ids=None, | |
| attention_mask=None, | |
| past_key_values=None): | |
| b, l, _ = x.shape | |
| attention_mask = attention_mask.to(device=self.device, dtype=torch.bool) | |
| attention_mask = torch.cat([ | |
| attention_mask, attention_mask.new_ones(b, l) | |
| ], dim=1) | |
| position_ids = torch.cumsum(attention_mask, dim=1) - 1 | |
| position_ids[position_ids < 0] = 0 | |
| # prepare context | |
| if past_key_values is not None: | |
| inputs_embeds = x | |
| position_ids = position_ids[:, -l:] | |
| else: | |
| if inputs_embeds is None: | |
| input_ids = input_ids.to(self.device) | |
| inputs_embeds = self.llm.get_input_embeddings()(input_ids) | |
| inputs_embeds = torch.cat([inputs_embeds, x], dim=1) | |
| return dict(inputs_embeds=inputs_embeds, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values) | |
| def extract_visual_feature(self, x, mask=None, detach=False): | |
| b, m, n, _ = x.shape | |
| x = x.view(b, m*n, -1) | |
| # x: b mn c | |
| if mask is None: | |
| mask = torch.zeros_like(x[..., 0]) | |
| null_embeds = self.mar.fake_latent.expand(x.shape[0], -1) | |
| x_enc = self.mar.forward_mae_encoder(x, mask, null_embeds, image_shape=(m, n)) | |
| z_enc = self.proj_in(x_enc) | |
| # Move buffers to the end of the image sequence | |
| z_enc = torch.cat([ | |
| z_enc[:, self.mar.buffer_size:], | |
| z_enc[:, :self.mar.buffer_size]], dim=1) | |
| if detach: | |
| x_enc = x_enc.detach() | |
| z_enc = z_enc.detach() | |
| return x_enc, z_enc | |
| def forward_mae_encoder(self, x, mask, detach=False, inputs_embeds=None, **context): | |
| b, m, n, _ = x.shape | |
| x_enc, z_enc = self.extract_visual_feature(x, mask=mask, detach=detach) | |
| inputs = self.prepare_forward_input(x=z_enc, inputs_embeds=inputs_embeds, **context) | |
| output = self.llm_model(**inputs, return_dict=True) | |
| z_llm = output.last_hidden_state[:, -z_enc.shape[1]:] | |
| # move buffers back to the start of the image sequence | |
| z_llm = torch.cat([ | |
| z_llm[:, -self.mar.buffer_size:], | |
| z_llm[:, :-self.mar.buffer_size]], dim=1) | |
| # residual learning | |
| x_enc = x_enc + self.proj_out(z_llm) | |
| return x_enc | |
| def curtail_cache(past_key_values, cur_len): | |
| for past_key_values_ in past_key_values: | |
| keys, values = past_key_values_ | |
| keys.data = keys.data[:, :, :cur_len] | |
| values.data = values.data[:, :, :cur_len] | |
| def prepare_text_conditions(self, prompt, cfg_prompt='Generate an image.'): | |
| all_prompts = [self.prompt_template['INSTRUCTION'].format(input=prompt), | |
| self.prompt_template['INSTRUCTION'].format(input=cfg_prompt)] | |
| input_ids = [self.tokenizer.encode(p, add_special_tokens=True, return_tensors='pt')[0] | |
| for p in all_prompts] | |
| valid_lens = [len(input_ids_) for input_ids_ in input_ids] | |
| input_ids = pad_sequence(input_ids, batch_first=True, | |
| padding_value=self.tokenizer.eos_token_id) | |
| attention_mask = torch.zeros_like(input_ids).bool() | |
| for i in range(len(input_ids)): | |
| attention_mask[i, :valid_lens[i]] = True | |
| return dict(input_ids=input_ids.to(self.device), | |
| attention_mask=attention_mask.to(self.device)) | |
| def sample(self, | |
| input_ids=None, inputs_embeds=None, | |
| attention_mask=None, num_iter=64, cfg=1.0, cfg_schedule="constant", temperature=1.0, | |
| progress=False, mask=None, past_key_values=None, image_shape=None, x_con=None, **kwargs): | |
| if inputs_embeds is None and input_ids is not None: | |
| inputs_embeds = self.llm.get_input_embeddings()(input_ids) | |
| bsz = attention_mask.shape[0] | |
| if cfg != 1.0: | |
| assert bsz % 2 == 0 | |
| if image_shape is None: | |
| m = n = int(self.gen_seq_len ** 0.5) | |
| else: | |
| m, n = image_shape | |
| if mask is None: | |
| mask = torch.ones(bsz, m*n, device=self.device, dtype=self.dtype) | |
| else: | |
| mask = mask.view(bsz, m*n) | |
| tokens = torch.zeros(bsz, m*n, self.token_embed_dim, | |
| device=self.device, dtype=self.dtype) | |
| orders = self.mar.sample_orders(bsz, seq_len=m*n) | |
| if cfg != 1.0: | |
| orders[bsz//2:] = orders[:bsz//2] | |
| indices = list(range(num_iter)) | |
| if progress: | |
| indices = tqdm(indices) | |
| # past key values can be prepared outside (usually in multi-turn editing) | |
| if past_key_values is None: | |
| output = self.llm_model(inputs_embeds=inputs_embeds, | |
| attention_mask=None, | |
| position_ids=None, | |
| past_key_values=DynamicCache.from_legacy_cache(), | |
| return_dict=True, | |
| use_cache=True) | |
| past_key_values = output.past_key_values | |
| # generate latents | |
| for step in indices: | |
| cur_tokens = tokens.clone() | |
| x_enc = self.forward_mae_encoder(tokens.view(bsz, m, n, -1), | |
| mask.to(self.dtype), | |
| past_key_values=past_key_values, | |
| # inputs_embeds=inputs_embeds, | |
| attention_mask=attention_mask) | |
| self.curtail_cache(past_key_values, inputs_embeds.shape[1]) | |
| z = self.mar.forward_mae_decoder(x_enc, mask.to(self.dtype), image_shape=(m, n), x_con=x_con) | |
| # mask ratio for the next round, following MaskGIT and MAGE. | |
| mask_ratio = np.cos(math.pi / 2. * (step + 1) / num_iter) | |
| mask_len = torch.Tensor([np.floor(m*n * mask_ratio)]).to(self.device) | |
| # masks out at least one for the next iteration | |
| mask_len = torch.maximum(torch.Tensor([1]).to(self.device), | |
| torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len)) | |
| # get masking for next iteration and locations to be predicted in this iteration | |
| mask_next = mask_by_order(mask_len[0], orders, bsz, m*n).to(self.device) | |
| if cfg != 1.0: | |
| mask_next[bsz//2:] = mask_next[:bsz//2] | |
| if step >= num_iter - 1: | |
| mask_to_pred = mask[:bsz].bool() | |
| else: | |
| mask_to_pred = torch.logical_xor(mask[:bsz].bool(), mask_next.bool()) | |
| mask = mask_next | |
| # if not cfg == 1.0: | |
| # mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0) | |
| # sample token latents for this step | |
| z = z[mask_to_pred.nonzero(as_tuple=True)] | |
| # cfg schedule follow Muse | |
| if cfg_schedule == "linear": | |
| cfg_iter = 1 + (cfg - 1) * (m*n - mask_len[0]) / (m*n) | |
| elif cfg_schedule == "constant": | |
| cfg_iter = cfg | |
| else: | |
| raise NotImplementedError | |
| sampled_token_latent = self.mar.diffloss.sample(z, temperature, cfg_iter).to(self.dtype) | |
| # if not cfg == 1.0: | |
| # sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0) # Remove null class samples | |
| # mask_to_pred, _ = mask_to_pred.chunk(2, dim=0) | |
| cur_tokens[mask_to_pred.nonzero(as_tuple=True)] = sampled_token_latent | |
| if cfg != 1.0: | |
| cur_tokens[bsz//2:] = cur_tokens[:bsz//2] | |
| tokens = cur_tokens.clone() | |
| pred = self.decode(tokens.view(bsz, m, n, -1)) | |
| if cfg != 1.0: | |
| pred = pred[:bsz//2] | |
| return pred | |
| def sample_recon(self, | |
| image, | |
| prompt="Describe this image in details.", | |
| temperature=1.0, | |
| num_iter=32, | |
| cfg=3.0, | |
| progress=True): | |
| x = self.encode(image) | |
| b, m, n, _ = x.shape | |
| _, z_enc = self.extract_visual_feature(x, detach=True) | |
| text_prompt = self.prompt_template['INSTRUCTION'].format(input=prompt) | |
| input_ids = self.tokenizer.encode(text_prompt, add_special_tokens=True, return_tensors='pt')[0].to(self.device) | |
| input_embeds = self.llm.get_input_embeddings()(input_ids.unsqueeze(0)) | |
| input_embeds = input_embeds.expand(b, -1, -1) | |
| input_embeds = torch.cat([ | |
| input_embeds[:,:3,:], | |
| z_enc.reshape(b, -1, input_embeds.shape[-1]), | |
| input_embeds[:,3:,:], | |
| ], dim=1) | |
| text_prompt_uncond = self.prompt_template['INSTRUCTION'].format(input="Generate an image.") | |
| input_ids_uncond = self.tokenizer.encode(text_prompt_uncond, add_special_tokens=True, return_tensors='pt')[0].to(self.device) | |
| # padding uncond to the same length | |
| padded_length = input_embeds.shape[1] - input_ids_uncond.shape[0] | |
| input_ids_uncond = torch.cat([input_ids_uncond, | |
| torch.full((input_embeds.shape[1] - input_ids_uncond.shape[0],), | |
| self.tokenizer.pad_token_id, device=self.device)], dim=0) | |
| input_embeds_uncond = self.llm.get_input_embeddings()(input_ids_uncond.unsqueeze(0)) | |
| input_embeds_uncond = input_embeds_uncond.expand(b, -1, -1) | |
| input_embeds = torch.cat([input_embeds, | |
| input_embeds_uncond], dim=0) | |
| # padding attention mask | |
| attention_mask = torch.ones((b * 2, input_embeds.shape[1]), | |
| dtype=torch.bool, | |
| device=self.device) | |
| attention_mask[b:, -padded_length:] = False | |
| recon_image = self.sample( | |
| inputs_embeds=input_embeds, | |
| attention_mask=attention_mask, | |
| num_iter=num_iter, | |
| temperature=temperature, | |
| progress=progress, | |
| cfg=cfg, | |
| image_shape=(m, n) | |
| ) | |
| return recon_image | |
| def sample_relpose(self, | |
| src_image, | |
| viewpoint_params, | |
| input_ids, | |
| num_iter=64, | |
| cfg=3.0, | |
| temperature=1.0, | |
| progress=False): | |
| """ | |
| Generate target image from source image + viewpoint parameters. | |
| Args: | |
| src_image: (batch, 3, H, W) source images in [-1, 1] | |
| viewpoint_params: (batch, 2) [azimuth, elevation] in radians | |
| input_ids: (batch, seq_len) tokenized prompt with IMAGE_TOKEN and view tokens | |
| num_iter: Number of sampling iterations | |
| cfg: Classifier-free guidance scale | |
| temperature: Sampling temperature | |
| progress: Whether to show progress bar | |
| Returns: | |
| generated_images: (batch, 3, H, W) in [-1, 1] | |
| """ | |
| # Encode source image | |
| src_latents = self.encode(src_image) # (batch, m, n, c) | |
| batch_size, m, n, _ = src_latents.shape | |
| # Extract visual features from source (no masking, detached) | |
| _, z_src = self.extract_visual_feature(src_latents, detach=True) # (batch, 1088, hidden_dim) | |
| # === STEP 1: Expand IMAGE_TOKEN from 1 to 1088 tokens === | |
| # Find IMAGE_TOKEN position (should be at position 3) | |
| image_token_pos = (input_ids == IMAGE_TOKEN_INDEX).nonzero(as_tuple=True)[1][0].item() | |
| # Expand IMAGE_TOKEN in input_ids | |
| expanded_input_ids = torch.cat([ | |
| input_ids[:, :image_token_pos], | |
| IMAGE_TOKEN_INDEX * torch.ones(batch_size, 1088, dtype=torch.long, device=input_ids.device), | |
| input_ids[:, image_token_pos + 1:], | |
| ], dim=1) | |
| # === STEP 2: Build inputs_embeds (same pattern as training) === | |
| # Create zero embeddings | |
| inputs_embeds = z_src.new_zeros(batch_size, expanded_input_ids.shape[1], self.llm.config.hidden_size) | |
| # Fill IMAGE_TOKEN positions with source visual features | |
| inputs_embeds[expanded_input_ids == IMAGE_TOKEN_INDEX] = z_src.flatten(0, 1) | |
| # Fill non-IMAGE positions with text embeddings | |
| inputs_embeds[expanded_input_ids != IMAGE_TOKEN_INDEX] = self.llm.get_input_embeddings()( | |
| expanded_input_ids[expanded_input_ids != IMAGE_TOKEN_INDEX] | |
| ) | |
| # === STEP 3: Inject viewpoint embeddings === | |
| inputs_embeds = self.inject_viewpoint_embeddings( | |
| expanded_input_ids, | |
| viewpoint_params, | |
| inputs_embeds | |
| ) | |
| # Create unconditional embeddings for CFG | |
| if cfg != 1.0: | |
| # Unconditional: "Generate an image." | |
| uncond_prompt = self.prompt_template['INSTRUCTION'].format(input="Generate an image.") | |
| uncond_ids = self.tokenizer.encode(uncond_prompt, add_special_tokens=True, return_tensors='pt')[0].to(self.device) | |
| # Pad to same length as conditional | |
| padded_length = inputs_embeds.shape[1] - uncond_ids.shape[0] | |
| uncond_ids_padded = torch.cat([ | |
| uncond_ids, | |
| torch.full((padded_length,), self.tokenizer.pad_token_id, device=self.device) | |
| ], dim=0) | |
| uncond_embeds = self.llm.get_input_embeddings()(uncond_ids_padded.unsqueeze(0)) | |
| uncond_embeds = uncond_embeds.expand(batch_size, -1, -1).to(dtype=inputs_embeds.dtype) | |
| # Concatenate conditional and unconditional | |
| inputs_embeds = torch.cat([inputs_embeds, uncond_embeds], dim=0) | |
| # Create attention masks | |
| attention_mask_cond = torch.ones( | |
| (batch_size, inputs_embeds.shape[1] // 2), | |
| dtype=torch.bool, | |
| device=self.device | |
| ) | |
| attention_mask_uncond = torch.ones( | |
| (batch_size, inputs_embeds.shape[1] // 2), | |
| dtype=torch.bool, | |
| device=self.device | |
| ) | |
| attention_mask_uncond[:, -padded_length:] = False | |
| attention_mask = torch.cat([attention_mask_cond, attention_mask_uncond], dim=0) | |
| else: | |
| attention_mask = torch.ones( | |
| (batch_size, inputs_embeds.shape[1]), | |
| dtype=torch.bool, | |
| device=self.device | |
| ) | |
| # Generate target image | |
| generated = self.sample( | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=attention_mask, | |
| num_iter=num_iter, | |
| temperature=temperature, | |
| progress=progress, | |
| cfg=cfg, | |
| image_shape=(m, n) | |
| ) | |
| return generated | |