| from typing import List, Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
|
|
| import transformers |
| from transformers import GenerationConfig |
| from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig |
| SPEECH_TOKEN_INDEX = -200 |
|
|
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| from transformers.generation.utils import GenerateOutput |
|
|
| from ..ola_arch import OlaMetaModel, OlaMetaForCausalLM |
| from transformers import Qwen3Config, Qwen3Model, Qwen3ForCausalLM |
| from .conversation import get_conv_template |
| from ola.constants import IGNORE_INDEX |
|
|
| def tokenizer_speech_token(prompt, tokenizer, speech_token_index=SPEECH_TOKEN_INDEX, return_tensors=None): |
| """Tokenize prompt with speech tokens, similar to OLA's implementation""" |
| prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<speech>')] |
|
|
| def insert_separator(X, sep): |
| return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] |
|
|
| input_ids = [] |
| offset = 0 |
| if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: |
| offset = 1 |
| input_ids.append(prompt_chunks[0][0]) |
|
|
| for x in insert_separator(prompt_chunks, [speech_token_index] * (offset + 1)): |
| input_ids.extend(x[offset:]) |
|
|
| if return_tensors is not None: |
| if return_tensors == 'pt': |
| return torch.tensor(input_ids, dtype=torch.long) |
| raise ValueError(f'Unsupported tensor type: {return_tensors}') |
| return input_ids |
|
|
|
|
| class Qwen3Model(Qwen3Model): |
| def __init__(self, config: Qwen3Config, llm_config: Qwen3Config): |
| |
| super(Qwen3Model, self).__init__(llm_config) |
|
|
| class OlaConfigQwen3(Qwen3Config, PretrainedConfig): |
| model_type = "ola_internvl" |
|
|
|
|
| class OlaQwen3Model(OlaMetaModel, Qwen3Model): |
| config_class = OlaConfigQwen3 |
| |
| def __init__(self, config: Qwen3Config): |
| |
| super(OlaQwen3Model, self).__init__(config, config.llm_config) |
|
|
|
|
| class OlaQwen3ForCausalLM(Qwen3ForCausalLM, OlaMetaForCausalLM): |
| config_class = OlaConfigQwen3 |
| |
| |
| |
| |
| |
| |
| |
| def __init__(self, config): |
| super(Qwen3ForCausalLM, self).__init__(config) |
| |
| config.rope_scaling = None |
| |
| self.model = OlaQwen3Model(config) |
| self.vocab_size = config.vocab_size |
| |
| self.ps_version = config.ps_version |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| self.template = "plm_v" |
| self.select_layer = config.select_layer |
| self.conv_template = get_conv_template(self.template) |
| self.system_message = self.conv_template.system_message |
| self.num_image_token = int((config.vision_config.image_size // config.vision_config.patch_size) ** 2 * (config.downsample_ratio ** 2)) |
| self.downsample_ratio = config.downsample_ratio |
| |
| self.post_init() |
| |
|
|
| def get_model(self): |
| return self.model |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| speech: Optional[torch.FloatTensor] = None, |
| speech_lengths: Optional[torch.LongTensor] = None, |
| speech_chunks: Optional[torch.LongTensor] = None, |
| speech_wav: Optional[torch.FloatTensor] = None, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| images_highres: Optional[List[torch.FloatTensor]] = None, |
| image_sizes: Optional[List[List[int]]] = None, |
| modalities: Optional[List[str]] = ["image"], |
| image_flags: Optional[torch.LongTensor] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| ) -> Union[Tuple, CausalLMOutputWithPast]: |
| |
| if inputs_embeds is None: |
| ( |
| input_ids, |
| position_ids, |
| attention_mask, |
| past_key_values, |
| inputs_embeds, |
| labels |
| ) = self.prepare_inputs_labels_for_speech_text_for_internvl( |
| input_ids, |
| position_ids, |
| attention_mask, |
| past_key_values, |
| labels, |
| speech, |
| speech_lengths, |
| speech_chunks, |
| speech_wav, |
| modalities, |
| ) |
| |
| if labels is None: |
| return super().forward( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict |
| ) |
| else: |
| return self.forward_llm_efficient( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| labels=labels, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict |
| ) |
| |
|
|
| def forward_llm_efficient(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict): |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| print(f"Debug - Input embeddings range: {inputs_embeds.min().item()} to {inputs_embeds.max().item()}") |
| print(f"Debug - Input embeddings has nan: {torch.isnan(inputs_embeds).any().item()}") |
| print(f"Debug - Input embeddings has inf: {torch.isinf(inputs_embeds).any().item()}") |
| |
| |
| outputs = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| hidden_states = outputs[0] |
| |
| |
| print(f"Debug - Raw hidden states range: {hidden_states.min().item()} to {hidden_states.max().item()}") |
| print(f"Debug - Raw hidden states has nan: {torch.isnan(hidden_states).any().item()}") |
| print(f"Debug - Raw hidden states has inf: {torch.isinf(hidden_states).any().item()}") |
| hidden_dim = hidden_states.size(-1) |
| shift_labels = labels[..., 1:].contiguous().reshape(-1) |
| shift_hidden_states = hidden_states[..., :-1, :].contiguous().reshape(-1, hidden_dim) |
| assert shift_labels.size(0) == shift_hidden_states.size(0) |
| mask = shift_labels != IGNORE_INDEX |
| |
| |
| print(f"Debug - Total tokens: {shift_labels.size(0)}") |
| print(f"Debug - Valid tokens: {mask.float().sum().item()}") |
| print(f"Debug - Ignored tokens: {(~mask).float().sum().item()}") |
| print(f"Debug - Label range: {shift_labels.min().item()} to {shift_labels.max().item()}") |
| |
| assert mask.float().sum() > 0, f"No valid tokens found! Total: {shift_labels.size(0)}, Valid: {mask.float().sum().item()}" |
| shift_labels = shift_labels[mask] |
| shift_hidden_states = shift_hidden_states[mask, :] |
| |
| print(f"Debug - After filtering: {shift_labels.size(0)} tokens") |
| print(f"Debug - Hidden states shape: {shift_hidden_states.shape}") |
| print(f"Debug - Hidden states range: {shift_hidden_states.min().item()} to {shift_hidden_states.max().item()}") |
| print(f"Debug - Hidden states has nan: {torch.isnan(shift_hidden_states).any().item()}") |
| print(f"Debug - Hidden states has inf: {torch.isinf(shift_hidden_states).any().item()}") |
| |
| |
| print(f"Debug - lm_head weight shape: {self.lm_head.weight.shape}") |
| print(f"Debug - lm_head weight range: {self.lm_head.weight.min().item()} to {self.lm_head.weight.max().item()}") |
| print(f"Debug - lm_head weight has nan: {torch.isnan(self.lm_head.weight).any().item()}") |
| print(f"Debug - lm_head weight has inf: {torch.isinf(self.lm_head.weight).any().item()}") |
| |
| logits = self.lm_head(shift_hidden_states) |
| logits = logits.float() |
| |
| print(f"Debug - Logits shape: {logits.shape}") |
| print(f"Debug - Logits range: {logits.min().item()} to {logits.max().item()}") |
| print(f"Debug - Logits has nan: {torch.isnan(logits).any().item()}") |
| print(f"Debug - Logits has inf: {torch.isinf(logits).any().item()}") |
| |
| |
| if torch.isnan(logits).any(): |
| print("WARNING: Found nan values in logits, replacing with zeros") |
| logits = torch.where(torch.isnan(logits), torch.zeros_like(logits), logits) |
| |
| |
| if torch.isinf(logits).any(): |
| print("WARNING: Found inf values in logits, clamping to finite range") |
| logits = torch.clamp(logits, min=-1e4, max=1e4) |
| |
| |
| if torch.isnan(logits).any() or torch.isinf(logits).any(): |
| print("ERROR: Logits still contain nan/inf after fixing, using fallback") |
| logits = torch.zeros_like(logits) |
| |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(logits, shift_labels) |
| |
| print(f"Debug - Loss: {loss.item()}") |
| print(f"Debug - Loss has nan: {torch.isnan(loss).item()}") |
| |
|
|
| if not return_dict: |
| output = (logits,) + outputs[1:] |
| return (loss,) + output if loss is not None else output |
|
|
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
| def pixel_shuffle(self, x, scale_factor=0.5): |
| n, w, h, c = x.size() |
| |
| x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) |
| |
| x = x.permute(0, 2, 1, 3).contiguous() |
| |
| x = x.view(n, int(h * scale_factor), int(w * scale_factor), |
| int(c / (scale_factor * scale_factor))) |
| if self.ps_version == 'v1': |
| warnings.warn("In ps_version 'v1', the height and width have not been swapped back, " |
| 'which results in a transposed image.') |
| else: |
| x = x.permute(0, 2, 1, 3).contiguous() |
| return x |
|
|
| def extract_feature(self, pixel_values): |
| if self.select_layer == -1: |
| |
| vit_embeds = self.get_vision_tower()( |
| pixel_values=pixel_values, |
| output_hidden_states=False, |
| return_dict=True).last_hidden_state |
| else: |
| vit_embeds = self.get_vision_tower()( |
| pixel_values=pixel_values, |
| output_hidden_states=True, |
| return_dict=True).hidden_states[self.select_layer] |
| vit_embeds = vit_embeds[:, 1:, :] |
|
|
| h = w = int(vit_embeds.shape[1] ** 0.5) |
| vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) |
| vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) |
| vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) |
| |
| vit_embeds = self.get_vision_projector()(vit_embeds) |
| return vit_embeds |
| @torch.no_grad() |
| def generate( |
| self, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| input_ids: Optional[torch.FloatTensor] = None, |
| attention_mask: Optional[torch.LongTensor] = None, |
| visual_features: Optional[torch.FloatTensor] = None, |
| generation_config: Optional[GenerationConfig] = None, |
| output_hidden_states: Optional[bool] = None, |
| speech: Optional[torch.FloatTensor] = None, |
| speech_lengths: Optional[torch.LongTensor] = None, |
| speech_chunks: Optional[torch.LongTensor] = None, |
| speech_wav: Optional[torch.FloatTensor] = None, |
| modalities: Optional[List[str]] = ["image"], |
| **kwargs, |
| ) -> Union[GenerateOutput, torch.LongTensor]: |
| position_ids = kwargs.pop("position_ids", None) |
| |
| if speech is not None: |
| ( |
| _, |
| position_ids, |
| attention_mask, |
| _, |
| input_embeds, |
| _ |
| ) = self.prepare_inputs_labels_for_speech_text_for_internvl( |
| input_ids, |
| position_ids, |
| attention_mask, |
| None, |
| None, |
| speech, |
| speech_lengths, |
| speech_chunks, |
| speech_wav, |
| modalities, |
| ) |
| else: |
| |
| assert self.img_context_token_id is not None |
| if pixel_values is not None: |
| if visual_features is not None: |
| vit_embeds = visual_features |
| else: |
| vit_embeds = self.extract_feature(pixel_values) |
| input_embeds = self.get_model().get_input_embeddings()(input_ids) |
| B, N, C = input_embeds.shape |
| input_embeds = input_embeds.reshape(B * N, C) |
| input_ids = input_ids.reshape(B * N) |
| selected = (input_ids == self.img_context_token_id) |
| assert selected.sum() != 0 |
| input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device) |
| input_embeds = input_embeds.reshape(B, N, C) |
| else: |
| input_embeds = self.get_model().get_input_embeddings()(input_ids) |
| return super().generate( |
| inputs_embeds=input_embeds, |
| attention_mask=attention_mask, |
| generation_config=generation_config, |
| output_hidden_states=output_hidden_states, |
| use_cache=True, |
| **kwargs, |
| ) |
|
|
|
|
| def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False, |
| num_patches_list=None, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', |
| verbose=False, speech=None, speech_lengths=None, speech_wav=None, speech_chunks=None): |
| if history is None and pixel_values is not None and '<image>' not in question: |
| question = '<image>\n' + question |
|
|
| if num_patches_list is None: |
| num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else [] |
| assert pixel_values is None or len(pixel_values) == sum(num_patches_list) |
|
|
| img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) |
| self.img_context_token_id = img_context_token_id |
|
|
| template = get_conv_template(self.template) |
| template.system_message = self.system_message |
| eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip()) |
|
|
| history = [] if history is None else history |
| for (old_question, old_answer) in history: |
| template.append_message(template.roles[0], old_question) |
| template.append_message(template.roles[1], old_answer) |
| template.append_message(template.roles[0], question) |
| template.append_message(template.roles[1], None) |
| query = template.get_prompt() |
|
|
| if verbose and pixel_values is not None: |
| image_bs = pixel_values.shape[0] |
| print(f'dynamic ViT batch size: {image_bs}') |
| |
|
|
| |
| for num_patches in num_patches_list: |
| image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN |
| query = query.replace('<image>', image_tokens, 1) |
| from ola.conversation import conv_templates, SeparatorStyle |
| from ola.mm_utils import KeywordsStoppingCriteria |
| conv_mode = "plm_v" |
| conv = conv_templates[conv_mode].copy() |
| stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 |
| keywords = [stop_str] |
|
|
| |
| if speech is not None and '<speech>' in query: |
| |
| input_ids = tokenizer_speech_token(query, tokenizer, return_tensors='pt').unsqueeze(0).to(self.device) |
| |
| pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 151643 |
| attention_mask = input_ids.ne(pad_token_id).long().to(self.device) |
|
|
| else: |
| model_inputs = tokenizer(query, return_tensors='pt') |
| input_ids = model_inputs['input_ids'].to(self.device) |
| attention_mask = model_inputs['attention_mask'].to(self.device) |
| generation_config['eos_token_id'] = eos_token_id |
| stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) |
| |
| generation_output = self.generate( |
| pixel_values=pixel_values, |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| speech=speech, |
| speech_lengths=speech_lengths, |
| speech_chunks=speech_chunks, |
| speech_wav=speech_wav, |
| **generation_config |
| ) |
| response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0] |
| response = response.split(template.sep.strip())[0].strip() |
| history.append((question, response)) |
| if return_history: |
| return response, history |
| else: |
| query_to_print = query.replace(IMG_CONTEXT_TOKEN, '') |
| query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>') |
| if verbose: |
| print(query_to_print, response) |
| return response |
|
|
| def prepare_inputs_for_generation(self, input_ids, past_key_values=None, |
| inputs_embeds=None, **kwargs): |
| speech = kwargs.pop("speech", None) |
| speech_lengths = kwargs.pop("speech_lengths", None) |
| speech_chunks = kwargs.pop("speech_chunks", None) |
| images = kwargs.pop("images", None) |
| image_sizes = kwargs.pop("image_sizes", None) |
| inputs = super().prepare_inputs_for_generation( |
| input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs |
| ) |
| if speech is not None: |
| inputs['speech'] = speech |
| inputs['speech_lengths'] = speech_lengths |
| inputs['speech_chunks'] = speech_chunks |
| if images is not None: |
| inputs["images"] = images |
| if image_sizes is not None: |
| inputs["image_sizes"] = image_sizes |
| return inputs |
|
|
| AutoConfig.register("ola_internvl", OlaConfigQwen3) |
| AutoModelForCausalLM.register(OlaConfigQwen3, OlaQwen3ForCausalLM) |
|
|