# -------------------------------------------------------- # NVIDIA # Copyright (c) 2025 NVIDIA # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- import time from typing import List, Optional, Tuple, Union import numpy as np import torch from torch import nn from torch.nn import CrossEntropyLoss from transformers.generation import GenerationMixin from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_utils import PreTrainedModel from transformers.utils import add_start_docstrings, is_flash_attn_2_available, logging from peft import LoraConfig, get_peft_model from .configuration_locateanything import LocateAnythingConfig from .modeling_qwen2 import Qwen2ForCausalLM from .modeling_vit import MoonVitPretrainedModel from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM from .mask_sdpa_utils import * from .mask_magi_utils import * from .configuration_qwen2 import Qwen2Config from .generate_utils import ( sample_tokens, handle_pattern, get_token_ids_from_config, ) logger = logging.get_logger(__name__) LOCATEANYTHING_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`LocateAnythingConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ @add_start_docstrings( "The bare LocateAnything Model outputting raw hidden-states without any specific head on top.", LOCATEANYTHING_START_DOCSTRING, ) class LocateAnythingPreTrainedModel(PreTrainedModel): config_class = LocateAnythingConfig base_model_prefix = "model" main_input_name = 'input_ids' supports_gradient_checkpointing = True _no_split_modules = ["Qwen2DecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_cache_class = True _supports_static_cache = True _supports_quantized_cache = True _supports_sdpa = True @classmethod def _autoset_attn_implementation(cls, config, *args, **kwargs): if getattr(config, '_attn_implementation', None) == 'magi': return config return super()._autoset_attn_implementation(config, *args, **kwargs) def _check_and_adjust_attn_implementation(self, attn_implementation, is_init_check=False): if attn_implementation == "magi": return "magi" return super()._check_and_adjust_attn_implementation(attn_implementation, is_init_check) def _init_weights(self, module): std = getattr(self.config, 'initializer_range', None) or self.config.text_config.initializer_range if isinstance(module, (nn.Linear, nn.Conv2d)): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() class LocateAnythingForConditionalGeneration(LocateAnythingPreTrainedModel, GenerationMixin): config_class = LocateAnythingConfig def __init__(self, config: LocateAnythingConfig, vision_model=None, language_model=None): super().__init__(config) self.template = config.template self.mlp_checkpoint = config.mlp_checkpoint logger.info(f'mlp_checkpoint: {self.mlp_checkpoint}') if vision_model is not None: self.vision_model = vision_model else: if config.vision_config.model_type == 'moonvit': vision_attn_impl = getattr(config.vision_config, '_attn_implementation', None) or 'flash_attention_2' if vision_attn_impl == 'flash_attention_2' and not is_flash_attn_2_available(): logger.warning_once( "flash_attn is not available for MoonViT inference; falling back to sdpa." ) vision_attn_impl = 'sdpa' config.vision_config._attn_implementation = vision_attn_impl self.vision_model = MoonVitPretrainedModel(config.vision_config) else: raise ValueError(f'Unsupported vision model type: {config.vision_config.model_type}. Only moonvit is supported.') text_attn_impl = ( getattr(config.text_config, '_attn_implementation', None) or getattr(config, '_attn_implementation', None) or 'magi' ) config.text_config._attn_implementation = text_attn_impl if language_model is not None: self.language_model = language_model else: if config.text_config.architectures[0] == 'Qwen2ForCausalLM': self.language_model = Qwen2ForCausalLM(config.text_config) elif config.text_config.architectures[0] == 'Qwen3ForCausalLM': self.language_model = Qwen3ForCausalLM(config.text_config) else: raise ValueError(f'Unsupported language model architecture: {config.text_config.architectures[0]}. Only Qwen2ForCausalLM and Qwen3ForCausalLM are supported.') vit_hidden_size = config.vision_config.hidden_size llm_hidden_size = config.text_config.hidden_size # MLP for moonvit (without pixel_shuffle_back, direct mapping) self.mlp1 = nn.Sequential( nn.LayerNorm(vit_hidden_size*4), nn.Linear(vit_hidden_size*4, llm_hidden_size), nn.GELU(), nn.Linear(llm_hidden_size, llm_hidden_size) ) self.image_token_index = config.image_token_index self.neftune_alpha = None if config.use_backbone_lora: self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=2 * config.use_backbone_lora) self.use_llm_lora = config.use_llm_lora if config.use_llm_lora: self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora) self.token_ids = get_token_ids_from_config(config) # Set _no_split_modules dynamically based on the actual LLM architecture arch = config.text_config.architectures[0] if hasattr(config.text_config, 'architectures') and config.text_config.architectures else 'Qwen2ForCausalLM' if 'Qwen3' in arch: self._no_split_modules = ["Qwen3DecoderLayer"] else: self._no_split_modules = ["Qwen2DecoderLayer"] def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05): lora_config = LoraConfig( r=r, target_modules=['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', 'mlp.fc1', 'mlp.fc2'], lora_alpha=lora_alpha, lora_dropout=lora_dropout, ) self.vision_model = get_peft_model(self.vision_model, lora_config) self.vision_model.print_trainable_parameters() def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05): lora_config = LoraConfig( r=r, target_modules=['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj', 'mlp.gate_proj', 'mlp.down_proj', 'mlp.up_proj'], lora_alpha=lora_alpha, lora_dropout=lora_dropout, task_type='CAUSAL_LM' ) self.language_model = get_peft_model(self.language_model, lora_config) self.language_model.enable_input_require_grads() self.language_model.print_trainable_parameters() self.use_llm_lora = True def forward( self, pixel_values: List[torch.FloatTensor], input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, image_grid_hws: Optional[torch.Tensor] = None, image_flags: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict input_embeds = self.language_model.get_input_embeddings()(input_ids) has_images = image_flags is not None and image_flags.sum() > 0 vit_embeds = self.extract_feature(pixel_values, image_grid_hws) B, N, C = input_embeds.shape input_embeds = input_embeds.reshape(B * N, C) if has_images: filtered_vit_embeds = [] idx = 0 for flag in image_flags: flag_val = flag.item() if flag_val != 0: filtered_vit_embeds.extend(vit_embeds[idx:idx + flag_val]) idx += flag_val else: idx += 1 vit_embeds = filtered_vit_embeds vit_embeds = torch.cat(vit_embeds, dim=0) vit_embeds = self.mlp1(vit_embeds) input_ids = input_ids.reshape(B * N) selected = (input_ids == self.image_token_index) input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:selected.sum()] else: if vit_embeds: vit_embeds = torch.cat(vit_embeds, dim=0) vit_embeds = self.mlp1(vit_embeds) input_ids = input_ids.reshape(B * N) selected = (input_ids == self.image_token_index) if selected.sum() > 0: input_embeds[selected] = vit_embeds[:selected.sum()] input_embeds = input_embeds.reshape(B, N, C) outputs = self.language_model( inputs_embeds=input_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) logits = outputs.logits loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def extract_feature(self, pixel_values, image_grid_hws): vit_embeds = self.vision_model(pixel_values=pixel_values, grid_hws=image_grid_hws) return vit_embeds def get_input_embeddings(self): return self.language_model.get_input_embeddings() def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) def get_output_embeddings(self): return self.language_model.get_output_embeddings() def set_output_embeddings(self, new_embeddings): self.language_model.set_output_embeddings(new_embeddings) def set_decoder(self, decoder): self.language_model.set_decoder(decoder) def get_decoder(self): return self.language_model.get_decoder() @torch.no_grad() def generate( self, pixel_values: Optional[torch.FloatTensor] = None, input_ids: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, visual_features: Optional[torch.FloatTensor] = None, image_grid_hws: Optional[torch.Tensor] = None, tokenizer = None, n_future_tokens: int = 6, **generate_kwargs, ) -> torch.LongTensor: verbose = generate_kwargs.pop('verbose', False) start_time = time.time() prefill_time = None pixel_values = pixel_values.to(self.language_model.dtype) # Convert numpy array to tensor if needed if isinstance(image_grid_hws, np.ndarray): image_grid_hws = torch.from_numpy(image_grid_hws).to(pixel_values.device, dtype=torch.int32) batch_size, seq_len = input_ids.shape assert batch_size == 1, 'only batch size = 1 is supported now' assert generate_kwargs.get('use_cache', False), "Only use_cache=True is supported." generated = input_ids.clone() total_gen_length = min(tokenizer.model_max_length, seq_len + generate_kwargs.get('max_new_tokens', 2048)) iter_round = 0 past_key_values = None # Extract visual features once before the loop if visual_features is not None: vit_embeds = visual_features elif pixel_values is not None: vit_embeds = self.extract_feature(pixel_values, image_grid_hws) else: vit_embeds = None if image_grid_hws is not None: vit_embeds = torch.cat(vit_embeds, dim=0) vit_embeds = self.mlp1(vit_embeds) # ==================== Generation Mode ==================== # 'fast' : MTP only, never fall back to AR # 'slow' : AR only, pure auto-regressive decoding # 'hybrid' : MTP first, fall back to AR on error, switch back on box_end generation_mode = generate_kwargs.get('generation_mode', 'hybrid') assert generation_mode in ('fast', 'slow', 'hybrid'), \ f"Unsupported generation_mode='{generation_mode}'. Use 'fast', 'slow', or 'hybrid'." sampling_history = [] use_mtp = generation_mode in ('fast', 'hybrid') switch_to_ar_count = 0 # Pre-allocate mask tokens and position ids default_mask_token_id = self.token_ids['default_mask_token_id'] pre_mask_tokens = torch.full( (batch_size, n_future_tokens - 1), default_mask_token_id, dtype=generated.dtype, device=generated.device ) max_possible_len = total_gen_length + n_future_tokens full_position_ids = torch.arange(0, max_possible_len, device=generated.device).unsqueeze(0) def _prepare_inputs_in_mtp(generated): generated_with_mask = torch.cat( ( generated, generated[:, -1].unsqueeze(1), pre_mask_tokens ), dim=1 ) # [batch_size, seq_len + 1 + n_future_tokens - 1] # Update pe for kvcache start_idx = past_key_values[0][0].size(2) if past_key_values is not None else 0 position_ids = full_position_ids[:, start_idx : generated_with_mask.size(1)].clone() position_ids[0, -n_future_tokens:] -= 1 prepare_inputs = self.language_model.prepare_inputs_for_generation( generated_with_mask, past_key_values, None, inputs_embeds=None, use_cache=True, position_ids=position_ids ) return prepare_inputs def _prepare_input_in_ar(generated): start_idx = past_key_values[0][0].size(2) if past_key_values is not None else 0 position_ids = full_position_ids[:, start_idx : generated.size(1)] prepare_inputs = self.language_model.prepare_inputs_for_generation( generated, past_key_values, None, inputs_embeds=None, use_cache=True, position_ids=position_ids ) return prepare_inputs def _sample_token_in_mtp(generated, outputs): """Sample tokens using MTP (Multi-Token Prediction) mode.""" next_token_logits = outputs.logits[:, -n_future_tokens:, :] probs, confidence, x0, box_avg = sample_tokens( next_token_logits, generated, self.token_ids, keep_k=5, **generate_kwargs ) is_box_empty = (box_avg[0] == 0).all() new_tokens = x0[0] if is_box_empty else box_avg[0] out_pattern = handle_pattern(new_tokens, self.token_ids, generation_mode) out_type = out_pattern['type'] out_token = torch.tensor(out_pattern['tokens'], dtype=x0.dtype, device=x0.device) return out_type, out_token def _sample_token_in_ar(generated, outputs): """Sample a single token using AR (Auto-Regressive) mode.""" next_token_logits = outputs.logits[:, -1:, :] probs, confidence, x0, _ = sample_tokens( next_token_logits, generated, self.token_ids, **generate_kwargs ) out_token = x0[0] out_type = 'continue_ar' token_val = out_token[0].item() box_end_token_id = self.token_ids['box_end_token_id'] coord_start_token_id = self.token_ids['coord_start_token_id'] coord_end_token_id = self.token_ids['coord_end_token_id'] none_token_id = self.token_ids['none_token_id'] im_end_token_id = self.token_ids['im_end_token_id'] if generation_mode == 'hybrid': # Hybrid AR phase: detect box boundaries to switch back to MTP if token_val == box_end_token_id: out_type = 'box_end_ar' elif coord_start_token_id <= token_val <= coord_end_token_id or token_val == none_token_id: out_type = 'coord_ar' else: out_type = 'im_end' else: # Slow mode: pure AR, only stop on im_end if token_val == im_end_token_id: out_type = 'im_end' return out_type, out_token # Generate loop while generated.size(1) < total_gen_length: iter_round += 1 # Step 1: Prepare inputs if use_mtp: prepare_inputs = _prepare_inputs_in_mtp(generated) else: prepare_inputs = _prepare_input_in_ar(generated) if iter_round == 1: prepare_inputs.update({ 'visual_features': vit_embeds, 'image_token_index': self.config.image_token_index, }) # Step 2: Model forward & update KV cache with torch.no_grad(): outputs = self.language_model(**prepare_inputs) past_key_values = tuple( (kv[0][:, :, :generated.shape[1], :], kv[1][:, :, :generated.shape[1], :]) for kv in outputs.past_key_values ) # Step 3: Sample tokens if use_mtp: out_type, out_token = _sample_token_in_mtp(generated, outputs) else: out_type, out_token = _sample_token_in_ar(generated, outputs) if verbose: sampling_history.append(('ar' if 'ar' in out_type else 'mtp', tokenizer.decode(out_token, skip_special_tokens=False))) generated = torch.cat([generated, out_token.unsqueeze(0)], dim=1) # Step 4: Mode switching & termination if out_type == 'im_end': break if generation_mode == 'hybrid': if out_type == 'error_box': use_mtp = False switch_to_ar_count += 1 elif out_type == 'box_end_ar': use_mtp = True # fast mode: use_mtp stays True always # slow mode: use_mtp stays False always if prefill_time is None: prefill_time = time.time() - start_time # Decode and return generated_ids = generated[:, seq_len:] response = tokenizer.batch_decode(generated_ids, skip_special_tokens=False) if verbose: end_time = time.time() num_tokens = generated_ids.size(1) num_boxes = response[0].count("") total_time = end_time - start_time out_info = f"\nStatistic Info, num_tokens={num_tokens}; " + \ f"generate_time(s)={total_time:.4f}; " + \ f"tps={(num_tokens / total_time):.4f}; " + \ f"forward_step={iter_round}; " + \ f"num_boxes={num_boxes}; " + \ f"bps={(num_boxes / total_time):.4f}; " + \ f"prefill_time={(prefill_time):.4f}; " + \ f"switch_to_ar={switch_to_ar_count}\n" print(out_info) return response[0], sampling_history, out_info return response[0]