Image-Text-to-Text
Transformers
Safetensors
English
locateanything
feature-extraction
nvidia
eagle
vision
object-detection
grounding
conversational
custom_code
Instructions to use nvidia/LocateAnything-3B with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use nvidia/LocateAnything-3B with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("image-text-to-text", model="nvidia/LocateAnything-3B", trust_remote_code=True) messages = [ { "role": "user", "content": [ {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/p-blog/candy.JPG"}, {"type": "text", "text": "What animal is on the candy?"} ] }, ] pipe(text=messages)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("nvidia/LocateAnything-3B", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use nvidia/LocateAnything-3B with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "nvidia/LocateAnything-3B" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "nvidia/LocateAnything-3B", "messages": [ { "role": "user", "content": [ { "type": "text", "text": "Describe this image in one sentence." }, { "type": "image_url", "image_url": { "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" } } ] } ] }'Use Docker
docker model run hf.co/nvidia/LocateAnything-3B
- SGLang
How to use nvidia/LocateAnything-3B with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "nvidia/LocateAnything-3B" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "nvidia/LocateAnything-3B", "messages": [ { "role": "user", "content": [ { "type": "text", "text": "Describe this image in one sentence." }, { "type": "image_url", "image_url": { "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" } } ] } ] }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "nvidia/LocateAnything-3B" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "nvidia/LocateAnything-3B", "messages": [ { "role": "user", "content": [ { "type": "text", "text": "Describe this image in one sentence." }, { "type": "image_url", "image_url": { "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" } } ] } ] }' - Docker Model Runner
How to use nvidia/LocateAnything-3B with Docker Model Runner:
docker model run hf.co/nvidia/LocateAnything-3B
| # -------------------------------------------------------- | |
| # NVIDIA | |
| # Copyright (c) 2025 NVIDIA | |
| # Licensed under The MIT License [see LICENSE for details] | |
| # -------------------------------------------------------- | |
| import time | |
| from typing import List, Optional, Tuple, Union | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| from torch.nn import CrossEntropyLoss | |
| from transformers.generation import GenerationMixin | |
| from transformers.modeling_outputs import CausalLMOutputWithPast | |
| from transformers.modeling_utils import PreTrainedModel | |
| from transformers.utils import add_start_docstrings, is_flash_attn_2_available, logging | |
| from peft import LoraConfig, get_peft_model | |
| from .configuration_locateanything import LocateAnythingConfig | |
| from .modeling_qwen2 import Qwen2ForCausalLM | |
| from .modeling_vit import MoonVitPretrainedModel | |
| from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM | |
| from .mask_sdpa_utils import * | |
| from .mask_magi_utils import * | |
| from .configuration_qwen2 import Qwen2Config | |
| from .generate_utils import ( | |
| sample_tokens, | |
| handle_pattern, | |
| get_token_ids_from_config, | |
| ) | |
| logger = logging.get_logger(__name__) | |
| LOCATEANYTHING_START_DOCSTRING = r""" | |
| This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the | |
| library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads | |
| etc.) | |
| This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. | |
| Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage | |
| and behavior. | |
| Parameters: | |
| config ([`LocateAnythingConfig`]): | |
| Model configuration class with all the parameters of the model. Initializing with a config file does not | |
| load the weights associated with the model, only the configuration. Check out the | |
| [`~PreTrainedModel.from_pretrained`] method to load the model weights. | |
| """ | |
| class LocateAnythingPreTrainedModel(PreTrainedModel): | |
| config_class = LocateAnythingConfig | |
| base_model_prefix = "model" | |
| main_input_name = 'input_ids' | |
| supports_gradient_checkpointing = True | |
| _no_split_modules = ["Qwen2DecoderLayer"] | |
| _skip_keys_device_placement = "past_key_values" | |
| _supports_flash_attn_2 = True | |
| _supports_cache_class = True | |
| _supports_static_cache = True | |
| _supports_quantized_cache = True | |
| _supports_sdpa = True | |
| def _autoset_attn_implementation(cls, config, *args, **kwargs): | |
| if getattr(config, '_attn_implementation', None) == 'magi': | |
| return config | |
| return super()._autoset_attn_implementation(config, *args, **kwargs) | |
| def _check_and_adjust_attn_implementation(self, attn_implementation, is_init_check=False): | |
| if attn_implementation == "magi": | |
| return "magi" | |
| return super()._check_and_adjust_attn_implementation(attn_implementation, is_init_check) | |
| def _init_weights(self, module): | |
| std = getattr(self.config, 'initializer_range', None) or self.config.text_config.initializer_range | |
| if isinstance(module, (nn.Linear, nn.Conv2d)): | |
| module.weight.data.normal_(mean=0.0, std=std) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.Embedding): | |
| module.weight.data.normal_(mean=0.0, std=std) | |
| if module.padding_idx is not None: | |
| module.weight.data[module.padding_idx].zero_() | |
| class LocateAnythingForConditionalGeneration(LocateAnythingPreTrainedModel, GenerationMixin): | |
| config_class = LocateAnythingConfig | |
| def __init__(self, config: LocateAnythingConfig, vision_model=None, language_model=None): | |
| super().__init__(config) | |
| self.template = config.template | |
| self.mlp_checkpoint = config.mlp_checkpoint | |
| logger.info(f'mlp_checkpoint: {self.mlp_checkpoint}') | |
| if vision_model is not None: | |
| self.vision_model = vision_model | |
| else: | |
| if config.vision_config.model_type == 'moonvit': | |
| vision_attn_impl = getattr(config.vision_config, '_attn_implementation', None) or 'flash_attention_2' | |
| if vision_attn_impl == 'flash_attention_2' and not is_flash_attn_2_available(): | |
| logger.warning_once( | |
| "flash_attn is not available for MoonViT inference; falling back to sdpa." | |
| ) | |
| vision_attn_impl = 'sdpa' | |
| config.vision_config._attn_implementation = vision_attn_impl | |
| self.vision_model = MoonVitPretrainedModel(config.vision_config) | |
| else: | |
| raise ValueError(f'Unsupported vision model type: {config.vision_config.model_type}. Only moonvit is supported.') | |
| text_attn_impl = ( | |
| getattr(config.text_config, '_attn_implementation', None) | |
| or getattr(config, '_attn_implementation', None) | |
| or 'magi' | |
| ) | |
| config.text_config._attn_implementation = text_attn_impl | |
| if language_model is not None: | |
| self.language_model = language_model | |
| else: | |
| if config.text_config.architectures[0] == 'Qwen2ForCausalLM': | |
| self.language_model = Qwen2ForCausalLM(config.text_config) | |
| elif config.text_config.architectures[0] == 'Qwen3ForCausalLM': | |
| self.language_model = Qwen3ForCausalLM(config.text_config) | |
| else: | |
| raise ValueError(f'Unsupported language model architecture: {config.text_config.architectures[0]}. Only Qwen2ForCausalLM and Qwen3ForCausalLM are supported.') | |
| vit_hidden_size = config.vision_config.hidden_size | |
| llm_hidden_size = config.text_config.hidden_size | |
| # MLP for moonvit (without pixel_shuffle_back, direct mapping) | |
| self.mlp1 = nn.Sequential( | |
| nn.LayerNorm(vit_hidden_size*4), | |
| nn.Linear(vit_hidden_size*4, llm_hidden_size), | |
| nn.GELU(), | |
| nn.Linear(llm_hidden_size, llm_hidden_size) | |
| ) | |
| self.image_token_index = config.image_token_index | |
| self.neftune_alpha = None | |
| if config.use_backbone_lora: | |
| self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=2 * config.use_backbone_lora) | |
| self.use_llm_lora = config.use_llm_lora | |
| if config.use_llm_lora: | |
| self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora) | |
| self.token_ids = get_token_ids_from_config(config) | |
| # Set _no_split_modules dynamically based on the actual LLM architecture | |
| arch = config.text_config.architectures[0] if hasattr(config.text_config, 'architectures') and config.text_config.architectures else 'Qwen2ForCausalLM' | |
| if 'Qwen3' in arch: | |
| self._no_split_modules = ["Qwen3DecoderLayer"] | |
| else: | |
| self._no_split_modules = ["Qwen2DecoderLayer"] | |
| def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05): | |
| lora_config = LoraConfig( | |
| r=r, | |
| target_modules=['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.out_proj', | |
| 'mlp.fc1', 'mlp.fc2'], | |
| lora_alpha=lora_alpha, | |
| lora_dropout=lora_dropout, | |
| ) | |
| self.vision_model = get_peft_model(self.vision_model, lora_config) | |
| self.vision_model.print_trainable_parameters() | |
| def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05): | |
| lora_config = LoraConfig( | |
| r=r, | |
| target_modules=['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj', | |
| 'mlp.gate_proj', 'mlp.down_proj', 'mlp.up_proj'], | |
| lora_alpha=lora_alpha, | |
| lora_dropout=lora_dropout, | |
| task_type='CAUSAL_LM' | |
| ) | |
| self.language_model = get_peft_model(self.language_model, lora_config) | |
| self.language_model.enable_input_require_grads() | |
| self.language_model.print_trainable_parameters() | |
| self.use_llm_lora = True | |
| def forward( | |
| self, | |
| pixel_values: List[torch.FloatTensor], | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| image_grid_hws: Optional[torch.Tensor] = None, | |
| image_flags: Optional[torch.Tensor] = None, | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple, CausalLMOutputWithPast]: | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| input_embeds = self.language_model.get_input_embeddings()(input_ids) | |
| has_images = image_flags is not None and image_flags.sum() > 0 | |
| vit_embeds = self.extract_feature(pixel_values, image_grid_hws) | |
| B, N, C = input_embeds.shape | |
| input_embeds = input_embeds.reshape(B * N, C) | |
| if has_images: | |
| filtered_vit_embeds = [] | |
| idx = 0 | |
| for flag in image_flags: | |
| flag_val = flag.item() | |
| if flag_val != 0: | |
| filtered_vit_embeds.extend(vit_embeds[idx:idx + flag_val]) | |
| idx += flag_val | |
| else: | |
| idx += 1 | |
| vit_embeds = filtered_vit_embeds | |
| vit_embeds = torch.cat(vit_embeds, dim=0) | |
| vit_embeds = self.mlp1(vit_embeds) | |
| input_ids = input_ids.reshape(B * N) | |
| selected = (input_ids == self.image_token_index) | |
| input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:selected.sum()] | |
| else: | |
| if vit_embeds: | |
| vit_embeds = torch.cat(vit_embeds, dim=0) | |
| vit_embeds = self.mlp1(vit_embeds) | |
| input_ids = input_ids.reshape(B * N) | |
| selected = (input_ids == self.image_token_index) | |
| if selected.sum() > 0: | |
| input_embeds[selected] = vit_embeds[:selected.sum()] | |
| input_embeds = input_embeds.reshape(B, N, C) | |
| outputs = self.language_model( | |
| inputs_embeds=input_embeds, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| ) | |
| logits = outputs.logits | |
| loss = None | |
| if labels is not None: | |
| # Shift so that tokens < n predict n | |
| shift_logits = logits[..., :-1, :].contiguous() | |
| shift_labels = labels[..., 1:].contiguous() | |
| # Flatten the tokens | |
| loss_fct = CrossEntropyLoss() | |
| shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size) | |
| shift_labels = shift_labels.view(-1) | |
| # Enable model parallelism | |
| shift_labels = shift_labels.to(shift_logits.device) | |
| loss = loss_fct(shift_logits, shift_labels) | |
| if not return_dict: | |
| output = (logits,) + outputs[1:] | |
| return (loss,) + output if loss is not None else output | |
| return CausalLMOutputWithPast( | |
| loss=loss, | |
| logits=logits, | |
| past_key_values=outputs.past_key_values, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |
| def extract_feature(self, pixel_values, image_grid_hws): | |
| vit_embeds = self.vision_model(pixel_values=pixel_values, grid_hws=image_grid_hws) | |
| return vit_embeds | |
| def get_input_embeddings(self): | |
| return self.language_model.get_input_embeddings() | |
| def set_input_embeddings(self, value): | |
| self.language_model.set_input_embeddings(value) | |
| def get_output_embeddings(self): | |
| return self.language_model.get_output_embeddings() | |
| def set_output_embeddings(self, new_embeddings): | |
| self.language_model.set_output_embeddings(new_embeddings) | |
| def set_decoder(self, decoder): | |
| self.language_model.set_decoder(decoder) | |
| def get_decoder(self): | |
| return self.language_model.get_decoder() | |
| def generate( | |
| self, | |
| pixel_values: Optional[torch.FloatTensor] = None, | |
| input_ids: Optional[torch.FloatTensor] = None, | |
| attention_mask: Optional[torch.LongTensor] = None, | |
| visual_features: Optional[torch.FloatTensor] = None, | |
| image_grid_hws: Optional[torch.Tensor] = None, | |
| tokenizer = None, | |
| n_future_tokens: int = 6, | |
| **generate_kwargs, | |
| ) -> torch.LongTensor: | |
| verbose = generate_kwargs.pop('verbose', False) | |
| start_time = time.time() | |
| prefill_time = None | |
| pixel_values = pixel_values.to(self.language_model.dtype) | |
| # Convert numpy array to tensor if needed | |
| if isinstance(image_grid_hws, np.ndarray): | |
| image_grid_hws = torch.from_numpy(image_grid_hws).to(pixel_values.device, dtype=torch.int32) | |
| batch_size, seq_len = input_ids.shape | |
| assert batch_size == 1, 'only batch size = 1 is supported now' | |
| assert generate_kwargs.get('use_cache', False), "Only use_cache=True is supported." | |
| generated = input_ids.clone() | |
| total_gen_length = min(tokenizer.model_max_length, seq_len + generate_kwargs.get('max_new_tokens', 2048)) | |
| iter_round = 0 | |
| past_key_values = None | |
| # Extract visual features once before the loop | |
| if visual_features is not None: | |
| vit_embeds = visual_features | |
| elif pixel_values is not None: | |
| vit_embeds = self.extract_feature(pixel_values, image_grid_hws) | |
| else: | |
| vit_embeds = None | |
| if image_grid_hws is not None: | |
| vit_embeds = torch.cat(vit_embeds, dim=0) | |
| vit_embeds = self.mlp1(vit_embeds) | |
| # ==================== Generation Mode ==================== | |
| # 'fast' : MTP only, never fall back to AR | |
| # 'slow' : AR only, pure auto-regressive decoding | |
| # 'hybrid' : MTP first, fall back to AR on error, switch back on box_end | |
| generation_mode = generate_kwargs.get('generation_mode', 'hybrid') | |
| assert generation_mode in ('fast', 'slow', 'hybrid'), \ | |
| f"Unsupported generation_mode='{generation_mode}'. Use 'fast', 'slow', or 'hybrid'." | |
| sampling_history = [] | |
| use_mtp = generation_mode in ('fast', 'hybrid') | |
| switch_to_ar_count = 0 | |
| # Pre-allocate mask tokens and position ids | |
| default_mask_token_id = self.token_ids['default_mask_token_id'] | |
| pre_mask_tokens = torch.full( | |
| (batch_size, n_future_tokens - 1), | |
| default_mask_token_id, | |
| dtype=generated.dtype, | |
| device=generated.device | |
| ) | |
| max_possible_len = total_gen_length + n_future_tokens | |
| full_position_ids = torch.arange(0, max_possible_len, device=generated.device).unsqueeze(0) | |
| def _prepare_inputs_in_mtp(generated): | |
| generated_with_mask = torch.cat( | |
| ( | |
| generated, | |
| generated[:, -1].unsqueeze(1), | |
| pre_mask_tokens | |
| ), | |
| dim=1 | |
| ) # [batch_size, seq_len + 1 + n_future_tokens - 1] | |
| # Update pe for kvcache | |
| start_idx = past_key_values[0][0].size(2) if past_key_values is not None else 0 | |
| position_ids = full_position_ids[:, start_idx : generated_with_mask.size(1)].clone() | |
| position_ids[0, -n_future_tokens:] -= 1 | |
| prepare_inputs = self.language_model.prepare_inputs_for_generation( | |
| generated_with_mask, | |
| past_key_values, | |
| None, | |
| inputs_embeds=None, | |
| use_cache=True, | |
| position_ids=position_ids | |
| ) | |
| return prepare_inputs | |
| def _prepare_input_in_ar(generated): | |
| start_idx = past_key_values[0][0].size(2) if past_key_values is not None else 0 | |
| position_ids = full_position_ids[:, start_idx : generated.size(1)] | |
| prepare_inputs = self.language_model.prepare_inputs_for_generation( | |
| generated, | |
| past_key_values, | |
| None, | |
| inputs_embeds=None, | |
| use_cache=True, | |
| position_ids=position_ids | |
| ) | |
| return prepare_inputs | |
| def _sample_token_in_mtp(generated, outputs): | |
| """Sample tokens using MTP (Multi-Token Prediction) mode.""" | |
| next_token_logits = outputs.logits[:, -n_future_tokens:, :] | |
| probs, confidence, x0, box_avg = sample_tokens( | |
| next_token_logits, generated, self.token_ids, keep_k=5, **generate_kwargs | |
| ) | |
| is_box_empty = (box_avg[0] == 0).all() | |
| new_tokens = x0[0] if is_box_empty else box_avg[0] | |
| out_pattern = handle_pattern(new_tokens, self.token_ids, generation_mode) | |
| out_type = out_pattern['type'] | |
| out_token = torch.tensor(out_pattern['tokens'], dtype=x0.dtype, device=x0.device) | |
| return out_type, out_token | |
| def _sample_token_in_ar(generated, outputs): | |
| """Sample a single token using AR (Auto-Regressive) mode.""" | |
| next_token_logits = outputs.logits[:, -1:, :] | |
| probs, confidence, x0, _ = sample_tokens( | |
| next_token_logits, generated, self.token_ids, **generate_kwargs | |
| ) | |
| out_token = x0[0] | |
| out_type = 'continue_ar' | |
| token_val = out_token[0].item() | |
| box_end_token_id = self.token_ids['box_end_token_id'] | |
| coord_start_token_id = self.token_ids['coord_start_token_id'] | |
| coord_end_token_id = self.token_ids['coord_end_token_id'] | |
| none_token_id = self.token_ids['none_token_id'] | |
| im_end_token_id = self.token_ids['im_end_token_id'] | |
| if generation_mode == 'hybrid': | |
| # Hybrid AR phase: detect box boundaries to switch back to MTP | |
| if token_val == box_end_token_id: | |
| out_type = 'box_end_ar' | |
| elif coord_start_token_id <= token_val <= coord_end_token_id or token_val == none_token_id: | |
| out_type = 'coord_ar' | |
| else: | |
| out_type = 'im_end' | |
| else: | |
| # Slow mode: pure AR, only stop on im_end | |
| if token_val == im_end_token_id: | |
| out_type = 'im_end' | |
| return out_type, out_token | |
| # Generate loop | |
| while generated.size(1) < total_gen_length: | |
| iter_round += 1 | |
| # Step 1: Prepare inputs | |
| if use_mtp: | |
| prepare_inputs = _prepare_inputs_in_mtp(generated) | |
| else: | |
| prepare_inputs = _prepare_input_in_ar(generated) | |
| if iter_round == 1: | |
| prepare_inputs.update({ | |
| 'visual_features': vit_embeds, | |
| 'image_token_index': self.config.image_token_index, | |
| }) | |
| # Step 2: Model forward & update KV cache | |
| with torch.no_grad(): | |
| outputs = self.language_model(**prepare_inputs) | |
| past_key_values = tuple( | |
| (kv[0][:, :, :generated.shape[1], :], kv[1][:, :, :generated.shape[1], :]) | |
| for kv in outputs.past_key_values | |
| ) | |
| # Step 3: Sample tokens | |
| if use_mtp: | |
| out_type, out_token = _sample_token_in_mtp(generated, outputs) | |
| else: | |
| out_type, out_token = _sample_token_in_ar(generated, outputs) | |
| if verbose: | |
| sampling_history.append(('ar' if 'ar' in out_type else 'mtp', tokenizer.decode(out_token, skip_special_tokens=False))) | |
| generated = torch.cat([generated, out_token.unsqueeze(0)], dim=1) | |
| # Step 4: Mode switching & termination | |
| if out_type == 'im_end': | |
| break | |
| if generation_mode == 'hybrid': | |
| if out_type == 'error_box': | |
| use_mtp = False | |
| switch_to_ar_count += 1 | |
| elif out_type == 'box_end_ar': | |
| use_mtp = True | |
| # fast mode: use_mtp stays True always | |
| # slow mode: use_mtp stays False always | |
| if prefill_time is None: | |
| prefill_time = time.time() - start_time | |
| # Decode and return | |
| generated_ids = generated[:, seq_len:] | |
| response = tokenizer.batch_decode(generated_ids, skip_special_tokens=False) | |
| if verbose: | |
| end_time = time.time() | |
| num_tokens = generated_ids.size(1) | |
| num_boxes = response[0].count("<box>") | |
| total_time = end_time - start_time | |
| out_info = f"\nStatistic Info, num_tokens={num_tokens}; " + \ | |
| f"generate_time(s)={total_time:.4f}; " + \ | |
| f"tps={(num_tokens / total_time):.4f}; " + \ | |
| f"forward_step={iter_round}; " + \ | |
| f"num_boxes={num_boxes}; " + \ | |
| f"bps={(num_boxes / total_time):.4f}; " + \ | |
| f"prefill_time={(prefill_time):.4f}; " + \ | |
| f"switch_to_ar={switch_to_ar_count}\n" | |
| print(out_info) | |
| return response[0], sampling_history, out_info | |
| return response[0] |