jjw0126's picture
Upload files
84ff315 verified
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import transformers
from transformers import GenerationConfig
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig
SPEECH_TOKEN_INDEX = -200
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation.utils import GenerateOutput
from ..ola_arch import OlaMetaModel, OlaMetaForCausalLM
from transformers import Qwen3Config, Qwen3Model, Qwen3ForCausalLM
from .conversation import get_conv_template
from ola.constants import IGNORE_INDEX
def tokenizer_speech_token(prompt, tokenizer, speech_token_index=SPEECH_TOKEN_INDEX, return_tensors=None):
"""Tokenize prompt with speech tokens, similar to OLA's implementation"""
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<speech>')]
def insert_separator(X, sep):
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
input_ids = []
offset = 0
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
offset = 1
input_ids.append(prompt_chunks[0][0])
for x in insert_separator(prompt_chunks, [speech_token_index] * (offset + 1)):
input_ids.extend(x[offset:])
if return_tensors is not None:
if return_tensors == 'pt':
return torch.tensor(input_ids, dtype=torch.long)
raise ValueError(f'Unsupported tensor type: {return_tensors}')
return input_ids
class Qwen3Model(Qwen3Model):
def __init__(self, config: Qwen3Config, llm_config: Qwen3Config):
# breakpoint()
super(Qwen3Model, self).__init__(llm_config)
class OlaConfigQwen3(Qwen3Config, PretrainedConfig):
model_type = "ola_internvl"
class OlaQwen3Model(OlaMetaModel, Qwen3Model):
config_class = OlaConfigQwen3
def __init__(self, config: Qwen3Config):
super(OlaQwen3Model, self).__init__(config, config.llm_config)
class OlaQwen3ForCausalLM(Qwen3ForCausalLM, OlaMetaForCausalLM):
config_class = OlaConfigQwen3
# 从零初始化时不需要 checkpoint conversion mapping
# _checkpoint_conversion_mapping = {
# "^language_model.lm_head": "lm_head",
# "^language_model.model": "model.model",
# "^vision_model": "model.vision_tower",
# }
# model.model.embed_tokens:
def __init__(self, config):
super(Qwen3ForCausalLM, self).__init__(config)
config.rope_scaling = None
# breakpoint()
self.model = OlaQwen3Model(config)
self.vocab_size = config.vocab_size
# breakpoint()
self.ps_version = config.ps_version
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.template = "plm_v"
self.select_layer = config.select_layer
self.conv_template = get_conv_template(self.template)
self.system_message = self.conv_template.system_message
self.num_image_token = int((config.vision_config.image_size // config.vision_config.patch_size) ** 2 * (config.downsample_ratio ** 2))
self.downsample_ratio = config.downsample_ratio
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
speech: Optional[torch.FloatTensor] = None,
speech_lengths: Optional[torch.LongTensor] = None,
speech_chunks: Optional[torch.LongTensor] = None,
speech_wav: Optional[torch.FloatTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
images_highres: Optional[List[torch.FloatTensor]] = None,
image_sizes: Optional[List[List[int]]] = None,
modalities: Optional[List[str]] = ["image"],
image_flags: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
# breakpoint()
if inputs_embeds is None:
(
input_ids,
position_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels
) = self.prepare_inputs_labels_for_speech_text_for_internvl(
input_ids,
position_ids,
attention_mask,
past_key_values,
labels,
speech,
speech_lengths,
speech_chunks,
speech_wav,
modalities,
)
if labels is None:
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
else:
return self.forward_llm_efficient(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
def forward_llm_efficient(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Check inputs before model forward
print(f"Debug - Input embeddings range: {inputs_embeds.min().item()} to {inputs_embeds.max().item()}")
print(f"Debug - Input embeddings has nan: {torch.isnan(inputs_embeds).any().item()}")
print(f"Debug - Input embeddings has inf: {torch.isinf(inputs_embeds).any().item()}")
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
# Check hidden states immediately after model forward
print(f"Debug - Raw hidden states range: {hidden_states.min().item()} to {hidden_states.max().item()}")
print(f"Debug - Raw hidden states has nan: {torch.isnan(hidden_states).any().item()}")
print(f"Debug - Raw hidden states has inf: {torch.isinf(hidden_states).any().item()}")
hidden_dim = hidden_states.size(-1)
shift_labels = labels[..., 1:].contiguous().reshape(-1)
shift_hidden_states = hidden_states[..., :-1, :].contiguous().reshape(-1, hidden_dim)
assert shift_labels.size(0) == shift_hidden_states.size(0)
mask = shift_labels != IGNORE_INDEX
# Debug logging
print(f"Debug - Total tokens: {shift_labels.size(0)}")
print(f"Debug - Valid tokens: {mask.float().sum().item()}")
print(f"Debug - Ignored tokens: {(~mask).float().sum().item()}")
print(f"Debug - Label range: {shift_labels.min().item()} to {shift_labels.max().item()}")
assert mask.float().sum() > 0, f"No valid tokens found! Total: {shift_labels.size(0)}, Valid: {mask.float().sum().item()}"
shift_labels = shift_labels[mask]
shift_hidden_states = shift_hidden_states[mask, :]
print(f"Debug - After filtering: {shift_labels.size(0)} tokens")
print(f"Debug - Hidden states shape: {shift_hidden_states.shape}")
print(f"Debug - Hidden states range: {shift_hidden_states.min().item()} to {shift_hidden_states.max().item()}")
print(f"Debug - Hidden states has nan: {torch.isnan(shift_hidden_states).any().item()}")
print(f"Debug - Hidden states has inf: {torch.isinf(shift_hidden_states).any().item()}")
# Check lm_head weights
print(f"Debug - lm_head weight shape: {self.lm_head.weight.shape}")
print(f"Debug - lm_head weight range: {self.lm_head.weight.min().item()} to {self.lm_head.weight.max().item()}")
print(f"Debug - lm_head weight has nan: {torch.isnan(self.lm_head.weight).any().item()}")
print(f"Debug - lm_head weight has inf: {torch.isinf(self.lm_head.weight).any().item()}")
logits = self.lm_head(shift_hidden_states)
logits = logits.float()
print(f"Debug - Logits shape: {logits.shape}")
print(f"Debug - Logits range: {logits.min().item()} to {logits.max().item()}")
print(f"Debug - Logits has nan: {torch.isnan(logits).any().item()}")
print(f"Debug - Logits has inf: {torch.isinf(logits).any().item()}")
# Fix nan values in logits
if torch.isnan(logits).any():
print("WARNING: Found nan values in logits, replacing with zeros")
logits = torch.where(torch.isnan(logits), torch.zeros_like(logits), logits)
# Fix inf values in logits
if torch.isinf(logits).any():
print("WARNING: Found inf values in logits, clamping to finite range")
logits = torch.clamp(logits, min=-1e4, max=1e4)
# Additional check: if logits are still problematic, use a fallback
if torch.isnan(logits).any() or torch.isinf(logits).any():
print("ERROR: Logits still contain nan/inf after fixing, using fallback")
logits = torch.zeros_like(logits)
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits, shift_labels)
print(f"Debug - Loss: {loss.item()}")
print(f"Debug - Loss has nan: {torch.isnan(loss).item()}")
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size()
# N, W, H, C --> N, W, H * scale, C // scale
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
x = x.permute(0, 2, 1, 3).contiguous()
# N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
x = x.view(n, int(h * scale_factor), int(w * scale_factor),
int(c / (scale_factor * scale_factor)))
if self.ps_version == 'v1':
warnings.warn("In ps_version 'v1', the height and width have not been swapped back, "
'which results in a transposed image.')
else:
x = x.permute(0, 2, 1, 3).contiguous()
return x
def extract_feature(self, pixel_values):
if self.select_layer == -1:
# breakpoint()
vit_embeds = self.get_vision_tower()(
pixel_values=pixel_values,
output_hidden_states=False,
return_dict=True).last_hidden_state
else:
vit_embeds = self.get_vision_tower()(
pixel_values=pixel_values,
output_hidden_states=True,
return_dict=True).hidden_states[self.select_layer]
vit_embeds = vit_embeds[:, 1:, :]
h = w = int(vit_embeds.shape[1] ** 0.5)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
# breakpoint()
vit_embeds = self.get_vision_projector()(vit_embeds)
return vit_embeds
@torch.no_grad()
def generate(
self,
pixel_values: Optional[torch.FloatTensor] = None,
input_ids: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
visual_features: Optional[torch.FloatTensor] = None,
generation_config: Optional[GenerationConfig] = None,
output_hidden_states: Optional[bool] = None,
speech: Optional[torch.FloatTensor] = None,
speech_lengths: Optional[torch.LongTensor] = None,
speech_chunks: Optional[torch.LongTensor] = None,
speech_wav: Optional[torch.FloatTensor] = None,
modalities: Optional[List[str]] = ["image"],
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
position_ids = kwargs.pop("position_ids", None)
if speech is not None:
(
_,
position_ids,
attention_mask,
_,
input_embeds,
_
) = self.prepare_inputs_labels_for_speech_text_for_internvl(
input_ids,
position_ids,
attention_mask,
None,
None, # labels
speech,
speech_lengths,
speech_chunks,
speech_wav,
modalities,
)
else:
# internvl
assert self.img_context_token_id is not None
if pixel_values is not None:
if visual_features is not None:
vit_embeds = visual_features
else:
vit_embeds = self.extract_feature(pixel_values)
input_embeds = self.get_model().get_input_embeddings()(input_ids)
B, N, C = input_embeds.shape
input_embeds = input_embeds.reshape(B * N, C)
input_ids = input_ids.reshape(B * N)
selected = (input_ids == self.img_context_token_id)
assert selected.sum() != 0
input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
input_embeds = input_embeds.reshape(B, N, C)
else:
input_embeds = self.get_model().get_input_embeddings()(input_ids)
return super().generate(
inputs_embeds=input_embeds,
attention_mask=attention_mask,
generation_config=generation_config,
output_hidden_states=output_hidden_states,
use_cache=True,
**kwargs,
)
def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False,
num_patches_list=None, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>',
verbose=False, speech=None, speech_lengths=None, speech_wav=None, speech_chunks=None):
if history is None and pixel_values is not None and '<image>' not in question:
question = '<image>\n' + question
if num_patches_list is None:
num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
assert pixel_values is None or len(pixel_values) == sum(num_patches_list)
img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
self.img_context_token_id = img_context_token_id
template = get_conv_template(self.template)
template.system_message = self.system_message
eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())
history = [] if history is None else history
for (old_question, old_answer) in history:
template.append_message(template.roles[0], old_question)
template.append_message(template.roles[1], old_answer)
template.append_message(template.roles[0], question)
template.append_message(template.roles[1], None)
query = template.get_prompt()
if verbose and pixel_values is not None:
image_bs = pixel_values.shape[0]
print(f'dynamic ViT batch size: {image_bs}')
# Replace image tokens
for num_patches in num_patches_list:
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
query = query.replace('<image>', image_tokens, 1)
from ola.conversation import conv_templates, SeparatorStyle
from ola.mm_utils import KeywordsStoppingCriteria
conv_mode = "plm_v"
conv = conv_templates[conv_mode].copy()
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
# Use OLA-style tokenization for speech inputs
if speech is not None and '<speech>' in query:
# Use OLA-style tokenization directly with <speech> tokens
input_ids = tokenizer_speech_token(query, tokenizer, return_tensors='pt').unsqueeze(0).to(self.device)
# Handle case where pad_token_id might be None
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 151643
attention_mask = input_ids.ne(pad_token_id).long().to(self.device)
else:
model_inputs = tokenizer(query, return_tensors='pt')
input_ids = model_inputs['input_ids'].to(self.device)
attention_mask = model_inputs['attention_mask'].to(self.device)
generation_config['eos_token_id'] = eos_token_id
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
# generation_config["stopping_criteria"] = stopping_criteria
generation_output = self.generate(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
speech=speech,
speech_lengths=speech_lengths,
speech_chunks=speech_chunks,
speech_wav=speech_wav,
**generation_config
)
response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
response = response.split(template.sep.strip())[0].strip()
history.append((question, response))
if return_history:
return response, history
else:
query_to_print = query.replace(IMG_CONTEXT_TOKEN, '')
query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>')
if verbose:
print(query_to_print, response)
return response
def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
inputs_embeds=None, **kwargs):
speech = kwargs.pop("speech", None)
speech_lengths = kwargs.pop("speech_lengths", None)
speech_chunks = kwargs.pop("speech_chunks", None)
images = kwargs.pop("images", None)
image_sizes = kwargs.pop("image_sizes", None)
inputs = super().prepare_inputs_for_generation(
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
)
if speech is not None:
inputs['speech'] = speech
inputs['speech_lengths'] = speech_lengths
inputs['speech_chunks'] = speech_chunks
if images is not None:
inputs["images"] = images
if image_sizes is not None:
inputs["image_sizes"] = image_sizes
return inputs
AutoConfig.register("ola_internvl", OlaConfigQwen3)
AutoModelForCausalLM.register(OlaConfigQwen3, OlaQwen3ForCausalLM)