diff --git a/inference/infer.py b/inference/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..9e101e0f29d6d83cd0ca7cdc3ae541c98104b85e --- /dev/null +++ b/inference/infer.py @@ -0,0 +1,317 @@ +import os + +os.environ['LOWRES_RESIZE'] = '384x32' +os.environ['HIGHRES_BASE'] = '0x32' +os.environ['VIDEO_RESIZE'] = "0x64" +os.environ['VIDEO_MAXRES'] = "480" +os.environ['VIDEO_MINRES'] = "288" +os.environ['MAXRES'] = '1536' +os.environ['MINRES'] = '0' +os.environ['FORCE_NO_DOWNSAMPLE'] = '1' +os.environ['LOAD_VISION_EARLY'] = '1' +os.environ['PAD2STRIDE'] = '1' + +import gradio as gr +import torch +import re +from decord import VideoReader, cpu +from PIL import Image +import numpy as np +import transformers +import moviepy as mp +from typing import Dict, Optional, Sequence, List +import librosa +import whisper +from ola.conversation import conv_templates, SeparatorStyle +from ola.model.builder import load_pretrained_model +from ola.datasets.preprocess import tokenizer_image_token, tokenizer_speech_image_token, tokenizer_speech_question_image_token, tokenizer_speech_token +from ola.mm_utils import KeywordsStoppingCriteria, process_anyres_video, process_anyres_highres_image +from ola.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN, SPEECH_TOKEN_INDEX +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument('--model_path', type=str, default='/data1/cxy/model/THUdyh/Ola-7b') +parser.add_argument('--text', type=str, default="What does the speech say?") +parser.add_argument('--audio_path', type=str, default="/data1/cxy/dataset/english.mp3") +parser.add_argument('--image_path', type=str, default=None) +parser.add_argument('--video_path', type=str, default=None) +args = parser.parse_args() + +model_path = args.model_path +tokenizer, model, image_processor, _ = load_pretrained_model(model_path, None) +model = model.to('cuda').eval() +model = model.bfloat16() +# breakpoint() +USE_SPEECH=False +cur_dir = os.path.dirname(os.path.abspath(__file__)) + +def load_audio(audio_file_name): + speech_wav, samplerate = librosa.load(audio_file_name, sr=16000) + if len(speech_wav.shape) > 1: + speech_wav = speech_wav[:, 0] + speech_wav = speech_wav.astype(np.float32) + CHUNK_LIM = 480000 + SAMPLE_RATE = 16000 + speechs = [] + speech_wavs = [] + + if len(speech_wav) <= CHUNK_LIM: + speech = whisper.pad_or_trim(speech_wav) + speech_wav = whisper.pad_or_trim(speech_wav) + speechs.append(speech) + speech_wavs.append(torch.from_numpy(speech_wav).unsqueeze(0)) + else: + for i in range(0, len(speech_wav), CHUNK_LIM): + chunk = speech_wav[i : i + CHUNK_LIM] + if len(chunk) < CHUNK_LIM: + chunk = whisper.pad_or_trim(chunk) + speechs.append(chunk) + speech_wavs.append(torch.from_numpy(chunk).unsqueeze(0)) + mels = [] + for chunk in speechs: + chunk = whisper.log_mel_spectrogram(chunk, n_mels=128).permute(1, 0).unsqueeze(0) + mels.append(chunk) + + mels = torch.cat(mels, dim=0) + speech_wavs = torch.cat(speech_wavs, dim=0) + if mels.shape[0] > 25: + mels = mels[:25] + speech_wavs = speech_wavs[:25] + + speech_length = torch.LongTensor([mels.shape[1]] * mels.shape[0]) + speech_chunks = torch.LongTensor([mels.shape[0]]) + return mels, speech_length, speech_chunks, speech_wavs + +def extract_audio(videos_file_path): + my_clip = mp.VideoFileClip(videos_file_path) + return my_clip.audio + +image_path = args.image_path +audio_path = args.audio_path +video_path = args.video_path +text = args.text + +if video_path is not None: + modality = "video" + visual = video_path + assert image_path is None + +elif image_path is not None: + visual = image_path + modality = "image" + assert video_path is None + +elif audio_path is not None: + modality = "text" + + +# input audio and video, do not parse audio in the video, else parse audio in the video +if audio_path: + USE_SPEECH = True +elif modality == "video": + USE_SPEECH = True +else: + USE_SPEECH = False + +speechs = [] +speech_lengths = [] +speech_wavs = [] +speech_chunks = [] +if modality == "video": + vr = VideoReader(visual, ctx=cpu(0)) + total_frame_num = len(vr) + fps = round(vr.get_avg_fps()) + uniform_sampled_frames = np.linspace(0, total_frame_num - 1, 64, dtype=int) + frame_idx = uniform_sampled_frames.tolist() + spare_frames = vr.get_batch(frame_idx).asnumpy() + video = [Image.fromarray(frame) for frame in spare_frames] +elif modality == "image": + image = [Image.open(visual)] + image_sizes = [image[0].size] +else: + images = [torch.zeros(1, 3, 224, 224).to(dtype=torch.bfloat16, device='cuda', non_blocking=True)] + images_highres = [torch.zeros(1, 3, 224, 224).to(dtype=torch.bfloat16, device='cuda', non_blocking=True)] + image_sizes = [(224, 224)] + + +if USE_SPEECH and audio_path: + audio_path = audio_path + speech, speech_length, speech_chunk, speech_wav = load_audio(audio_path) + speechs.append(speech.bfloat16().to('cuda')) + speech_lengths.append(speech_length.to('cuda')) + speech_chunks.append(speech_chunk.to('cuda')) + speech_wavs.append(speech_wav.to('cuda')) + print('load audio') +elif USE_SPEECH and not audio_path: + # parse audio in the video + audio = extract_audio(visual) + audio.write_audiofile("./video_audio.wav") + video_audio_path = './video_audio.wav' + speech, speech_length, speech_chunk, speech_wav = load_audio(video_audio_path) + speechs.append(speech.bfloat16().to('cuda')) + speech_lengths.append(speech_length.to('cuda')) + speech_chunks.append(speech_chunk.to('cuda')) + speech_wavs.append(speech_wav.to('cuda')) +else: + speechs = [torch.zeros(1, 3000, 128).bfloat16().to('cuda')] + speech_lengths = [torch.LongTensor([3000]).to('cuda')] + speech_wavs = [torch.zeros([1, 480000]).to('cuda')] + speech_chunks = [torch.LongTensor([1]).to('cuda')] + +conv_mode = "qwen_1_5" +if text: + qs = text +else: + qs = '' + +if USE_SPEECH and audio_path and image_path: # image + speech instruction + qs = DEFAULT_IMAGE_TOKEN + "\n" + "User's question in speech: " + DEFAULT_SPEECH_TOKEN + '\n' +elif USE_SPEECH and video_path: # video + audio + qs = DEFAULT_SPEECH_TOKEN + DEFAULT_IMAGE_TOKEN + "\n" + qs +elif USE_SPEECH and audio_path: # audio + text + qs = DEFAULT_SPEECH_TOKEN + "\n" + qs +elif image_path or video_path: # image / video + qs = DEFAULT_IMAGE_TOKEN + "\n" + qs +elif text: # text + qs = qs + +conv = conv_templates[conv_mode].copy() +conv.append_message(conv.roles[0], qs) +conv.append_message(conv.roles[1], None) +prompt = conv.get_prompt() +if USE_SPEECH and audio_path and image_path: # image + speech instruction + input_ids = tokenizer_speech_question_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda') +elif USE_SPEECH and video_path: # video + audio + input_ids = tokenizer_speech_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda') +elif USE_SPEECH and audio_path: # audio + text + # breakpoint() + input_ids = tokenizer_speech_token(prompt, tokenizer, SPEECH_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda') +else: + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda') + +if modality == "video": + video_processed = [] + for idx, frame in enumerate(video): + image_processor.do_resize = False + image_processor.do_center_crop = False + frame = process_anyres_video(frame, image_processor) + + if frame_idx is not None and idx in frame_idx: + video_processed.append(frame.unsqueeze(0)) + elif frame_idx is None: + video_processed.append(frame.unsqueeze(0)) + + if frame_idx is None: + frame_idx = np.arange(0, len(video_processed), dtype=int).tolist() + + video_processed = torch.cat(video_processed, dim=0).bfloat16().to("cuda") + video_processed = (video_processed, video_processed) + + video_data = (video_processed, (384, 384), "video") +elif modality == "image": + image_processor.do_resize = False + image_processor.do_center_crop = False + image_tensor, image_highres_tensor = [], [] + for visual in image: + image_tensor_, image_highres_tensor_ = process_anyres_highres_image(visual, image_processor) + image_tensor.append(image_tensor_) + image_highres_tensor.append(image_highres_tensor_) + if all(x.shape == image_tensor[0].shape for x in image_tensor): + image_tensor = torch.stack(image_tensor, dim=0) + if all(x.shape == image_highres_tensor[0].shape for x in image_highres_tensor): + image_highres_tensor = torch.stack(image_highres_tensor, dim=0) + if type(image_tensor) is list: + image_tensor = [_image.bfloat16().to("cuda") for _image in image_tensor] + else: + image_tensor = image_tensor.bfloat16().to("cuda") + if type(image_highres_tensor) is list: + image_highres_tensor = [_image.bfloat16().to("cuda") for _image in image_highres_tensor] + else: + image_highres_tensor = image_highres_tensor.bfloat16().to("cuda") + +pad_token_ids = 151643 + +attention_masks = input_ids.ne(pad_token_ids).long().to('cuda') +stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 +keywords = [stop_str] +stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) + +gen_kwargs = {} + +if "max_new_tokens" not in gen_kwargs: + gen_kwargs["max_new_tokens"] = 1024 +if "temperature" not in gen_kwargs: + gen_kwargs["temperature"] = 0.2 +if "top_p" not in gen_kwargs: + gen_kwargs["top_p"] = None +if "num_beams" not in gen_kwargs: + gen_kwargs["num_beams"] = 1 +# breakpoint() +with torch.inference_mode(): + if modality == "video": + output_ids = model.generate( + inputs=input_ids, + images=video_data[0][0], + images_highres=video_data[0][1], + modalities=video_data[2], + speech=speechs, + speech_lengths=speech_lengths, + speech_chunks=speech_chunks, + speech_wav=speech_wavs, + attention_mask=attention_masks, + use_cache=True, + stopping_criteria=[stopping_criteria], + do_sample=True if gen_kwargs["temperature"] > 0 else False, + temperature=gen_kwargs["temperature"], + top_p=gen_kwargs["top_p"], + num_beams=gen_kwargs["num_beams"], + max_new_tokens=gen_kwargs["max_new_tokens"], + ) + elif modality == "image": + output_ids = model.generate( + inputs=input_ids, + images=image_tensor, + images_highres=image_highres_tensor, + image_sizes=image_sizes, + modalities=['image'], + speech=speechs, + speech_lengths=speech_lengths, + speech_chunks=speech_chunks, + speech_wav=speech_wavs, + attention_mask=attention_masks, + use_cache=True, + stopping_criteria=[stopping_criteria], + do_sample=True if gen_kwargs["temperature"] > 0 else False, + temperature=gen_kwargs["temperature"], + top_p=gen_kwargs["top_p"], + num_beams=gen_kwargs["num_beams"], + max_new_tokens=gen_kwargs["max_new_tokens"], + ) + elif modality == "text": + output_ids = model.generate( + input_ids, + images=images, + images_highres=images_highres, + image_sizes=image_sizes, + modalities=['text'], + speech=speechs, + speech_lengths=speech_lengths, + speech_chunks=speech_chunks, + speech_wav=speech_wavs, + attention_mask=attention_masks, + use_cache=True, + stopping_criteria=[stopping_criteria], + do_sample=True if gen_kwargs["temperature"] > 0 else False, + temperature=gen_kwargs["temperature"], + top_p=gen_kwargs["top_p"], + num_beams=gen_kwargs["num_beams"], + max_new_tokens=gen_kwargs["max_new_tokens"], + ) + +outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] +outputs = outputs.strip() +if outputs.endswith(stop_str): + outputs = outputs[:-len(stop_str)] +outputs = outputs.strip() + +print(outputs) \ No newline at end of file diff --git a/inference/infer_ola_internvl.py b/inference/infer_ola_internvl.py new file mode 100644 index 0000000000000000000000000000000000000000..623b275994384e3639d07c05a4fd4c04fd0f0c10 --- /dev/null +++ b/inference/infer_ola_internvl.py @@ -0,0 +1,448 @@ +import os + +os.environ['LOWRES_RESIZE'] = '384x32' +os.environ['HIGHRES_BASE'] = '0x32' +os.environ['VIDEO_RESIZE'] = "0x64" +os.environ['VIDEO_MAXRES'] = "480" +os.environ['VIDEO_MINRES'] = "288" +os.environ['MAXRES'] = '1536' +os.environ['MINRES'] = '0' +os.environ['FORCE_NO_DOWNSAMPLE'] = '1' +os.environ['LOAD_VISION_EARLY'] = '1' +os.environ['PAD2STRIDE'] = '1' +os.environ['CUDA_LAUNCH_BLOCKING'] = '1' +import os +import sys +from pathlib import Path +import math +import numpy as np +import torch +import torchvision.transforms as T +from decord import VideoReader, cpu # 暂时注释掉,专注于语音功能测试 +from PIL import Image +from torchvision.transforms.functional import InterpolationMode +from transformers import AutoModel, AutoTokenizer +from contextlib import redirect_stdout +import io +import librosa +import whisper +import moviepy as mp +import torch +from transformers import AutoTokenizer, AutoConfig, AutoModel + +# pure text +# image + text +# video + text +# audio + text +# video + audio + text + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) +import gradio as gr +import torch +import re +from decord import VideoReader, cpu +from PIL import Image +import numpy as np +import transformers +import moviepy as mp +from typing import Dict, Optional, Sequence, List +import librosa +import whisper +from ola.conversation import conv_templates, SeparatorStyle +from ola.model.builder import load_pretrained_model +from ola.datasets.preprocess import tokenizer_image_token, tokenizer_speech_image_token, tokenizer_speech_question_image_token, tokenizer_speech_token +from ola.mm_utils import KeywordsStoppingCriteria, process_anyres_video, process_anyres_highres_image +from ola.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN, SPEECH_TOKEN_INDEX +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument('--model_path', type=str, default='/data1/cxy/plm-v/modeling/internvl3_5-2B') +parser.add_argument('--text', type=str, default="What does the speech say?") +parser.add_argument('--audio_path', type=str, default=None) +parser.add_argument('--image_path', type=str, default=None) +parser.add_argument('--video_path', type=str, default=None) +args = parser.parse_args() + +model_path = args.model_path +tokenizer, model, image_processor, _ = load_pretrained_model(model_path,'ola_internvl', None) +model = model.to('cuda').eval() +model = model.bfloat16() + +resource_path = "/data1/cxy/plm-v/modeling/example/" +# set the max number of tiles in `max_num` +def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): + best_ratio_diff = float('inf') + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio +def build_transform(input_size): + MEAN, STD = IMAGENET_MEAN, IMAGENET_STD + transform = T.Compose([ + T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=MEAN, std=STD) + ]) + return transform + +def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False): + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = set( + (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if + i * j <= max_num and i * j >= min_num) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, image_size) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + return processed_images + + +def load_image(image_file, input_size=448, max_num=12): + image = Image.open(image_file).convert('RGB') + transform = build_transform(input_size=input_size) + images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) + pixel_values = [transform(image) for image in images] + pixel_values = torch.stack(pixel_values) + return pixel_values + + + +pixel_values = load_image(f'{resource_path}image1.jpg', max_num=12).to(torch.bfloat16).cuda() +# breakpoint() +generation_config = dict(max_new_tokens=1024, do_sample=True) + + +# breakpoint() +question = 'Hello, who are you?' +response, history = model.chat(tokenizer, None, question, generation_config, history=None, return_history=True) +print(f'User: {question}\nAssistant: {response}') + + + +# 多模态推理测试 +print("\n" + "="*80) +print("🧪 开始多模态推理测试") +print("="*80) + +def test_inference(test_name, question, pixel_values_input=None, speech_input=None, speech_lengths_input=None, num_patches_list=None): + """统一的推理测试函数""" + print(f"\n{'='*60}") + print(f"🧪 测试: {test_name}") + print(f"📝 问题: {question}") + print(f"{'='*60}") + + try: + # 准备参数 + chat_kwargs = { + 'tokenizer': tokenizer, + 'pixel_values': pixel_values_input, + 'question': question, + 'generation_config': generation_config, + 'verbose': True + } + + # 如果有视频数据,添加num_patches_list参数 + if num_patches_list is not None: + chat_kwargs['num_patches_list'] = num_patches_list + + # 如果有speech数据,添加speech参数 + if speech_input is not None: + chat_kwargs.update({ + 'speech': speech_input, # mel 谱图,用于 Whisper + 'speech_lengths': speech_lengths_input, + 'speech_wav': speech_wavs, # 原始音频波形,用于 BEATs + }) + + # 执行推理 + # breakpoint() + response = model.chat(**chat_kwargs) + + print(f"✅ 推理成功!") + print(f"🤖 回复: {response}") + + return True, response + + except Exception as e: + print(f"❌ 推理失败: {str(e)}") + import traceback + traceback.print_exc() + return False, str(e) + +# 测试1: Pure Text (应该正常,使用训练好的InternVL) +success1, response1 = test_inference( + test_name="Pure Text", + question="Hello, who are you? Please introduce yourself briefly.", + pixel_values_input=None, + speech_input=None, + speech_lengths_input=None +) + +# 测试2: Text & Image - Visual only (应该正常,使用训练好的InternVL) +# success2, response2 = test_inference( +# test_name="Text & Image (Visual only)", +# question="\nPlease describe this image in detail.", +# pixel_values_input=pixel_values, +# speech_input=None, +# speech_lengths_input=None +# ) + +print("\n" + "="*60) +print("🔄 准备Speech相关测试 (可能输出乱码,因为speech部分未训练)") +print("="*60) +def get_index(bound, fps, max_frame, first_idx=0, num_segments=32): + if bound: + start, end = bound[0], bound[1] + else: + start, end = -100000, 100000 + start_idx = max(first_idx, round(start * fps)) + end_idx = min(round(end * fps), max_frame) + seg_size = float(end_idx - start_idx) / num_segments + frame_indices = np.array([ + int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) + for idx in range(num_segments) + ]) + return frame_indices + +def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32): + vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) + max_frame = len(vr) - 1 + fps = float(vr.get_avg_fps()) + + pixel_values_list, num_patches_list = [], [] + transform = build_transform(input_size=input_size) + frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments) + for frame_index in frame_indices: + img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB') + img = dynamic_preprocess(img, image_size=input_size, use_thumbnail=True, max_num=max_num) + pixel_values = [transform(tile) for tile in img] + pixel_values = torch.stack(pixel_values) + num_patches_list.append(pixel_values.shape[0]) + pixel_values_list.append(pixel_values) + pixel_values = torch.cat(pixel_values_list) + return pixel_values, num_patches_list + +def load_audio(audio_file_name): + """ + 加载音频文件,使用Ola风格的mel谱图预处理 + 这与原始的Ola load_audio函数保持一致 + """ + speech_wav, samplerate = librosa.load(audio_file_name, sr=16000) + if len(speech_wav.shape) > 1: + speech_wav = speech_wav[:, 0] + speech_wav = speech_wav.astype(np.float32) + CHUNK_LIM = 480000 + SAMPLE_RATE = 16000 + speechs = [] + speech_wavs = [] + + if len(speech_wav) <= CHUNK_LIM: + speech = whisper.pad_or_trim(speech_wav) + speech_wav_chunk = whisper.pad_or_trim(speech_wav) + speechs.append(speech) + speech_wavs.append(torch.from_numpy(speech_wav_chunk).unsqueeze(0)) + else: + for i in range(0, len(speech_wav), CHUNK_LIM): + chunk = speech_wav[i : i + CHUNK_LIM] + if len(chunk) < CHUNK_LIM: + chunk = whisper.pad_or_trim(chunk) + speechs.append(chunk) + speech_wavs.append(torch.from_numpy(chunk).unsqueeze(0)) + + # 生成mel谱图 + mels = [] + for chunk in speechs: + chunk = whisper.log_mel_spectrogram(chunk, n_mels=128).permute(1, 0).unsqueeze(0) + mels.append(chunk) + + mels = torch.cat(mels, dim=0) + speech_wavs = torch.cat(speech_wavs, dim=0) + if mels.shape[0] > 25: + mels = mels[:25] + speech_wavs = speech_wavs[:25] + + speech_length = torch.LongTensor([mels.shape[1]] * mels.shape[0]) + speech_chunks = torch.LongTensor([mels.shape[0]]) + + return mels, speech_length, speech_chunks, speech_wavs + +def extract_audio(videos_file_path): + my_clip = mp.VideoFileClip(videos_file_path) + return my_clip.audio + +# 加载视频数据用于视频测试 +print("\n📥 加载视频数据...") +try: + video_path = f'{resource_path}red-panda.mp4' + if os.path.exists(video_path): + video_pixel_values, video_num_patches_list = load_video(video_path, num_segments=8, max_num=1) + video_pixel_values = video_pixel_values.to(torch.bfloat16).cuda() + video_loaded = True + print(f"✅ 视频加载成功:") + print(f" - 视频帧数: {len(video_num_patches_list)}") + print(f" - 视频像素值形状: {video_pixel_values.shape}") + print(f" - 每帧patch数: {video_num_patches_list}") + else: + print(f"⚠️ 视频文件不存在: {video_path}") + video_loaded = False + video_pixel_values = None + video_num_patches_list = None +except Exception as e: + print(f"❌ 视频加载失败: {e}") + video_loaded = False + video_pixel_values = None + video_num_patches_list = None + + + +audio_path = f'/data1/cxy/dataset/english.mp3' + +# 加载音频数据用于后续测试 +print("\n📥 加载音频数据...") +try: + # 加载音频文件 - 使用Ola风格的mel谱图预处理 + mels, speech_lengths, speech_chunks, speech_wavs = load_audio(audio_path) + print(f"✅ 音频加载成功:") + print(f" - mel谱图形状: {mels.shape}") + print(f" - 音频长度: {speech_lengths}") + print(f" - 音频块数: {speech_chunks}") + print(f" - 原始音频波形形状: {speech_wavs.shape}") + + # 将音频数据转换为适当的格式并移到GPU + mels = mels.to(torch.bfloat16).cuda() + speech_lengths = speech_lengths.cuda() + speech_chunks = speech_chunks.cuda() + speech_wavs = speech_wavs.cuda() + + audio_loaded = True + +except Exception as e: + print(f"❌ 音频加载失败: {e}") + audio_loaded = False + mels = None + speech_lengths = None + +# 测试3: Audio only (可能乱码,speech部分未训练) +if audio_loaded: + success3, response3 = test_inference( + test_name="Audio only (预期乱码)", + question="\nPlease transcribe and summarize what you heard in the audio.", + pixel_values_input=None, + speech_input=mels, + speech_lengths_input=speech_lengths + ) +else: + print("⚠️ 跳过Audio only测试 (音频加载失败)") + success3 = False + +# # 测试4: Audio + Image (可能乱码,speech部分未训练) +# if audio_loaded: +# success4, response4 = test_inference( +# test_name="Audio + Image (预期乱码)", +# question="\nUser's question in speech: \n", +# pixel_values_input=pixel_values, +# speech_input=mels, +# speech_lengths_input=speech_lengths +# ) +# else: +# print("⚠️ 跳过Audio + Image测试 (音频加载失败)") +# success4 = False + +# 测试5: Video + Text (应该正常,使用训练好的InternVL) +# if video_loaded: +# # 构建视频帧前缀 +# video_prefix = ''.join([f'Frame{i+1}: \n' for i in range(len(video_num_patches_list))]) +# video_question = video_prefix + 'What is the red panda doing in this video? Please describe the actions and movements you observe.' + +# success5, response5 = test_inference( +# test_name="Video + Text", +# question=video_question, +# pixel_values_input=video_pixel_values, +# speech_input=None, +# speech_lengths_input=None, +# num_patches_list=video_num_patches_list +# ) +# else: +# print("⚠️ 跳过Video + Text测试 (视频加载失败)") +# success5 = False + +# 测试5: Video + Audio (可能乱码,speech部分未训练) +# if audio_loaded: +# success5, response5 = test_inference( +# test_name="Video + Audio (预期乱码)", +# question="\nDescribe what you hear and see in this content.", +# pixel_values_input=pixel_values, +# speech_input=mels, +# speech_lengths_input=speech_lengths +# ) +# else: +# print("⚠️ 跳过Video + Audio测试 (音频加载失败)") +# success5 = False + +# 测试总结 +print("\n" + "="*80) +print("📊 多模态推理测试总结") +print("="*80) + +test_results = [ + ("Pure Text", success1, "PASS", "应该正常 (训练好的InternVL)"), + # ("Text & Image", success2, "PASS", "应该正常 (训练好的InternVL)"), + # ("Video + Text", success5 if video_loaded else False, "PASS", "应该正常 (训练好的InternVL)"), + ("Audio only", success3 if audio_loaded else False, "GARBLED", "可能乱码 (speech未训练)"), + # ("Audio + Image", success4 if audio_loaded else False, "GARBLED", "可能乱码 (speech未训练)"), +] + +for test_name, success, expected, note in test_results: + status = "✅ PASS" if success else "❌ FAIL" + print(f"{status} {test_name:<15} (预期: {expected:<8}) - {note}") + +passed = sum(1 for _, success, _, _ in test_results if success) +total = len(test_results) +print(f"\n📈 测试统计: {passed}/{total} 通过") + +if passed >= 2: # 至少pure text、text&image、video+text中的2个应该通过 + print("🎉 基础功能正常,Speech集成架构成功!") + print("💡 Speech相关测试如果输出乱码是正常的,因为speech部分还未训练") + if passed >= 3: + print("🌟 所有基础模态测试都通过了!") +else: + print("⚠️ 基础功能可能存在问题,需要进一步检查") + +print("\n=== 多模态推理测试完成 ===") + diff --git a/inference/infer_ola_internvl_audio.py b/inference/infer_ola_internvl_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..185cbe071bebc62b83e856bbdfcb0922069062eb --- /dev/null +++ b/inference/infer_ola_internvl_audio.py @@ -0,0 +1,244 @@ +import os + +os.environ['LOWRES_RESIZE'] = '384x32' +os.environ['HIGHRES_BASE'] = '0x32' +os.environ['VIDEO_RESIZE'] = "0x64" +os.environ['VIDEO_MAXRES'] = "480" +os.environ['VIDEO_MINRES'] = "288" +os.environ['MAXRES'] = '1536' +os.environ['MINRES'] = '0' +os.environ['FORCE_NO_DOWNSAMPLE'] = '1' +os.environ['LOAD_VISION_EARLY'] = '1' +os.environ['PAD2STRIDE'] = '1' +os.environ['CUDA_LAUNCH_BLOCKING'] = '1' +import os +import sys +from pathlib import Path +import math +import numpy as np +import torch +import torchvision.transforms as T +from decord import VideoReader, cpu # 暂时注释掉,专注于语音功能测试 +from PIL import Image +from torchvision.transforms.functional import InterpolationMode +from transformers import AutoModel, AutoTokenizer +from contextlib import redirect_stdout +import io +import librosa +import whisper +import moviepy as mp +import torch +from transformers import AutoTokenizer, AutoConfig, AutoModel + + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) +import gradio as gr +import torch +import re +from decord import VideoReader, cpu +from PIL import Image +import numpy as np +import transformers +import moviepy as mp +from typing import Dict, Optional, Sequence, List +import librosa +import whisper +from ola.conversation import conv_templates, SeparatorStyle +from ola.model.builder import load_pretrained_model +from ola.datasets.preprocess import tokenizer_image_token, tokenizer_speech_image_token, tokenizer_speech_question_image_token, tokenizer_speech_token +from ola.mm_utils import KeywordsStoppingCriteria, process_anyres_video, process_anyres_highres_image +from ola.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN, SPEECH_TOKEN_INDEX +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument('--model_path', type=str, default='/data1/cxy/plm-v/modeling/internvl3_5-2B') +parser.add_argument('--text', type=str, default="What does the speech say?") +parser.add_argument('--audio_path', type=str, default=None) +parser.add_argument('--image_path', type=str, default=None) +parser.add_argument('--video_path', type=str, default=None) +args = parser.parse_args() + +model_path = args.model_path +tokenizer, model, image_processor, _ = load_pretrained_model(model_path,'ola_internvl', None) +model = model.to('cuda').eval() +model = model.bfloat16() + +resource_path = "/data1/cxy/plm-v/modeling/example/" + +generation_config = dict( + max_new_tokens=256, + do_sample=False, # Use greedy decoding to avoid sampling issues + temperature=0.5, + top_p=0.8, + top_k=10, +) + +# 多模态推理测试 +print("\n" + "="*80) +print("🧪 开始多模态推理测试") +print("="*80) + +def test_inference(test_name, question, pixel_values_input=None, speech_input=None, speech_lengths_input=None, speech_wavs_input=None, speech_chunks_input=None, num_patches_list=None): + """统一的推理测试函数""" + print(f"\n{'='*60}") + print(f"🧪 测试: {test_name}") + print(f"📝 问题: {question}") + print(f"{'='*60}") + + try: + # 准备参数 + chat_kwargs = { + 'tokenizer': tokenizer, + 'pixel_values': pixel_values_input, + 'question': question, + 'generation_config': generation_config, + 'verbose': True + } + + # 如果有视频数据,添加num_patches_list参数 + if num_patches_list is not None: + chat_kwargs['num_patches_list'] = num_patches_list + + # 如果有speech数据,添加speech参数 + if speech_input is not None: + chat_kwargs.update({ + 'speech': speech_input, # mel 谱图,用于 Whisper + 'speech_lengths': speech_lengths_input, + 'speech_wav': speech_wavs_input, # 原始音频波形,用于 BEATs + }) + + # 如果有speech_chunks数据,添加speech_chunks参数 + if speech_chunks_input is not None: + chat_kwargs['speech_chunks'] = speech_chunks_input + + # 执行推理 + # breakpoint() + response = model.chat(**chat_kwargs) + + print(f"✅ 推理成功!") + print(f"🤖 回复: {response}") + + return True, response + + except Exception as e: + print(f"❌ 推理失败: {str(e)}") + import traceback + traceback.print_exc() + return False, str(e) + +# success1, response1 = test_inference( +# test_name="Pure Text", +# question="What is China's capital? Please introduce the city in detail.", +# pixel_values_input=None, +# speech_input=None, +# speech_lengths_input=None +# ) + +print("\n" + "="*60) +print("🔄 准备Speech相关测试 (可能输出乱码,因为speech部分未训练)") +print("="*60) + + +def load_audio(audio_file_name): + speech_wav, samplerate = librosa.load(audio_file_name, sr=16000) + if len(speech_wav.shape) > 1: + speech_wav = speech_wav[:, 0] + speech_wav = speech_wav.astype(np.float32) + CHUNK_LIM = 480000 + SAMPLE_RATE = 16000 + speechs = [] + speech_wavs = [] + + if len(speech_wav) <= CHUNK_LIM: + speech = whisper.pad_or_trim(speech_wav) + speech_wav = whisper.pad_or_trim(speech_wav) + speechs.append(speech) + speech_wavs.append(torch.from_numpy(speech_wav).unsqueeze(0)) + else: + for i in range(0, len(speech_wav), CHUNK_LIM): + chunk = speech_wav[i : i + CHUNK_LIM] + if len(chunk) < CHUNK_LIM: + chunk = whisper.pad_or_trim(chunk) + speechs.append(chunk) + speech_wavs.append(torch.from_numpy(chunk).unsqueeze(0)) + mels = [] + for chunk in speechs: + chunk = whisper.log_mel_spectrogram(chunk, n_mels=128).permute(1, 0).unsqueeze(0) + mels.append(chunk) + + mels = torch.cat(mels, dim=0) + speech_wavs = torch.cat(speech_wavs, dim=0) + if mels.shape[0] > 25: + mels = mels[:25] + speech_wavs = speech_wavs[:25] + + speech_length = torch.LongTensor([mels.shape[1]] * mels.shape[0]) + speech_chunks = torch.LongTensor([mels.shape[0]]) + return mels, speech_length, speech_chunks, speech_wavs + + + + +audio_path = f'/data1/cxy/dataset/english.mp3' + +# 加载音频数据用于后续测试 +print("\n📥 加载音频数据...") +try: + # 加载音频文件 - 使用Ola风格的mel谱图预处理 + speech, speech_lengths, speech_chunks, speech_wavs = load_audio(audio_path) + print(f"✅ 音频加载成功:") + print(f" - mel谱图形状: {speech.shape}") + print(f" - 音频长度: {speech_lengths}") + print(f" - 音频块数: {speech_chunks}") + print(f" - 原始音频波形形状: {speech_wavs.shape}") + + # 将音频数据转换为适当的格式并移到GPU + speech = speech.to(torch.bfloat16).cuda() + speech_lengths = speech_lengths.cuda() + speech_chunks = speech_chunks.cuda() + speech_wavs = speech_wavs.cuda() + + audio_loaded = True + +except Exception as e: + print(f"❌ 音频加载失败: {e}") + audio_loaded = False + mels = None + speech_lengths = None + +# 测试3: Audio only (可能乱码,speech部分未训练) +if audio_loaded: + success3, response3 = test_inference( + test_name="Audio only (预期乱码)", + question="\nPlease transcribe and summarize what you heard in the audio.", + pixel_values_input=None, + speech_input=speech, + speech_lengths_input=speech_lengths, + speech_wavs_input=speech_wavs, + speech_chunks_input=speech_chunks + ) +else: + print("⚠️ 跳过Audio only测试 (音频加载失败)") + success3 = False + + +# 测试总结 +print("\n" + "="*80) +print("📊 多模态推理测试总结") +print("="*80) + +test_results = [ + ("Audio only", success3 if audio_loaded else False, "GARBLED", "可能乱码 (speech未训练)"), +] + +for test_name, success, expected, note in test_results: + status = "✅ PASS" if success else "❌ FAIL" + print(f"{status} {test_name:<15} (预期: {expected:<8}) - {note}") + +passed = sum(1 for _, success, _, _ in test_results if success) +total = len(test_results) +print(f"\n📈 测试统计: {passed}/{total} 通过") + +print("\n=== 多模态推理测试完成 ===") + diff --git a/inference/infer_ola_internvl_audio_ckpt.py b/inference/infer_ola_internvl_audio_ckpt.py new file mode 100644 index 0000000000000000000000000000000000000000..815fc9de4db297654819c1f1e6d61816a04703ab --- /dev/null +++ b/inference/infer_ola_internvl_audio_ckpt.py @@ -0,0 +1,245 @@ +import os + +os.environ['LOWRES_RESIZE'] = '384x32' +os.environ['HIGHRES_BASE'] = '0x32' +os.environ['VIDEO_RESIZE'] = "0x64" +os.environ['VIDEO_MAXRES'] = "480" +os.environ['VIDEO_MINRES'] = "288" +os.environ['MAXRES'] = '1536' +os.environ['MINRES'] = '0' +os.environ['FORCE_NO_DOWNSAMPLE'] = '1' +os.environ['LOAD_VISION_EARLY'] = '1' +os.environ['PAD2STRIDE'] = '1' +os.environ['CUDA_LAUNCH_BLOCKING'] = '1' +import os +import sys +from pathlib import Path +import math +import numpy as np +import torch +import torchvision.transforms as T +from decord import VideoReader, cpu # 暂时注释掉,专注于语音功能测试 +from PIL import Image +from torchvision.transforms.functional import InterpolationMode +from transformers import AutoModel, AutoTokenizer +from contextlib import redirect_stdout +import io +import librosa +import whisper +import moviepy as mp +import torch +from transformers import AutoTokenizer, AutoConfig, AutoModel + + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) +import gradio as gr +import torch +import re +from decord import VideoReader, cpu +from PIL import Image +import numpy as np +import transformers +import moviepy as mp +from typing import Dict, Optional, Sequence, List +import librosa +import whisper +from ola.conversation import conv_templates, SeparatorStyle +from ola.model.builder import load_pretrained_model +from ola.datasets.preprocess import tokenizer_image_token, tokenizer_speech_image_token, tokenizer_speech_question_image_token, tokenizer_speech_token +from ola.mm_utils import KeywordsStoppingCriteria, process_anyres_video, process_anyres_highres_image +from ola.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN, SPEECH_TOKEN_INDEX +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument('--model_path', type=str, default='/data1/cxy/plm-v/modeling/internvl3_5-2B') +parser.add_argument('--text', type=str, default="Give the caption of the given audio or speech.") +parser.add_argument('--audio_path', type=str, default=None) +parser.add_argument('--image_path', type=str, default=None) +parser.add_argument('--video_path', type=str, default=None) +args = parser.parse_args() + +model_path = args.model_path +tokenizer, model, image_processor, _ = load_pretrained_model(model_path,'ola_internvl', None) +model = model.to('cuda').eval() +model = model.bfloat16() + +resource_path = "/data1/cxy/plm-v/modeling/example/" + +generation_config = dict( + max_new_tokens=256, + do_sample=False, # Use greedy decoding to avoid sampling issues + temperature=0.5, + top_p=0.8, + top_k=10, +) + +# 多模态推理测试 +print("\n" + "="*80) +print("🧪 开始多模态推理测试") +print("="*80) + +def test_inference(test_name, question, pixel_values_input=None, speech_input=None, speech_lengths_input=None, speech_wavs_input=None, speech_chunks_input=None, num_patches_list=None): + """统一的推理测试函数""" + print(f"\n{'='*60}") + print(f"🧪 测试: {test_name}") + print(f"📝 问题: {question}") + print(f"{'='*60}") + + try: + # 准备参数 + chat_kwargs = { + 'tokenizer': tokenizer, + 'pixel_values': pixel_values_input, + 'question': question, + 'generation_config': generation_config, + 'verbose': True + } + + # 如果有视频数据,添加num_patches_list参数 + if num_patches_list is not None: + chat_kwargs['num_patches_list'] = num_patches_list + + # 如果有speech数据,添加speech参数 + if speech_input is not None: + chat_kwargs.update({ + 'speech': speech_input, # mel 谱图,用于 Whisper + 'speech_lengths': speech_lengths_input, + 'speech_wav': speech_wavs_input, # 原始音频波形,用于 BEATs + }) + + # 如果有speech_chunks数据,添加speech_chunks参数 + if speech_chunks_input is not None: + chat_kwargs['speech_chunks'] = speech_chunks_input + + # 执行推理 + # breakpoint() + response = model.chat(**chat_kwargs) + + print(f"✅ 推理成功!") + print(f"🤖 回复: {response}") + + return True, response + + except Exception as e: + print(f"❌ 推理失败: {str(e)}") + import traceback + traceback.print_exc() + return False, str(e) + +# success1, response1 = test_inference( +# test_name="Pure Text", +# question="What is China's capital? Please introduce the city in detail.", +# pixel_values_input=None, +# speech_input=None, +# speech_lengths_input=None +# ) + +print("\n" + "="*60) +print("🔄 准备Speech相关测试 (可能输出乱码,因为speech部分未训练)") +print("="*60) + + +def load_audio(audio_file_name): + speech_wav, samplerate = librosa.load(audio_file_name, sr=16000) + if len(speech_wav.shape) > 1: + speech_wav = speech_wav[:, 0] + speech_wav = speech_wav.astype(np.float32) + CHUNK_LIM = 480000 + SAMPLE_RATE = 16000 + speechs = [] + speech_wavs = [] + + if len(speech_wav) <= CHUNK_LIM: + speech = whisper.pad_or_trim(speech_wav) + speech_wav = whisper.pad_or_trim(speech_wav) + speechs.append(speech) + speech_wavs.append(torch.from_numpy(speech_wav).unsqueeze(0)) + else: + for i in range(0, len(speech_wav), CHUNK_LIM): + chunk = speech_wav[i : i + CHUNK_LIM] + if len(chunk) < CHUNK_LIM: + chunk = whisper.pad_or_trim(chunk) + speechs.append(chunk) + speech_wavs.append(torch.from_numpy(chunk).unsqueeze(0)) + mels = [] + for chunk in speechs: + chunk = whisper.log_mel_spectrogram(chunk, n_mels=128).permute(1, 0).unsqueeze(0) + mels.append(chunk) + + mels = torch.cat(mels, dim=0) + speech_wavs = torch.cat(speech_wavs, dim=0) + if mels.shape[0] > 25: + mels = mels[:25] + speech_wavs = speech_wavs[:25] + + speech_length = torch.LongTensor([mels.shape[1]] * mels.shape[0]) + speech_chunks = torch.LongTensor([mels.shape[0]]) + return mels, speech_length, speech_chunks, speech_wavs + + + + +# audio_path = f'/data1/cxy/dataset/english.mp3' +audio_path = "/data1/cxy/plm-v/modeling/data/Clotho/train/Leaves rustling.wav" + +# 加载音频数据用于后续测试 +print("\n📥 加载音频数据...") +try: + # 加载音频文件 - 使用Ola风格的mel谱图预处理 + speech, speech_lengths, speech_chunks, speech_wavs = load_audio(audio_path) + print(f"✅ 音频加载成功:") + print(f" - mel谱图形状: {speech.shape}") + print(f" - 音频长度: {speech_lengths}") + print(f" - 音频块数: {speech_chunks}") + print(f" - 原始音频波形形状: {speech_wavs.shape}") + + # 将音频数据转换为适当的格式并移到GPU + speech = speech.to(torch.bfloat16).cuda() + speech_lengths = speech_lengths.cuda() + speech_chunks = speech_chunks.cuda() + speech_wavs = speech_wavs.cuda() + + audio_loaded = True + +except Exception as e: + print(f"❌ 音频加载失败: {e}") + audio_loaded = False + mels = None + speech_lengths = None + +# 测试3: Audio only (可能乱码,speech部分未训练) +if audio_loaded: + success3, response3 = test_inference( + test_name="Audio only (预期乱码)", + question="\nGive the caption of the given audio or speech.", + pixel_values_input=None, + speech_input=speech, + speech_lengths_input=speech_lengths, + speech_wavs_input=speech_wavs, + speech_chunks_input=speech_chunks + ) +else: + print("⚠️ 跳过Audio only测试 (音频加载失败)") + success3 = False + + +# 测试总结 +print("\n" + "="*80) +print("📊 多模态推理测试总结") +print("="*80) + +test_results = [ + ("Audio only", success3 if audio_loaded else False, "GARBLED", "可能乱码 (speech未训练)"), +] + +for test_name, success, expected, note in test_results: + status = "✅ PASS" if success else "❌ FAIL" + print(f"{status} {test_name:<15} (预期: {expected:<8}) - {note}") + +passed = sum(1 for _, success, _, _ in test_results if success) +total = len(test_results) +print(f"\n📈 测试统计: {passed}/{total} 通过") + +print("\n=== 多模态推理测试完成 ===") + diff --git a/inference/infer_ola_internvl_copy.py b/inference/infer_ola_internvl_copy.py new file mode 100644 index 0000000000000000000000000000000000000000..376a5e8226551e249f1257ee2d7346bb8c0d7834 --- /dev/null +++ b/inference/infer_ola_internvl_copy.py @@ -0,0 +1,318 @@ +import os + +os.environ['LOWRES_RESIZE'] = '384x32' +os.environ['HIGHRES_BASE'] = '0x32' +os.environ['VIDEO_RESIZE'] = "0x64" +os.environ['VIDEO_MAXRES'] = "480" +os.environ['VIDEO_MINRES'] = "288" +os.environ['MAXRES'] = '1536' +os.environ['MINRES'] = '0' +os.environ['FORCE_NO_DOWNSAMPLE'] = '1' +os.environ['LOAD_VISION_EARLY'] = '1' +os.environ['PAD2STRIDE'] = '1' + +import gradio as gr +import torch +import re +from decord import VideoReader, cpu +from PIL import Image +import numpy as np +import transformers +import moviepy as mp +from typing import Dict, Optional, Sequence, List +import librosa +import whisper +from ola.conversation import conv_templates, SeparatorStyle +from ola.model.builder import load_pretrained_model +from ola.datasets.preprocess import tokenizer_image_token, tokenizer_speech_image_token, tokenizer_speech_question_image_token, tokenizer_speech_token +from ola.mm_utils import KeywordsStoppingCriteria, process_anyres_video, process_anyres_highres_image +from ola.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN, SPEECH_TOKEN_INDEX +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument('--model_path', type=str, default='/data1/cxy/plm-v/modeling/internvl3_5-2B') +parser.add_argument('--text', type=str, default="What does the speech say?") +parser.add_argument('--audio_path', type=str, default=None) +parser.add_argument('--image_path', type=str, default=None) +parser.add_argument('--video_path', type=str, default=None) +args = parser.parse_args() + +model_path = args.model_path +tokenizer, model, image_processor, _ = load_pretrained_model(model_path,'ola_internvl', None) +breakpoint() +model = model.to('cuda').eval() +model = model.bfloat16() +# breakpoint() +USE_SPEECH=False +cur_dir = os.path.dirname(os.path.abspath(__file__)) + +def load_audio(audio_file_name): + speech_wav, samplerate = librosa.load(audio_file_name, sr=16000) + if len(speech_wav.shape) > 1: + speech_wav = speech_wav[:, 0] + speech_wav = speech_wav.astype(np.float32) + CHUNK_LIM = 480000 + SAMPLE_RATE = 16000 + speechs = [] + speech_wavs = [] + + if len(speech_wav) <= CHUNK_LIM: + speech = whisper.pad_or_trim(speech_wav) + speech_wav = whisper.pad_or_trim(speech_wav) + speechs.append(speech) + speech_wavs.append(torch.from_numpy(speech_wav).unsqueeze(0)) + else: + for i in range(0, len(speech_wav), CHUNK_LIM): + chunk = speech_wav[i : i + CHUNK_LIM] + if len(chunk) < CHUNK_LIM: + chunk = whisper.pad_or_trim(chunk) + speechs.append(chunk) + speech_wavs.append(torch.from_numpy(chunk).unsqueeze(0)) + mels = [] + for chunk in speechs: + chunk = whisper.log_mel_spectrogram(chunk, n_mels=128).permute(1, 0).unsqueeze(0) + mels.append(chunk) + + mels = torch.cat(mels, dim=0) + speech_wavs = torch.cat(speech_wavs, dim=0) + if mels.shape[0] > 25: + mels = mels[:25] + speech_wavs = speech_wavs[:25] + + speech_length = torch.LongTensor([mels.shape[1]] * mels.shape[0]) + speech_chunks = torch.LongTensor([mels.shape[0]]) + return mels, speech_length, speech_chunks, speech_wavs + +def extract_audio(videos_file_path): + my_clip = mp.VideoFileClip(videos_file_path) + return my_clip.audio + +image_path = args.image_path +audio_path = args.audio_path +video_path = args.video_path +text = args.text +modality = "text" +if video_path is not None: + modality = "video" + visual = video_path + assert image_path is None + +elif image_path is not None: + visual = image_path + modality = "image" + assert video_path is None + +elif audio_path is not None: + modality = "text" + + +# input audio and video, do not parse audio in the video, else parse audio in the video +if audio_path: + USE_SPEECH = True +elif modality == "video": + USE_SPEECH = True +else: + USE_SPEECH = False + +speechs = [] +speech_lengths = [] +speech_wavs = [] +speech_chunks = [] +if modality == "video": + vr = VideoReader(visual, ctx=cpu(0)) + total_frame_num = len(vr) + fps = round(vr.get_avg_fps()) + uniform_sampled_frames = np.linspace(0, total_frame_num - 1, 64, dtype=int) + frame_idx = uniform_sampled_frames.tolist() + spare_frames = vr.get_batch(frame_idx).asnumpy() + video = [Image.fromarray(frame) for frame in spare_frames] +elif modality == "image": + image = [Image.open(visual)] + image_sizes = [image[0].size] +else: + images = [torch.zeros(1, 3, 224, 224).to(dtype=torch.bfloat16, device='cuda', non_blocking=True)] + images_highres = [torch.zeros(1, 3, 224, 224).to(dtype=torch.bfloat16, device='cuda', non_blocking=True)] + image_sizes = [(224, 224)] + + +if USE_SPEECH and audio_path: + audio_path = audio_path + speech, speech_length, speech_chunk, speech_wav = load_audio(audio_path) + speechs.append(speech.bfloat16().to('cuda')) + speech_lengths.append(speech_length.to('cuda')) + speech_chunks.append(speech_chunk.to('cuda')) + speech_wavs.append(speech_wav.to('cuda')) + print('load audio') +elif USE_SPEECH and not audio_path: + # parse audio in the video + audio = extract_audio(visual) + audio.write_audiofile("./video_audio.wav") + video_audio_path = './video_audio.wav' + speech, speech_length, speech_chunk, speech_wav = load_audio(video_audio_path) + speechs.append(speech.bfloat16().to('cuda')) + speech_lengths.append(speech_length.to('cuda')) + speech_chunks.append(speech_chunk.to('cuda')) + speech_wavs.append(speech_wav.to('cuda')) +else: + speechs = [torch.zeros(1, 3000, 128).bfloat16().to('cuda')] + speech_lengths = [torch.LongTensor([3000]).to('cuda')] + speech_wavs = [torch.zeros([1, 480000]).to('cuda')] + speech_chunks = [torch.LongTensor([1]).to('cuda')] + +conv_mode = "qwen_1_5" +if text: + qs = text +else: + qs = '' + +if USE_SPEECH and audio_path and image_path: # image + speech instruction + qs = DEFAULT_IMAGE_TOKEN + "\n" + "User's question in speech: " + DEFAULT_SPEECH_TOKEN + '\n' +elif USE_SPEECH and video_path: # video + audio + qs = DEFAULT_SPEECH_TOKEN + DEFAULT_IMAGE_TOKEN + "\n" + qs +elif USE_SPEECH and audio_path: # audio + text + qs = DEFAULT_SPEECH_TOKEN + "\n" + qs +elif image_path or video_path: # image / video + qs = DEFAULT_IMAGE_TOKEN + "\n" + qs +elif text: # text + qs = qs + +conv = conv_templates[conv_mode].copy() +conv.append_message(conv.roles[0], qs) +conv.append_message(conv.roles[1], None) +prompt = conv.get_prompt() +if USE_SPEECH and audio_path and image_path: # image + speech instruction + input_ids = tokenizer_speech_question_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda') +elif USE_SPEECH and video_path: # video + audio + input_ids = tokenizer_speech_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda') +elif USE_SPEECH and audio_path: # audio + text + # breakpoint() + input_ids = tokenizer_speech_token(prompt, tokenizer, SPEECH_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda') +else: + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda') + +if modality == "video": + video_processed = [] + for idx, frame in enumerate(video): + image_processor.do_resize = False + image_processor.do_center_crop = False + frame = process_anyres_video(frame, image_processor) + + if frame_idx is not None and idx in frame_idx: + video_processed.append(frame.unsqueeze(0)) + elif frame_idx is None: + video_processed.append(frame.unsqueeze(0)) + + if frame_idx is None: + frame_idx = np.arange(0, len(video_processed), dtype=int).tolist() + + video_processed = torch.cat(video_processed, dim=0).bfloat16().to("cuda") + video_processed = (video_processed, video_processed) + + video_data = (video_processed, (384, 384), "video") +elif modality == "image": + image_processor.do_resize = False + image_processor.do_center_crop = False + image_tensor, image_highres_tensor = [], [] + for visual in image: + image_tensor_, image_highres_tensor_ = process_anyres_highres_image(visual, image_processor) + image_tensor.append(image_tensor_) + image_highres_tensor.append(image_highres_tensor_) + if all(x.shape == image_tensor[0].shape for x in image_tensor): + image_tensor = torch.stack(image_tensor, dim=0) + if all(x.shape == image_highres_tensor[0].shape for x in image_highres_tensor): + image_highres_tensor = torch.stack(image_highres_tensor, dim=0) + if type(image_tensor) is list: + image_tensor = [_image.bfloat16().to("cuda") for _image in image_tensor] + else: + image_tensor = image_tensor.bfloat16().to("cuda") + if type(image_highres_tensor) is list: + image_highres_tensor = [_image.bfloat16().to("cuda") for _image in image_highres_tensor] + else: + image_highres_tensor = image_highres_tensor.bfloat16().to("cuda") + +pad_token_ids = 151643 + +attention_masks = input_ids.ne(pad_token_ids).long().to('cuda') +stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 +keywords = [stop_str] +stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) + +gen_kwargs = {} + +if "max_new_tokens" not in gen_kwargs: + gen_kwargs["max_new_tokens"] = 1024 +if "temperature" not in gen_kwargs: + gen_kwargs["temperature"] = 0.2 +if "top_p" not in gen_kwargs: + gen_kwargs["top_p"] = None +if "num_beams" not in gen_kwargs: + gen_kwargs["num_beams"] = 1 +# breakpoint() +with torch.inference_mode(): + if modality == "video": + output_ids = model.generate( + inputs=input_ids, + images=video_data[0][0], + images_highres=video_data[0][1], + modalities=video_data[2], + speech=speechs, + speech_lengths=speech_lengths, + speech_chunks=speech_chunks, + speech_wav=speech_wavs, + attention_mask=attention_masks, + use_cache=True, + stopping_criteria=[stopping_criteria], + do_sample=True if gen_kwargs["temperature"] > 0 else False, + temperature=gen_kwargs["temperature"], + top_p=gen_kwargs["top_p"], + num_beams=gen_kwargs["num_beams"], + max_new_tokens=gen_kwargs["max_new_tokens"], + ) + elif modality == "image": + output_ids = model.generate( + inputs=input_ids, + images=image_tensor, + images_highres=image_highres_tensor, + image_sizes=image_sizes, + modalities=['image'], + speech=speechs, + speech_lengths=speech_lengths, + speech_chunks=speech_chunks, + speech_wav=speech_wavs, + attention_mask=attention_masks, + use_cache=True, + stopping_criteria=[stopping_criteria], + do_sample=True if gen_kwargs["temperature"] > 0 else False, + temperature=gen_kwargs["temperature"], + top_p=gen_kwargs["top_p"], + num_beams=gen_kwargs["num_beams"], + max_new_tokens=gen_kwargs["max_new_tokens"], + ) + elif modality == "text": + output_ids = model.generate( + input_ids, + images=images, + images_highres=images_highres, + image_sizes=image_sizes, + modalities=['text'], + speech=speechs, + speech_lengths=speech_lengths, + speech_chunks=speech_chunks, + speech_wav=speech_wavs, + attention_mask=attention_masks, + use_cache=True, + stopping_criteria=[stopping_criteria], + do_sample=True if gen_kwargs["temperature"] > 0 else False, + temperature=gen_kwargs["temperature"], + top_p=gen_kwargs["top_p"], + num_beams=gen_kwargs["num_beams"], + max_new_tokens=gen_kwargs["max_new_tokens"], + ) + +outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] +outputs = outputs.strip() +if outputs.endswith(stop_str): + outputs = outputs[:-len(stop_str)] +outputs = outputs.strip() + +print(outputs) \ No newline at end of file diff --git a/inference/infer_ola_internvl_text_visual.py b/inference/infer_ola_internvl_text_visual.py new file mode 100644 index 0000000000000000000000000000000000000000..6f9090e47ca59276e5cecd7621a289f673306c9d --- /dev/null +++ b/inference/infer_ola_internvl_text_visual.py @@ -0,0 +1,409 @@ +import os + +os.environ['LOWRES_RESIZE'] = '384x32' +os.environ['HIGHRES_BASE'] = '0x32' +os.environ['VIDEO_RESIZE'] = "0x64" +os.environ['VIDEO_MAXRES'] = "480" +os.environ['VIDEO_MINRES'] = "288" +os.environ['MAXRES'] = '1536' +os.environ['MINRES'] = '0' +os.environ['FORCE_NO_DOWNSAMPLE'] = '1' +os.environ['LOAD_VISION_EARLY'] = '1' +os.environ['PAD2STRIDE'] = '1' +os.environ['CUDA_LAUNCH_BLOCKING'] = '1' +import os +import sys +from pathlib import Path +import math +import numpy as np +import torch +import torchvision.transforms as T +from decord import VideoReader, cpu # 暂时注释掉,专注于语音功能测试 +from PIL import Image +from torchvision.transforms.functional import InterpolationMode +from transformers import AutoModel, AutoTokenizer +from contextlib import redirect_stdout +import io +import librosa +import whisper +import moviepy as mp +import torch +from transformers import AutoTokenizer, AutoConfig, AutoModel + +# pure text +# image + text +# video + text +# audio + text +# video + audio + text + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) +import gradio as gr +import torch +import re +from decord import VideoReader, cpu +from PIL import Image +import numpy as np +import transformers +import moviepy as mp +from typing import Dict, Optional, Sequence, List +import librosa +import whisper +from ola.conversation import conv_templates, SeparatorStyle +from ola.model.builder import load_pretrained_model +from ola.datasets.preprocess import tokenizer_image_token, tokenizer_speech_image_token, tokenizer_speech_question_image_token, tokenizer_speech_token +from ola.mm_utils import KeywordsStoppingCriteria, process_anyres_video, process_anyres_highres_image +from ola.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN, SPEECH_TOKEN_INDEX +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument('--model_path', type=str, default='/data1/cxy/plm-v/modeling/internvl3_5-2B') +parser.add_argument('--text', type=str, default="What does the speech say?") +parser.add_argument('--audio_path', type=str, default=None) +parser.add_argument('--image_path', type=str, default=None) +parser.add_argument('--video_path', type=str, default=None) +args = parser.parse_args() + +model_path = args.model_path +tokenizer, model, image_processor, _ = load_pretrained_model(model_path,'ola_internvl', None) +model = model.to('cuda').eval() +model = model.bfloat16() + +resource_path = "/data1/cxy/plm-v/modeling/example/" +# set the max number of tiles in `max_num` +def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): + best_ratio_diff = float('inf') + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio +def build_transform(input_size): + MEAN, STD = IMAGENET_MEAN, IMAGENET_STD + transform = T.Compose([ + T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=MEAN, std=STD) + ]) + return transform + +def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False): + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = set( + (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if + i * j <= max_num and i * j >= min_num) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, image_size) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + return processed_images + + +def load_image(image_file, input_size=448, max_num=12): + image = Image.open(image_file).convert('RGB') + transform = build_transform(input_size=input_size) + images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) + pixel_values = [transform(image) for image in images] + pixel_values = torch.stack(pixel_values) + return pixel_values + + + +pixel_values = load_image(f'{resource_path}image1.jpg', max_num=12).to(torch.bfloat16).cuda() +# breakpoint() +generation_config = dict(max_new_tokens=1024, do_sample=True) +generation_config = dict( + max_new_tokens=256, + do_sample=False, # Use greedy decoding to avoid sampling issues + temperature=0.5, + top_p=0.8, + top_k=10, +) + + + +# 多模态推理测试 +print("\n" + "="*80) +print("🧪 开始多模态推理测试") +print("="*80) + +def test_inference(test_name, question, pixel_values_input=None, speech_input=None, speech_lengths_input=None, num_patches_list=None): + """统一的推理测试函数""" + print(f"\n{'='*60}") + print(f"🧪 测试: {test_name}") + print(f"📝 问题: {question}") + print(f"{'='*60}") + + try: + # 准备参数 + chat_kwargs = { + 'tokenizer': tokenizer, + 'pixel_values': pixel_values_input, + 'question': question, + 'generation_config': generation_config, + 'verbose': True + } + + # 如果有视频数据,添加num_patches_list参数 + if num_patches_list is not None: + chat_kwargs['num_patches_list'] = num_patches_list + + # 如果有speech数据,添加speech参数 + if speech_input is not None: + chat_kwargs.update({ + 'speech': speech_input, # mel 谱图,用于 Whisper + 'speech_lengths': speech_lengths_input, + 'speech_wav': speech_wavs, # 原始音频波形,用于 BEATs + }) + + # 执行推理 + # breakpoint() + response = model.chat(**chat_kwargs) + + print(f"✅ 推理成功!") + print(f"🤖 回复: {response}") + + return True, response + + except Exception as e: + print(f"❌ 推理失败: {str(e)}") + import traceback + traceback.print_exc() + return False, str(e) + +# 测试1: Pure Text (应该正常,使用训练好的InternVL) +success1, response1 = test_inference( + test_name="Pure Text", + question="Hello, who are you? Please introduce yourself briefly.", + pixel_values_input=None, + speech_input=None, + speech_lengths_input=None +) +# breakpoint() +# 测试2: Text & Image - Visual only (应该正常,使用训练好的InternVL) +success2, response2 = test_inference( + test_name="Text & Image (Visual only)", + question="\nPlease describe this image in detail.", + pixel_values_input=pixel_values, + speech_input=None, + speech_lengths_input=None +) + +print("\n" + "="*60) +print("🔄 准备Speech相关测试 (可能输出乱码,因为speech部分未训练)") +print("="*60) +def get_index(bound, fps, max_frame, first_idx=0, num_segments=32): + if bound: + start, end = bound[0], bound[1] + else: + start, end = -100000, 100000 + start_idx = max(first_idx, round(start * fps)) + end_idx = min(round(end * fps), max_frame) + seg_size = float(end_idx - start_idx) / num_segments + frame_indices = np.array([ + int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) + for idx in range(num_segments) + ]) + return frame_indices + +def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32): + vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) + max_frame = len(vr) - 1 + fps = float(vr.get_avg_fps()) + + pixel_values_list, num_patches_list = [], [] + transform = build_transform(input_size=input_size) + frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments) + for frame_index in frame_indices: + img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB') + img = dynamic_preprocess(img, image_size=input_size, use_thumbnail=True, max_num=max_num) + pixel_values = [transform(tile) for tile in img] + pixel_values = torch.stack(pixel_values) + num_patches_list.append(pixel_values.shape[0]) + pixel_values_list.append(pixel_values) + pixel_values = torch.cat(pixel_values_list) + return pixel_values, num_patches_list + +def load_audio(audio_file_name): + """ + 加载音频文件,使用Ola风格的mel谱图预处理 + 这与原始的Ola load_audio函数保持一致 + """ + speech_wav, samplerate = librosa.load(audio_file_name, sr=16000) + if len(speech_wav.shape) > 1: + speech_wav = speech_wav[:, 0] + speech_wav = speech_wav.astype(np.float32) + CHUNK_LIM = 480000 + SAMPLE_RATE = 16000 + speechs = [] + speech_wavs = [] + + if len(speech_wav) <= CHUNK_LIM: + speech = whisper.pad_or_trim(speech_wav) + speech_wav_chunk = whisper.pad_or_trim(speech_wav) + speechs.append(speech) + speech_wavs.append(torch.from_numpy(speech_wav_chunk).unsqueeze(0)) + else: + for i in range(0, len(speech_wav), CHUNK_LIM): + chunk = speech_wav[i : i + CHUNK_LIM] + if len(chunk) < CHUNK_LIM: + chunk = whisper.pad_or_trim(chunk) + speechs.append(chunk) + speech_wavs.append(torch.from_numpy(chunk).unsqueeze(0)) + + # 生成mel谱图 + mels = [] + for chunk in speechs: + chunk = whisper.log_mel_spectrogram(chunk, n_mels=128).permute(1, 0).unsqueeze(0) + mels.append(chunk) + + mels = torch.cat(mels, dim=0) + speech_wavs = torch.cat(speech_wavs, dim=0) + if mels.shape[0] > 25: + mels = mels[:25] + speech_wavs = speech_wavs[:25] + + speech_length = torch.LongTensor([mels.shape[1]] * mels.shape[0]) + speech_chunks = torch.LongTensor([mels.shape[0]]) + + return mels, speech_length, speech_chunks, speech_wavs + +def extract_audio(videos_file_path): + my_clip = mp.VideoFileClip(videos_file_path) + return my_clip.audio + +# 加载视频数据用于视频测试 +print("\n📥 加载视频数据...") +try: + video_path = f'{resource_path}red-panda.mp4' + if os.path.exists(video_path): + video_pixel_values, video_num_patches_list = load_video(video_path, num_segments=8, max_num=1) + video_pixel_values = video_pixel_values.to(torch.bfloat16).cuda() + video_loaded = True + print(f"✅ 视频加载成功:") + print(f" - 视频帧数: {len(video_num_patches_list)}") + print(f" - 视频像素值形状: {video_pixel_values.shape}") + print(f" - 每帧patch数: {video_num_patches_list}") + else: + print(f"⚠️ 视频文件不存在: {video_path}") + video_loaded = False + video_pixel_values = None + video_num_patches_list = None +except Exception as e: + print(f"❌ 视频加载失败: {e}") + video_loaded = False + video_pixel_values = None + video_num_patches_list = None + + + +audio_path = f'/data1/cxy/dataset/english.mp3' + +# 加载音频数据用于后续测试 +print("\n📥 加载音频数据...") +try: + # 加载音频文件 - 使用Ola风格的mel谱图预处理 + mels, speech_lengths, speech_chunks, speech_wavs = load_audio(audio_path) + print(f"✅ 音频加载成功:") + print(f" - mel谱图形状: {mels.shape}") + print(f" - 音频长度: {speech_lengths}") + print(f" - 音频块数: {speech_chunks}") + print(f" - 原始音频波形形状: {speech_wavs.shape}") + + # 将音频数据转换为适当的格式并移到GPU + mels = mels.to(torch.bfloat16).cuda() + speech_lengths = speech_lengths.cuda() + speech_chunks = speech_chunks.cuda() + speech_wavs = speech_wavs.cuda() + + audio_loaded = True + +except Exception as e: + print(f"❌ 音频加载失败: {e}") + audio_loaded = False + mels = None + speech_lengths = None + + +# 测试5: Video + Text (应该正常,使用训练好的InternVL) +if video_loaded: + # 构建视频帧前缀 + video_prefix = ''.join([f'Frame{i+1}: \n' for i in range(len(video_num_patches_list))]) + video_question = video_prefix + 'What is the red panda doing in this video? Please describe the actions and movements you observe.' + + success5, response5 = test_inference( + test_name="Video + Text", + question=video_question, + pixel_values_input=video_pixel_values, + speech_input=None, + speech_lengths_input=None, + num_patches_list=video_num_patches_list + ) +else: + print("⚠️ 跳过Video + Text测试 (视频加载失败)") + success5 = False + + +print("\n" + "="*80) +print("📊 多模态推理测试总结") +print("="*80) + +test_results = [ + ("Pure Text", success1, "PASS", "应该正常 (训练好的InternVL)"), + ("Text & Image", success2, "PASS", "应该正常 (训练好的InternVL)"), + ("Video + Text", success5 if video_loaded else False, "PASS", "应该正常 (训练好的InternVL)"), +] + +for test_name, success, expected, note in test_results: + status = "✅ PASS" if success else "❌ FAIL" + print(f"{status} {test_name:<15} (预期: {expected:<8}) - {note}") + +passed = sum(1 for _, success, _, _ in test_results if success) +total = len(test_results) +print(f"\n📈 测试统计: {passed}/{total} 通过") + +if passed >= 2: # 至少pure text、text&image、video+text中的2个应该通过 + print("🎉 基础功能正常,Speech集成架构成功!") + print("💡 Speech相关测试如果输出乱码是正常的,因为speech部分还未训练") + if passed >= 3: + print("🌟 所有基础模态测试都通过了!") +else: + print("⚠️ 基础功能可能存在问题,需要进一步检查") + +print("\n=== 多模态推理测试完成 ===") + diff --git a/inference/log.txt b/inference/log.txt new file mode 100644 index 0000000000000000000000000000000000000000..c4aac9eb5a0bc1d0019abec75410ae36a572254a --- /dev/null +++ b/inference/log.txt @@ -0,0 +1,480 @@ +[2025-09-15 09:15:42,098] [INFO] [real_accelerator.py:222:get_accelerator] Setting ds_accelerator to cuda (auto detect) +LOAD_VISION_EARLY is set +FORCE_NO_DOWNSAMPLE is set +VIDEO_RESIZE is set as 0x64, 0, 64 +HIGHRES_BASE is set as 0x32, 0, 32 +MAXRES is set as 1536 +MINRES is set as 0 +VIDEO_MAXRES is set as 480 +VIDEO_MINRES is set as 288 +PAD2STRIDE is set +LOWRES_RESIZE is set as 384x32 +Loading OlaQwen3ForCausalLM model... +Loading BEATs Model +Missing keys: ['model.speech_encoder.whisper_model.positional_embedding', 'model.speech_encoder.whisper_model.conv1.weight', 'model.speech_encoder.whisper_model.conv1.bias', 'model.speech_encoder.whisper_model.conv2.weight', 'model.speech_encoder.whisper_model.conv2.bias', 'model.speech_encoder.whisper_model.blocks.0.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.0.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.0.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.0.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.0.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.0.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.0.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.0.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.0.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.0.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.0.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.0.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.0.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.0.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.0.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.1.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.1.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.1.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.1.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.1.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.1.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.1.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.1.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.1.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.1.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.1.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.1.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.1.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.1.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.1.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.2.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.2.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.2.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.2.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.2.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.2.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.2.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.2.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.2.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.2.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.2.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.2.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.2.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.2.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.2.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.3.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.3.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.3.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.3.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.3.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.3.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.3.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.3.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.3.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.3.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.3.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.3.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.3.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.3.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.3.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.4.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.4.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.4.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.4.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.4.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.4.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.4.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.4.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.4.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.4.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.4.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.4.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.4.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.4.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.4.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.5.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.5.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.5.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.5.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.5.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.5.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.5.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.5.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.5.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.5.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.5.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.5.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.5.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.5.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.5.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.6.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.6.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.6.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.6.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.6.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.6.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.6.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.6.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.6.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.6.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.6.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.6.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.6.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.6.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.6.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.7.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.7.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.7.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.7.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.7.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.7.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.7.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.7.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.7.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.7.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.7.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.7.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.7.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.7.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.7.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.8.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.8.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.8.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.8.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.8.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.8.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.8.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.8.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.8.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.8.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.8.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.8.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.8.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.8.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.8.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.9.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.9.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.9.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.9.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.9.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.9.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.9.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.9.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.9.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.9.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.9.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.9.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.9.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.9.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.9.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.10.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.10.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.10.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.10.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.10.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.10.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.10.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.10.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.10.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.10.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.10.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.10.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.10.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.10.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.10.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.11.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.11.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.11.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.11.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.11.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.11.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.11.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.11.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.11.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.11.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.11.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.11.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.11.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.11.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.11.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.12.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.12.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.12.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.12.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.12.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.12.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.12.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.12.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.12.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.12.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.12.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.12.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.12.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.12.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.12.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.13.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.13.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.13.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.13.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.13.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.13.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.13.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.13.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.13.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.13.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.13.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.13.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.13.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.13.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.13.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.14.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.14.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.14.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.14.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.14.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.14.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.14.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.14.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.14.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.14.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.14.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.14.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.14.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.14.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.14.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.15.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.15.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.15.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.15.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.15.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.15.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.15.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.15.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.15.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.15.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.15.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.15.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.15.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.15.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.15.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.16.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.16.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.16.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.16.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.16.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.16.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.16.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.16.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.16.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.16.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.16.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.16.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.16.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.16.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.16.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.17.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.17.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.17.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.17.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.17.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.17.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.17.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.17.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.17.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.17.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.17.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.17.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.17.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.17.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.17.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.18.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.18.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.18.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.18.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.18.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.18.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.18.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.18.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.18.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.18.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.18.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.18.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.18.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.18.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.18.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.19.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.19.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.19.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.19.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.19.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.19.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.19.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.19.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.19.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.19.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.19.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.19.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.19.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.19.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.19.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.20.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.20.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.20.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.20.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.20.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.20.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.20.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.20.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.20.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.20.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.20.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.20.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.20.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.20.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.20.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.21.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.21.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.21.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.21.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.21.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.21.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.21.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.21.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.21.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.21.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.21.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.21.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.21.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.21.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.21.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.22.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.22.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.22.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.22.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.22.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.22.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.22.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.22.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.22.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.22.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.22.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.22.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.22.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.22.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.22.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.23.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.23.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.23.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.23.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.23.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.23.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.23.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.23.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.23.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.23.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.23.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.23.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.23.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.23.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.23.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.24.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.24.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.24.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.24.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.24.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.24.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.24.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.24.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.24.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.24.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.24.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.24.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.24.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.24.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.24.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.25.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.25.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.25.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.25.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.25.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.25.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.25.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.25.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.25.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.25.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.25.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.25.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.25.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.25.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.25.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.26.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.26.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.26.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.26.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.26.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.26.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.26.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.26.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.26.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.26.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.26.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.26.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.26.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.26.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.26.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.27.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.27.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.27.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.27.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.27.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.27.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.27.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.27.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.27.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.27.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.27.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.27.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.27.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.27.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.27.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.28.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.28.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.28.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.28.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.28.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.28.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.28.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.28.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.28.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.28.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.28.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.28.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.28.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.28.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.28.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.29.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.29.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.29.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.29.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.29.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.29.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.29.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.29.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.29.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.29.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.29.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.29.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.29.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.29.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.29.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.30.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.30.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.30.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.30.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.30.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.30.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.30.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.30.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.30.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.30.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.30.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.30.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.30.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.30.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.30.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.31.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.31.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.31.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.31.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.31.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.31.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.31.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.31.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.31.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.31.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.31.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.31.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.31.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.31.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.31.mlp_ln.bias', 'model.speech_encoder.whisper_model.ln_post.weight', 'model.speech_encoder.whisper_model.ln_post.bias', 'model.speech_encoder.beats_model.post_extract_proj.weight', 'model.speech_encoder.beats_model.post_extract_proj.bias', 'model.speech_encoder.beats_model.patch_embedding.weight', 'model.speech_encoder.beats_model.encoder.pos_conv.0.bias', 'model.speech_encoder.beats_model.encoder.pos_conv.0.weight_g', 'model.speech_encoder.beats_model.encoder.pos_conv.0.weight_v', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.0.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.0.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.0.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.0.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.0.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.0.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.1.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.1.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.1.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.1.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.1.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.1.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.2.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.2.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.2.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.2.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.2.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.2.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.3.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.3.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.3.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.3.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.3.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.3.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.4.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.4.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.4.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.4.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.4.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.4.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.5.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.5.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.5.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.5.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.5.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.5.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.6.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.6.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.6.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.6.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.6.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.6.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.7.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.7.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.7.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.7.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.7.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.7.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.8.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.8.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.8.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.8.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.8.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.8.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.9.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.9.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.9.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.9.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.9.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.9.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.10.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.10.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.10.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.10.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.10.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.10.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.11.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.11.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.11.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.11.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.11.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.11.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layer_norm.bias', 'model.speech_encoder.beats_model.layer_norm.weight', 'model.speech_encoder.beats_model.layer_norm.bias', 'model.speech_encoder.beats_model.predictor.weight', 'model.speech_encoder.beats_model.predictor.bias', 'model.speech_projector.speech_newline', 'model.speech_projector.speech_begin', 'model.speech_projector.speech_end', 'model.speech_projector.linear1.weight', 'model.speech_projector.linear1.bias', 'model.speech_projector.linear2.weight', 'model.speech_projector.linear2.bias'] +Unexpected keys: [] +Loading vision tower... +Loading vision tower succeeded. +User: Hello, who are you? +Assistant: Hello! I'm PLM-V, an AI assistant created to provide information, answer questions, and help with various tasks. How can I assist you today? + +================================================================================ +🧪 开始多模态推理测试 +================================================================================ + +============================================================ +🧪 测试: Pure Text +📝 问题: Hello, who are you? Please introduce yourself briefly. +============================================================ +<|im_start|>system +You are PLM-V, a helpful assistant.<|im_end|> +<|im_start|>user +Hello, who are you? Please introduce yourself briefly.<|im_end|> +<|im_start|>assistant + Hello! I am PLM-V, an intelligent assistant designed to provide you with detailed and accurate information, answer questions, and assist with a wide range of topics. My goal is to support you in a friendly and knowledgeable manner, making your interactions with me as informative and helpful as possible. How can I assist you today? +✅ 推理成功! +🤖 回复: Hello! I am PLM-V, an intelligent assistant designed to provide you with detailed and accurate information, answer questions, and assist with a wide range of topics. My goal is to support you in a friendly and knowledgeable manner, making your interactions with me as informative and helpful as possible. How can I assist you today? + +============================================================ +🧪 测试: Text & Image (Visual only) +📝 问题: +Please describe this image in detail. +============================================================ +dynamic ViT batch size: 13 +<|im_start|>system +You are PLM-V, a helpful assistant.<|im_end|> +<|im_start|>user + +Please describe this image in detail.<|im_end|> +<|im_start|>assistant + The image shows a close-up of a red panda, characterized by its distinctive reddish-brown fur with white markings around its face and back. The red panda appears to be leaning on a wooden structure, possibly a ledge or a part of a platform. The subject has a mix of black and white fur around its ears, eyes, and muzzle, which contrasts with its redder head. The expression on the red panda's face seems calm and curious as it looks directly at the camera. The background features blurred greenery, suggesting that the animal is in an outdoor environment with trees and plants. The setting gives a natural, outdoor feel to the image. +✅ 推理成功! +🤖 回复: The image shows a close-up of a red panda, characterized by its distinctive reddish-brown fur with white markings around its face and back. The red panda appears to be leaning on a wooden structure, possibly a ledge or a part of a platform. The subject has a mix of black and white fur around its ears, eyes, and muzzle, which contrasts with its redder head. The expression on the red panda's face seems calm and curious as it looks directly at the camera. The background features blurred greenery, suggesting that the animal is in an outdoor environment with trees and plants. The setting gives a natural, outdoor feel to the image. + +============================================================ +🔄 准备Speech相关测试 (可能输出乱码,因为speech部分未训练) +============================================================ + +📥 加载视频数据... +✅ 视频加载成功: + - 视频帧数: 8 + - 视频像素值形状: torch.Size([8, 3, 448, 448]) + - 每帧patch数: [1, 1, 1, 1, 1, 1, 1, 1] + +📥 加载音频数据... +✅ 音频加载成功: + - mel谱图形状: torch.Size([2, 3000, 128]) + - 音频长度: tensor([3000, 3000]) + - 音频块数: tensor([2]) + - 原始音频波形形状: torch.Size([2, 480000]) + +============================================================ +🧪 测试: Audio only (预期乱码) +📝 问题: +Please transcribe and summarize what you heard in the audio. +============================================================ +speech batch size: 2 +<|im_start|>system +You are PLM-V, a helpful assistant.<|im_end|> +<|im_start|>user + +Please transcribe and summarize what you heard in the audio.<|im_end|> +<|im_start|>assistant + The end + +The internet + +or + +or + +structure + +The overcorrection + +As年人 + +pre-hel theorem is the following: suppose that the other +There is no one knows that the + +We consider that the following is the +It is sufficient that the + +Consider that the + +As we see the answer to the question + +Let us consider the following: +The initial + +or, in the beginning of the text + +We need to consider the following: + +The answer is + +Here, we need to consider the following answers + +However, the question is: What is the +But the answer to the question is: +Or the sequence of the answer is: + +The final answer that is: + +Actually, the factor, and the issue is: The problem is: +The problem solved, and the answer is: + +However, the problem that is: The problem posed is: For the final solution + +The answer to the question is: For the following question + +But the question that is solved by: The issue here is: The question we need to solve: The problem we are considering: The question that exists: The problem that we need to address: The question that is: + +Let us consider the following: The problem that was + +The answer is: The problem is: + +We reconsider the question: The question now + +The problem is: The question we have: The problem that follows: The question is: The final question: The problem that exists: The problem that is: The question posed + +The final answer is: The problem: The question considered: The question posed: The final answer: The problem that we consider: The problem that was solved: The answer: The problem solved: The problem solved: + +After a moment of reflection, the question, and the problem is: The question, and the problem is: The answer to the problem: The question and the problem is: The problem that the question is: The problem where the question is + +However, the initial question, and the answer following these lines: + +Alternatively, the problem is solved by the following question: +The sequence of the question and answers + +The problem posed in the following +After some thought,the sequence: The problem that the system + +The actual problem: The actual answer to the question: The actual issue that is: The actual thing that is: The actual fact that is: The actual solution to the problem: The actual moment: The actual thing: The actual purpose: The actual issue: The actual event: The actual time: The actual issue resolved: The actual thing that is: The actual thing that is: The actual moment that is: The actual situation + +The end of the answer is: The end of the process is: The end of the procedure is +The end of the process that is: + +The final answer is: The final step in the process is: The final part of the procedure is: +The final note that is: The final issue is: The final issue that is: +The final step is: The final issue solved is: The final part of the question is: The final question is: The final version of the question is: The final part of the structure is: The final structure that is: The final structure of the problem is: The final structure of the consideration is: The final consideration of the structure is: The structure that is: The structure that contains the + +But the final answer is: The final answer to the question: The final answer that is: The final answer: The final answer: The final answer: The final answer: The final answer: + +The final issue is: The final answer: The initial issue is: + +The final answer +So, the final answer is: The final answer + +Suppose we consider the final answer: The final answer is: The final answer + +The final answer is: The final answer +The final answer +The final answer +The final answer + +The final answer is: The final answer + +The final answer is: The final answer + +The final answer is: The final answer: The final answer: The final answer +But the final answer to the question is: The final answer: The final answer: The final answer: +The final answer: The final answer + +The final answer + +The final answer is: The final answer + +The final solution is: The final answer + +The final answer + +The final answer is: The final answer + +The final answer + +The final step is: The final answer + +The final answer + +The final answer + +The final answer is: The final answer + +The final answer: The final answer + +The final answer + +The final answer +The final answer +The final answer +The final answer + +The final answer (a) + +But the + +The final answer + +The final answer + +The final answer + +The final answer + +The final answer: The final answer + +The final answer + +The final answer + +The final answer + +The final answer + +The final answer + +The final answer + +The final answer + +The final answer + +The final answer + +The final answer +✅ 推理成功! +🤖 回复: The end + +The internet + +or + +or + +structure + +The overcorrection + +As年人 + +pre-hel theorem is the following: suppose that the other +There is no one knows that the + +We consider that the following is the +It is sufficient that the + +Consider that the + +As we see the answer to the question + +Let us consider the following: +The initial + +or, in the beginning of the text + +We need to consider the following: + +The answer is + +Here, we need to consider the following answers + +However, the question is: What is the +But the answer to the question is: +Or the sequence of the answer is: + +The final answer that is: + +Actually, the factor, and the issue is: The problem is: +The problem solved, and the answer is: + +However, the problem that is: The problem posed is: For the final solution + +The answer to the question is: For the following question + +But the question that is solved by: The issue here is: The question we need to solve: The problem we are considering: The question that exists: The problem that we need to address: The question that is: + +Let us consider the following: The problem that was + +The answer is: The problem is: + +We reconsider the question: The question now + +The problem is: The question we have: The problem that follows: The question is: The final question: The problem that exists: The problem that is: The question posed + +The final answer is: The problem: The question considered: The question posed: The final answer: The problem that we consider: The problem that was solved: The answer: The problem solved: The problem solved: + +After a moment of reflection, the question, and the problem is: The question, and the problem is: The answer to the problem: The question and the problem is: The problem that the question is: The problem where the question is + +However, the initial question, and the answer following these lines: + +Alternatively, the problem is solved by the following question: +The sequence of the question and answers + +The problem posed in the following +After some thought,the sequence: The problem that the system + +The actual problem: The actual answer to the question: The actual issue that is: The actual thing that is: The actual fact that is: The actual solution to the problem: The actual moment: The actual thing: The actual purpose: The actual issue: The actual event: The actual time: The actual issue resolved: The actual thing that is: The actual thing that is: The actual moment that is: The actual situation + +The end of the answer is: The end of the process is: The end of the procedure is +The end of the process that is: + +The final answer is: The final step in the process is: The final part of the procedure is: +The final note that is: The final issue is: The final issue that is: +The final step is: The final issue solved is: The final part of the question is: The final question is: The final version of the question is: The final part of the structure is: The final structure that is: The final structure of the problem is: The final structure of the consideration is: The final consideration of the structure is: The structure that is: The structure that contains the + +But the final answer is: The final answer to the question: The final answer that is: The final answer: The final answer: The final answer: The final answer: The final answer: + +The final issue is: The final answer: The initial issue is: + +The final answer +So, the final answer is: The final answer + +Suppose we consider the final answer: The final answer is: The final answer + +The final answer is: The final answer +The final answer +The final answer +The final answer + +The final answer is: The final answer + +The final answer is: The final answer + +The final answer is: The final answer: The final answer: The final answer +But the final answer to the question is: The final answer: The final answer: The final answer: +The final answer: The final answer + +The final answer + +The final answer is: The final answer + +The final solution is: The final answer + +The final answer + +The final answer is: The final answer + +The final answer + +The final step is: The final answer + +The final answer + +The final answer + +The final answer is: The final answer + +The final answer: The final answer + +The final answer + +The final answer +The final answer +The final answer +The final answer + +The final answer (a) + +But the + +The final answer + +The final answer + +The final answer + +The final answer + +The final answer: The final answer + +The final answer + +The final answer + +The final answer + +The final answer + +The final answer + +The final answer + +The final answer + +The final answer + +The final answer + +The final answer + +============================================================ +🧪 测试: Audio + Image (预期乱码) +📝 问题: +User's question in speech: + +============================================================ +dynamic ViT batch size: 13 +speech batch size: 2 +<|im_start|>system +You are PLM-V, a helpful assistant.<|im_end|> +<|im_start|>user + +User's question in speech: +<|im_end|> +<|im_start|>assistant + Bicytude: 连们在考虑了的时候 - 衋体在当前的是一场关于亚的类型、以通过的是一支集 +在使用后端使用相同的方式的人,可能不知道,或者,其是否在何时进入,而是一次使用相同错误地管理,可能不会发生,或可能是在之前,或者是不是还有其他相关的人,或者,或者,或是否,或,或呢,或是一个,或是一个或,或是 + +I don't understand the given content: 迭取到一张图片,该是关于一个文本文件,内容和类似于的功能,或者是一个关于,或者是一个的人,或若是关于一个关于学习和关于其他主题的上的行为,它是什么原因(或者,或者,或者,或?)还是关于其他主题呢,或在关于其他生物,比如,或者,或者,或者,或呢,或,我需要在使用了,或者,或者,或然呢? - 我需要找到文件中关于“如何改进一个文本文件,以便于更趋时,或是在其他主题,或者是一个关于如何管理一个复杂的数据文件,或是在一个应用数据文件,或者一个关于如何在浏览器的“如何利用Python的自动控制中的“或,或是一个关于一个关于如何使用标签,或是一个关于其他事情,比如,或者一个关于另一个关于如何使用的或是什么,或者,或呢?等,或者??或呢? +✅ 推理成功! +🤖 回复: Bicytude: 连们在考虑了的时候 - 衋体在当前的是一场关于亚的类型、以通过的是一支集 +在使用后端使用相同的方式的人,可能不知道,或者,其是否在何时进入,而是一次使用相同错误地管理,可能不会发生,或可能是在之前,或者是不是还有其他相关的人,或者,或者,或是否,或,或呢,或是一个,或是一个或,或是 + +I don't understand the given content: 迭取到一张图片,该是关于一个文本文件,内容和类似于的功能,或者是一个关于,或者是一个的人,或若是关于一个关于学习和关于其他主题的上的行为,它是什么原因(或者,或者,或者,或?)还是关于其他主题呢,或在关于其他生物,比如,或者,或者,或者,或呢,或,我需要在使用了,或者,或者,或然呢? - 我需要找到文件中关于“如何改进一个文本文件,以便于更趋时,或是在其他主题,或者是一个关于如何管理一个复杂的数据文件,或是在一个应用数据文件,或者一个关于如何在浏览器的“如何利用Python的自动控制中的“或,或是一个关于一个关于如何使用标签,或是一个关于其他事情,比如,或者一个关于另一个关于如何使用的或是什么,或者,或呢?等,或者??或呢? + +============================================================ +🧪 测试: Video + Text +📝 问题: Frame1: +Frame2: +Frame3: +Frame4: +Frame5: +Frame6: +Frame7: +Frame8: +What is the red panda doing in this video? Please describe the actions and movements you observe. +============================================================ +dynamic ViT batch size: 8 +<|im_start|>system +You are PLM-V, a helpful assistant.<|im_end|> +<|im_start|>user +Frame1: +Frame2: +Frame3: +Frame4: +Frame5: +Frame6: +Frame7: +Frame8: +What is the red panda doing in this video? Please describe the actions and movements you observe.<|im_end|> +<|im_start|>assistant + In this video, a red panda is climbing up a branch, perched on it while holding something in its mouth, and later sitting on the ground and reaching up towards bamboo sticks suspended from a tree. At one point, one of the red pandas chews on bamboo, and at another point, the blue creature is seen on the grassy ground, looking up towards the red panda. The scene then shows the red panda still perched on the branch, holding something in its mouth, and another red panda is perched on the ground, reaching up towards the bamboo on the tree. After a few moments, the panda on the ground finishes its activity and sits down on the grassy ground. +✅ 推理成功! +🤖 回复: In this video, a red panda is climbing up a branch, perched on it while holding something in its mouth, and later sitting on the ground and reaching up towards bamboo sticks suspended from a tree. At one point, one of the red pandas chews on bamboo, and at another point, the blue creature is seen on the grassy ground, looking up towards the red panda. The scene then shows the red panda still perched on the branch, holding something in its mouth, and another red panda is perched on the ground, reaching up towards the bamboo on the tree. After a few moments, the panda on the ground finishes its activity and sits down on the grassy ground. + +================================================================================ +📊 多模态推理测试总结 +================================================================================ +✅ PASS Pure Text (预期: PASS ) - 应该正常 (训练好的InternVL) +✅ PASS Text & Image (预期: PASS ) - 应该正常 (训练好的InternVL) +✅ PASS Video + Text (预期: PASS ) - 应该正常 (训练好的InternVL) +✅ PASS Audio only (预期: GARBLED ) - 可能乱码 (speech未训练) +✅ PASS Audio + Image (预期: GARBLED ) - 可能乱码 (speech未训练) + +📈 测试统计: 5/5 通过 +🎉 基础功能正常,Speech集成架构成功! +💡 Speech相关测试如果输出乱码是正常的,因为speech部分还未训练 +🌟 所有基础模态测试都通过了! + +=== 多模态推理测试完成 === diff --git a/inference/log1.txt b/inference/log1.txt new file mode 100644 index 0000000000000000000000000000000000000000..808240161b31eaa21972e76d08e25bcc7afabee9 --- /dev/null +++ b/inference/log1.txt @@ -0,0 +1,339 @@ +[2025-09-15 09:26:52,568] [INFO] [real_accelerator.py:222:get_accelerator] Setting ds_accelerator to cuda (auto detect) +LOAD_VISION_EARLY is set +FORCE_NO_DOWNSAMPLE is set +VIDEO_RESIZE is set as 0x64, 0, 64 +HIGHRES_BASE is set as 0x32, 0, 32 +MAXRES is set as 1536 +MINRES is set as 0 +VIDEO_MAXRES is set as 480 +VIDEO_MINRES is set as 288 +PAD2STRIDE is set +LOWRES_RESIZE is set as 384x32 +Loading OlaQwen3ForCausalLM model... +Loading BEATs Model +Missing keys: ['model.speech_encoder.whisper_model.positional_embedding', 'model.speech_encoder.whisper_model.conv1.weight', 'model.speech_encoder.whisper_model.conv1.bias', 'model.speech_encoder.whisper_model.conv2.weight', 'model.speech_encoder.whisper_model.conv2.bias', 'model.speech_encoder.whisper_model.blocks.0.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.0.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.0.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.0.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.0.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.0.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.0.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.0.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.0.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.0.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.0.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.0.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.0.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.0.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.0.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.1.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.1.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.1.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.1.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.1.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.1.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.1.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.1.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.1.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.1.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.1.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.1.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.1.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.1.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.1.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.2.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.2.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.2.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.2.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.2.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.2.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.2.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.2.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.2.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.2.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.2.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.2.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.2.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.2.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.2.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.3.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.3.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.3.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.3.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.3.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.3.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.3.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.3.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.3.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.3.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.3.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.3.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.3.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.3.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.3.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.4.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.4.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.4.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.4.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.4.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.4.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.4.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.4.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.4.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.4.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.4.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.4.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.4.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.4.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.4.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.5.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.5.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.5.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.5.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.5.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.5.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.5.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.5.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.5.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.5.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.5.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.5.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.5.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.5.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.5.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.6.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.6.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.6.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.6.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.6.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.6.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.6.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.6.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.6.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.6.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.6.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.6.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.6.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.6.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.6.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.7.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.7.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.7.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.7.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.7.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.7.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.7.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.7.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.7.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.7.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.7.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.7.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.7.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.7.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.7.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.8.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.8.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.8.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.8.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.8.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.8.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.8.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.8.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.8.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.8.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.8.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.8.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.8.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.8.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.8.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.9.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.9.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.9.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.9.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.9.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.9.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.9.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.9.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.9.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.9.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.9.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.9.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.9.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.9.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.9.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.10.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.10.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.10.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.10.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.10.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.10.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.10.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.10.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.10.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.10.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.10.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.10.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.10.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.10.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.10.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.11.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.11.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.11.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.11.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.11.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.11.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.11.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.11.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.11.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.11.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.11.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.11.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.11.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.11.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.11.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.12.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.12.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.12.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.12.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.12.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.12.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.12.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.12.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.12.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.12.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.12.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.12.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.12.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.12.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.12.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.13.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.13.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.13.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.13.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.13.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.13.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.13.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.13.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.13.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.13.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.13.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.13.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.13.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.13.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.13.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.14.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.14.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.14.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.14.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.14.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.14.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.14.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.14.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.14.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.14.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.14.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.14.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.14.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.14.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.14.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.15.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.15.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.15.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.15.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.15.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.15.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.15.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.15.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.15.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.15.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.15.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.15.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.15.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.15.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.15.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.16.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.16.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.16.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.16.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.16.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.16.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.16.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.16.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.16.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.16.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.16.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.16.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.16.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.16.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.16.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.17.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.17.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.17.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.17.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.17.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.17.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.17.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.17.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.17.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.17.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.17.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.17.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.17.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.17.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.17.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.18.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.18.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.18.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.18.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.18.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.18.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.18.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.18.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.18.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.18.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.18.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.18.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.18.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.18.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.18.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.19.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.19.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.19.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.19.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.19.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.19.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.19.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.19.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.19.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.19.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.19.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.19.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.19.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.19.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.19.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.20.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.20.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.20.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.20.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.20.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.20.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.20.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.20.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.20.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.20.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.20.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.20.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.20.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.20.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.20.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.21.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.21.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.21.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.21.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.21.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.21.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.21.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.21.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.21.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.21.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.21.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.21.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.21.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.21.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.21.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.22.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.22.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.22.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.22.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.22.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.22.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.22.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.22.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.22.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.22.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.22.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.22.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.22.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.22.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.22.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.23.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.23.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.23.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.23.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.23.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.23.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.23.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.23.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.23.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.23.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.23.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.23.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.23.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.23.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.23.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.24.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.24.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.24.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.24.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.24.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.24.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.24.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.24.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.24.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.24.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.24.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.24.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.24.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.24.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.24.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.25.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.25.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.25.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.25.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.25.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.25.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.25.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.25.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.25.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.25.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.25.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.25.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.25.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.25.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.25.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.26.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.26.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.26.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.26.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.26.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.26.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.26.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.26.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.26.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.26.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.26.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.26.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.26.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.26.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.26.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.27.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.27.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.27.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.27.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.27.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.27.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.27.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.27.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.27.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.27.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.27.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.27.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.27.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.27.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.27.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.28.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.28.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.28.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.28.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.28.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.28.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.28.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.28.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.28.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.28.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.28.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.28.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.28.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.28.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.28.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.29.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.29.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.29.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.29.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.29.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.29.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.29.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.29.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.29.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.29.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.29.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.29.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.29.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.29.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.29.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.30.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.30.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.30.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.30.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.30.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.30.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.30.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.30.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.30.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.30.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.30.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.30.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.30.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.30.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.30.mlp_ln.bias', 'model.speech_encoder.whisper_model.blocks.31.attn.query.weight', 'model.speech_encoder.whisper_model.blocks.31.attn.query.bias', 'model.speech_encoder.whisper_model.blocks.31.attn.key.weight', 'model.speech_encoder.whisper_model.blocks.31.attn.value.weight', 'model.speech_encoder.whisper_model.blocks.31.attn.value.bias', 'model.speech_encoder.whisper_model.blocks.31.attn.out.weight', 'model.speech_encoder.whisper_model.blocks.31.attn.out.bias', 'model.speech_encoder.whisper_model.blocks.31.attn_ln.weight', 'model.speech_encoder.whisper_model.blocks.31.attn_ln.bias', 'model.speech_encoder.whisper_model.blocks.31.mlp.0.weight', 'model.speech_encoder.whisper_model.blocks.31.mlp.0.bias', 'model.speech_encoder.whisper_model.blocks.31.mlp.2.weight', 'model.speech_encoder.whisper_model.blocks.31.mlp.2.bias', 'model.speech_encoder.whisper_model.blocks.31.mlp_ln.weight', 'model.speech_encoder.whisper_model.blocks.31.mlp_ln.bias', 'model.speech_encoder.whisper_model.ln_post.weight', 'model.speech_encoder.whisper_model.ln_post.bias', 'model.speech_encoder.beats_model.post_extract_proj.weight', 'model.speech_encoder.beats_model.post_extract_proj.bias', 'model.speech_encoder.beats_model.patch_embedding.weight', 'model.speech_encoder.beats_model.encoder.pos_conv.0.bias', 'model.speech_encoder.beats_model.encoder.pos_conv.0.weight_g', 'model.speech_encoder.beats_model.encoder.pos_conv.0.weight_v', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.0.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.0.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.0.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.0.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.0.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.0.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.0.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.1.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.1.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.1.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.1.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.1.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.1.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.1.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.2.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.2.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.2.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.2.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.2.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.2.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.2.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.3.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.3.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.3.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.3.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.3.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.3.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.3.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.4.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.4.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.4.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.4.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.4.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.4.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.4.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.5.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.5.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.5.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.5.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.5.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.5.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.5.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.6.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.6.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.6.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.6.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.6.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.6.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.6.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.7.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.7.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.7.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.7.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.7.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.7.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.7.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.8.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.8.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.8.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.8.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.8.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.8.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.8.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.9.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.9.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.9.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.9.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.9.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.9.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.9.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.10.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.10.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.10.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.10.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.10.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.10.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.10.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.grep_a', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.k_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.k_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.v_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.v_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.q_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.q_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.out_proj.weight', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.out_proj.bias', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.grep_linear.weight', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.grep_linear.bias', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn.relative_attention_bias.weight', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.11.self_attn_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layers.11.fc1.weight', 'model.speech_encoder.beats_model.encoder.layers.11.fc1.bias', 'model.speech_encoder.beats_model.encoder.layers.11.fc2.weight', 'model.speech_encoder.beats_model.encoder.layers.11.fc2.bias', 'model.speech_encoder.beats_model.encoder.layers.11.final_layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layers.11.final_layer_norm.bias', 'model.speech_encoder.beats_model.encoder.layer_norm.weight', 'model.speech_encoder.beats_model.encoder.layer_norm.bias', 'model.speech_encoder.beats_model.layer_norm.weight', 'model.speech_encoder.beats_model.layer_norm.bias', 'model.speech_encoder.beats_model.predictor.weight', 'model.speech_encoder.beats_model.predictor.bias', 'model.speech_projector.speech_newline', 'model.speech_projector.speech_begin', 'model.speech_projector.speech_end', 'model.speech_projector.linear1.weight', 'model.speech_projector.linear1.bias', 'model.speech_projector.linear2.weight', 'model.speech_projector.linear2.bias'] +Unexpected keys: [] +Loading vision tower... +Loading vision tower succeeded. +User: Hello, who are you? +Assistant: Hello! I am an AI assistant called PLM-V. How can I help you today? + +================================================================================ +🧪 开始多模态推理测试 +================================================================================ + +============================================================ +🧪 测试: Pure Text +📝 问题: Hello, who are you? Please introduce yourself briefly. +============================================================ +<|im_start|>system +You are PLM-V, a helpful assistant.<|im_end|> +<|im_start|>user +Hello, who are you? Please introduce yourself briefly.<|im_end|> +<|im_start|>assistant + Hello! I am PLM-V, a language model created by the platform OpenAI. My primary function is to assist users by providing information, answering questions, and engaging in conversation. I aim to be helpful, accurate, and respectful in all interactions. How can I assist you today? +✅ 推理成功! +🤖 回复: Hello! I am PLM-V, a language model created by the platform OpenAI. My primary function is to assist users by providing information, answering questions, and engaging in conversation. I aim to be helpful, accurate, and respectful in all interactions. How can I assist you today? + +============================================================ +🧪 测试: Text & Image (Visual only) +📝 问题: +Please describe this image in detail. +============================================================ +dynamic ViT batch size: 13 +<|im_start|>system +You are PLM-V, a helpful assistant.<|im_end|> +<|im_start|>user + +Please describe this image in detail.<|im_end|> +<|im_start|>assistant + The image shows a cute red panda sitting on a wooden platform. This reddish-brown animal has distinctive black and white markings: a white face with black stripes on its cheeks and around its nose. Its fur appears soft and fluffy. The red panda has large, dark eyes and white whiskers, adding to its endearing appearance. It is resting close to a tree trunk, with its black ears perked up on either side. The background is filled with blurred green foliage, suggesting this red panda is in a natural or outdoor setting, possibly a zoo or wildlife sanctuary. The wooden platform appears to be part of a structure designed for the animal to rest or climb on comfortably. +✅ 推理成功! +🤖 回复: The image shows a cute red panda sitting on a wooden platform. This reddish-brown animal has distinctive black and white markings: a white face with black stripes on its cheeks and around its nose. Its fur appears soft and fluffy. The red panda has large, dark eyes and white whiskers, adding to its endearing appearance. It is resting close to a tree trunk, with its black ears perked up on either side. The background is filled with blurred green foliage, suggesting this red panda is in a natural or outdoor setting, possibly a zoo or wildlife sanctuary. The wooden platform appears to be part of a structure designed for the animal to rest or climb on comfortably. + +============================================================ +🔄 准备Speech相关测试 (可能输出乱码,因为speech部分未训练) +============================================================ + +📥 加载视频数据... +✅ 视频加载成功: + - 视频帧数: 8 + - 视频像素值形状: torch.Size([8, 3, 448, 448]) + - 每帧patch数: [1, 1, 1, 1, 1, 1, 1, 1] + +📥 加载音频数据... +✅ 音频加载成功: + - mel谱图形状: torch.Size([2, 3000, 128]) + - 音频长度: tensor([3000, 3000]) + - 音频块数: tensor([2]) + - 原始音频波形形状: torch.Size([2, 480000]) + +============================================================ +🧪 测试: Audio only (预期乱码) +📝 问题: +Please transcribe and summarize what you heard in the audio. +============================================================ +speech batch size: 2 +<|im_start|>system +You are PLM-V, a helpful assistant.<|im_end|> +<|im_start|>user + +Please transcribe and summarize what you heard in the audio.<|im_end|> +<|im_start|>assistant + The秘密 is currently notPowered should be valid. + +It's possible that's a complex combination of unsub,$ + +Here's an unexpected task is to assess the����是有效的。 + +Now, the秘密 + +### 払下秘密是从加密的 + +Here的秘密 + +It seems the encrypted message has been truncated. It's a confusing combination of unworking! We're not seeing the message from a valid IP的操作未完成. We're here to assist." + +It's a sequence of 10.21.48. Here we can assist to encrypting the message. This is the result of a combination of unworking. If you're trying to determine the message that the code for the current message. It seems we're not seeing the message that the network can't work. We're too broad. + +Therefore, the final answer is that we can't help you. We're currently working to inform you. We're now able to help you. + +What did you saw?" + +It seems like the message is invalid. In fact, the message is to decrypt the message that the network can't work. Here's an encryption that's confusing. We're not really saying that the combination is not possible. We're unable to help you. + +However, it seems that the message is returning to its initial message. We can assist in determining the message that the network can't help. We're unable to assist with this message. + +To assist in the current message, we need to decrypt the message that the network is not working. We'll now assist with the message that the network can't help. + +However, it seems that the message is not working. We can't do that. + +To break the message, we need to help you with the message that the network is not working. We'll now assist with the message that the network can't assist. + +Based on the message, the network seems to be having a difficulty. We're unable to do that. We can assist in determining the message that the network is not working. We're unable to assist with the message that the network is not working. + +It seems like the message is not working. We can't assist with the message that the network is not allowed to help you. + +Please, however, we can assist with the message that the network is not working. We're currently unable to help you. + +Are you able to help with the message that the network is not working? We can't help you with the message that the network is not available to assist with the message. + +But the message is not doing that. We can't assist with the message that the network can't assist you. + +In this message, we need to help you with the message that the network is not working. We'll help the message that the network can't assist with the message. + +It seems like the message is not working. We're unable to help you with the message that the network is not working. + +Let's focus on the message that the network is not working. We cannot assist with the message that the network is not available to assist with the message. + +Given that the message is not working, we'll now assist with the message that the network can't assist. + +However, it seems that the message cannot be provided based on the message that the network is not working. + +Let's continue to assist with the message that the network can't assist. + +It seems like the message is not possible for the message to assist with the message that the network can't assist. + +We are unable to assist with the message that the network is not working. + +Please, however, we are unable to assist with the message that the network is not working. + +The message is not working. We're currently unable to assist with the message that the network is not working. + +We are unable to assist with the message that the network is not working. + +The message seems to be unreadable. We're unable to help with the message that the network is not working. + +Please, we are unable to help with the message that the message is not working. + +The task is asking for help with the message that the network is not working. + +It seems that the message is not possible for the message to assist with the message that the network is not working. + +Let's assume the message is not working. We can't assist with the message that the message is not working. + +The message is not possible to assist with the message that the network is not working. + +We're unable to help with the message that the network is not working. + +Here's the message that the network is not working. + +The message is not possible to help with the message that the network is not working. + +The message is a mystery. We're unable to assist with the message that the network is not working. + +It seems that the message is not working. We can't assist with the message that the network is not working. + +We're unable to help with the message that the +✅ 推理成功! +🤖 回复: The秘密 is currently notPowered should be valid. + +It's possible that's a complex combination of unsub,$ + +Here's an unexpected task is to assess the����是有效的。 + +Now, the秘密 + +### 払下秘密是从加密的 + +Here的秘密 + +It seems the encrypted message has been truncated. It's a confusing combination of unworking! We're not seeing the message from a valid IP的操作未完成. We're here to assist." + +It's a sequence of 10.21.48. Here we can assist to encrypting the message. This is the result of a combination of unworking. If you're trying to determine the message that the code for the current message. It seems we're not seeing the message that the network can't work. We're too broad. + +Therefore, the final answer is that we can't help you. We're currently working to inform you. We're now able to help you. + +What did you saw?" + +It seems like the message is invalid. In fact, the message is to decrypt the message that the network can't work. Here's an encryption that's confusing. We're not really saying that the combination is not possible. We're unable to help you. + +However, it seems that the message is returning to its initial message. We can assist in determining the message that the network can't help. We're unable to assist with this message. + +To assist in the current message, we need to decrypt the message that the network is not working. We'll now assist with the message that the network can't help. + +However, it seems that the message is not working. We can't do that. + +To break the message, we need to help you with the message that the network is not working. We'll now assist with the message that the network can't assist. + +Based on the message, the network seems to be having a difficulty. We're unable to do that. We can assist in determining the message that the network is not working. We're unable to assist with the message that the network is not working. + +It seems like the message is not working. We can't assist with the message that the network is not allowed to help you. + +Please, however, we can assist with the message that the network is not working. We're currently unable to help you. + +Are you able to help with the message that the network is not working? We can't help you with the message that the network is not available to assist with the message. + +But the message is not doing that. We can't assist with the message that the network can't assist you. + +In this message, we need to help you with the message that the network is not working. We'll help the message that the network can't assist with the message. + +It seems like the message is not working. We're unable to help you with the message that the network is not working. + +Let's focus on the message that the network is not working. We cannot assist with the message that the network is not available to assist with the message. + +Given that the message is not working, we'll now assist with the message that the network can't assist. + +However, it seems that the message cannot be provided based on the message that the network is not working. + +Let's continue to assist with the message that the network can't assist. + +It seems like the message is not possible for the message to assist with the message that the network can't assist. + +We are unable to assist with the message that the network is not working. + +Please, however, we are unable to assist with the message that the network is not working. + +The message is not working. We're currently unable to assist with the message that the network is not working. + +We are unable to assist with the message that the network is not working. + +The message seems to be unreadable. We're unable to help with the message that the network is not working. + +Please, we are unable to help with the message that the message is not working. + +The task is asking for help with the message that the network is not working. + +It seems that the message is not possible for the message to assist with the message that the network is not working. + +Let's assume the message is not working. We can't assist with the message that the message is not working. + +The message is not possible to assist with the message that the network is not working. + +We're unable to help with the message that the network is not working. + +Here's the message that the network is not working. + +The message is not possible to help with the message that the network is not working. + +The message is a mystery. We're unable to assist with the message that the network is not working. + +It seems that the message is not working. We can't assist with the message that the network is not working. + +We're unable to help with the message that the + +============================================================ +🧪 测试: Audio + Image (预期乱码) +📝 问题: +User's question in speech: + +============================================================ +dynamic ViT batch size: 13 +speech batch size: 2 +<|im_start|>system +You are PLM-V, a helpful assistant.<|im_end|> +<|im_start|>user + +User's question in speech: +<|im_end|> +<|im_start|>assistant + In the previous text 10个 + +在是在一个建筑中,一下的人们在使用前一张“我们来学习的是该地区的植物学10月5月的浏览器:这是一个关于在是在一个大的是一张复杂的、我们无法想象的,我需要找出数组中的是否是什么,如何使用异,这个代码看起来不太的人群在该地区进行了8年,然后在使用了4个的用户,我需要将一个的是一张更大的网站,以便确保我的眼睛在没有的,我需要在进行中的人群中,我的网络里,以确保我的眼睛,然后将帮助我理解一下,以便确保我在学习中,需要考虑在使用中的使用的人群之间出现的获取的代码,以便确保,我会帮助您确认,我的代码,或者我需要确认,我没有在不使用的,因为我总是不知道,因为我们需要在的,因为我的困惑,或者因为,因为我的困惑感到不安)等等,我需要帮助控制,或者因为个人在,我需要确保我的眼睛,或者因为我求如何在进行了8年,我需要帮助我的人在或其他方面,我需要确保,我可以尝试以其他的,确保我能提供或不确定,我需要确保,或者为了确保我的眼睛,确保我的眼睛,或者我不确定,我不计算的,或因为我的眼睛,因为我经常在的用户,或者如果使用其他方法,但我需要确保,我需要帮助我,或者因为某些原因,我如何处理使用与我使用Python编程语言使用一个示的用户,我需要确认,我,我、确保,或者在什么时候要使用不同的技巧,但可能,或者我想要确认,或者因为,我需要确保,我无法确保,我需要确保,或者因为,我需要确认,或者我不会知道更多信息,或者我需要确定,或者我需要,我没有,或者可能,或者因为,或者我需要确保,我需要确认,我没有,或者我认为我是否需要,我需要确保,我需要确保,或者我需要帮助,我需要找到,或者我需要在使用了“inferred - �国我需要确保,或者我需要确认,我需要考虑,或者我需要帮助我,我需要知道,我没有,或者我需要,我用一下如何提高或我需要做到的,我需要确认,或者我需要确认,或者我需要,或者我没有,或可能通过这些代码或其他方法,我需要帮助,或者我需要知道,或我需要知道,或我需要确保,我需要,或我需要确保,我需要知道,或者我需要确保,或者其他类似任何事情,或者我需要确保,等等。我需要帮助我如何确保,或者我需要知道,我需要确保,我需要确定,或者我需要处理,或者我需要帮助,或者我不会知道,或者我需要确保,或者我需要帮助,或者我需要,我需要确认,或我需要,我需要确如何确保,我需要确保,或我需要确保,或者我需要保证,我需要确保我需要,或我需要,即使,我需要,我需要确保,或可能需要帮助,我需要确保,或者我需要确保,我需要确保我需要确保,我需要确保,或我需要确定,或我需要确保,或者我需要,我需要确保,我需要,我需要确保,或者我需要其他方式,或者我需要,或我需要,或者我需要,或我需要,或者我我需要,或我我需要,或我需要,我需要,我需要,或我需要,或我需要,或其他,我需要——,或我需要,我需要确认,或者我需要,我需要确保,我需要,或我需要,或我需要,或我在使用了,我需要,但即使,我需要,或我需要,或者我需要,或我需要,或我在通过使用于,或我需要,或者我需要,我需要,或我需要,我需要,或我需要,但即使,我需要,或者我需要,或我需要,或者我的担心,我需要,或我需要,或我需要,我们,或我需要,或者我需要,或我需要,或我我需要,我需要,或我需要,或我需要,或我在使用了,在我需要,或我需要,且我需要,或者我需要,我需要,我需要,或我需要,或者,我需要,让我们需要,或我需要,或我需要,或我需要,或我需要,那么我需要,或者我需要,或我是否需要,或我需要,或我需要,或我需要,你能够帮助我,我需要,或我需要,我需要,或我需要,或我需要是否,或我需要,我需要,让我需要,或我需要 +✅ 推理成功! +🤖 回复: In the previous text 10个 + +在是在一个建筑中,一下的人们在使用前一张“我们来学习的是该地区的植物学10月5月的浏览器:这是一个关于在是在一个大的是一张复杂的、我们无法想象的,我需要找出数组中的是否是什么,如何使用异,这个代码看起来不太的人群在该地区进行了8年,然后在使用了4个的用户,我需要将一个的是一张更大的网站,以便确保我的眼睛在没有的,我需要在进行中的人群中,我的网络里,以确保我的眼睛,然后将帮助我理解一下,以便确保我在学习中,需要考虑在使用中的使用的人群之间出现的获取的代码,以便确保,我会帮助您确认,我的代码,或者我需要确认,我没有在不使用的,因为我总是不知道,因为我们需要在的,因为我的困惑,或者因为,因为我的困惑感到不安)等等,我需要帮助控制,或者因为个人在,我需要确保我的眼睛,或者因为我求如何在进行了8年,我需要帮助我的人在或其他方面,我需要确保,我可以尝试以其他的,确保我能提供或不确定,我需要确保,或者为了确保我的眼睛,确保我的眼睛,或者我不确定,我不计算的,或因为我的眼睛,因为我经常在的用户,或者如果使用其他方法,但我需要确保,我需要帮助我,或者因为某些原因,我如何处理使用与我使用Python编程语言使用一个示的用户,我需要确认,我,我、确保,或者在什么时候要使用不同的技巧,但可能,或者我想要确认,或者因为,我需要确保,我无法确保,我需要确保,或者因为,我需要确认,或者我不会知道更多信息,或者我需要确定,或者我需要,我没有,或者可能,或者因为,或者我需要确保,我需要确认,我没有,或者我认为我是否需要,我需要确保,我需要确保,或者我需要帮助,我需要找到,或者我需要在使用了“inferred - �国我需要确保,或者我需要确认,我需要考虑,或者我需要帮助我,我需要知道,我没有,或者我需要,我用一下如何提高或我需要做到的,我需要确认,或者我需要确认,或者我需要,或者我没有,或可能通过这些代码或其他方法,我需要帮助,或者我需要知道,或我需要知道,或我需要确保,我需要,或我需要确保,我需要知道,或者我需要确保,或者其他类似任何事情,或者我需要确保,等等。我需要帮助我如何确保,或者我需要知道,我需要确保,我需要确定,或者我需要处理,或者我需要帮助,或者我不会知道,或者我需要确保,或者我需要帮助,或者我需要,我需要确认,或我需要,我需要确如何确保,我需要确保,或我需要确保,或者我需要保证,我需要确保我需要,或我需要,即使,我需要,我需要确保,或可能需要帮助,我需要确保,或者我需要确保,我需要确保我需要确保,我需要确保,或我需要确定,或我需要确保,或者我需要,我需要确保,我需要,我需要确保,或者我需要其他方式,或者我需要,或我需要,或者我需要,或我需要,或者我我需要,或我我需要,或我需要,我需要,我需要,或我需要,或我需要,或其他,我需要——,或我需要,我需要确认,或者我需要,我需要确保,我需要,或我需要,或我需要,或我在使用了,我需要,但即使,我需要,或我需要,或者我需要,或我需要,或我在通过使用于,或我需要,或者我需要,我需要,或我需要,我需要,或我需要,但即使,我需要,或者我需要,或我需要,或者我的担心,我需要,或我需要,或我需要,我们,或我需要,或者我需要,或我需要,或我我需要,我需要,或我需要,或我需要,或我在使用了,在我需要,或我需要,且我需要,或者我需要,我需要,我需要,或我需要,或者,我需要,让我们需要,或我需要,或我需要,或我需要,或我需要,那么我需要,或者我需要,或我是否需要,或我需要,或我需要,或我需要,你能够帮助我,我需要,或我需要,我需要,或我需要,或我需要是否,或我需要,我需要,让我需要,或我需要 + +============================================================ +🧪 测试: Video + Text +📝 问题: Frame1: +Frame2: +Frame3: +Frame4: +Frame5: +Frame6: +Frame7: +Frame8: +What is the red panda doing in this video? Please describe the actions and movements you observe. +============================================================ +dynamic ViT batch size: 8 +<|im_start|>system +You are PLM-V, a helpful assistant.<|im_end|> +<|im_start|>user +Frame1: +Frame2: +Frame3: +Frame4: +Frame5: +Frame6: +Frame7: +Frame8: +What is the red panda doing in this video? Please describe the actions and movements you observe.<|im_end|> +<|im_start|>assistant + The red panda in this video is shown eating bamboo and holding a piece of bamboo. In the beginning, the red panda is eating bamboo from the other end of the structure while the baby in front of it is reaching up to eat. They move to the right, and the adult red panda is eating bamboo while the baby continues to reach up. Towards the end, the baby starts eating a piece of bamboo while the adult is eating bamboo from the structure above. +✅ 推理成功! +🤖 回复: The red panda in this video is shown eating bamboo and holding a piece of bamboo. In the beginning, the red panda is eating bamboo from the other end of the structure while the baby in front of it is reaching up to eat. They move to the right, and the adult red panda is eating bamboo while the baby continues to reach up. Towards the end, the baby starts eating a piece of bamboo while the adult is eating bamboo from the structure above. + +============================================================ +🧪 测试: Video + Audio (预期乱码) +📝 问题: +Describe what you hear and see in this content. +============================================================ +dynamic ViT batch size: 13 +speech batch size: 2 +<|im_start|>system +You are PLM-V, a helpful assistant.<|im_end|> +<|im_start|>user + +Describe what you hear and see in this content.<|im_end|> +<|im_start|>assistant + The first step is to determine the 301445560526679335.928824157.5.5, 12.496688788434,49505993735.3846390994.45546936.455539779387". Which of the 201863188435.3179473.943878315546397753.3976.388571.7538856.569376.454694876.56974376.937831.63687.48876.388571.7764387938094875.3880948178093961.7831809387546.9378318878315.776383187673876.83187648780937883187525.3856839375469783188787831875831879838378763831831875469839188767383187583187876383187546983783187145935.429563977878318763831875831868781763831887583187638388093787831887831875469837831887831875831876383831875469837831878318783187546983783187831875831876383831875831876383831875831878318763831875831878318763831875831876.26.33.73.176.1235.32.14.1235.23.48.987659934.763839376.38.7966.8376.393776.23.876.37635.917.53.4687836.75.9356.58378788.7678393783188786.388876.4388876.4387.46.8839376.3937809487.438763831887831875.88317831.776.94699378831875938.783187583188783187583187638318758318763838318758318763831878318758318783187638318758318766.38876.3876.43876.4387839378.3187583187583187638318758318763838 +✅ 推理成功! +🤖 回复: The first step is to determine the 301445560526679335.928824157.5.5, 12.496688788434,49505993735.3846390994.45546936.455539779387". Which of the 201863188435.3179473.943878315546397753.3976.388571.7538856.569376.454694876.56974376.937831.63687.48876.388571.7764387938094875.3880948178093961.7831809387546.9378318878315.776383187673876.83187648780937883187525.3856839375469783188787831875831879838378763831831875469839188767383187583187876383187546983783187145935.429563977878318763831875831868781763831887583187638388093787831887831875469837831887831875831876383831875469837831878318783187546983783187831875831876383831875831876383831875831878318763831875831878318763831875831876.26.33.73.176.1235.32.14.1235.23.48.987659934.763839376.38.7966.8376.393776.23.876.37635.917.53.4687836.75.9356.58378788.7678393783188786.388876.4388876.4387.46.8839376.3937809487.438763831887831875.88317831.776.94699378831875938.783187583188783187583187638318758318763838318758318763831878318758318783187638318758318766.38876.3876.43876.4387839378.3187583187583187638318758318763838 + +================================================================================ +📊 多模态推理测试总结 +================================================================================ +✅ PASS Pure Text (预期: PASS ) - 应该正常 (训练好的InternVL) +✅ PASS Text & Image (预期: PASS ) - 应该正常 (训练好的InternVL) +✅ PASS Video + Text (预期: PASS ) - 应该正常 (训练好的InternVL) +✅ PASS Audio only (预期: GARBLED ) - 可能乱码 (speech未训练) +✅ PASS Audio + Image (预期: GARBLED ) - 可能乱码 (speech未训练) + +📈 测试统计: 5/5 通过 +🎉 基础功能正常,Speech集成架构成功! +💡 Speech相关测试如果输出乱码是正常的,因为speech部分还未训练 +🌟 所有基础模态测试都通过了! + +=== 多模态推理测试完成 === diff --git a/ola.egg-info/PKG-INFO b/ola.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..c79a3fffc15e4465d72d32d3363caf106487b3e6 --- /dev/null +++ b/ola.egg-info/PKG-INFO @@ -0,0 +1,265 @@ +Metadata-Version: 2.4 +Name: ola +Version: 1.0.0 +Summary: Omni-Modal Language Model +Classifier: Programming Language :: Python :: 3 +Classifier: License :: OSI Approved :: Apache Software License +Requires-Python: >=3.10 +Description-Content-Type: text/markdown +License-File: LICENSE +Requires-Dist: torch==2.1.2 +Requires-Dist: torchvision==0.16.2 +Requires-Dist: torchaudio==2.1.2 +Requires-Dist: transformers==4.43.4 +Requires-Dist: tokenizers==0.19.1 +Requires-Dist: sentencepiece==0.1.99 +Requires-Dist: shortuuid +Requires-Dist: accelerate==0.33.0 +Requires-Dist: peft==0.11.1 +Requires-Dist: bitsandbytes==0.43.1 +Requires-Dist: pydantic +Requires-Dist: markdown2[all] +Requires-Dist: numpy +Requires-Dist: scikit-learn==1.2.2 +Requires-Dist: gradio==4.43.0 +Requires-Dist: gradio_client==1.3.0 +Requires-Dist: requests +Requires-Dist: httpx==0.27.2 +Requires-Dist: uvicorn +Requires-Dist: fastapi +Requires-Dist: soundfile +Requires-Dist: einops==0.6.1 +Requires-Dist: einops-exts==0.0.4 +Requires-Dist: timm==0.9.16 +Requires-Dist: openai-whisper +Requires-Dist: setuptools==59.5.0 +Requires-Dist: omegaconf==2.0.6 +Requires-Dist: loguru +Requires-Dist: av +Requires-Dist: librosa +Provides-Extra: train +Requires-Dist: deepspeed==0.12.6; extra == "train" +Requires-Dist: ninja; extra == "train" +Requires-Dist: wandb; extra == "train" +Requires-Dist: tensorboardX; extra == "train" +Provides-Extra: build +Requires-Dist: build; extra == "build" +Requires-Dist: twine; extra == "build" +Dynamic: license-file + +

+967023137dff29e65b21544e7620e0f7.webp +

+
+ +## Ola: Pushing the Frontiers of Omni-Modal Language Model + +

+ Zuyan Liu*,1,2  + Yuhao Dong*,2,3  + Jiahui Wang1
+ Ziwei Liu3  + Winston Hu2  + Jiwen Lu1,✉  + Yongming Rao2,1,✉  +

+ + +

1Tsinghua University   2Tencent Hunyuan Research  3S-Lab, NTU 

+ +

* Equal Contribution  ✉ Corresponding Author

+ +[![Ola](https://img.shields.io/badge/Rank_1-OpenCampass(<15B)-blue)](https://rank.opencompass.org.cn/leaderboard-multimodal/?m=REALTIME) [![Ola](https://img.shields.io/badge/Rank_8-VideoMME-red)](https://video-mme.github.io/home_page.html#leaderboard) + +--- + +**Project Page:** [![Ola](https://img.shields.io/badge/Ola-project_page-orange)](https://ola-omni.github.io) + +**Weights in Huggingface:** [![hf_checkpoint](https://img.shields.io/badge/🤗-Ola_7b-green)](https://huggingface.co/THUdyh/Ola-7b) [![hf_checkpoint](https://img.shields.io/badge/🤗-Ola_Image-green)](https://huggingface.co/THUdyh/Ola-Image) [![hf_checkpoint](https://img.shields.io/badge/🤗-Ola_Video-green)](https://huggingface.co/THUdyh/Ola-Video) + +**arXiv Paper:** [![arxiv](https://img.shields.io/badge/Arxiv-2502.04328-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2502.04328) + +**Demo by Gradio:** [![demo](https://img.shields.io/badge/Ola-Demo-yellow)](https://huggingface.co/spaces/THUdyh/Ola) + +**Training Data:** [![data](https://img.shields.io/badge/Ola-Data-purple)](https://huggingface.co/datasets/THUdyh/Ola-Data) + +**中文解读**: [![chinese](https://img.shields.io/badge/Ola-机器之心-cyan)](https://mp.weixin.qq.com/s/N4bjcHOejJudtxTFZVAXmg) + +Contact: Leave an issue or contact liuzuyan19@gmail.com . We are on call to respond. + +## 📢 News + +- 🔥[28/2/2025] We release the intermediate model, Ola-Image and Ola-Video, try building your own omni-modal models! + +- 🚀[19/2/2025] We release the huggingface demo of Ola, try the advanced omni-modal model on your own! + +- 🔥[18/2/2025] The training data, training script for Ola-7b is released! + +- 🎉[07/2/2025] The Ola is released! Check our [project page](https://ola-omni.github.io), [model weights](https://huggingface.co/THUdyh/Ola-7b), [arXiv paper](https://arxiv.org/pdf/2502.04328) for the strong omni-modal understanding model! + +- 🔥[06/2/2025] [Ola-7b](https://huggingface.co/THUdyh/Ola-7b) achieves **Rank #1** on the OpenCompass Multi-modal Leaderboard among all the models under 15B parameters with average score of **72.6**. Check the impressive results [here](https://rank.opencompass.org.cn/leaderboard-multimodal/?m=REALTIME)! + +## 🚀Coming Soon + +- [x] Evaluation code on omni-modal benchmarks +- [x] Gradio Demo +- [x] Training Data (Video, Audio, Cross-Modality) + +## 🌟 Introduction + +### Roads to Ola + +

+road.png +

+
+ +**Ola** is an Omni-modal language model that achieves competitive performance across image, video, and audio understanding compared to specialized counterparts. We conduct a comprehensive exploration of architectural design, data curation, and training strategies essential for building a robust omni-modal model. + +

+teaser.png +

+
+ +### Architecture + +

+method.png +

+
+ +Ola supports omni-modal inputs including text, image, video, and audio, capable of processing the inputs simultaneously with competitive performance on understanding tasks for all these modalities. Meanwhile, Ola supports user-friendly real-time streaming decoding for texts and speeches thanks to the text detokenizer and the speech decoder. + +### Training Strategies + +

+training.png +

+
+ +We visualize the relationships among modalities in the left part. Speech acts as the connection between language and audio knowledge, while video constructs the bridge with highly relevant visual and audio information. Therefore, we design the progressive alignment training strategy from primary to periphery. Furthermore, we design the cross-modality video-audio data to better capture the relationships among modalities. + +### Performance + +

+results.png +

+
+ +Ola achieves competitive performance across major multi-modal benchmarks when compared to state-of-the-art specialist-modal LLMs. + +## Installation + + +#### 1. Clone this repository: +```bash +git clone https://github.com/Ola-Omni/Ola +cd Ola +``` + +#### 2. Install the required package: +```bash +conda create -n ola python=3.10 -y +conda activate ola +pip install --upgrade pip +pip install -e . +``` +#### 3.Install additional packages for training cases + +```bash +pip install -e ".[train]" +pip install flash-attn --no-build-isolation +``` + +## Model Zoo + +We provide our checkpoints at [Huggingface](https://huggingface.co/collections/THUdyh/ola-67b8220eb93406ec87aeec37) + +| Model | Link | Size | Modal | +|:---:|:---:|:---:|:---:| +|Ola-7b | [Huggingface](https://huggingface.co/THUdyh/Ola-7b) | 7B | Text, Image, Video, Audio | +|Ola-Image | [Huggingface](https://huggingface.co/THUdyh/Ola-Image) | 7B | Text, Image | +|Ola-Video | [Huggingface](https://huggingface.co/THUdyh/Ola-Video) | 7B | Text, Image, Video | + + +## Quick Start + +1. Download `Ola-7b` from [Huggingface](https://huggingface.co/THUdyh/Ola-7b) or skip the step to using the online weights directly. + +2. Download audio encoder from [Huggingface](https://huggingface.co/THUdyh/Ola_speech_encoders/tree/main) and put the weights `large-v3.pt` and `BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt` under repo directory `path/to/Ola/` + +3. Run `inference/infer.py` + +- Text & Image Understanding + +``` +python3 inference/infer.py --image_path *.png,jpg --text user_instruction +``` + +- Text & Video Understanding + +``` +python3 inference/infer.py --video_path *.mp4 --text user_instruction +``` + +- Text & Audio Understanding + +``` +python3 inference/infer.py --audio_path *.wav,mp3 --text user_instruction +``` + +- Audio & Image Understanding + +``` +python3 inference/infer.py --audio_path *.png,jpg --audio_path *.wav,mp3 +``` + +## Evaluation + +You can evaluate Ola model with [VLMEvalKit](https://github.com/open-compass/VLMEvalKit) and [lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval). + +## Training + +### Data Preparation + +Please refer to [DATA.md](https://github.com/Ola-Omni/Ola/blob/main/DATA.md) for instructions of customized finetuning or using the provided datasets. + +### Start Training + +Please follow the script below to start training. Make sure you have created the correct datasets for fine-tuning. + +1. Finetuning Ola-7b Model: + +``` +bash ./scripts/finetune_ola.sh +``` + +2. Finetuning Ola-Image Model (Ola Stage1 or Stage2) + +``` +bash ./scripts/finetune_ola_image.sh +``` + +3. Finetuning Ola-Video Model (Ola Stage3): + +``` +bash ./scripts/finetune_ola_video.sh +``` + +## Citation + +If you find it useful for your research and applications, please cite our paper using this BibTeX: +```bibtex +@article{liu2025ola, +title={Ola: Pushing the Frontiers of Omni-Modal Language Model with Progressive Modality Alignment}, +author={Liu, Zuyan and Dong, Yuhao and Wang, Jiahui and Liu, Ziwei and Hu, Winston and Lu, Jiwen and Rao, Yongming}, +journal={arXiv preprint arXiv:2502.04328}, +year={2025} +} +``` + +## Acknowledgement + +- Our codebase is conducted on [LLaVA](https://github.com/LLaVA-VL/LLaVA-NeXT) + +- Thanks [VLMEvalKit](https://github.com/open-compass/VLMEvalKit) and [lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval) team for the evaluation system! diff --git a/ola.egg-info/SOURCES.txt b/ola.egg-info/SOURCES.txt new file mode 100644 index 0000000000000000000000000000000000000000..006db71a9024ac49f14d937e5b33d18aa7ecc370 --- /dev/null +++ b/ola.egg-info/SOURCES.txt @@ -0,0 +1,44 @@ +LICENSE +README.md +pyproject.toml +inference/infer.py +ola/arguments.py +ola/constants.py +ola/conversation.py +ola/mm_utils.py +ola/utils.py +ola.egg-info/PKG-INFO +ola.egg-info/SOURCES.txt +ola.egg-info/dependency_links.txt +ola.egg-info/requires.txt +ola.egg-info/top_level.txt +ola/datasets/__init__.py +ola/datasets/preprocess.py +ola/model/__init__.py +ola/model/builder.py +ola/model/ola_arch.py +ola/model/language_model/ola_qwen.py +ola/model/multimodal_encoder/builder.py +ola/model/multimodal_encoder/oryx_vit.py +ola/model/multimodal_projector/builder.py +ola/model/multimodal_projector/pooler_projector.py +ola/model/multimodal_resampler/builder.py +ola/model/speech_encoder/builder.py +ola/model/speech_encoder/speech_encoder.py +ola/model/speech_encoder/beats/BEATs.py +ola/model/speech_encoder/beats/Tokenizers.py +ola/model/speech_encoder/beats/__init__.py +ola/model/speech_encoder/beats/backbone.py +ola/model/speech_encoder/beats/kaldi.py +ola/model/speech_encoder/beats/modules.py +ola/model/speech_encoder/beats/quantizer.py +ola/model/speech_projector/builder.py +ola/model/speech_projector/speech_projector.py +ola/serve/__init__.py +ola/serve/controller.py +ola/serve/gradio_web_server.py +ola/serve/model_worker.py +ola/train/ola_trainer.py +ola/train/train.py +tools/convert_mp4_wav.py +tools/create_patch.py \ No newline at end of file diff --git a/ola.egg-info/dependency_links.txt b/ola.egg-info/dependency_links.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/ola.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/ola.egg-info/requires.txt b/ola.egg-info/requires.txt new file mode 100644 index 0000000000000000000000000000000000000000..01d668adc5c64c45c9f3b6e2ebc1a10b3e13f0e0 --- /dev/null +++ b/ola.egg-info/requires.txt @@ -0,0 +1,40 @@ +torch==2.1.2 +torchvision==0.16.2 +torchaudio==2.1.2 +transformers==4.43.4 +tokenizers==0.19.1 +sentencepiece==0.1.99 +shortuuid +accelerate==0.33.0 +peft==0.11.1 +bitsandbytes==0.43.1 +pydantic +markdown2[all] +numpy +scikit-learn==1.2.2 +gradio==4.43.0 +gradio_client==1.3.0 +requests +httpx==0.27.2 +uvicorn +fastapi +soundfile +einops==0.6.1 +einops-exts==0.0.4 +timm==0.9.16 +openai-whisper +setuptools==59.5.0 +omegaconf==2.0.6 +loguru +av +librosa + +[build] +build +twine + +[train] +deepspeed==0.12.6 +ninja +wandb +tensorboardX diff --git a/ola.egg-info/top_level.txt b/ola.egg-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..0eb0169c9b7b8dfe60a2cf2fe602276326405735 --- /dev/null +++ b/ola.egg-info/top_level.txt @@ -0,0 +1,4 @@ +inference +ola +scripts +tools diff --git a/ola/__pycache__/arguments.cpython-312.pyc b/ola/__pycache__/arguments.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de2b7f13275faae347102622b9659de123761e7e Binary files /dev/null and b/ola/__pycache__/arguments.cpython-312.pyc differ diff --git a/ola/__pycache__/constants.cpython-312.pyc b/ola/__pycache__/constants.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd0aaf3a94fd0d1886e6eafdf98b64e358463d53 Binary files /dev/null and b/ola/__pycache__/constants.cpython-312.pyc differ diff --git a/ola/__pycache__/conversation.cpython-312.pyc b/ola/__pycache__/conversation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4101f266a3497c78de7a717e98d4e936fc2510e9 Binary files /dev/null and b/ola/__pycache__/conversation.cpython-312.pyc differ diff --git a/ola/__pycache__/mm_utils.cpython-312.pyc b/ola/__pycache__/mm_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..921418d7a5cc8b133e9cc8075e119a6cb53beb1e Binary files /dev/null and b/ola/__pycache__/mm_utils.cpython-312.pyc differ diff --git a/ola/__pycache__/utils.cpython-312.pyc b/ola/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5abece01cd1071e85b81c3b3cd43e81180d1b886 Binary files /dev/null and b/ola/__pycache__/utils.cpython-312.pyc differ diff --git a/ola/arguments.py b/ola/arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..199c5d5b7912fefbe3882a0b6f774e31a5f80cfc --- /dev/null +++ b/ola/arguments.py @@ -0,0 +1,65 @@ +import transformers + +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + version: Optional[str] = field(default="v0") + freeze_backbone: bool = field(default=False) + tune_speech_projector: bool = field(default=False) + tune_speech_encoder: bool = field(default=False) + tune_speech_generator_only: bool = field(default=False) + speech_encoder_type: Optional[str] = field(default=None) + speech_encoder: Optional[str] = field(default=None) + pretrain_speech_projector: Optional[str] = field(default=None) + speech_projector_type: Optional[str] = field(default='linear') + speech_encoder_ds_rate: int = 5 + speech_encoder_hidden_size: int = 1280 + + +@dataclass +class DataArguments: + data_path: str = field(default=None, + metadata={"help": "Path to the training data."}) + is_multimodal: bool = False + input_type: str = field(default="mel") + speech_normalize: bool = False + mel_size: int = 128 + has_tgt_units: bool = False + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + freeze_speech_projector: bool = field(default=False) + model_max_length: int = field( + default=512, + metadata={ + "help": + "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + }, + ) + double_quant: bool = field( + default=True, + metadata={"help": "Compress the quantization statistics through double quantization."} + ) + quant_type: str = field( + default="nf4", + metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."} + ) + bits: int = field( + default=16, + metadata={"help": "How many bits to use."} + ) + lora_enable: bool = False + lora_r: int = 64 + lora_alpha: int = 16 + lora_dropout: float = 0.05 + lora_weight_path: str = "" + lora_bias: str = "none" + speech_projector_lr: Optional[float] = None + group_by_modality_length: bool = field(default=False) \ No newline at end of file diff --git a/ola/constants.py b/ola/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..9b903d94f9122dc8b657383f8604555aad819400 --- /dev/null +++ b/ola/constants.py @@ -0,0 +1,14 @@ +CONTROLLER_HEART_BEAT_EXPIRATION = 30 +WORKER_HEART_BEAT_INTERVAL = 15 + +LOGDIR = "." + +# Model Constants +IGNORE_INDEX = -100 +SPEECH_TOKEN_INDEX = -200 +DEFAULT_SPEECH_TOKEN = "" +IMAGE_TOKEN_INDEX= -300 +DEFAULT_IMAGE_TOKEN = "" +DEFAULT_IMAGE_PATCH_TOKEN = "" +DEFAULT_IM_START_TOKEN = "" +DEFAULT_IM_END_TOKEN = "" \ No newline at end of file diff --git a/ola/conversation.py b/ola/conversation.py new file mode 100644 index 0000000000000000000000000000000000000000..c71e1d6d88b9aea503b272f9a09e4f5a4baf421c --- /dev/null +++ b/ola/conversation.py @@ -0,0 +1,266 @@ +import dataclasses +from enum import auto, Enum +from typing import List, Any, Union, Tuple +import base64 +from io import BytesIO +from PIL import Image + + +class SeparatorStyle(Enum): + """Different separator style.""" + TWO = auto() + PLAIN = auto() + CHATML = auto() + LLAMA_2 = auto() + LLAMA_3 = auto() + QWEN2 = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that keeps all conversation history.""" + system: str + roles: List[str] + messages: List[List[str]] + offset: int + sep_style: SeparatorStyle = SeparatorStyle.PLAIN + sep: str = "###" + sep2: str = None + version: str = "Unknown" + + tokenizer_id: str = "" + tokenizer: Any = None + # Stop criteria (the default one is EOS token) + stop_str: Union[str, List[str]] = None + # Stops generation if meeting any token in this list + stop_token_ids: List[int] = None + + skip_next: bool = False + + def get_prompt(self): + messages = self.messages + + if self.sep_style == SeparatorStyle.TWO: + seps = [self.sep, self.sep2] + ret = self.system + seps[0] + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message = message[0] + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + elif self.sep_style == SeparatorStyle.LLAMA_3: + wrap_sys = lambda msg: f"<|start_header_id|>system<|end_header_id|>\n\n{msg}<|eot_id|>" if len(msg) > 0 else msg + ret = "<|begin_of_text|>" + wrap_sys(self.system) + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message = message[0] + ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" + ret += message.strip() + self.sep2 + else: + ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" + return ret + elif self.sep_style == SeparatorStyle.LLAMA_2: + wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" if len(msg) > 0 else msg + wrap_inst = lambda msg: f"[INST] {msg} [/INST]" + ret = "" + + for i, (role, message) in enumerate(messages): + if i == 0: + assert message, "first message should not be none" + assert role == self.roles[0], "first message should come from user" + if message: + if type(message) is tuple: + message, _, _ = message + if i == 0: + message = wrap_sys(self.system) + message + if i % 2 == 0: + message = wrap_inst(message) + ret += self.sep + message + else: + ret += " " + message + " " + self.sep2 + else: + ret += "" + ret = ret.lstrip(self.sep) + elif self.sep_style == SeparatorStyle.PLAIN: + seps = [self.sep, self.sep2] + ret = self.system + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message, _, _ = message + ret += message + seps[i % 2] + else: + ret += "" + + elif self.sep_style == SeparatorStyle.CHATML: + ret = "" if self.system == "" else self.system + self.sep + "\n" + for role, message in messages: + if message: + if type(message) is tuple: + raise ValueError("Tuple not supported in CHATML") + message, images = message + message = "" * len(images) + message + ret += role + "\n" + message + self.sep + "\n" + else: + ret += role + "\n" + return ret + elif self.sep_style == SeparatorStyle.QWEN2: + start = '<|im_start|>' + end = '<|im_end|>\n' + ret = start + 'system\n' + self.system + end + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message, _, _ = message + + if message.endswith('<|endoftext|>'): + message = message.replace('<|endoftext|>', '') + ret += start + role + "\n" + message + end + '<|endoftext|>' + else: + assert not '<|endoftext|>' in message, f"Invalid message: {message}" + ret += start + role + "\n" + message + end + else: + ret += start + role + "\n" + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + return ret + + def append_message(self, role, message): + self.messages.append([role, message]) + + def to_gradio_chatbot(self): + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset:]): + if i % 2 == 0: + if type(msg) is tuple: + msg, speech = msg + ret.append([msg, None]) + else: + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def copy(self): + return Conversation( + system=self.system, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + version=self.version) + + def dict(self): + if len(self.get_images()) > 0: + return { + "system": self.system, + "roles": self.roles, + "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + } + return { + "system": self.system, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + } + +conv_vicuna_v1 = Conversation( + system="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.", + roles=("USER", "ASSISTANT"), + version="v1", + messages=[], + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + +conv_llama_2 = Conversation( + system="You are a helpful language and speech assistant. " "You are able to understand the speech content that the user provides, " "and assist the user with a variety of tasks using natural language.", + roles=("USER", "ASSISTANT"), + version="llama_v2", + messages=[], + offset=0, + sep_style=SeparatorStyle.LLAMA_2, + sep="", + sep2="", +) + +conv_llama_3 = Conversation( + system="You are a helpful language and speech assistant. " "You are able to understand the speech content that the user provides, " "and assist the user with a variety of tasks using natural language.", + roles=("user", "assistant"), + version="llama_v3", + messages=[], + offset=0, + sep_style=SeparatorStyle.LLAMA_3, + sep="", + sep2="<|eot_id|>" +) + + +conv_qwen_v1 = Conversation( + system="You are a helpful assistant.", + roles=("user", "assistant"), + version="v1", + messages=(), + offset=0, + sep_style=SeparatorStyle.QWEN2, +) + +conv_plain = Conversation( + system="", + roles=("", ""), + messages=( + ), + offset=0, + sep_style=SeparatorStyle.PLAIN, + sep="", +) + +conv_qwen = Conversation( + system="""<|im_start|>system +You are a helpful assistant.""", + roles=("<|im_start|>user", "<|im_start|>assistant"), + version="qwen", + messages=[], + offset=0, + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", +) + +conv_plmv = Conversation( + system="""<|im_start|>system +You are PLM-V, developed by PLM-Team, a helpful assistant.""", + roles=("<|im_start|>user", "<|im_start|>assistant"), + version="plm_v", + messages=[], + offset=0, + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", +) + +default_conversation = conv_plmv +conv_templates = { + "v1": conv_vicuna_v1, + "plain": conv_plain, + "llama_2": conv_llama_2, + "llama_3": conv_llama_3, + 'v1_qwen2': conv_qwen_v1, + "qwen_1_5": conv_qwen, + "plm_v": conv_plmv, +} + + +if __name__ == "__main__": + print(default_conversation.get_prompt()) diff --git a/ola/datasets/__init__.py b/ola/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ola/datasets/__pycache__/__init__.cpython-312.pyc b/ola/datasets/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16e14bcd51abcc7e981c26309f1707a43fc5b1d6 Binary files /dev/null and b/ola/datasets/__pycache__/__init__.cpython-312.pyc differ diff --git a/ola/datasets/__pycache__/preprocess.cpython-312.pyc b/ola/datasets/__pycache__/preprocess.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5700489fe161c6cde721247c66501ba6b844f85 Binary files /dev/null and b/ola/datasets/__pycache__/preprocess.cpython-312.pyc differ diff --git a/ola/datasets/preprocess.py b/ola/datasets/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..1723c4ae50cad23a629f75febc28c81f41610245 --- /dev/null +++ b/ola/datasets/preprocess.py @@ -0,0 +1,234 @@ +import copy +import torch +import transformers +import tokenizers + +from typing import Dict, Sequence + +from ola.constants import IGNORE_INDEX, DEFAULT_SPEECH_TOKEN, IMAGE_TOKEN_INDEX +from ola import conversation as conversation_lib +from ola.model import * +from ola.arguments import DataArguments +from ola.constants import SPEECH_TOKEN_INDEX + +from packaging import version + +IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14') + + +def tokenizer_speech_token(prompt, tokenizer, speech_token_index=SPEECH_TOKEN_INDEX, return_tensors=None): + prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] + + def insert_separator(X, sep): + return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] + + input_ids = [] + offset = 0 + if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: + offset = 1 + input_ids.append(prompt_chunks[0][0]) + + for x in insert_separator(prompt_chunks, [speech_token_index] * (offset + 1)): + input_ids.extend(x[offset:]) + + if return_tensors is not None: + if return_tensors == 'pt': + return torch.tensor(input_ids, dtype=torch.long) + raise ValueError(f'Unsupported tensor type: {return_tensors}') + return input_ids + +def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): + prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] + + def insert_separator(X, sep): + return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] + + input_ids = [] + offset = 0 + if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: + offset = 1 + input_ids.append(prompt_chunks[0][0]) + + for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): + input_ids.extend(x[offset:]) + + if return_tensors is not None: + if return_tensors == 'pt': + return torch.tensor(input_ids, dtype=torch.long) + raise ValueError(f'Unsupported tensor type: {return_tensors}') + return input_ids + +def tokenizer_speech_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, speech_token_idx=SPEECH_TOKEN_INDEX, return_tensors=None): + prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] + + def insert_separator(X, sep): + return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] + + input_ids = [] + offset = 0 + if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: + offset = 1 + input_ids.append(prompt_chunks[0][0]) + + for x in insert_separator(prompt_chunks, [speech_token_idx, image_token_index] * (offset + 1)): + input_ids.extend(x[offset:]) + + if return_tensors is not None: + if return_tensors == 'pt': + return torch.tensor(input_ids, dtype=torch.long) + raise ValueError(f'Unsupported tensor type: {return_tensors}') + return input_ids + +def tokenizer_speech_question_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, speech_token_idx=SPEECH_TOKEN_INDEX, return_tensors=None): + prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("\nUser's question in speech: \n")] + + def insert_separator(X, sep): + return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] + + input_ids = [] + offset = 0 + if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: + offset = 1 + input_ids.append(prompt_chunks[0][0]) + + nl_tokens = tokenizer("\n").input_ids + special_chunks = [image_token_index, nl_tokens, tokenizer("User's question in speech: ").input_ids, speech_token_idx, nl_tokens] + + for x in insert_separator(prompt_chunks, [special_chunks] * (offset + 1)): + input_ids.extend(x[offset:]) + + if return_tensors is not None: + if return_tensors == 'pt': + return torch.tensor(input_ids, dtype=torch.long) + raise ValueError(f'Unsupported tensor type: {return_tensors}') + return input_ids + +def preprocess_v1( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_speech: bool = False +) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_speech: + input_ids = torch.stack([tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + + assert conv.sep_style == conversation_lib.SeparatorStyle.TWO + + # Mask targets + sep = conv.sep + conv.roles[1] + ": " + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep2) + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_speech: + round_len = len(tokenizer_speech_token(rou, tokenizer)) + instruction_len = len(tokenizer_speech_token(parts[0], tokenizer)) - 2 + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + # FIXME: tokenizer bug + if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14: + round_len -= 1 + instruction_len -= 1 + + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + print(f"Debug - Conversation: {conversation[:200]}...") + print(f"Debug - Target shape: {target.shape}") + print(f"Debug - All labels are IGNORE_INDEX: {(target == IGNORE_INDEX).all().item()}") + + return dict( + input_ids=input_ids, + labels=targets, + ) + + +def preprocess_plain( + sources: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer, +) -> Dict: + # add end signal and concatenate together + conversations = [] + for source in sources: + assert len(source) == 2 + assert DEFAULT_SPEECH_TOKEN in source[0]['value'] + source[0]['value'] = DEFAULT_SPEECH_TOKEN + conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep + conversations.append(conversation) + # tokenize conversations + input_ids = [tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] + targets = copy.deepcopy(input_ids) + for target, source in zip(targets, sources): + tokenized_len = len(tokenizer_speech_token(source[0]['value'], tokenizer)) + target[:tokenized_len] = IGNORE_INDEX + + return dict(input_ids=input_ids, labels=targets) + + +def preprocess( + sources: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer, + has_speech: bool = False +) -> Dict: + """ + Given a list of sources, each is a conversation list. This transform: + 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; + 2. Concatenate conversations together; + 3. Tokenize the concatenated conversation; + 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. + """ + if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN: + return preprocess_plain(sources, tokenizer) + if conversation_lib.default_conversation.version.startswith("v1"): + return preprocess_v1(sources, tokenizer, has_speech=has_speech) + raise NotImplementedError \ No newline at end of file diff --git a/ola/mm_utils.py b/ola/mm_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3bf2189b6554e913d66396cbc33426cadc6294b5 --- /dev/null +++ b/ola/mm_utils.py @@ -0,0 +1,271 @@ +from PIL import Image +import base64 +import math +import ast + +import torch +from transformers import StoppingCriteria +import os +import io + +if 'VIDEO_RESIZE' in os.environ: + # highresxpatch + VIDEO_RESIZE = os.environ['VIDEO_RESIZE'] + video_base, video_ps = VIDEO_RESIZE.split('x') + video_base = int(video_base) + video_ps = int(video_ps) + print(f"VIDEO_RESIZE is set as {VIDEO_RESIZE}, {video_base}, {video_ps}") +else: + HIGHRES_BASE = None + +if 'HIGHRES_BASE' in os.environ: + # highresxpatch + HIGHRES_BASE = os.environ['HIGHRES_BASE'] + highres_base, highres_ps = HIGHRES_BASE.split('x') + highres_base = int(highres_base) + highres_ps = int(highres_ps) + print(f"HIGHRES_BASE is set as {HIGHRES_BASE}, {highres_base}, {highres_ps}") +else: + HIGHRES_BASE = None + +if 'MAXRES' in os.environ: + # highresxpatch + MAXRES = int(os.environ['MAXRES']) + print(f"MAXRES is set as {MAXRES}") +else: + MAXRES = 1536 + +if 'MINRES' in os.environ: + # highresxpatch + MINRES = int(os.environ['MINRES']) + print(f"MINRES is set as {MINRES}") +else: + MINRES = 0 + +if 'VIDEO_MAXRES' in os.environ: + # highresxpatch + VIDEO_MAXRES = int(os.environ['VIDEO_MAXRES']) + print(f"VIDEO_MAXRES is set as {VIDEO_MAXRES}") +else: + VIDEO_MAXRES = 1536 + +if 'VIDEO_MINRES' in os.environ: + # highresxpatch + VIDEO_MINRES = int(os.environ['VIDEO_MINRES']) + print(f"VIDEO_MINRES is set as {VIDEO_MINRES}") +else: + MINRES = 0 + +if 'PAD2STRIDE' in os.environ: + # highresxpatch + PAD2STRIDE = True + print(f"PAD2STRIDE is set") +else: + PAD2STRIDE = False + +if 'LOWRES_RESIZE' in os.environ: + LOWRES_RESIZE = os.environ['LOWRES_RESIZE'] + print(f"LOWRES_RESIZE is set as {LOWRES_RESIZE}") + if 'x' in LOWRES_RESIZE: + size, ps = LOWRES_RESIZE.split('x') + size = int(size) + ps = int(ps) + LOWRES_RESIZE = (size, ps) + else: + LOWRES_RESIZE = int(LOWRES_RESIZE) +else: + LOWRES_RESIZE = None + + +def pad_image(image, target_resolution, value=0): + """ + Resize and pad an image to a target resolution while maintaining aspect ratio. + + Args: + image (PIL.Image.Image): The input image. + target_resolution (tuple): The target resolution (width, height) of the image. + + Returns: + PIL.Image.Image: The resized and padded image. + """ + original_width, original_height = image.size + target_width, target_height = target_resolution + # Create a new image with the target size and paste the resized image onto it + new_image = Image.new('RGB', (target_width, target_height), (value, value, value)) + paste_x = (target_width - original_width) // 2 + paste_y = (target_height - original_height) // 2 + new_image.paste(image, (paste_x, paste_y)) + return new_image + +def resize_images(image, patch_size=14, base_size=896): + h, w = image.size + if base_size == 0: + if h * w > MAXRES * MAXRES: + # print(f'{h}x{w} larger than max size {MAXRES}, resize to {MAXRES}') + scale = MAXRES * MAXRES / (h * w) + scale = math.sqrt(scale) + elif h * w < MINRES * MINRES: + # print(f'{h}x{w} smaller than max size {MINRES}, resize to {MINRES}') + scale = MINRES * MINRES / (h * w) + scale = math.sqrt(scale) + else: + scale = None + else: + scale = base_size * base_size / (h * w) + scale = math.sqrt(scale) + + + if scale is not None: + new_h = int(h * scale / patch_size) * patch_size + new_w = int(w * scale / patch_size) * patch_size + new_h = max(new_h, patch_size) + new_w = max(new_w, patch_size) + image = image.resize((new_h, new_w)) + elif PAD2STRIDE: + if h % patch_size == 0: + new_h = h + else: + new_h = (h // patch_size + 1) * patch_size + + if w % patch_size == 0: + new_w = w + else: + new_w = (w // patch_size + 1) * patch_size + image = pad_image(image, (new_h, new_w), value=127) + else: + scale = 1.0 + new_h = int(h * scale / patch_size) * patch_size + new_w = int(w * scale / patch_size) * patch_size + new_h = max(new_h, patch_size) + new_w = max(new_w, patch_size) + image = image.resize((new_h, new_w)) + + return image + +def resize_video(image, patch_size=14, base_size=896): + h, w = image.size + if base_size == 0: + if h * w > VIDEO_MAXRES * VIDEO_MAXRES: + # print(f'{h}x{w} larger than max size {MAXRES}, resize to {MAXRES}') + scale = VIDEO_MAXRES * VIDEO_MAXRES / (h * w) + scale = math.sqrt(scale) + elif h * w < VIDEO_MINRES * VIDEO_MINRES: + # print(f'{h}x{w} smaller than max size {MINRES}, resize to {MINRES}') + scale = VIDEO_MINRES * VIDEO_MINRES / (h * w) + scale = math.sqrt(scale) + else: + scale = None + else: + scale = base_size * base_size / (h * w) + scale = math.sqrt(scale) + + if scale is not None: + new_h = int(h * scale / patch_size) * patch_size + new_w = int(w * scale / patch_size) * patch_size + image = image.resize((new_h, new_w)) + elif PAD2STRIDE: + if h % patch_size == 0: + new_h = h + else: + new_h = (h // patch_size + 1) * patch_size + + if w % patch_size == 0: + new_w = w + else: + new_w = (w // patch_size + 1) * patch_size + image = pad_image(image, (new_h, new_w), value=127) + else: + scale = 1.0 + new_h = int(h * scale / patch_size) * patch_size + new_w = int(w * scale / patch_size) * patch_size + image = image.resize((new_h, new_w)) + + return image + +def process_anyres_video(image, processor): + if VIDEO_RESIZE is not None: + image = resize_video(image, patch_size=video_ps, base_size=video_base) + image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + return image.unsqueeze(0) + else: + raise ValueError("VIDEO_RESIZE is not set") + +def process_anyres_highres_image(image, processor): + processor2 = None + if type(processor) is tuple: + processor, processor2 = processor[0], processor[1] + + if HIGHRES_BASE is not None: + image = resize_images(image, patch_size=highres_ps, base_size=highres_base) + + if processor2 is not None: + image_original_resize = image.resize((processor2.size['shortest_edge'], processor.size['shortest_edge'])) + image_patches = [image_original_resize] + [image_original_resize] + image_patches = [processor2.preprocess(image_patch, return_tensors='pt')['pixel_values'][0] + for image_patch in image_patches] + else: + if LOWRES_RESIZE is not None: + if type(LOWRES_RESIZE) is int: + image_original_resize = resize_images(image, patch_size=14, base_size=LOWRES_RESIZE) + else: + image_original_resize = resize_images(image, patch_size=LOWRES_RESIZE[1], base_size=LOWRES_RESIZE[0]) + else: + image_original_resize = image.resize((336, 336)) + image_patches = [image_original_resize] + image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0] + for image_patch in image_patches] + image_padded = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + return torch.stack(image_patches, dim=0), image_padded.unsqueeze(0) + +def read_image_patch(patch_info): + if 'img_path' in patch_info.keys(): + image = Image.open(patch_info['img_path']).convert('RGB') + else: + if 'image_encoing' in patch_info.keys(): + patch_info['image_encoding'] = patch_info['image_encoing'] + image_file_name = patch_info['patch'] + start_bytes = int(patch_info['start_num']) + file_size = int(patch_info['size']) + + with open(image_file_name, 'rb') as f: + f.seek(start_bytes) + if 'image_encoding' in patch_info.keys() and patch_info['image_encoding'] == 'base64': + image = Image.open(io.BytesIO(base64.b64decode(f.read(file_size).decode()))).convert("RGB") + else: + image = Image.open(io.BytesIO(f.read(file_size))).convert("RGB") + return image + + +def get_model_name_from_path(model_path): + model_path = model_path.strip("/") + model_paths = model_path.split("/") + if model_paths[-1].startswith('checkpoint-'): + return model_paths[-2] + "_" + model_paths[-1] + else: + return model_paths[-1] + + +class KeywordsStoppingCriteria(StoppingCriteria): + def __init__(self, keywords, tokenizer, input_ids): + self.keywords = keywords + self.keyword_ids = [] + for keyword in keywords: + cur_keyword_ids = tokenizer(keyword).input_ids + if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: + cur_keyword_ids = cur_keyword_ids[1:] + self.keyword_ids.append(torch.tensor(cur_keyword_ids)) + self.tokenizer = tokenizer + self.start_len = input_ids.shape[1] + + def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO + offset = min(output_ids.shape[1] - self.start_len, 3) + self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] + for keyword_id in self.keyword_ids: + if output_ids[0, -keyword_id.shape[0]:] == keyword_id: + return True + outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] + for keyword in self.keywords: + if keyword in outputs: + return True + return False diff --git a/ola/model/__init__.py b/ola/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b6b744c368a9a8992c4e10c798772bfdd179c61 --- /dev/null +++ b/ola/model/__init__.py @@ -0,0 +1,2 @@ +from .language_model.ola_qwen import OlaQwenForCausalLM, OlaConfigQwen +from .language_model.ola_qwen3 import OlaQwen3ForCausalLM, OlaConfigQwen3 \ No newline at end of file diff --git a/ola/model/__pycache__/__init__.cpython-312.pyc b/ola/model/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eaf6c21ab24d09376758843aed1a655e0245fdf5 Binary files /dev/null and b/ola/model/__pycache__/__init__.cpython-312.pyc differ diff --git a/ola/model/__pycache__/builder.cpython-312.pyc b/ola/model/__pycache__/builder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb62ba821caf8b10a76dbdaa727667a00d84c229 Binary files /dev/null and b/ola/model/__pycache__/builder.cpython-312.pyc differ diff --git a/ola/model/__pycache__/ola_arch.cpython-312.pyc b/ola/model/__pycache__/ola_arch.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8ea0d55a8905d07ece26cc3d3b3f34139d1ca77 Binary files /dev/null and b/ola/model/__pycache__/ola_arch.cpython-312.pyc differ diff --git a/ola/model/builder.py b/ola/model/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..44c18e173dcc9166c62f5c942dfde82167bd4198 --- /dev/null +++ b/ola/model/builder.py @@ -0,0 +1,97 @@ +import os +import warnings +import shutil + +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig, AutoProcessor +import torch +from ola.model import * +from ola.model.speech_encoder.builder import build_speech_encoder + +# 过滤掉 PyTorch 的 meta parameter 警告 +warnings.filterwarnings("ignore", message=".*copying from a non-meta parameter in the checkpoint to a meta parameter.*") + +def load_pretrained_model(model_path, model_type, model_base, is_lora=False, s2s=False, load_8bit=False, load_4bit=False, device="cuda", use_flash_attn=False, **kwargs): + device = "cuda" + if load_8bit: + kwargs['load_in_8bit'] = True + elif load_4bit: + kwargs['load_in_4bit'] = True + kwargs['quantization_config'] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4' + ) + else: + kwargs['torch_dtype'] = torch.bfloat16 + + if use_flash_attn: + kwargs['attn_implementation'] = 'flash_attention_2' + + if model_type == 'ola_internvl': + model_cls = OlaQwen3ForCausalLM + print('Loading OlaQwen3ForCausalLM model...') + else: + model_cls = OlaQwenForCausalLM + + # Load Ola model + if is_lora: + assert model_base is not None, "model_base is required for LoRA models." + from ola.model.language_model.ola_qwen import OlaConfigQwen + lora_cfg_pretrained = OlaConfigQwen.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + print('Loading Ola from base model...') + model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=lora_cfg_pretrained, **kwargs) + print('Loading additional Ola weights...') + if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): + non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') + non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} + if any(k.startswith('model.model.') for k in non_lora_trainables): + non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} + model.load_state_dict(non_lora_trainables, strict=False, assign=True) + + from peft import PeftModel + print('Loading LoRA weights...') + model = PeftModel.from_pretrained(model, model_path) + print('Merging LoRA weights...') + model = model.merge_and_unload() + print('Model is loaded...') + elif model_base is not None: + print('Loading Ola from base model...') + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + cfg_pretrained = AutoConfig.from_pretrained(model_path) + model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=cfg_pretrained, **kwargs) + + speech_projector_weights = torch.load(os.path.join(model_path, 'speech_projector.bin'), map_location='cpu') + speech_projector_weights = {k: v.to(torch.float16) for k, v in speech_projector_weights.items()} + model.load_state_dict(speech_projector_weights, strict=False, assign=True) + model = model.to(device=device) + else: + # model_path = "/data1/cxy/plm-v/modeling/plm_internvl3_5_ola" + model_path = "/data1/cxy/plm-v/modeling/ckpt/ola_audio_8_8gpu/checkpoint-120" + tokernizer_path = "/data1/cxy/plm-v/modeling/internvl3_5-2B" + tokenizer = AutoTokenizer.from_pretrained(tokernizer_path, use_fast=False, trust_remote_code=True) + cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + with torch.device("cuda"): + model = model_cls.from_pretrained( + model_path, + trust_remote_code=True, + config=cfg, + # device_map="auto", + **kwargs, + ) + model = model.to(device=device) + # breakpoint() + image_processor = None + model.resize_token_embeddings(len(tokenizer)) + # breakpoint() + print("Loading vision tower...") + print("Loading vision tower succeeded.") + + if hasattr(model.config, "max_sequence_length"): + context_len = model.config.max_sequence_length + else: + context_len = 16384 + image_processor = AutoProcessor.from_pretrained("/data1/cxy/plm-v/modeling/internvl3_5-2B-HF") + + return tokenizer, model, image_processor, context_len diff --git a/ola/model/builder_back.py b/ola/model/builder_back.py new file mode 100644 index 0000000000000000000000000000000000000000..0152cd9732203ff8119111f2ee1b2d3793e3c771 --- /dev/null +++ b/ola/model/builder_back.py @@ -0,0 +1,294 @@ +import os +import warnings +import shutil + +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig, AutoProcessor +import torch +from ola.model import * +from ola.model.speech_encoder.builder import build_speech_encoder + +def load_pretrained_model(model_path, model_type, model_base, is_lora=False, s2s=False, load_8bit=False, load_4bit=False, device="cuda", use_flash_attn=False, **kwargs): + device = "cuda" + if load_8bit: + kwargs['load_in_8bit'] = True + elif load_4bit: + kwargs['load_in_4bit'] = True + kwargs['quantization_config'] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4' + ) + else: + kwargs['torch_dtype'] = torch.bfloat16 + + if use_flash_attn: + kwargs['attn_implementation'] = 'flash_attention_2' + + if model_type == 'ola_internvl': + model_cls = OlaQwen3ForCausalLM + print('Loading OlaQwen3ForCausalLM model...') + else: + model_cls = OlaQwenForCausalLM + + # Load Ola model + if is_lora: + assert model_base is not None, "model_base is required for LoRA models." + from ola.model.language_model.ola_qwen import OlaConfigQwen + lora_cfg_pretrained = OlaConfigQwen.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + print('Loading Ola from base model...') + model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=lora_cfg_pretrained, **kwargs) + print('Loading additional Ola weights...') + if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): + non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') + non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} + if any(k.startswith('model.model.') for k in non_lora_trainables): + non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} + model.load_state_dict(non_lora_trainables, strict=False, assign=True) + + from peft import PeftModel + print('Loading LoRA weights...') + model = PeftModel.from_pretrained(model, model_path) + print('Merging LoRA weights...') + model = model.merge_and_unload() + print('Model is loaded...') + elif model_base is not None: + print('Loading Ola from base model...') + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + cfg_pretrained = AutoConfig.from_pretrained(model_path) + model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=cfg_pretrained, **kwargs) + + speech_projector_weights = torch.load(os.path.join(model_path, 'speech_projector.bin'), map_location='cpu') + speech_projector_weights = {k: v.to(torch.float16) for k, v in speech_projector_weights.items()} + model.load_state_dict(speech_projector_weights, strict=False, assign=True) + model = model.to(device=device) + elif model_type == 'ola_internvl': + cfg = AutoConfig.from_pretrained("/data1/cxy/plm-v/modeling/old_ola", trust_remote_code=True) + # breakpoint() + tokenizer = AutoTokenizer.from_pretrained("/data1/cxy/plm-v/modeling/internvl3_5-2B", use_fast=False) + with torch.device("cpu"): + # model = model_cls.from_pretrained("/data1/cxy/plm-v/modeling/internvl3_5-2B", low_cpu_mem_usage=False, attn_implementation="eager", config=cfg, **kwargs) + # model = model_cls.from_config(config=cfg) + model = model_cls(cfg) + # breakpoint() + # model.model.layers[1].self_attn.q_proj.weight + else: + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + with torch.device("cpu"): + model = model_cls.from_pretrained( + model_path, + **kwargs, + ) + model = model.to(device=device) + # model.resize_token_embeddings(len(tokenizer)) + from safetensors.torch import load_file + partial_state_dict = load_file(f"/data1/cxy/plm-v/modeling/internvl3_5-2B/model.safetensors") # 替换为你的部分权重路径 + mapping = { + "mlp1.0.weight": "model.mm_projector.layer_norm.weight", + "mlp1.0.bias": "model.mm_projector.layer_norm.bias", + "mlp1.1.weight": "model.mm_projector.linear_1.weight", + "mlp1.1.bias": "model.mm_projector.linear_1.bias", + "mlp1.3.weight": "model.mm_projector.linear_2.weight", + "mlp1.3.bias": "model.mm_projector.linear_2.bias", + } + +# 遍历 state_dict 并重命名 + def remap_keys(state_dict, mapping): + new_state_dict = {} + for k, v in state_dict.items(): + if k in mapping: + new_state_dict[mapping[k]] = v + else: + new_state_dict[k] = v + return new_state_dict + # merged_state_dict = {**partial_state_dict, **partial_state_dict2} + # 2. 重命名 key:multi_modal_projector -> mm_projector + # breakpoint() + rename_dict = {} + for k in list(partial_state_dict.keys()): + if k.startswith("language_model"): + new_k = k.replace("language_model.", "", 1) + rename_dict[k] = new_k + if k.startswith("vision_model"): + new_k = k.replace("vision_model", "model.vision_tower", 1) + rename_dict[k] = new_k + + # 应用重命名 + for old_k, new_k in rename_dict.items(): + partial_state_dict[new_k] = partial_state_dict.pop(old_k) + partial_state_dict = remap_keys(partial_state_dict, mapping) + + whisper_state_dict = torch.load("/data1/cxy/model/THUdyh/Ola-7b/large-v3.pt", map_location='cpu') + # breakpoint() + whisper_state_dict = whisper_state_dict["model_state_dict"] + + # Filter to keep only encoder weights + whisper_encoder_dict = {} + for key, value in whisper_state_dict.items(): + if key.startswith('encoder.'): + whisper_encoder_dict[key] = value + + print(f"Original Whisper keys: {len(whisper_state_dict)}") + print(f"Filtered encoder keys: {len(whisper_encoder_dict)}") + print("Sample encoder keys:") + for i, key in enumerate(list(whisper_encoder_dict.keys())[:5]): + print(f" {key}") + + # Create mapping for Whisper parameters to OLA format + def create_whisper_mapping(): + mapping = {} + + # Base encoder components + base_mappings = { + 'encoder.positional_embedding': 'model.speech_encoder.whisper_model.positional_embedding', + 'encoder.conv1.weight': 'model.speech_encoder.whisper_model.conv1.weight', + 'encoder.conv1.bias': 'model.speech_encoder.whisper_model.conv1.bias', + 'encoder.conv2.weight': 'model.speech_encoder.whisper_model.conv2.weight', + 'encoder.conv2.bias': 'model.speech_encoder.whisper_model.conv2.bias', + 'encoder.ln_post.weight': 'model.speech_encoder.whisper_model.ln_post.weight', + 'encoder.ln_post.bias': 'model.speech_encoder.whisper_model.ln_post.bias', + } + mapping.update(base_mappings) + + # Encoder blocks (32 blocks: 0-31) + for block_idx in range(32): + # Attention components + attn_components = [ + 'attn.query.weight', 'attn.query.bias', + 'attn.key.weight', 'attn.key.bias', + 'attn.value.weight', 'attn.value.bias', + 'attn.out.weight', 'attn.out.bias', + 'attn_ln.weight', 'attn_ln.bias' + ] + + for component in attn_components: + source_key = f'encoder.blocks.{block_idx}.{component}' + target_key = f'model.speech_encoder.whisper_model.blocks.{block_idx}.{component}' + mapping[source_key] = target_key + + # MLP components + mlp_components = [ + 'mlp.0.weight', 'mlp.0.bias', + 'mlp.2.weight', 'mlp.2.bias', + 'mlp_ln.weight', 'mlp_ln.bias' + ] + + for component in mlp_components: + source_key = f'encoder.blocks.{block_idx}.{component}' + target_key = f'model.speech_encoder.whisper_model.blocks.{block_idx}.{component}' + mapping[source_key] = target_key + + return mapping + + # Apply mapping to whisper_encoder_dict + whisper_mapping = create_whisper_mapping() + mapped_whisper_dict = {} + unmapped_whisper_keys = [] + + for key, value in whisper_encoder_dict.items(): + if key in whisper_mapping: + mapped_key = whisper_mapping[key] + mapped_whisper_dict[mapped_key] = value + else: + unmapped_whisper_keys.append(key) + print(f"Warning: No mapping found for Whisper encoder key '{key}'") + + if unmapped_whisper_keys: + print(f"Total unmapped Whisper encoder keys: {len(unmapped_whisper_keys)}") + print("First 10 unmapped Whisper encoder keys:") + for key in unmapped_whisper_keys[:10]: + print(f" {key}") + + print(f"Successfully mapped {len(mapped_whisper_dict)} encoder parameters") + + beat_state_dict = torch.load("/data1/cxy/model/THUdyh/Ola-7b//BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt", map_location='cpu') + beat_state_dict = beat_state_dict['model'] + beat_state_dict = {"model.speech_encoder.beats_model."+k: v for k, v in beat_state_dict.items()} + + # 处理 BEATs 模型中的参数化权重映射 (先pop后添加) + keys_to_process = list(beat_state_dict.keys()) + breakpoint() + processed_count = 0 + + # for key in keys_to_process: + # if 'weight_g' in key: + # # pop 原始权重并添加为 weight_g + # weight_tensor = beat_state_dict.pop(key) + # new_key = key.replace('weight_g','parametrizations.weight.original0') + # beat_state_dict[new_key] = weight_tensor + # processed_count += 1 + # elif 'weight_v' in key: + # # pop 原始权重并添加为 weight_v + # weight_tensor = beat_state_dict.pop(key) + # new_key = key.replace('weight_v', 'parametrizations.weight.original1') + # beat_state_dict[new_key] = weight_tensor + # processed_count += 1 + + print(f"Processed {processed_count} parametrized weight keys in BEATs model (pop and add)") + breakpoint() + # breakpoint() + partial_state_dict = {**partial_state_dict, **mapped_whisper_dict, **beat_state_dict} + + # Ensure all tensors in the state dict are on CPU and have proper device information + print("Moving all state dict tensors to CPU...") + for key, tensor in partial_state_dict.items(): + if torch.is_tensor(tensor): + # Ensure tensor has device information and move to CPU + if not tensor.device.type: + print(f"Warning: Tensor {key} has no device, creating on CPU") + partial_state_dict[key] = torch.tensor(tensor.detach().numpy()).cpu() + else: + partial_state_dict[key] = tensor.cpu() + + # Ensure model is on CPU before loading state dict to avoid device mismatches + print("Moving model to CPU before loading state dict...") + model = model.cpu() + + print("Loading state dict...") + breakpoint() + missing, unexpected = model.load_state_dict(partial_state_dict, strict=False, assign=True) + + print("Missing keys:", missing) + print("Unexpected keys:", unexpected) + + # Convert model to bfloat16 before saving + print("Converting model to bfloat16...") + model = model.to(torch.bfloat16) + model = model.to("cpu") + + # Save model in bfloat16 format + print("Saving model in bfloat16 format...") + model.save_pretrained("/data1/cxy/plm-v/modeling/plm_internvl3_ola", safe_serialization=False, torch_dtype=torch.bfloat16) + print("Model saved successfully in bfloat16 format!") + breakpoint() + # model.model.mm_projector.linear_1.weight:-0.0106 multi_modal_projector.linear_1.weight model.mm_projector.linear_2.bias + # model.vision_tower.encoder.layers.7.attn.proj.bias + # model.model.vision_tower.encoder.layers[0].attn.qkv.weight: -6.5613e-03 dui + # + # breakpoint() + # model.get_model().speech_encoder.load_model("") + # language_model.model.layers.9.mlp.up_proj.weight vision_model.encoder.layers + # model.layers.14.self_attn.q_proj.weight model.vision_tower.encoder.layers.23.attn.proj.bias + # model.get_model().speech_encoder = build_speech_encoder(model.config) + # model.get_model().speech_encoder.to(device=device, dtype=torch.float16) + image_processor = None + model.resize_token_embeddings(len(tokenizer)) + vision_tower = model.get_vision_tower() + print("Loading vision tower...") + # if not vision_tower.is_loaded: + # vision_tower.load_model(device_map=device) + # if device != "auto": + # vision_tower.to(device="cuda", dtype=torch.bfloat16) + # else: + # vision_tower.to(device="cuda:0", dtype=torch.bfloat16) + # image_processor = vision_tower.image_processor + print("Loading vision tower succeeded.") + + if hasattr(model.config, "max_sequence_length"): + context_len = model.config.max_sequence_length + else: + context_len = 16384 + image_processor = AutoProcessor.from_pretrained("/data1/cxy/plm-v/modeling/internvl3_5-2B-HF") + # breakpoint() + return tokenizer, model, image_processor, context_len diff --git a/ola/model/language_model/__pycache__/conversation.cpython-312.pyc b/ola/model/language_model/__pycache__/conversation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a10599b8636012923e2e9a85a17e6a88c918299a Binary files /dev/null and b/ola/model/language_model/__pycache__/conversation.cpython-312.pyc differ diff --git a/ola/model/language_model/__pycache__/ola_qwen.cpython-312.pyc b/ola/model/language_model/__pycache__/ola_qwen.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93f1c3ccf56acdd02c408edd88ad73e8c5ed8838 Binary files /dev/null and b/ola/model/language_model/__pycache__/ola_qwen.cpython-312.pyc differ diff --git a/ola/model/language_model/__pycache__/ola_qwen3.cpython-312.pyc b/ola/model/language_model/__pycache__/ola_qwen3.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a147a612247363f0988b1f3dd0d75df89c37956 Binary files /dev/null and b/ola/model/language_model/__pycache__/ola_qwen3.cpython-312.pyc differ diff --git a/ola/model/language_model/conversation.py b/ola/model/language_model/conversation.py new file mode 100644 index 0000000000000000000000000000000000000000..e30ddf98816986ccb0c1c3df155c7f272dc4f7bb --- /dev/null +++ b/ola/model/language_model/conversation.py @@ -0,0 +1,403 @@ +""" +Conversation prompt templates. + +We kindly request that you import fastchat instead of copying this file if you wish to use it. +If you have changes in mind, please contribute back so the community can benefit collectively and continue to maintain these valuable templates. + +Modified from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py +""" + +import dataclasses +from enum import IntEnum, auto +from typing import Dict, List, Tuple, Union + + +class SeparatorStyle(IntEnum): + """Separator styles.""" + + ADD_COLON_SINGLE = auto() + ADD_COLON_TWO = auto() + ADD_COLON_SPACE_SINGLE = auto() + NO_COLON_SINGLE = auto() + NO_COLON_TWO = auto() + ADD_NEW_LINE_SINGLE = auto() + LLAMA2 = auto() + CHATGLM = auto() + CHATML = auto() + CHATINTERN = auto() + DOLLY = auto() + RWKV = auto() + PHOENIX = auto() + ROBIN = auto() + FALCON_CHAT = auto() + CHATGLM3 = auto() + INTERNVL_ZH = auto() + MPT = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that manages prompt templates and keeps all conversation history.""" + + # The name of this template + name: str + # The template of the system prompt + system_template: str = '{system_message}' + # The system message + system_message: str = '' + # The names of two roles + roles: Tuple[str] = ('USER', 'ASSISTANT') + # All messages. Each item is (role, message). + messages: List[List[str]] = () + # The number of few shot examples + offset: int = 0 + # The separator style and configurations + sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE + sep: str = '\n' + sep2: str = None + # Stop criteria (the default one is EOS token) + stop_str: Union[str, List[str]] = None + # Stops generation if meeting any token in this list + stop_token_ids: List[int] = None + + def get_prompt(self) -> str: + """Get the prompt for generation.""" + system_prompt = self.system_template.format(system_message=self.system_message) + if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE: + ret = system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ': ' + message + self.sep + else: + ret += role + ':' + return ret + elif self.sep_style == SeparatorStyle.ADD_COLON_TWO: + seps = [self.sep, self.sep2] + ret = system_prompt + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ': ' + message + seps[i % 2] + else: + ret += role + ':' + return ret + elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: + ret = system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ': ' + message + self.sep + else: + ret += role + ': ' # must be end with a space + return ret + elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: + ret = '' if system_prompt == '' else system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + '\n' + message + self.sep + else: + ret += role + '\n' + return ret + elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE: + ret = system_prompt + for role, message in self.messages: + if message: + ret += role + message + self.sep + else: + ret += role + return ret + elif self.sep_style == SeparatorStyle.NO_COLON_TWO: + seps = [self.sep, self.sep2] + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + message + seps[i % 2] + else: + ret += role + return ret + elif self.sep_style == SeparatorStyle.RWKV: + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + ret += ( + role + + ': ' + + message.replace('\r\n', '\n').replace('\n\n', '\n') + ) + ret += '\n\n' + else: + ret += role + ':' + return ret + elif self.sep_style == SeparatorStyle.LLAMA2: + seps = [self.sep, self.sep2] + if self.system_message: + ret = system_prompt + else: + ret = '[INST] ' + for i, (role, message) in enumerate(self.messages): + tag = self.roles[i % 2] + if message: + if i == 0: + ret += message + ' ' + else: + ret += tag + ' ' + message + seps[i % 2] + else: + ret += tag + return ret + elif self.sep_style == SeparatorStyle.CHATGLM: + # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308 + # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926 + round_add_n = 1 if self.name == 'chatglm2' else 0 + if system_prompt: + ret = system_prompt + self.sep + else: + ret = '' + + for i, (role, message) in enumerate(self.messages): + if i % 2 == 0: + ret += f'[Round {i//2 + round_add_n}]{self.sep}' + + if message: + ret += f'{role}:{message}{self.sep}' + else: + ret += f'{role}:' + return ret + elif self.sep_style == SeparatorStyle.CHATML: + ret = '' if system_prompt == '' else system_prompt + self.sep + '\n' + for role, message in self.messages: + if message: + ret += role + '\n' + message + self.sep + '\n' + else: + ret += role + '\n' + return ret + elif self.sep_style == SeparatorStyle.CHATGLM3: + ret = '' + if self.system_message: + ret += system_prompt + for role, message in self.messages: + if message: + ret += role + '\n' + ' ' + message + else: + ret += role + return ret + elif self.sep_style == SeparatorStyle.CHATINTERN: + # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771 + seps = [self.sep, self.sep2] + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + # if i % 2 == 0: + # ret += "" + if message: + ret += role + ':' + message + seps[i % 2] + '\n' + else: + ret += role + ':' + return ret + elif self.sep_style == SeparatorStyle.DOLLY: + seps = [self.sep, self.sep2] + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ':\n' + message + seps[i % 2] + if i % 2 == 1: + ret += '\n\n' + else: + ret += role + ':\n' + return ret + elif self.sep_style == SeparatorStyle.PHOENIX: + ret = system_prompt + for role, message in self.messages: + if message: + ret += role + ': ' + '' + message + '' + else: + ret += role + ': ' + '' + return ret + elif self.sep_style == SeparatorStyle.ROBIN: + ret = system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ':\n' + message + self.sep + else: + ret += role + ':\n' + return ret + elif self.sep_style == SeparatorStyle.FALCON_CHAT: + ret = '' + if self.system_message: + ret += system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ': ' + message + self.sep + else: + ret += role + ':' + + return ret + elif self.sep_style == SeparatorStyle.INTERNVL_ZH: + seps = [self.sep, self.sep2] + ret = self.system_message + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ': ' + message + seps[i % 2] + else: + ret += role + ':' + return ret + elif self.sep_style == SeparatorStyle.MPT: + ret = system_prompt + self.sep + for role, message in self.messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + message + self.sep + else: + ret += role + return ret + else: + raise ValueError(f'Invalid style: {self.sep_style}') + + def set_system_message(self, system_message: str): + """Set the system message.""" + self.system_message = system_message + + def append_message(self, role: str, message: str): + """Append a new message.""" + self.messages.append([role, message]) + + def update_last_message(self, message: str): + """Update the last output. + + The last message is typically set to be None when constructing the prompt, + so we need to update it in-place after getting the response from a model. + """ + self.messages[-1][1] = message + + def to_gradio_chatbot(self): + """Convert the conversation to gradio chatbot format.""" + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def to_openai_api_messages(self): + """Convert the conversation to OpenAI chat completion format.""" + ret = [{'role': 'system', 'content': self.system_message}] + + for i, (_, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + ret.append({'role': 'user', 'content': msg}) + else: + if msg is not None: + ret.append({'role': 'assistant', 'content': msg}) + return ret + + def copy(self): + return Conversation( + name=self.name, + system_template=self.system_template, + system_message=self.system_message, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + stop_str=self.stop_str, + stop_token_ids=self.stop_token_ids, + ) + + def dict(self): + return { + 'template_name': self.name, + 'system_message': self.system_message, + 'roles': self.roles, + 'messages': self.messages, + 'offset': self.offset, + } + + +# A global registry for all conversation templates +conv_templates: Dict[str, Conversation] = {} + + +def register_conv_template(template: Conversation, override: bool = False): + """Register a new conversation template.""" + if not override: + assert ( + template.name not in conv_templates + ), f'{template.name} has been registered.' + + conv_templates[template.name] = template + + +def get_conv_template(name: str) -> Conversation: + """Get a conversation template.""" + # breakpoint() + return conv_templates[name].copy() + + +# Both Hermes-2 and internlm2-chat are chatml-format conversation templates. The difference +# is that during training, the preprocessing function for the Hermes-2 template doesn't add +# at the beginning of the tokenized sequence, while the internlm2-chat template does. +# Therefore, they are completely equivalent during inference. +register_conv_template( + Conversation( + name='Hermes-2', + system_template='<|im_start|>system\n{system_message}', + # note: The new system prompt was not used here to avoid changes in benchmark performance. + # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。', + system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。', + roles=('<|im_start|>user\n', '<|im_start|>assistant\n'), + sep_style=SeparatorStyle.MPT, + sep='<|im_end|>', + stop_str='<|endoftext|>', + ) +) + + +register_conv_template( + Conversation( + name='internlm2-chat', + system_template='<|im_start|>system\n{system_message}', + # note: The new system prompt was not used here to avoid changes in benchmark performance. + # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。', + system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。', + roles=('<|im_start|>user\n', '<|im_start|>assistant\n'), + sep_style=SeparatorStyle.MPT, + sep='<|im_end|>', + ) +) + + +register_conv_template( + Conversation( + name='phi3-chat', + system_template='<|system|>\n{system_message}', + # note: The new system prompt was not used here to avoid changes in benchmark performance. + # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。', + system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。', + roles=('<|user|>\n', '<|assistant|>\n'), + sep_style=SeparatorStyle.MPT, + sep='<|end|>', + ) +) + + +register_conv_template( + Conversation( + name='internvl2_5', + system_template='<|im_start|>system\n{system_message}', + system_message='你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。', + roles=('<|im_start|>user\n', '<|im_start|>assistant\n'), + sep_style=SeparatorStyle.MPT, + sep='<|im_end|>\n', + ) +) + +register_conv_template( + Conversation( + name='plm_v', + system_template='<|im_start|>system\n{system_message}', + system_message='You are PLM-V, developed by PLM-Team, a helpful assistant.', + roles=('<|im_start|>user\n', '<|im_start|>assistant\n'), + sep_style=SeparatorStyle.MPT, + sep='<|im_end|>\n', + ) +) diff --git a/ola/model/language_model/ola_qwen.py b/ola/model/language_model/ola_qwen.py new file mode 100644 index 0000000000000000000000000000000000000000..fd88538c53603ef929abc3dee892e109e2cd0844 --- /dev/null +++ b/ola/model/language_model/ola_qwen.py @@ -0,0 +1,237 @@ +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +import transformers +from transformers import AutoConfig, AutoModelForCausalLM + + +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.generation.utils import GenerateOutput + +from ..ola_arch import OlaMetaModel, OlaMetaForCausalLM +from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM + + +class OlaConfigQwen(Qwen2Config): + model_type = "ola_qwen" + + +class OlaQwenModel(OlaMetaModel, Qwen2Model): + config_class = OlaConfigQwen + + def __init__(self, config: Qwen2Config): + super(OlaQwenModel, self).__init__(config) + + +class OlaQwenForCausalLM(Qwen2ForCausalLM, OlaMetaForCausalLM): + config_class = OlaConfigQwen + + def __init__(self, config): + super(Qwen2ForCausalLM, self).__init__(config) + + config.rope_scaling = None + self.model = OlaQwenModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_model(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + speech: Optional[torch.FloatTensor] = None, + speech_lengths: Optional[torch.LongTensor] = None, + speech_chunks: Optional[torch.LongTensor] = None, + speech_wav: Optional[torch.FloatTensor] = None, + images: Optional[torch.FloatTensor] = None, + images_highres: Optional[List[torch.FloatTensor]] = None, + image_sizes: Optional[List[List[int]]] = None, + modalities: Optional[List[str]] = ["image"], + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + + if inputs_embeds is None: + ( + input_ids, + position_ids, + attention_mask, + past_key_values, + inputs_embeds, + labels + ) = self.prepare_inputs_labels_for_speech_vision_text( + input_ids, + position_ids, + attention_mask, + past_key_values, + labels, + speech, + speech_lengths, + speech_chunks, + speech_wav, + images, + modalities, + image_sizes, + images_highres + ) + + if labels is None: + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + else: + return self.forward_llm_efficient( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + + def forward_llm_efficient(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_dim = hidden_states.size(-1) + shift_labels = labels[..., 1:].contiguous().reshape(-1) + shift_hidden_states = hidden_states[..., :-1, :].contiguous().reshape(-1, hidden_dim) + assert shift_labels.size(0) == shift_hidden_states.size(0) + mask = shift_labels > -1 + assert mask.float().sum() > 0 + shift_labels = shift_labels[mask] + shift_hidden_states = shift_hidden_states[mask, :] + logits = self.lm_head(shift_hidden_states) + logits = logits.float() + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(logits, shift_labels) + + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + speech: Optional[torch.Tensor] = None, + speech_lengths: Optional[torch.Tensor] = None, + speech_chunks: Optional[torch.Tensor] = None, + speech_wav: Optional[torch.FloatTensor] = None, + images: Optional[torch.Tensor] = None, + images_highres: Optional[List[torch.FloatTensor]] = None, + image_sizes: Optional[torch.Tensor] = None, + modalities: Optional[List[str]] = ["image"], + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + position_ids = kwargs.pop("position_ids", None) + attention_mask = kwargs.pop("attention_mask", None) + if "inputs_embeds" in kwargs: + raise NotImplementedError("`inputs_embeds` is not supported") + + ( + inputs, + position_ids, + attention_mask, + _, + inputs_embeds, + _ + ) = self.prepare_inputs_labels_for_speech_vision_text( + inputs, + position_ids, + attention_mask, + None, + None, + speech, + speech_lengths, + speech_chunks, + speech_wav, + images, + modalities, + image_sizes, + images_highres + ) + + return super().generate( + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + **kwargs + ) + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, + inputs_embeds=None, **kwargs): + speech = kwargs.pop("speech", None) + speech_lengths = kwargs.pop("speech_lengths", None) + speech_chunks = kwargs.pop("speech_chunks", None) + images = kwargs.pop("images", None) + image_sizes = kwargs.pop("image_sizes", None) + inputs = super().prepare_inputs_for_generation( + input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs + ) + if speech is not None: + inputs['speech'] = speech + inputs['speech_lengths'] = speech_lengths + inputs['speech_chunks'] = speech_chunks + if images is not None: + inputs["images"] = images + if image_sizes is not None: + inputs["image_sizes"] = image_sizes + return inputs + +AutoConfig.register("ola_qwen", OlaConfigQwen) +AutoModelForCausalLM.register(OlaConfigQwen, OlaQwenForCausalLM) diff --git a/ola/model/language_model/ola_qwen3.py b/ola/model/language_model/ola_qwen3.py new file mode 100644 index 0000000000000000000000000000000000000000..9a6a5e592c08f626c5949ec1b439be2cd717c834 --- /dev/null +++ b/ola/model/language_model/ola_qwen3.py @@ -0,0 +1,466 @@ +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +import transformers +from transformers import GenerationConfig +from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig +SPEECH_TOKEN_INDEX = -200 + +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.generation.utils import GenerateOutput + +from ..ola_arch import OlaMetaModel, OlaMetaForCausalLM +from transformers import Qwen3Config, Qwen3Model, Qwen3ForCausalLM +from .conversation import get_conv_template +from ola.constants import IGNORE_INDEX + +def tokenizer_speech_token(prompt, tokenizer, speech_token_index=SPEECH_TOKEN_INDEX, return_tensors=None): + """Tokenize prompt with speech tokens, similar to OLA's implementation""" + prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] + + def insert_separator(X, sep): + return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] + + input_ids = [] + offset = 0 + if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: + offset = 1 + input_ids.append(prompt_chunks[0][0]) + + for x in insert_separator(prompt_chunks, [speech_token_index] * (offset + 1)): + input_ids.extend(x[offset:]) + + if return_tensors is not None: + if return_tensors == 'pt': + return torch.tensor(input_ids, dtype=torch.long) + raise ValueError(f'Unsupported tensor type: {return_tensors}') + return input_ids + + +class Qwen3Model(Qwen3Model): + def __init__(self, config: Qwen3Config, llm_config: Qwen3Config): + # breakpoint() + super(Qwen3Model, self).__init__(llm_config) + +class OlaConfigQwen3(Qwen3Config, PretrainedConfig): + model_type = "ola_internvl" + + +class OlaQwen3Model(OlaMetaModel, Qwen3Model): + config_class = OlaConfigQwen3 + + def __init__(self, config: Qwen3Config): + + super(OlaQwen3Model, self).__init__(config, config.llm_config) + + +class OlaQwen3ForCausalLM(Qwen3ForCausalLM, OlaMetaForCausalLM): + config_class = OlaConfigQwen3 + # 从零初始化时不需要 checkpoint conversion mapping + # _checkpoint_conversion_mapping = { + # "^language_model.lm_head": "lm_head", + # "^language_model.model": "model.model", + # "^vision_model": "model.vision_tower", + # } + # model.model.embed_tokens: + def __init__(self, config): + super(Qwen3ForCausalLM, self).__init__(config) + + config.rope_scaling = None + # breakpoint() + self.model = OlaQwen3Model(config) + self.vocab_size = config.vocab_size + # breakpoint() + self.ps_version = config.ps_version + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.template = "plm_v" + self.select_layer = config.select_layer + self.conv_template = get_conv_template(self.template) + self.system_message = self.conv_template.system_message + self.num_image_token = int((config.vision_config.image_size // config.vision_config.patch_size) ** 2 * (config.downsample_ratio ** 2)) + self.downsample_ratio = config.downsample_ratio + # Initialize weights and apply final processing + self.post_init() + + + def get_model(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + speech: Optional[torch.FloatTensor] = None, + speech_lengths: Optional[torch.LongTensor] = None, + speech_chunks: Optional[torch.LongTensor] = None, + speech_wav: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + images_highres: Optional[List[torch.FloatTensor]] = None, + image_sizes: Optional[List[List[int]]] = None, + modalities: Optional[List[str]] = ["image"], + image_flags: Optional[torch.LongTensor] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + # breakpoint() + if inputs_embeds is None: + ( + input_ids, + position_ids, + attention_mask, + past_key_values, + inputs_embeds, + labels + ) = self.prepare_inputs_labels_for_speech_text_for_internvl( + input_ids, + position_ids, + attention_mask, + past_key_values, + labels, + speech, + speech_lengths, + speech_chunks, + speech_wav, + modalities, + ) + + if labels is None: + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + else: + return self.forward_llm_efficient( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + + def forward_llm_efficient(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Check inputs before model forward + print(f"Debug - Input embeddings range: {inputs_embeds.min().item()} to {inputs_embeds.max().item()}") + print(f"Debug - Input embeddings has nan: {torch.isnan(inputs_embeds).any().item()}") + print(f"Debug - Input embeddings has inf: {torch.isinf(inputs_embeds).any().item()}") + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + # Check hidden states immediately after model forward + print(f"Debug - Raw hidden states range: {hidden_states.min().item()} to {hidden_states.max().item()}") + print(f"Debug - Raw hidden states has nan: {torch.isnan(hidden_states).any().item()}") + print(f"Debug - Raw hidden states has inf: {torch.isinf(hidden_states).any().item()}") + hidden_dim = hidden_states.size(-1) + shift_labels = labels[..., 1:].contiguous().reshape(-1) + shift_hidden_states = hidden_states[..., :-1, :].contiguous().reshape(-1, hidden_dim) + assert shift_labels.size(0) == shift_hidden_states.size(0) + mask = shift_labels != IGNORE_INDEX + + # Debug logging + print(f"Debug - Total tokens: {shift_labels.size(0)}") + print(f"Debug - Valid tokens: {mask.float().sum().item()}") + print(f"Debug - Ignored tokens: {(~mask).float().sum().item()}") + print(f"Debug - Label range: {shift_labels.min().item()} to {shift_labels.max().item()}") + + assert mask.float().sum() > 0, f"No valid tokens found! Total: {shift_labels.size(0)}, Valid: {mask.float().sum().item()}" + shift_labels = shift_labels[mask] + shift_hidden_states = shift_hidden_states[mask, :] + + print(f"Debug - After filtering: {shift_labels.size(0)} tokens") + print(f"Debug - Hidden states shape: {shift_hidden_states.shape}") + print(f"Debug - Hidden states range: {shift_hidden_states.min().item()} to {shift_hidden_states.max().item()}") + print(f"Debug - Hidden states has nan: {torch.isnan(shift_hidden_states).any().item()}") + print(f"Debug - Hidden states has inf: {torch.isinf(shift_hidden_states).any().item()}") + + # Check lm_head weights + print(f"Debug - lm_head weight shape: {self.lm_head.weight.shape}") + print(f"Debug - lm_head weight range: {self.lm_head.weight.min().item()} to {self.lm_head.weight.max().item()}") + print(f"Debug - lm_head weight has nan: {torch.isnan(self.lm_head.weight).any().item()}") + print(f"Debug - lm_head weight has inf: {torch.isinf(self.lm_head.weight).any().item()}") + + logits = self.lm_head(shift_hidden_states) + logits = logits.float() + + print(f"Debug - Logits shape: {logits.shape}") + print(f"Debug - Logits range: {logits.min().item()} to {logits.max().item()}") + print(f"Debug - Logits has nan: {torch.isnan(logits).any().item()}") + print(f"Debug - Logits has inf: {torch.isinf(logits).any().item()}") + + # Fix nan values in logits + if torch.isnan(logits).any(): + print("WARNING: Found nan values in logits, replacing with zeros") + logits = torch.where(torch.isnan(logits), torch.zeros_like(logits), logits) + + # Fix inf values in logits + if torch.isinf(logits).any(): + print("WARNING: Found inf values in logits, clamping to finite range") + logits = torch.clamp(logits, min=-1e4, max=1e4) + + # Additional check: if logits are still problematic, use a fallback + if torch.isnan(logits).any() or torch.isinf(logits).any(): + print("ERROR: Logits still contain nan/inf after fixing, using fallback") + logits = torch.zeros_like(logits) + + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(logits, shift_labels) + + print(f"Debug - Loss: {loss.item()}") + print(f"Debug - Loss has nan: {torch.isnan(loss).item()}") + + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + def pixel_shuffle(self, x, scale_factor=0.5): + n, w, h, c = x.size() + # N, W, H, C --> N, W, H * scale, C // scale + x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) + # N, W, H * scale, C // scale --> N, H * scale, W, C // scale + x = x.permute(0, 2, 1, 3).contiguous() + # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2) + x = x.view(n, int(h * scale_factor), int(w * scale_factor), + int(c / (scale_factor * scale_factor))) + if self.ps_version == 'v1': + warnings.warn("In ps_version 'v1', the height and width have not been swapped back, " + 'which results in a transposed image.') + else: + x = x.permute(0, 2, 1, 3).contiguous() + return x + + def extract_feature(self, pixel_values): + if self.select_layer == -1: + # breakpoint() + vit_embeds = self.get_vision_tower()( + pixel_values=pixel_values, + output_hidden_states=False, + return_dict=True).last_hidden_state + else: + vit_embeds = self.get_vision_tower()( + pixel_values=pixel_values, + output_hidden_states=True, + return_dict=True).hidden_states[self.select_layer] + vit_embeds = vit_embeds[:, 1:, :] + + h = w = int(vit_embeds.shape[1] ** 0.5) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) + vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) + # breakpoint() + vit_embeds = self.get_vision_projector()(vit_embeds) + return vit_embeds + @torch.no_grad() + def generate( + self, + pixel_values: Optional[torch.FloatTensor] = None, + input_ids: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + visual_features: Optional[torch.FloatTensor] = None, + generation_config: Optional[GenerationConfig] = None, + output_hidden_states: Optional[bool] = None, + speech: Optional[torch.FloatTensor] = None, + speech_lengths: Optional[torch.LongTensor] = None, + speech_chunks: Optional[torch.LongTensor] = None, + speech_wav: Optional[torch.FloatTensor] = None, + modalities: Optional[List[str]] = ["image"], + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + position_ids = kwargs.pop("position_ids", None) + + if speech is not None: + ( + _, + position_ids, + attention_mask, + _, + input_embeds, + _ + ) = self.prepare_inputs_labels_for_speech_text_for_internvl( + input_ids, + position_ids, + attention_mask, + None, + None, # labels + speech, + speech_lengths, + speech_chunks, + speech_wav, + modalities, + ) + else: + # internvl + assert self.img_context_token_id is not None + if pixel_values is not None: + if visual_features is not None: + vit_embeds = visual_features + else: + vit_embeds = self.extract_feature(pixel_values) + input_embeds = self.get_model().get_input_embeddings()(input_ids) + B, N, C = input_embeds.shape + input_embeds = input_embeds.reshape(B * N, C) + input_ids = input_ids.reshape(B * N) + selected = (input_ids == self.img_context_token_id) + assert selected.sum() != 0 + input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device) + input_embeds = input_embeds.reshape(B, N, C) + else: + input_embeds = self.get_model().get_input_embeddings()(input_ids) + return super().generate( + inputs_embeds=input_embeds, + attention_mask=attention_mask, + generation_config=generation_config, + output_hidden_states=output_hidden_states, + use_cache=True, + **kwargs, + ) + + + def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False, + num_patches_list=None, IMG_START_TOKEN='', IMG_END_TOKEN='', IMG_CONTEXT_TOKEN='', + verbose=False, speech=None, speech_lengths=None, speech_wav=None, speech_chunks=None): + if history is None and pixel_values is not None and '' not in question: + question = '\n' + question + + if num_patches_list is None: + num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else [] + assert pixel_values is None or len(pixel_values) == sum(num_patches_list) + + img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) + self.img_context_token_id = img_context_token_id + + template = get_conv_template(self.template) + template.system_message = self.system_message + eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip()) + + history = [] if history is None else history + for (old_question, old_answer) in history: + template.append_message(template.roles[0], old_question) + template.append_message(template.roles[1], old_answer) + template.append_message(template.roles[0], question) + template.append_message(template.roles[1], None) + query = template.get_prompt() + + if verbose and pixel_values is not None: + image_bs = pixel_values.shape[0] + print(f'dynamic ViT batch size: {image_bs}') + + + # Replace image tokens + for num_patches in num_patches_list: + image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN + query = query.replace('', image_tokens, 1) + from ola.conversation import conv_templates, SeparatorStyle + from ola.mm_utils import KeywordsStoppingCriteria + conv_mode = "plm_v" + conv = conv_templates[conv_mode].copy() + stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 + keywords = [stop_str] + + # Use OLA-style tokenization for speech inputs + if speech is not None and '' in query: + # Use OLA-style tokenization directly with tokens + input_ids = tokenizer_speech_token(query, tokenizer, return_tensors='pt').unsqueeze(0).to(self.device) + # Handle case where pad_token_id might be None + pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 151643 + attention_mask = input_ids.ne(pad_token_id).long().to(self.device) + + else: + model_inputs = tokenizer(query, return_tensors='pt') + input_ids = model_inputs['input_ids'].to(self.device) + attention_mask = model_inputs['attention_mask'].to(self.device) + generation_config['eos_token_id'] = eos_token_id + stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) + # generation_config["stopping_criteria"] = stopping_criteria + generation_output = self.generate( + pixel_values=pixel_values, + input_ids=input_ids, + attention_mask=attention_mask, + speech=speech, + speech_lengths=speech_lengths, + speech_chunks=speech_chunks, + speech_wav=speech_wav, + **generation_config + ) + response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0] + response = response.split(template.sep.strip())[0].strip() + history.append((question, response)) + if return_history: + return response, history + else: + query_to_print = query.replace(IMG_CONTEXT_TOKEN, '') + query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '') + if verbose: + print(query_to_print, response) + return response + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, + inputs_embeds=None, **kwargs): + speech = kwargs.pop("speech", None) + speech_lengths = kwargs.pop("speech_lengths", None) + speech_chunks = kwargs.pop("speech_chunks", None) + images = kwargs.pop("images", None) + image_sizes = kwargs.pop("image_sizes", None) + inputs = super().prepare_inputs_for_generation( + input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs + ) + if speech is not None: + inputs['speech'] = speech + inputs['speech_lengths'] = speech_lengths + inputs['speech_chunks'] = speech_chunks + if images is not None: + inputs["images"] = images + if image_sizes is not None: + inputs["image_sizes"] = image_sizes + return inputs + +AutoConfig.register("ola_internvl", OlaConfigQwen3) +AutoModelForCausalLM.register(OlaConfigQwen3, OlaQwen3ForCausalLM) diff --git a/ola/model/multimodal_encoder/__pycache__/builder.cpython-312.pyc b/ola/model/multimodal_encoder/__pycache__/builder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0997e97fa1f6acc5fda5430e0738ee4144e30689 Binary files /dev/null and b/ola/model/multimodal_encoder/__pycache__/builder.cpython-312.pyc differ diff --git a/ola/model/multimodal_encoder/__pycache__/configuration_intern_vit.cpython-312.pyc b/ola/model/multimodal_encoder/__pycache__/configuration_intern_vit.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cd5f08fc8644ac60fa0cb8bf3ba8a5880f60105 Binary files /dev/null and b/ola/model/multimodal_encoder/__pycache__/configuration_intern_vit.cpython-312.pyc differ diff --git a/ola/model/multimodal_encoder/__pycache__/internvl_vit.cpython-312.pyc b/ola/model/multimodal_encoder/__pycache__/internvl_vit.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25c39f5bc7258663968226ea7469e8cc84143031 Binary files /dev/null and b/ola/model/multimodal_encoder/__pycache__/internvl_vit.cpython-312.pyc differ diff --git a/ola/model/multimodal_encoder/__pycache__/oryx_vit.cpython-312.pyc b/ola/model/multimodal_encoder/__pycache__/oryx_vit.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21c238f68aa457f0905d3f5750c6778185f99b3d Binary files /dev/null and b/ola/model/multimodal_encoder/__pycache__/oryx_vit.cpython-312.pyc differ diff --git a/ola/model/multimodal_encoder/builder.py b/ola/model/multimodal_encoder/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..eaba7d6fc34a321be95e945583984588df27a9a8 --- /dev/null +++ b/ola/model/multimodal_encoder/builder.py @@ -0,0 +1,16 @@ +import os +from .oryx_vit import SigLIPViTAnysizeWrapper +from .internvl_vit import InternVisionModel + +def build_vision_tower(vision_tower_cfg, **kwargs): + # breakpoint() + if vision_tower_cfg.model_type == 'intern_vit_6b': + vision_tower = InternVisionModel(vision_tower_cfg) + # breakpoint() + return vision_tower + else: + vision_tower = getattr(vision_tower_cfg, 'vision_tower', getattr(vision_tower_cfg, 'mm_vision_tower', None)) + is_absolute_path_exists = os.path.exists(vision_tower) + print(f"Buiding OryxViTWrapper from {vision_tower}...") + # path = vision_tower.split(":")[1] + return SigLIPViTAnysizeWrapper(vision_tower, path=vision_tower, args=vision_tower_cfg, **kwargs) \ No newline at end of file diff --git a/ola/model/multimodal_encoder/configuration_intern_vit.py b/ola/model/multimodal_encoder/configuration_intern_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..32f469c4bbfee021fe19e622245d16fd9ba0aae6 --- /dev/null +++ b/ola/model/multimodal_encoder/configuration_intern_vit.py @@ -0,0 +1,119 @@ +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2024 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- +import os +from typing import Union + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class InternVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to + instantiate a vision encoder according to the specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_channels (`int`, *optional*, defaults to 3): + Number of color channels in the input images (e.g., 3 for RGB). + patch_size (`int`, *optional*, defaults to 14): + The size (resolution) of each patch. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + qkv_bias (`bool`, *optional*, defaults to `False`): + Whether to add a bias to the queries and values in the self-attention layers. + hidden_size (`int`, *optional*, defaults to 3200): + Dimensionality of the encoder layers and the pooler layer. + num_attention_heads (`int`, *optional*, defaults to 25): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 12800): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + qk_normalization (`bool`, *optional*, defaults to `True`): + Whether to normalize the queries and keys in the self-attention layers. + num_hidden_layers (`int`, *optional*, defaults to 48): + Number of hidden layers in the Transformer encoder. + use_flash_attn (`bool`, *optional*, defaults to `True`): + Whether to use flash attention mechanism. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-6): + The epsilon used by the layer normalization layers. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + drop_path_rate (`float`, *optional*, defaults to 0.0): + Dropout rate for stochastic depth. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 0.1): + A factor for layer scale. + """ + + model_type = 'intern_vit_6b' + + def __init__( + self, + num_channels=3, + patch_size=14, + image_size=224, + qkv_bias=False, + hidden_size=3200, + num_attention_heads=25, + intermediate_size=12800, + qk_normalization=True, + num_hidden_layers=48, + use_flash_attn=True, + hidden_act='gelu', + norm_type='rms_norm', + layer_norm_eps=1e-6, + dropout=0.0, + drop_path_rate=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=0.1, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.dropout = dropout + self.drop_path_rate = drop_path_rate + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.norm_type = norm_type + self.qkv_bias = qkv_bias + self.qk_normalization = qk_normalization + self.use_flash_attn = use_flash_attn + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig': + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + if 'vision_config' in config_dict: + config_dict = config_dict['vision_config'] + + if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.' + ) + + return cls.from_dict(config_dict, **kwargs) \ No newline at end of file diff --git a/ola/model/multimodal_encoder/internvl_vit.py b/ola/model/multimodal_encoder/internvl_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..3f97682f1ca872f7ffa715a1524b002abd5c0e17 --- /dev/null +++ b/ola/model/multimodal_encoder/internvl_vit.py @@ -0,0 +1,435 @@ +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2024 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from einops import rearrange +from timm.layers import DropPath +from torch import nn +from transformers.activations import ACT2FN +from transformers.modeling_outputs import (BaseModelOutput, + BaseModelOutputWithPooling) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +from .configuration_intern_vit import InternVisionConfig + +try: + from flash_attn.bert_padding import pad_input, unpad_input + from flash_attn.flash_attn_interface import \ + flash_attn_varlen_qkvpacked_func + has_flash_attn = True +except: + print('FlashAttention2 is not installed.') + has_flash_attn = False + +logger = logging.get_logger(__name__) + + +class FlashAttention(nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.0) + """ + + def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None): + super().__init__() + self.softmax_scale = softmax_scale + self.dropout_p = attention_dropout + + def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None, + max_s=None, need_weights=False): + """Implements the multihead softmax attention. + Arguments + --------- + qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None + if unpadded: (nnz, 3, h, d) + key_padding_mask: a bool tensor of shape (B, S) + """ + assert not need_weights + assert qkv.dtype in [torch.float16, torch.bfloat16] + assert qkv.is_cuda + + if cu_seqlens is None: + batch_size = qkv.shape[0] + seqlen = qkv.shape[1] + if key_padding_mask is None: + qkv = rearrange(qkv, 'b s ... -> (b s) ...') + max_s = seqlen + cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, + device=qkv.device) + output = flash_attn_varlen_qkvpacked_func( + qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, + softmax_scale=self.softmax_scale, causal=causal + ) + output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) + else: + nheads = qkv.shape[-2] + x = rearrange(qkv, 'b s three h d -> b s (three h d)') + x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask) + x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads) + output_unpad = flash_attn_varlen_qkvpacked_func( + x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, + softmax_scale=self.softmax_scale, causal=causal + ) + output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), + indices, batch_size, seqlen), + 'b s (h d) -> b s h d', h=nheads) + else: + assert max_s is not None + output = flash_attn_varlen_qkvpacked_func( + qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, + softmax_scale=self.softmax_scale, causal=causal + ) + + return output, None + + +class InternRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +try: + from apex.normalization import FusedRMSNorm + + InternRMSNorm = FusedRMSNorm # noqa + + logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm') +except ImportError: + # using the normal InternRMSNorm + pass +except Exception: + logger.warning('discovered apex but it failed to load, falling back to InternRMSNorm') + pass + + +NORM2FN = { + 'rms_norm': InternRMSNorm, + 'layer_norm': nn.LayerNorm, +} + + +class InternVisionEmbeddings(nn.Module): + def __init__(self, config: InternVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter( + torch.randn(1, 1, self.embed_dim), + ) + + self.patch_embedding = nn.Conv2d( + in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + + self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim)) + + def _get_pos_embed(self, pos_embed, H, W): + target_dtype = pos_embed.dtype + pos_embed = pos_embed.float().reshape( + 1, self.image_size // self.patch_size, self.image_size // self.patch_size, -1).permute(0, 3, 1, 2) + pos_embed = F.interpolate(pos_embed, size=(H, W), mode='bicubic', align_corners=False). \ + reshape(1, -1, H * W).permute(0, 2, 1).to(target_dtype) + return pos_embed + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values) # shape = [*, channel, width, height] + batch_size, _, height, width = patch_embeds.shape + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + position_embedding = torch.cat([ + self.position_embedding[:, :1, :], + self._get_pos_embed(self.position_embedding[:, 1:, :], height, width) + ], dim=1) + embeddings = embeddings + position_embedding.to(target_dtype) + return embeddings + + +class InternAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: InternVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.use_flash_attn = config.use_flash_attn and has_flash_attn + if config.use_flash_attn and not has_flash_attn: + print('Warning: Flash Attention is not available, use_flash_attn is set to False.') + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:' + f' {self.num_heads}).' + ) + + self.scale = self.head_dim ** -0.5 + self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias) + self.attn_drop = nn.Dropout(config.attention_dropout) + self.proj_drop = nn.Dropout(config.dropout) + + self.qk_normalization = config.qk_normalization + + if self.qk_normalization: + self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps) + self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps) + + if self.use_flash_attn: + self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout) + self.proj = nn.Linear(self.embed_dim, self.embed_dim) + + def _naive_attn(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + if self.qk_normalization: + B_, H_, N_, D_ = q.shape + q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2) + k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2) + + attn = ((q * self.scale) @ k.transpose(-2, -1)) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def _flash_attn(self, x, key_padding_mask=None, need_weights=False): + qkv = self.qkv(x) + qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads) + + if self.qk_normalization: + q, k, v = qkv.unbind(2) + q = self.q_norm(q.flatten(-2, -1)).view(q.shape) + k = self.k_norm(k.flatten(-2, -1)).view(k.shape) + qkv = torch.stack([q, k, v], dim=2) + + context, _ = self.inner_attn( + qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False + ) + outs = self.proj(rearrange(context, 'b s h d -> b s (h d)')) + outs = self.proj_drop(outs) + return outs + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states) + return x + + +class InternMLP(nn.Module): + def __init__(self, config: InternVisionConfig): + super().__init__() + self.config = config + self.act = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class InternVisionEncoderLayer(nn.Module): + def __init__(self, config: InternVisionConfig, drop_path_rate: float): + super().__init__() + self.embed_dim = config.hidden_size + self.intermediate_size = config.intermediate_size + self.norm_type = config.norm_type + + self.attn = InternAttention(config) + self.mlp = InternMLP(config) + self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) + self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) + + self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim)) + self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim)) + self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() + self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + def forward( + self, + hidden_states: torch.Tensor, + ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]: + """ + Args: + hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)` + """ + hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1) + + hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states).to(hidden_states.dtype)) * self.ls2) + + return hidden_states + + +class InternVisionEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`InternEncoderLayer`]. + + Args: + config (`InternConfig`): + The corresponding vision configuration for the `InternEncoder`. + """ + + def __init__(self, config: InternVisionConfig): + super().__init__() + self.config = config + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)] + self.layers = nn.ModuleList([ + InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)]) + self.gradient_checkpointing = True + + def forward( + self, + inputs_embeds, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Embedded representation of the inputs. Should be float, not int tokens. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + hidden_states = inputs_embeds + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = torch.utils.checkpoint.checkpoint( + encoder_layer, + hidden_states) + else: + layer_outputs = encoder_layer( + hidden_states, + ) + hidden_states = layer_outputs + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states + ) + + +class InternVisionModel(PreTrainedModel): + main_input_name = 'pixel_values' + # _supports_flash_attn_2 = True + supports_gradient_checkpointing = True + config_class = InternVisionConfig + _no_split_modules = ['InternVisionEncoderLayer'] + # support transformers 4.51.+ + _tp_plan = '' + + def __init__(self, config: InternVisionConfig): + super().__init__(config) + self.config = config + # Force eager attention implementation to avoid scaled_dot_product_attention error + # self._attn_implementation = "eager" + + self.embeddings = InternVisionEmbeddings(config) + self.encoder = InternVisionEncoder(config) + + def resize_pos_embeddings(self, old_size, new_size, patch_size): + pos_emb = self.embeddings.position_embedding + _, num_positions, embed_dim = pos_emb.shape + cls_emb = pos_emb[:, :1, :] + pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2) + pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode='bicubic', align_corners=False) + pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1) + pos_emb = torch.cat([cls_emb, pos_emb], dim=1) + self.embeddings.position_embedding = nn.Parameter(pos_emb) + self.embeddings.image_size = new_size + logger.info('Resized position embeddings from {} to {}'.format(old_size, new_size)) + + def get_input_embeddings(self): + return self.embeddings + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_embeds: Optional[torch.FloatTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None and pixel_embeds is None: + raise ValueError('You have to specify pixel_values or pixel_embeds') + + if pixel_embeds is not None: + hidden_states = pixel_embeds + else: + if len(pixel_values.shape) == 4: + hidden_states = self.embeddings(pixel_values) + else: + raise ValueError(f'wrong pixel_values size: {pixel_values.shape}') + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = encoder_outputs.last_hidden_state + pooled_output = last_hidden_state[:, 0, :] + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) diff --git a/ola/model/multimodal_encoder/oryx_vit.py b/ola/model/multimodal_encoder/oryx_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..006ac6699dcf1f99aa56217e8537da9a1987d8c3 --- /dev/null +++ b/ola/model/multimodal_encoder/oryx_vit.py @@ -0,0 +1,1075 @@ +import math +import warnings +from dataclasses import dataclass +from functools import partial +from typing import ( + Callable, + Dict, + Final, + List, + Literal, + Optional, + Sequence, + Set, + Tuple, + Type, + Union, +) + +from torch.utils.checkpoint import checkpoint +import torch +import torch.nn as nn +import torch.nn.functional as F +try: + from timm.layers import ( + AttentionPoolLatent, + DropPath, + LayerType, + Mlp, + PatchDropout, + PatchEmbed, + resample_abs_pos_embed, + ) + from timm.models._manipulate import checkpoint_seq, named_apply +except: + print('Wrong timm version') + +from flash_attn import flash_attn_func, flash_attn_varlen_func + +from typing import Optional + +import logging +import torch +import torch.nn as nn +import torch.nn.functional as F + +import deepspeed +import os + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) # noqa: E741 + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + # type: (torch.Tensor, float, float, float, float) -> torch.Tensor + r"""The original timm.models.layers.weight_init.trunc_normal_ can not handle bfloat16 yet, here we first + convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its original dtype. + Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn + from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + + with torch.no_grad(): + dtype = tensor.dtype + tensor_fp32 = tensor.float() + tensor_fp32 = _no_grad_trunc_normal_(tensor_fp32, mean, std, a, b) + tensor_dtype = tensor_fp32.to(dtype=dtype) + tensor.copy_(tensor_dtype) + + +def init_weights(self): + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5) + trunc_normal_(self.latent, std=self.latent_dim**-0.5) + + +def init_weights_vit_timm(module: nn.Module, name: str = "") -> None: + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, "init_weights"): + module.init_weights() + + +class Attention(nn.Module): + fused_attn: Final[bool] + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + # self.fused_attn = use_fused_attn() + self.fused_attn = True + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0.0 else nn.Identity() + + def forward(self, x: torch.Tensor, cu_slens=None) -> torch.Tensor: + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, self.head_dim) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if cu_slens is not None: + q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + max_seqlen = torch.max(cu_slens[1:] - cu_slens[:-1]).item() + x = flash_attn_varlen_func( + q.squeeze(0), + k.squeeze(0), + v.squeeze(0), + cu_seqlens_q=cu_slens, + cu_seqlens_k=cu_slens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + softmax_scale=self.scale, + causal=False, + ) + + x = x.reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + + else: + q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + x = flash_attn_func(q, k, v, softmax_scale=self.scale) # -> b, n, h, c + + x = x.reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + # if self.fused_attn: + # x = F.scaled_dot_product_attention( + # q, + # k, + # v, + # dropout_p=self.attn_drop.p if self.training else 0.0, + # ) + # else: + # q = q * self.scale + # attn = q @ k.transpose(-2, -1) + # attn = attn.softmax(dim=-1) + # attn = self.attn_drop(attn) + # x = attn @ v + + # x = x.transpose(1, 2).reshape(B, N, C) + # x = self.proj(x) + # x = self.proj_drop(x) + return x + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: float = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_norm: bool = False, + proj_drop: float = 0.0, + attn_drop: float = 0.0, + init_values: Optional[float] = None, + drop_path: float = 0.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + mlp_layer: nn.Module = Mlp, + ) -> None: + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + ) + self.ls1 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = mlp_layer( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + ) + self.ls2 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x: torch.Tensor, cu_slens=None) -> torch.Tensor: + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), cu_slens=cu_slens))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class VisionTransformer(nn.Module): + """Vision Transformer + + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + """ + + dynamic_img_size: Final[bool] + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: Literal["", "avg", "token", "map"] = "token", + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_norm: bool = False, + init_values: Optional[float] = None, + class_token: bool = True, + no_embed_class: bool = False, + reg_tokens: int = 0, + pre_norm: bool = False, + fc_norm: Optional[bool] = None, + dynamic_img_size: bool = False, + dynamic_img_pad: bool = False, + drop_rate: float = 0.0, + pos_drop_rate: float = 0.0, + patch_drop_rate: float = 0.0, + proj_drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "", + embed_layer: Callable = PatchEmbed, + norm_layer: Optional[LayerType] = None, + act_layer: Optional[LayerType] = None, + strict_img_size: bool = False, + block_fn: Type[nn.Module] = Block, + mlp_layer: Type[nn.Module] = Mlp, + ignore_head: bool = False, + add_patch2x2: bool = False, + ) -> None: + """ + Args: + img_size: Input image size. + patch_size: Patch size. + in_chans: Number of image input channels. + num_classes: Mumber of classes for classification head. + global_pool: Type of global pooling for final sequence (default: 'token'). + embed_dim: Transformer embedding dimension. + depth: Depth of transformer. + num_heads: Number of attention heads. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: Enable bias for qkv projections if True. + init_values: Layer-scale init values (layer-scale enabled if not None). + class_token: Use class token. + no_embed_class: Don't include position embeddings for class (or reg) tokens. + reg_tokens: Number of register tokens. + fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'. + drop_rate: Head dropout rate. + pos_drop_rate: Position embedding dropout rate. + attn_drop_rate: Attention dropout rate. + drop_path_rate: Stochastic depth rate. + weight_init: Weight initialization scheme. + embed_layer: Patch embedding layer. + norm_layer: Normalization layer. + act_layer: MLP activation layer. + block_fn: Transformer block layer. + """ + super().__init__() + assert global_pool in ("", "avg", "token", "map") + assert class_token or global_pool != "token" + use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm + # norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) + # act_layer = get_act_layer(act_layer) or nn.GELU + norm_layer = partial(nn.LayerNorm, eps=1e-6) + act_layer = nn.GELU + + self.num_classes = num_classes + self.global_pool = global_pool + self.num_features = self.embed_dim = ( + embed_dim # num_features for consistency with other models + ) + self.num_prefix_tokens = 1 if class_token else 0 + self.num_prefix_tokens += reg_tokens + self.num_reg_tokens = reg_tokens + self.has_class_token = class_token + self.no_embed_class = ( + no_embed_class # don't embed prefix positions (includes reg) + ) + self.dynamic_img_size = dynamic_img_size + self.grad_checkpointing = False + self.ignore_head = ignore_head + + embed_args = {} + if dynamic_img_size: + # flatten deferred until after pos embed + embed_args.update(dict(strict_img_size=False, output_fmt="NHWC")) + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) + dynamic_img_pad=dynamic_img_pad, + strict_img_size=strict_img_size, + **embed_args, + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = ( + nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None + ) + self.reg_token = ( + nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None + ) + embed_len = ( + num_patches if no_embed_class else num_patches + self.num_prefix_tokens + ) + self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02) + + + # deepspeed.zero.register_external_parameter(self, self.pos_embed) + # deepspeed.zero.register_external_parameter(self, self.patch_embed.proj.weight) + # deepspeed.zero.register_external_parameter(self, self.patch_embed.proj.bias) + # print(self.patch_embed.state_dict().keys()) + + + self.pos_drop = nn.Dropout(p=pos_drop_rate) + if patch_drop_rate > 0: + self.patch_drop = PatchDropout( + patch_drop_rate, + num_prefix_tokens=self.num_prefix_tokens, + ) + else: + self.patch_drop = nn.Identity() + self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + self.blocks = nn.Sequential( + *[ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + init_values=init_values, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + mlp_layer=mlp_layer, + ) + for i in range(depth) + ] + ) + + + if add_patch2x2: + if add_patch2x2 == 'v2': + self.downsample = nn.Sequential( + nn.Conv2d(embed_dim, embed_dim*2, kernel_size=2, stride=2), + nn.GELU(), + nn.Conv2d(embed_dim*2, embed_dim*4, 1) + ) + else: + mid_dim = embed_dim * 2 + self.downsample = nn.Sequential( + nn.Conv2d(embed_dim, mid_dim, kernel_size=2, stride=2), + nn.GELU(), + nn.Conv2d(mid_dim, mid_dim, 1) + ) + + else: + self.downsample = None + + + # self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() + + # # Classifier Head + # if global_pool == "map": + # AttentionPoolLatent.init_weights = init_weights + # self.attn_pool = AttentionPoolLatent( + # self.embed_dim, + # num_heads=num_heads, + # mlp_ratio=mlp_ratio, + # norm_layer=norm_layer, + # ) + # else: + # self.attn_pool = None + # self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() + # self.head_drop = nn.Dropout(drop_rate) + # self.head = ( + # nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + # ) + + # if weight_init != "skip": + # self.init_weights(weight_init) + + def init_weights(self, mode: Literal["jax", "jax_nlhb", "moco", ""] = "") -> None: + assert mode in ("jax", "jax_nlhb", "moco", "") + # head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0 + trunc_normal_(self.pos_embed, std=0.02) + if self.cls_token is not None: + nn.init.normal_(self.cls_token, std=1e-6) + named_apply(init_weights_vit_timm, self) + + @torch.jit.ignore + def no_weight_decay(self) -> Set: + return {"pos_embed", "cls_token", "dist_token"} + + @torch.jit.ignore + def group_matcher(self, coarse: bool = False) -> Dict: + return dict( + stem=r"^cls_token|pos_embed|patch_embed", # stem and embed + blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))], + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable: bool = True) -> None: + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self) -> nn.Module: + return self.head + + def reset_classifier(self, num_classes: int, global_pool=None) -> None: + self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ("", "avg", "token", "map") + if global_pool == "map" and self.attn_pool is None: + assert ( + False + ), "Cannot currently add attention pooling in reset_classifier()." + elif global_pool != "map " and self.attn_pool is not None: + self.attn_pool = None # remove attention pooling + self.global_pool = global_pool + self.head = ( + nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + ) + + def rescale_positional_embedding(self, out_size): + h, w = out_size + pos_embed_shape = int((self.pos_embed.shape[1]) ** 0.5) + if (h, w) == (pos_embed_shape, pos_embed_shape): + return self.pos_embed + rescaled_positional_embedding = \ + self.pos_embed.new_zeros(1, h*w, self.pos_embed.shape[2]) + pe_2d = self.pos_embed[0].T.contiguous().view(1, -1, pos_embed_shape, pos_embed_shape) + pe_2d = F.interpolate(pe_2d, out_size, mode='bilinear', align_corners=False).view(-1, h*w) + rescaled_positional_embedding[0] = pe_2d.T.contiguous() + return rescaled_positional_embedding + + def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: + if self.dynamic_img_size: + B, H, W, C = x.shape + pos_embed = resample_abs_pos_embed( + self.pos_embed, + (H, W), + num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens, + ) + x = x.view(B, -1, C) + else: + pos_embed = self.pos_embed + + to_cat = [] + if self.cls_token is not None: + to_cat.append(self.cls_token.expand(x.shape[0], -1, -1)) + if self.reg_token is not None: + to_cat.append(self.reg_token.expand(x.shape[0], -1, -1)) + + if self.no_embed_class: + # deit-3, updated JAX (big vision) + # position embedding does not overlap with class token, add then concat + x = x + pos_embed + if to_cat: + x = torch.cat(to_cat + [x], dim=1) + else: + # original timm, JAX, and deit vit impl + # pos_embed has entry for class token, concat then add + if to_cat: + x = torch.cat(to_cat + [x], dim=1) + x = x + pos_embed + + return self.pos_drop(x) + + def _intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, + ) -> List[torch.Tensor]: + outputs, num_blocks = [], len(self.blocks) + take_indices = set( + range(num_blocks - n, num_blocks) if isinstance(n, int) else n + ) + + # forward pass + x = self.patch_embed(x) + x = self._pos_embed(x) + x = self.patch_drop(x) + x = self.norm_pre(x) + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in take_indices: + outputs.append(x) + + return outputs + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, + reshape: bool = False, + return_prefix_tokens: bool = False, + norm: bool = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + """Intermediate layer accessor (NOTE: This is a WIP experiment). + Inspired by DINO / DINOv2 interface + """ + # take last n blocks if n is an int, if in is a sequence, select by matching indices + outputs = self._intermediate_layers(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + prefix_tokens = [out[:, 0 : self.num_prefix_tokens] for out in outputs] + outputs = [out[:, self.num_prefix_tokens :] for out in outputs] + + if reshape: + grid_size = self.patch_embed.grid_size + outputs = [ + out.reshape(x.shape[0], grid_size[0], grid_size[1], -1) + .permute(0, 3, 1, 2) + .contiguous() + for out in outputs + ] + + if return_prefix_tokens: + return tuple(zip(outputs, prefix_tokens)) + return tuple(outputs) + + def forward_features_list(self, x_list): + x_all = [] + image_sizes = [] + for x in x_list: + bs, _, h, w = x.shape + + # fix patch size=14 in datasets + pad_h = (self.patch_embed.patch_size[0] - h % self.patch_embed.patch_size[0]) % self.patch_embed.patch_size[0] + pad_w = (self.patch_embed.patch_size[1] - w % self.patch_embed.patch_size[1]) % self.patch_embed.patch_size[1] + x = F.pad(x, (0, pad_w, 0, pad_h)) + + bs, _, h, w = x.shape + + h = h // self.patch_embed.patch_size[0] + w = w // self.patch_embed.patch_size[1] + + x = self.patch_embed(x) + # x = self._pos_embed(x) + x = x + self.rescale_positional_embedding(out_size=(h, w)) + x = self.patch_drop(x) + x = self.norm_pre(x) + x_all.append(x) + image_sizes.append((h, w)) + + slen = [xi.size(1) for xi in x_all] + x = torch.cat(x_all, dim=1) + + cu_indices = [0, ] + for i in slen: + cu_indices.append(cu_indices[-1] + i) + + cu_slens = torch.tensor(cu_indices, dtype=torch.int32).to(x.device) + for idx, blk in enumerate(self.blocks): + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x, cu_slens, use_reentrant=True) + else: + x = blk(x, cu_slens=cu_slens) + feats = x.split(slen, dim=1) #[(1, slen, c)] + + if self.downsample is not None: + new_feats = [] + new_sizes = [] + for f, s in zip(feats, image_sizes): + h, w = s + b, n, c = f.size() + f = f.reshape(b, h, w, c).permute(0, 3, 1, 2) + f = self.downsample(f) + b, c, h, w = f.size() + f = f.permute(0, 2, 3, 1).reshape(b, h*w, c) + new_feats.append(f) + new_sizes.append((h, w)) + return new_feats, new_sizes + + + return feats, image_sizes + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + bs, _, h, w = x.shape + h = h // self.patch_embed.patch_size[0] + w = w // self.patch_embed.patch_size[1] + + x = self.patch_embed(x) + # x = self._pos_embed(x) + x = x + self.rescale_positional_embedding(out_size=(h, w)) + x = self.patch_drop(x) + x = self.norm_pre(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + + if self.downsample is not None: + b, n, c = x.size() + x = x.reshape(b, h, w, c).permute(0, 3, 1, 2) + x = self.downsample(x) + b, c, h, w = x.size() + x = x.permute(0, 2, 3, 1).reshape(b, h*w, c) + new_feats = x + new_sizes = (h, w) + return new_feats, new_sizes + + return x, (h, w) + + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + x = self.norm(x) + if self.attn_pool is not None: + x = self.attn_pool(x) + elif self.global_pool == "avg": + x = x[:, self.num_prefix_tokens :].mean(dim=1) + elif self.global_pool: + x = x[:, 0] # class token + x = self.fc_norm(x) + x = self.head_drop(x) + return x if pre_logits else self.head(x) + + def forward(self, x, cal_attn_pool=False): + # import pdb;pdb.set_trace() + if type(x) is list: + x, image_sizes = self.forward_features_list(x) + return x, image_sizes, None + else: + x, image_sizes = self.forward_features(x) + return x, image_sizes, None + +@dataclass +class SigLIPVisionCfg: + width: int = 1152 + layers: Union[Tuple[int, int, int, int], int] = 27 + heads: int = 16 + patch_size: int = 14 + image_size: Union[Tuple[int, int], int] = 336 + global_pool: str = "map" + mlp_ratio: float = 3.7362 + class_token: bool = False + num_classes: int = 0 + use_checkpoint: bool = False + + +SigLIP_MODEL_CONFIG = { + "siglip_so400m_patch14_384": { + "image_size": 384, + "patch_size": 14, + "width": 1152, + "layers": 27, + "heads": 16, + "mlp_ratio": 3.7362, + "global_pool": "map", + "use_checkpoint": False, + }, + "siglip_so400m_patch16_384": { + "image_size": 384, + "patch_size": 16, + "width": 1152, + "layers": 27, + "heads": 16, + "mlp_ratio": 3.7362, + "global_pool": "map", + "use_checkpoint": False, + }, + "siglip_so400m_patch14_224": { + "image_size": 224, + "patch_size": 14, + "width": 1152, + "layers": 27, + "heads": 16, + "mlp_ratio": 3.7362, + "global_pool": "map", + "use_checkpoint": False, + }, + "siglip_large_patch16_384": { + "image_size": 384, + "patch_size": 16, + "width": 1024, + "layers": 24, + "heads": 16, + "mlp_ratio": 4, + "global_pool": "map", + "use_checkpoint": False, + }, +} + + +def resize_evaclip_pos_embed(model: VisionTransformer, interpolation: str = 'bicubic'): + # interpolate position embedding + orig_size = 24 + new_size = 128 + pos_tokens = model.pos_embed + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, model.embed_dim).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode=interpolation, align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + model.pos_embed = nn.Parameter(pos_tokens, requires_grad=True) + return model + +def create_siglip_vit( + model_name: str = "siglip_so400m_patch14_384", + image_size: int = 384, + select_layer: int = -1, + path: str = "", + gradient_checkpointing: bool = False, + **kwargs, +): + assert ( + model_name in SigLIP_MODEL_CONFIG.keys() + ), f"model name should be in {SigLIP_MODEL_CONFIG.keys()}" + + vision_cfg = SigLIPVisionCfg(**SigLIP_MODEL_CONFIG[model_name]) + + if select_layer <= 0: + layers = min(vision_cfg.layers, vision_cfg.layers + select_layer + 1) + else: + layers = min(vision_cfg.layers, select_layer) + + + + if 'patch2x2' or 'patch4x4' in path: + add_patch2x2 = True + else: + add_patch2x2 = False + + if 'patch4x4pool' in path or 'patch2x2from4x4' in path: + add_patch2x2 = 'v2' + + if FORCE_NO_DOWNSAMPLE: + add_patch2x2 = False + + model = VisionTransformer( + img_size=2048, + patch_size=16, + embed_dim=vision_cfg.width, + depth=layers, + num_heads=vision_cfg.heads, + mlp_ratio=vision_cfg.mlp_ratio, + class_token=vision_cfg.class_token, + global_pool=vision_cfg.global_pool, + dynamic_img_pad=False, + strict_img_size=False, + ignore_head=kwargs.get("ignore_head", False), + weight_init=kwargs.get("weight_init", "skip"), + num_classes=0, + add_patch2x2=add_patch2x2 + ) + + if gradient_checkpointing: + model.set_grad_checkpointing(True) + return model + +import os +if 'LOAD_VISION_EARLY' in os.environ: + print("LOAD_VISION_EARLY is set") + LOAD_VISION_EARLY = True +else: + LOAD_VISION_EARLY = False + +if 'VIT_WITH_GRAD' in os.environ: + print("VIT_WITH_GRAD is set") + VIT_WITH_GRAD = True +else: + VIT_WITH_GRAD = False + +if 'FIX_SIZE' in os.environ: + print("FIX_SIZE is set") + FIX_SIZE = True +else: + FIX_SIZE = False + +if 'ANYRES_SPLIT' in os.environ: + ANYRES_SPLIT = int(os.environ['ANYRES_SPLIT']) + print(f"ANYRES_SPLIT is set as {ANYRES_SPLIT}") +else: + ANYRES_SPLIT = None + + +if 'FORCE_NO_DOWNSAMPLE' in os.environ: + print("FORCE_NO_DOWNSAMPLE is set") + FORCE_NO_DOWNSAMPLE = True +else: + FORCE_NO_DOWNSAMPLE = False + +from transformers import CLIPImageProcessor +import torch.distributed as dist + +class SigLIPViTAnysizeWrapper(nn.Module): + def __init__(self, vision_tower, path, args, delay_load=False): + super().__init__() + + self.is_loaded = False + + self.vision_tower_name = vision_tower + self.args = args + self.path = path + + self.select_layer = -1 + if self.select_layer < -1: self.select_layer += 1 + self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') + + self.output_dim = 1152 + if not FORCE_NO_DOWNSAMPLE: + if 'patch2x2' or 'patch4x4' in path: + self.output_dim = 1152*2 + + if 'patch4x4pool' in path or 'patch2x2from4x4' in path: + self.output_dim = 1152*4 + + if not delay_load or LOAD_VISION_EARLY: + self.load_model() + elif getattr(args, "unfreeze_mm_vision_tower", False): + # TODO: better detector is needed. + print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") + self.load_model() + + def load_model(self, device_map=None): + if self.is_loaded: + print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name)) + return + + self.image_processor = CLIPImageProcessor.from_pretrained("/data1/cxy/model/openai/clip-vit-large-patch14") + if self.args.mm_projector_type == "conv_mlp" or self.args.mm_projector_type == "multipath_conv_mlp" or self.args.mm_projector_type == "multipath_conv_mlp_woconv": + self.image_processor.crop_size['height'] = 384 + self.image_processor.crop_size['width'] = 384 + self.image_processor.size['shortest_edge'] = 384 + print("Resizeing clip processor to 384...") + self.image_processor.image_mean = [0.5, 0.5, 0.5] + self.image_processor.image_std = [0.5, 0.5, 0.5] + print("Loading vision model...") + if VIT_WITH_GRAD: + self.vision_tower = create_siglip_vit(path=self.path, model_name='siglip_so400m_patch16_384', + gradient_checkpointing=True) + self.vision_tower.train() + else: + self.vision_tower = create_siglip_vit(path=self.path, model_name='siglip_so400m_patch16_384', + gradient_checkpointing=False) + for p in self.vision_tower.parameters(): + p.requires_grad = False + self.vision_tower.eval() + self.is_loaded = True + + def train(self, mode = True): + self.training = mode + + if self.is_loaded and not VIT_WITH_GRAD: + self.vision_tower.eval() + + def split_images(self, images, split_res=512, base_size=32): + split_images = [] + sub_images_info = [] + for image in images: + now_sub_images = [] + _, c, h, w = image.shape + if h * w <= split_res * split_res: + split_images.append(image) + sub_images_info.append( + ( + 1, 1, 1, h // base_size, w // base_size, [(0, h // base_size, 0, w // base_size)] + ) + ) + continue + nsplit_h = math.ceil(h / split_res) + nsplit_w = math.ceil(w / split_res) + sub_h = int(h / nsplit_h / base_size) * base_size + sub_w = int(w / nsplit_w / base_size) * base_size + crop_infos = [] + for i in range(nsplit_h): + for j in range(nsplit_w): + begin_h = i * sub_h + begin_w = j * sub_w + + if i == nsplit_h - 1: + end_h = h + else: + end_h = (i + 1) * sub_h + + if j == nsplit_w - 1: + end_w = w + else: + end_w = (j + 1) * sub_w + + assert (end_h - begin_h) % base_size == 0 and (end_w - begin_w) % base_size == 0 + + sub_image = image[:, :, begin_h:end_h, begin_w:end_w] + now_sub_images.append(sub_image) + crop_infos.append( + (begin_h // base_size, end_h // base_size, begin_w // base_size, end_w // base_size) + ) + + split_images += now_sub_images + sub_images_info.append( + ( + len(now_sub_images), nsplit_h, nsplit_w, h // base_size, w // base_size, crop_infos + ) + ) + + return split_images, sub_images_info + + + def unsplit_images(self, features, sizes, sub_images_info): + new_features = [] + for feature, size in zip(features, sizes): + h, w = size + new_features.append( + feature.reshape(1, h, w, -1) + ) + + fused_images = [] + images_sizes = [] + sub_count = 0 + for n_split, nsplit_h, nsplit_w, total_h, total_w, crop_infos in sub_images_info: + sub_features = new_features[sub_count:sub_count+n_split] + sub_count += n_split + + total_feature = new_features[0].new_zeros(1, total_h, total_w, self.hidden_size) + for feature, (begin_h, end_h, begin_w, end_w) in zip(sub_features, crop_infos): + total_feature[:, begin_h:end_h, begin_w:end_w] += feature + + fused_images.append(total_feature.reshape(1, total_h * total_w, self.hidden_size)) + images_sizes.append((total_h, total_w)) + + return fused_images, images_sizes + + + + def forward_func(self, images, force_fix_size=False, cal_attn_pool=False): + if type(images) is list: + xs = [x.to(self.dtype) for x in images] + image_features, img_size, cls_token = self.vision_tower(xs, cal_attn_pool=cal_attn_pool) + image_features = [x.to(images[0].dtype) for x in image_features] + + else: + image_forward_outs, img_size, cls_token = self.vision_tower(images.to(self.dtype), cal_attn_pool=cal_attn_pool) + image_features = image_forward_outs.to(images.dtype) + + return image_features, img_size, cls_token + + def forward(self, images, cal_attn_pool=False): + if VIT_WITH_GRAD: + image_features, img_size, cls_token = self.forward_func(images, cal_attn_pool=cal_attn_pool) + return image_features, img_size + else: + with torch.no_grad(): + image_features, img_size, cls_token = self.forward_func(images, cal_attn_pool=cal_attn_pool) + return image_features, img_size + + + @property + def dummy_feature(self): + return torch.zeros(1, 1152, device=self.device, dtype=self.dtype) + + @property + def dtype(self): + return self.vision_tower.pos_embed.dtype + + @property + def device(self): + return self.vision_tower.pos_embed.device + + @property + def hidden_size(self): + return self.output_dim + + @property + def config(self): + return type('LLaVAConfigWrapper', (), { + # 'image_size': 224, + 'patch_size': 16, + })() diff --git a/ola/model/multimodal_projector/__pycache__/builder.cpython-312.pyc b/ola/model/multimodal_projector/__pycache__/builder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08851294f904c5fd5417d192ffbad50c8da53715 Binary files /dev/null and b/ola/model/multimodal_projector/__pycache__/builder.cpython-312.pyc differ diff --git a/ola/model/multimodal_projector/__pycache__/internvl_projector.cpython-312.pyc b/ola/model/multimodal_projector/__pycache__/internvl_projector.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7fafd6727de40ccde784d69fc55c0d051e7a847a Binary files /dev/null and b/ola/model/multimodal_projector/__pycache__/internvl_projector.cpython-312.pyc differ diff --git a/ola/model/multimodal_projector/__pycache__/pooler_projector.cpython-312.pyc b/ola/model/multimodal_projector/__pycache__/pooler_projector.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89dea68f63aa7d5db5c4c21ea2c54a714be44e34 Binary files /dev/null and b/ola/model/multimodal_projector/__pycache__/pooler_projector.cpython-312.pyc differ diff --git a/ola/model/multimodal_projector/builder.py b/ola/model/multimodal_projector/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..439a2ce1866d2264531bd1ba21e54d1dbd7c3975 --- /dev/null +++ b/ola/model/multimodal_projector/builder.py @@ -0,0 +1,177 @@ +import torch +import torch.nn as nn +import re + +import math + +from .pooler_projector import NormalizedDwPooler +from .internvl_projector import InternVLMultiModalProjector +import os +import math + +class IdentityMap(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, *args, **kwargs): + return x + + @property + def config(self): + return {"mm_projector_type": 'identity'} + + +class SimpleResBlock(nn.Module): + def __init__(self, channels): + super().__init__() + self.pre_norm = nn.LayerNorm(channels) + + self.proj = nn.Sequential( + nn.Linear(channels, channels), + nn.GELU(), + nn.Linear(channels, channels) + ) + def forward(self, x): + x = self.pre_norm(x) + return x + self.proj(x) + +class OlaMLP(nn.Module): + def __init__(self, in_channels, out_channels, twoview=False): + super().__init__() + + self.proj1 = nn.Linear(in_channels, out_channels) + self.proj2 = nn.Linear(out_channels, out_channels) + self.act = nn.GELU() + self.pooler = NormalizedDwPooler(out_channels) + + embed_std = 1 / math.sqrt(out_channels) + self.image_newline = nn.Parameter( + torch.randn(out_channels) * embed_std + ) + self.image_begin = nn.Parameter( + torch.randn(out_channels) * embed_std + ) + self.image_end = nn.Parameter( + torch.randn(out_channels) * embed_std + ) + + if twoview: + self.image_sep = nn.Parameter( + torch.randn(out_channels) * embed_std + ) + + def forward(self, x, size=(16,16), x2=None, size2=(16, 16), modalities='image'): + + if modalities in ['image', 'text']: + h, w = size + dtype = x.dtype + x = x.reshape(x.shape[0], h, w, -1) + x = self.proj1(x) + x = self.pooler(x, forward_type='2x') + x = self.act(x) + x = self.proj2(x) + + + b, h, w, c = x.shape + x = torch.cat([ + x, + self.image_newline.reshape(1, 1, 1, c).expand(b, h, 1, c).to(dtype) + ], dim=2) + x = x.reshape(b, -1, c) + + if x2 is not None: + h2, w2 = size2 + x2 = x2.reshape(x2.shape[0], h2, w2, -1) + x2 = self.proj1(x2) + x2 = self.pooler(x2, forward_type='2x') + x2 = self.act(x2) + x2 = self.proj2(x2) + + b2, h2, w2, c2 = x2.shape + x2 = torch.cat([ + x2, + self.image_newline.reshape(1, 1, 1, c).expand(b, h2, 1, c).to(dtype) + ], dim=2) + x2 = x2.reshape(b, -1, c) + sep = self.image_sep.reshape(1, 1, -1).expand(b, 1, c2).to(dtype) + x = torch.cat([x, sep, x2], dim=1) + + begin = self.image_begin.reshape(1, 1, -1).expand(b, 1, c).to(dtype) + end = self.image_end.reshape(1, 1, -1).expand(b, 1, c).to(dtype) + x = torch.cat([begin, x, end], dim=1) + return x + elif modalities in ['video']: + # x2 is the true feature, ignore x + h, w = size + dtype = x.dtype + x = x.reshape(x.shape[0], h, w, -1) + x1 = self.proj1(x) + x1 = self.pooler(x1, forward_type='2x') + x1 = self.proj2(x1).mean() * 0.0 + + h2, w2 = size2 + x2 = x2.reshape(x2.shape[0], h2, w2, -1) + x2 = self.proj1(x2) + x2 = self.pooler(x2, forward_type='2x') + x2 = self.act(x2) + x2 = self.proj2(x2) + + b2, h2, w2, c = x2.shape + x2 = torch.cat([ + x2, + self.image_newline.reshape(1, 1, 1, c).expand(b2, h2, 1, c).to(dtype) + ], dim=2) + + x2 = x2.reshape(b2, -1, c) + + sep = self.image_sep.reshape(1, 1, -1).expand(b2, 1, c).to(dtype) + x2 = torch.cat([x2, sep], dim=1) + + x2 = x2.flatten(0, 1) + + begin = self.image_begin.reshape(1, -1).expand(1, c).to(dtype) + end = self.image_end.reshape(1, -1).expand(1, c).to(dtype) + x2 = torch.cat([begin, x2, end], dim=0) + x2 = x2.unsqueeze(0) + return x2 + else: + raise ValueError(f'Unknown modalities: {modalities}') + +def build_vision_projector(config, delay_load=False, **kwargs): + projector_type = getattr(config, 'mm_projector_type', 'linear') + + if projector_type == 'linear': + return nn.Linear(config.mm_hidden_size, config.hidden_size) + + elif projector_type == 'ola_mlp': + return OlaMLP(config.mm_hidden_size, config.hidden_size, twoview=True) + + elif projector_type == 'ola_internvl': + # breakpoint() + return InternVLMultiModalProjector(config) + + mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) + if mlp_gelu_match: + mlp_depth = int(mlp_gelu_match.group(1)) + modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(config.hidden_size, config.hidden_size)) + return nn.Sequential(*modules) + + mlp_gelu_resnet_match = re.match(r'^mlp(\d+)x_res(\d+)x_gelu$', projector_type) + if mlp_gelu_resnet_match: + mlp_depth = int(mlp_gelu_resnet_match.group(1)) + res_depth = int(mlp_gelu_resnet_match.group(2)) + modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(config.hidden_size, config.hidden_size)) + for _ in range(res_depth): + modules.append(SimpleResBlock(config.hidden_size)) + return nn.Sequential(*modules) + + if projector_type == 'identity': + return IdentityMap() + + raise ValueError(f'Unknown projector type: {projector_type}') diff --git a/ola/model/multimodal_projector/internvl_projector.py b/ola/model/multimodal_projector/internvl_projector.py new file mode 100644 index 0000000000000000000000000000000000000000..c7fcadf56bcf0283c9b85e943aec2f0346294416 --- /dev/null +++ b/ola/model/multimodal_projector/internvl_projector.py @@ -0,0 +1,23 @@ +import torch +import torch.nn as nn + +from transformers.models.internvl.modeling_internvl import ACT2FN +from transformers.models.internvl.configuration_internvl import InternVLConfig + +class InternVLMultiModalProjector(nn.Module): + def __init__(self, config: InternVLConfig): + super().__init__() + # breakpoint() + self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size * int(1 / config.downsample_ratio) ** 2) + self.linear_1 = nn.Linear( + config.vision_config.hidden_size * int(1 / config.downsample_ratio) ** 2, config.llm_config.hidden_size + ) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = nn.Linear(config.llm_config.hidden_size, config.llm_config.hidden_size) + + def forward(self, image_features): + hidden_states = self.layer_norm(image_features) + hidden_states = self.linear_1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states \ No newline at end of file diff --git a/ola/model/multimodal_projector/pooler_projector.py b/ola/model/multimodal_projector/pooler_projector.py new file mode 100644 index 0000000000000000000000000000000000000000..8e8aaaf7ccf33dfe5c3564682d9564fe39714d5a --- /dev/null +++ b/ola/model/multimodal_projector/pooler_projector.py @@ -0,0 +1,67 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +from transformers.models.clip.modeling_clip import CLIPVisionModel +import os + +class PoolerProjector(nn.Module): + def __init__(self, config, vision_cfg): + super().__init__() + self._config = config + self.hw = vision_cfg.image_size // vision_cfg.patch_size + + self.conv_pool = nn.Conv2d( + config.mm_hidden_size, config.hidden_size, + kernel_size=2, stride=2 + ) + + self.proj = nn.Sequential( + nn.GELU(), + nn.Linear(config.hidden_size, config.hidden_size), + ) + + def forward(self, x, *args, **kwargs): + height = width = self.hw + assert height * width == x.shape[1] + x = x.view(x.shape[0], height, width, -1).permute(0, 3, 1, 2) + x = self.conv_pool(x) + x = x.flatten(2).transpose(1, 2) + x = self.proj(x) + return x + + @property + def config(self): + return {"mm_projector_type": 'pooler'} + + +class NormalizedDwPooler(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + self.predictor = nn.Sequential( + nn.Linear(dim*2, dim), + nn.GELU(), + nn.Linear(dim, dim), + ) + + def forward(self, x, forward_type='2x'): + B, H, W, C = x.shape + + if forward_type == '2x': + new_x = x.reshape(B, H//2, 2, W//2, 2, C).permute(0, 1, 3, 2, 4, 5).reshape(B, H//2, W//2, 4, C) + pooled_x = new_x.mean(-2, keepdim=True).expand(-1, -1, -1, 4, -1) + fused_x = torch.cat([new_x, pooled_x], dim=-1) + elif forward_type == '1x': + new_x = x.reshape(B, H, W, 1, C) + fused_x = torch.cat([new_x, new_x], dim=-1) + elif forward_type == '4x': + new_x = x.reshape(B, H//4, 4, W//4, 4, C).permute(0, 1, 3, 2, 4, 5).reshape(B, H//4, W//4, 16, C) + pooled_x = new_x.mean(-2, keepdim=True).expand(-1, -1, -1, 16, -1) + fused_x = torch.cat([new_x, pooled_x], dim=-1) + + score = self.predictor(fused_x) + normalized_score = F.softmax(score, dim=-2) + new_x = (new_x * normalized_score).sum(dim=-2) + return new_x diff --git a/ola/model/multimodal_resampler/__pycache__/builder.cpython-310.pyc b/ola/model/multimodal_resampler/__pycache__/builder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2504b6fa10ef9697cfd360ce35a4961fcf625c55 Binary files /dev/null and b/ola/model/multimodal_resampler/__pycache__/builder.cpython-310.pyc differ diff --git a/ola/model/multimodal_resampler/__pycache__/builder.cpython-312.pyc b/ola/model/multimodal_resampler/__pycache__/builder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7b9b51ce9271916adc44f2d7e563a414ac54521 Binary files /dev/null and b/ola/model/multimodal_resampler/__pycache__/builder.cpython-312.pyc differ diff --git a/ola/model/multimodal_resampler/__pycache__/builder.cpython-38.pyc b/ola/model/multimodal_resampler/__pycache__/builder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1db0df7ee7fd8deedebb963f9d123d1df23836f Binary files /dev/null and b/ola/model/multimodal_resampler/__pycache__/builder.cpython-38.pyc differ diff --git a/ola/model/multimodal_resampler/__pycache__/perceiver.cpython-310.pyc b/ola/model/multimodal_resampler/__pycache__/perceiver.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3adcab5b4fcad2aa843e24417c081a77c0415b8 Binary files /dev/null and b/ola/model/multimodal_resampler/__pycache__/perceiver.cpython-310.pyc differ diff --git a/ola/model/multimodal_resampler/__pycache__/perceiver.cpython-38.pyc b/ola/model/multimodal_resampler/__pycache__/perceiver.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4cd70c00230a9c9332e1f8fc76062c018d06efd2 Binary files /dev/null and b/ola/model/multimodal_resampler/__pycache__/perceiver.cpython-38.pyc differ diff --git a/ola/model/multimodal_resampler/builder.py b/ola/model/multimodal_resampler/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..994f816cde2d79e802c846d624af23fdcec762a4 --- /dev/null +++ b/ola/model/multimodal_resampler/builder.py @@ -0,0 +1,20 @@ +import torch + +class IdentityMap(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, *args, **kwargs): + return x + + @property + def config(self): + return {"mm_resampler_type": None} + +def build_vision_resampler(model_args, delay_load=False, **kwargs): + # import pdb;pdb.set_trace() + resampler_type = getattr(model_args, 'mm_resampler_type', None) + if resampler_type is None: + return IdentityMap() + else: + raise ValueError(f'Unknown resampler type: {resampler_type}') diff --git a/ola/model/ola_arch copy.py b/ola/model/ola_arch copy.py new file mode 100644 index 0000000000000000000000000000000000000000..61c625b5a8a0b7d500c9039d723fda307a1b0ffc --- /dev/null +++ b/ola/model/ola_arch copy.py @@ -0,0 +1,611 @@ +from abc import ABC, abstractmethod + +import torch + +from .speech_encoder.builder import build_speech_encoder +from .speech_projector.builder import build_speech_projector +from ola.constants import IGNORE_INDEX, SPEECH_TOKEN_INDEX +from ola.utils import lengths_to_padding_mask + +from .multimodal_encoder.builder import build_vision_tower +from .multimodal_resampler.builder import build_vision_resampler +from .multimodal_projector.builder import build_vision_projector + +from ola.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN + +class OlaMetaModel: + + def __init__(self, config, llm_config=None): + super(OlaMetaModel, self).__init__(config, llm_config) + + if hasattr(config, "speech_encoder"): + self.speech_encoder = build_speech_encoder(config) + self.speech_projector = build_speech_projector(config) + # breakpoint() + if hasattr(config, "vision_config"): + self.vision_tower = build_vision_tower(config.vision_config, delay_load=True) + self.vision_resampler = build_vision_resampler(config.vision_config, vision_tower=self.vision_tower) + self.mm_projector = build_vision_projector(config, vision_cfg=self.vision_tower.config) + # breakpoint() + + def get_speech_encoder(self): + speech_encoder = getattr(self, 'speech_encoder', None) + if type(speech_encoder) is list: + speech_encoder = speech_encoder[0] + return speech_encoder + + def get_vision_projector(self): + vision_projector = getattr(self, 'mm_projector', None) + if type(vision_projector) is list: + vision_projector = vision_projector[0] + return vision_projector + + def get_vision_tower(self): + vision_tower = getattr(self, 'vision_tower', None) + if type(vision_tower) is list: + vision_tower = vision_tower[0] + return vision_tower + + def initialize_speech_modules(self, model_args, fsdp=None): + self.config.speech_encoder = getattr(model_args, "speech_encoder", None) + self.config.speech_encoder_type = getattr(model_args, "speech_encoder_type", None) + self.config.speech_projector_type = getattr(model_args, 'speech_projector_type', 'linear') + self.config.speech_encoder_ds_rate = getattr(model_args, 'speech_encoder_ds_rate', 5) + self.config.speech_encoder_hidden_size = getattr(model_args, 'speech_encoder_hidden_size', 1280) + self.config.music_encoder = getattr(model_args, 'music_encoder', None) + + if self.get_speech_encoder() is None: + speech_encoder = build_speech_encoder(self.config) + if fsdp is not None and len(fsdp) > 0: + self.speech_encoder = [speech_encoder] + else: + self.speech_encoder = speech_encoder + + if getattr(self, 'speech_projector', None) is None: + self.speech_projector = build_speech_projector(self.config) + else: + # In case it is frozen by LoRA + for p in self.speech_projector.parameters(): + p.requires_grad = True + + if model_args.pretrain_speech_projector is not None: + pretrain_speech_projector_weights = torch.load(model_args.pretrain_speech_projector, map_location='cpu') + def get_w(weights, keyword): + return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} + print('Loading pretrain speech projector weights') + + msg = self.speech_projector.load_state_dict(get_w(pretrain_speech_projector_weights, 'speech_projector'), strict=False, assign=True) + print(msg) + + def initialize_vision_modules(self, model_args, fsdp=None): + vision_tower = model_args.vision_tower + mm_vision_select_layer = model_args.mm_vision_select_layer + mm_vision_select_feature = model_args.mm_vision_select_feature + pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter + + self.config.mm_vision_tower = vision_tower + + if self.get_vision_tower() is None: + vision_tower = build_vision_tower(model_args) + vision_resampler = build_vision_resampler(model_args, vision_tower=vision_tower) + ## Get the mm_spatial_pool_mode and mm_spatial_pool_stride + for k, v in vision_resampler.config.items(): + setattr(self.config, k, v) + + if fsdp is not None and len(fsdp) > 0: + self.vision_tower = [vision_tower] + self.vision_resampler = [vision_resampler] + else: + self.vision_tower = vision_tower + self.vision_resampler = vision_resampler + else: + if fsdp is not None and len(fsdp) > 0: + vision_resampler = self.vision_resampler[0] + vision_tower = self.vision_tower[0] + else: + vision_resampler = self.vision_resampler + vision_tower = self.vision_tower + vision_tower.load_model() + + # In case it is frozen by LoRA + for p in self.vision_resampler.parameters(): + p.requires_grad = True + + self.config.use_mm_proj = True + self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear') + self.config.mm_hidden_size = getattr(vision_resampler, 'hidden_size', vision_tower.hidden_size) + + self.config.mm_vision_select_layer = mm_vision_select_layer + self.config.mm_vision_select_feature = mm_vision_select_feature + + if getattr(self, 'mm_projector', None) is None: + self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config) + else: + for p in self.mm_projector.parameters(): + p.requires_grad = True + + if pretrain_mm_mlp_adapter is not None: + mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') + def get_w(weights, keyword): + return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} + + self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector')) + print('Loading pretrain mm projector weights') + incompatible_keys = self.vision_resampler.load_state_dict(get_w(mm_projector_weights, 'vision_resampler'), strict=False, assign=True) + print(incompatible_keys) + +class OlaMetaForCausalLM(ABC): + + @abstractmethod + def get_model(self): + pass + + def get_speech_encoder(self): + return self.get_model().get_speech_encoder() + + def get_vision_tower(self): + return self.get_model().get_vision_tower() + + def get_vision_projector(self): + return self.get_model().get_vision_projector() + + def get_speech_projector(self): + return self.get_model().speech_projector + + def encode_speech(self, speech, speech_lengths, speech_wav): + # import pdb; pdb.set_trace() + speech_encoder_type = self.config.speech_encoder_type + speech_encoder = self.get_speech_encoder() + if "whisper" in speech_encoder_type.lower(): + encoder_outs = speech_encoder(speech.permute(0, 2, 1)) + speech_lengths = (speech_lengths + 1) // 2 + else: + encoder_outs = speech_encoder(speech.permute(0, 2, 1), raw_wav=speech_wav) + speech_lengths = (speech_lengths + 1) // 2 + speech_projector_type = self.config.speech_projector_type + speech_projector = self.get_speech_projector() + if speech_projector_type == "linear": + encoder_outs = speech_projector(encoder_outs) + speech_lengths = speech_lengths // speech_projector.k + else: + raise ValueError(f'Unknown speech projector: {speech_projector_type}') + # speech_features = [encoder_outs[i, :speech_lengths[i]] for i in range(len(encoder_outs))] + return encoder_outs + + def prepare_inputs_labels_for_speech_text_for_ola( + self, input_ids, position_ids, attention_mask, past_key_values, labels, + speech, speech_lengths, speech_chunks, speech_wav, images, modalities, image_sizes=None, images_highres=None + ): + # breakpoint() + speech_encoder = self.get_speech_encoder() + vision_tower = self.get_vision_tower() + + if speech_encoder is None or input_ids.shape[1] == 1: + return input_ids, position_ids, attention_mask, past_key_values, None, labels + + if vision_tower is None or input_ids.shape[1] == 1: + return input_ids, position_ids, attention_mask, past_key_values, None, labels + # encode speech + if not isinstance(speech, list): + speech = torch.split(speech, speech_chunks.tolist(), dim=0) + speech_lengths = torch.split(speech_lengths, speech_chunks.tolist(), dim=0) + speech_wav = torch.split(speech_wav, speech_chunks.tolist(), dim=0) + speech_features = [] + for idx in range(len(speech)): + speech_features.append(self.encode_speech(speech[idx], speech_lengths[idx], speech_wav[idx])) + + # encode vision + if isinstance(modalities, str): + modalities = [modalities] + + video_idx_in_batch = [] + for modal in range(len(modalities)): + if 'video' in modalities[modal]: + video_idx_in_batch.append(modal) + + aimg = images[-1] + lowres_img = [] + for idx, img_feat in enumerate(images): + if idx in video_idx_in_batch: + img_feat = aimg.new(1, 3, 128, 128).fill_(0) + lowres_img.append(img_feat) + + lowres_img_features, lowres_img_sizes = self.get_model().get_vision_tower()(lowres_img) + highres_img_features = [] + highres_img_sizes = [] + for idx, img_feat in enumerate(images_highres): + if img_feat.ndim == 5: + img_feat = img_feat.squeeze(1) + highres_img_feature, highres_img_size = self.get_model().get_vision_tower()(img_feat) + highres_img_features.append(highres_img_feature) + highres_img_sizes.append(highres_img_size) + image_features = [] + for idx in range(len(modalities)): + img_feat = self.get_model().mm_projector(lowres_img_features[idx], + lowres_img_sizes[idx], + highres_img_features[idx], + highres_img_sizes[idx], + modalities[idx]) + image_features.append(img_feat.flatten(0, 1)) + + _labels = labels + _position_ids = position_ids + _attention_mask = attention_mask + if attention_mask is None: + attention_mask = torch.ones_like(input_ids, dtype=torch.bool) + else: + attention_mask = attention_mask.bool() + if position_ids is None: + position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) + if labels is None: + labels = torch.full_like(input_ids, IGNORE_INDEX) + + # remove the padding using attention_mask -- FIXME + _input_ids = input_ids + input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] + labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] + + new_input_embeds = [] + new_labels = [] + cur_speech_idx = 0 + cur_image_idx = 0 + for batch_idx, cur_input_ids in enumerate(input_ids): + + num_speech = (cur_input_ids == SPEECH_TOKEN_INDEX).sum() + num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() + + num_speech_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() + (cur_input_ids == SPEECH_TOKEN_INDEX).sum() + + if num_speech_images == 0: + cur_speech_features = speech_features[cur_speech_idx] + cur_images_features = image_features[cur_image_idx] + cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) + cur_input_embeds = torch.cat([cur_input_embeds_1, cur_speech_features[0:0], cur_images_features[0:0]], dim=0) + new_input_embeds.append(cur_input_embeds) + new_labels.append(labels[batch_idx]) + cur_speech_idx += 1 + cur_image_idx += 1 + continue + speech_image_token_indices = [-1] + torch.where((cur_input_ids == SPEECH_TOKEN_INDEX) | (cur_input_ids == IMAGE_TOKEN_INDEX))[0].tolist() + [cur_input_ids.shape[0]] + + cur_input_ids_nospeech_image = [] + cur_labels = labels[batch_idx] + cur_labels_nospeech_image = [] + for i in range(len(speech_image_token_indices) - 1): + cur_input_ids_nospeech_image.append(cur_input_ids[speech_image_token_indices[i]+1:speech_image_token_indices[i+1]]) + cur_labels_nospeech_image.append(cur_labels[speech_image_token_indices[i]+1:speech_image_token_indices[i+1]]) + split_sizes = [x.shape[0] for x in cur_labels_nospeech_image] + cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_nospeech_image)) + cur_input_embeds_no_speech_image = torch.split(cur_input_embeds, split_sizes, dim=0) + cur_new_input_embeds = [] + cur_new_labels = [] + + for i in range(num_speech_images + 1): + cur_new_input_embeds.append(cur_input_embeds_no_speech_image[i]) + cur_new_labels.append(cur_labels_nospeech_image[i]) + if i < num_speech_images: + if i < num_images: + cur_images_features = image_features[cur_image_idx] + cur_image_idx += 1 + cur_new_input_embeds.append(cur_images_features) + cur_new_labels.append(torch.full((cur_images_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) + else: + cur_speech_features = speech_features[cur_speech_idx] + cur_speech_idx += 1 + cur_new_input_embeds.append(cur_speech_features) + cur_new_labels.append(torch.full((cur_speech_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) + + cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] + + cur_new_input_embeds = torch.cat(cur_new_input_embeds) + cur_new_labels = torch.cat(cur_new_labels) + + if num_images == 0: + cur_new_input_embeds = torch.cat([cur_new_input_embeds, image_features[cur_image_idx][0:0]], dim=0) + cur_image_idx += 1 + + if num_speech == 0: + cur_new_input_embeds = torch.cat([cur_new_input_embeds, speech_features[cur_speech_idx][0:0]], dim=0) + cur_speech_idx += 1 + + new_input_embeds.append(cur_new_input_embeds) + new_labels.append(cur_new_labels) + + # Truncate sequences to max length as speech features can make the sequence longer + tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None) + if tokenizer_model_max_length is not None: + new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] + new_labels = [x[:tokenizer_model_max_length] for x in new_labels] + + # Combine them + max_len = max(x.shape[0] for x in new_input_embeds) + batch_size = len(new_input_embeds) + + new_input_embeds_padded = [] + new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) + attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) + position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) + + for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): + cur_len = cur_new_embed.shape[0] + if getattr(self.config, 'tokenizer_padding_side', 'right') == "left": + new_input_embeds_padded.append(torch.cat(( + torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), + cur_new_embed + ), dim=0)) + if cur_len > 0: + new_labels_padded[i, -cur_len:] = cur_new_labels + attention_mask[i, -cur_len:] = True + position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) + else: + new_input_embeds_padded.append(torch.cat(( + cur_new_embed, + torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device) + ), dim=0)) + if cur_len > 0: + new_labels_padded[i, :cur_len] = cur_new_labels + attention_mask[i, :cur_len] = True + position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) + + new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) + + if _labels is None: + new_labels = None + else: + new_labels = new_labels_padded + + if _attention_mask is None: + attention_mask = None + else: + attention_mask = attention_mask.to(dtype=_attention_mask.dtype) + + if _position_ids is None: + position_ids = None + + return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels + + + def prepare_inputs_labels_for_speech_text_for_internvl( + self, input_ids, position_ids, attention_mask, past_key_values, labels, + speech, speech_lengths, speech_chunks, speech_wav, modalities + ): + speech_encoder = self.get_speech_encoder() + if speech_encoder is None or input_ids.shape[1] == 1: + return input_ids, position_ids, attention_mask, past_key_values, None, labels + + # Encode speech + if speech is not None: + if not isinstance(speech, list): + if speech_chunks is not None: + speech = torch.split(speech, speech_chunks.tolist(), dim=0) + speech_lengths = torch.split(speech_lengths, speech_chunks.tolist(), dim=0) + speech_wav = torch.split(speech_wav, speech_chunks.tolist(), dim=0) + else: + speech = [speech] + speech_lengths = [speech_lengths] + speech_wav = [speech_wav] + + speech_features = [] + for idx in range(len(speech)): + speech_feat = self.encode_speech(speech[idx], speech_lengths[idx], speech_wav[idx]) + if speech_feat is not None: + speech_features.append(speech_feat) + else: + speech_features = [] + + # Encode vision (skip if doing speech-only training) + if isinstance(modalities, str): + modalities = [modalities] + + image_features = [] + # Skip vision processing if doing speech-only training + if (pixel_values is not None and ("audio" not in modalities)): + if image_flags is not None: + image_flags = image_flags.squeeze(-1) + vit_embeds = self.extract_feature(pixel_values) + # Only process features where image_flags == 1 (actual images) + valid_indices = (image_flags == 1).nonzero(as_tuple=True)[0] + if len(valid_indices) > 0: + vit_embeds = vit_embeds[valid_indices] + # Apply vision projector + for idx in range(len(valid_indices)): + img_feat = self.get_vision_projector()(vit_embeds[idx:idx+1]) + image_features.append(img_feat.flatten(0, 1)) + else: + vit_embeds = self.extract_feature(pixel_values) + for idx in range(vit_embeds.shape[0]): + img_feat = vit_embeds[idx:idx+1] + image_features.append(img_feat.flatten(0, 1)) + + # Save original values + _labels = labels + _position_ids = position_ids + _attention_mask = attention_mask + + if attention_mask is None: + attention_mask = torch.ones_like(input_ids, dtype=torch.bool) + else: + attention_mask = attention_mask.bool() + if position_ids is None: + position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) + if labels is None: + labels = torch.full_like(input_ids, IGNORE_INDEX) + + # Remove padding using attention_mask + input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] + labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] + + new_input_embeds = [] + new_labels = [] + cur_speech_idx = 0 + cur_image_idx = 0 + + for batch_idx, cur_input_ids in enumerate(input_ids): + num_speech = (cur_input_ids == SPEECH_TOKEN_INDEX).sum() + num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() + num_speech_images = num_images + num_speech + + if num_speech_images == 0: + # No speech or image tokens - just use text embeddings + cur_input_embeds = self.get_model().get_input_embeddings()(cur_input_ids) + new_input_embeds.append(cur_input_embeds) + new_labels.append(labels[batch_idx]) + continue + + # Handle speech and image tokens + speech_image_token_indices = [-1] + torch.where((cur_input_ids == SPEECH_TOKEN_INDEX) | (cur_input_ids == IMAGE_TOKEN_INDEX))[0].tolist() + [cur_input_ids.shape[0]] + + cur_input_ids_nospeech_image = [] + cur_labels = labels[batch_idx] + cur_labels_nospeech_image = [] + + for i in range(len(speech_image_token_indices) - 1): + cur_input_ids_nospeech_image.append(cur_input_ids[speech_image_token_indices[i]+1:speech_image_token_indices[i+1]]) + cur_labels_nospeech_image.append(cur_labels[speech_image_token_indices[i]+1:speech_image_token_indices[i+1]]) + + split_sizes = [x.shape[0] for x in cur_labels_nospeech_image] + cur_input_embeds = self.get_model().get_input_embeddings()(torch.cat(cur_input_ids_nospeech_image)) + cur_input_embeds_no_speech_image = torch.split(cur_input_embeds, split_sizes, dim=0) + cur_new_input_embeds = [] + cur_new_labels = [] + + # Process tokens in order, similar to OLA's approach + for i in range(num_speech_images + 1): + cur_new_input_embeds.append(cur_input_embeds_no_speech_image[i]) + cur_new_labels.append(cur_labels_nospeech_image[i]) + + if i < num_speech_images: + # Determine which token type comes next based on position + if i < len(speech_image_token_indices) - 1: + token_pos = speech_image_token_indices[i + 1] + token_type = cur_input_ids[token_pos].item() + + if token_type == SPEECH_TOKEN_INDEX: + # Process speech token + if cur_speech_idx < len(speech_features): + cur_speech_features = speech_features[cur_speech_idx] + cur_speech_idx += 1 + cur_new_input_embeds.append(cur_speech_features) + cur_new_labels.append(torch.full((cur_speech_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) + else: + # No more speech features available - this shouldn't happen in normal cases + print(f"Warning: No speech features available for speech token at position {token_pos}") + + elif token_type == IMAGE_TOKEN_INDEX: + # Process image token + if cur_image_idx < len(image_features): + cur_images_features = image_features[cur_image_idx] + cur_image_idx += 1 + cur_new_input_embeds.append(cur_images_features) + cur_new_labels.append(torch.full((cur_images_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) + else: + # No more image features available - this shouldn't happen in normal cases + print(f"Warning: No image features available for image token at position {token_pos}") + + cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] + cur_new_input_embeds = torch.cat(cur_new_input_embeds) + cur_new_labels = torch.cat(cur_new_labels) + + new_input_embeds.append(cur_new_input_embeds) + new_labels.append(cur_new_labels) + + # Truncate sequences to max length + tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None) + if tokenizer_model_max_length is not None: + new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] + new_labels = [x[:tokenizer_model_max_length] for x in new_labels] + + # Combine and pad + max_len = max(x.shape[0] for x in new_input_embeds) if new_input_embeds else 0 + batch_size = len(new_input_embeds) + + if max_len > 0: + new_input_embeds_padded = [] + new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) + attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) + position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) + + for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): + cur_len = cur_new_embed.shape[0] + if getattr(self.config, 'tokenizer_padding_side', 'right') == "left": + new_input_embeds_padded.append(torch.cat(( + torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), + cur_new_embed + ), dim=0)) + if cur_len > 0: + new_labels_padded[i, -cur_len:] = cur_new_labels + attention_mask[i, -cur_len:] = True + position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) + else: + new_input_embeds_padded.append(torch.cat(( + cur_new_embed, + torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device) + ), dim=0)) + if cur_len > 0: + new_labels_padded[i, :cur_len] = cur_new_labels + attention_mask[i, :cur_len] = True + position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) + + new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) + + if _labels is None: + new_labels = None + else: + new_labels = new_labels_padded + + if _attention_mask is None: + attention_mask = None + else: + attention_mask = attention_mask.to(dtype=_attention_mask.dtype) + + if _position_ids is None: + position_ids = None + + return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels + else: + return input_ids, position_ids, attention_mask, past_key_values, None, labels + + + def initialize_vision_tokenizer(self, model_args, tokenizer): + if model_args.mm_use_im_patch_token: + tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) + self.resize_token_embeddings(len(tokenizer)) + + if model_args.mm_use_im_start_end: + num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) + self.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = self.get_input_embeddings().weight.data + output_embeddings = self.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + if model_args.tune_mm_mlp_adapter: + for p in self.get_input_embeddings().parameters(): + p.requires_grad = True + for p in self.get_output_embeddings().parameters(): + p.requires_grad = False + + if model_args.pretrain_mm_mlp_adapter: + mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu') + embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight'] + assert num_new_tokens == 2 + if input_embeddings.shape == embed_tokens_weight.shape: + input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] + elif embed_tokens_weight.shape[0] == num_new_tokens: + input_embeddings[-num_new_tokens:] = embed_tokens_weight + else: + raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") + elif model_args.mm_use_im_patch_token: + if model_args.tune_mm_mlp_adapter: + for p in self.get_input_embeddings().parameters(): + p.requires_grad = False + for p in self.get_output_embeddings().parameters(): + p.requires_grad = False \ No newline at end of file diff --git a/ola/model/ola_arch.py b/ola/model/ola_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..1a192e00719ec147d22c2324c4d63364b7bcf8aa --- /dev/null +++ b/ola/model/ola_arch.py @@ -0,0 +1,595 @@ +from abc import ABC, abstractmethod + +import torch + +from .speech_encoder.builder import build_speech_encoder +from .speech_projector.builder import build_speech_projector +from ola.constants import IGNORE_INDEX, SPEECH_TOKEN_INDEX +from ola.utils import lengths_to_padding_mask + +from .multimodal_encoder.builder import build_vision_tower +from .multimodal_resampler.builder import build_vision_resampler +from .multimodal_projector.builder import build_vision_projector + +from ola.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN + +class OlaMetaModel: + + def __init__(self, config, llm_config=None): + super(OlaMetaModel, self).__init__(config, llm_config) + + if hasattr(config, "speech_encoder"): + self.speech_encoder = build_speech_encoder(config) + self.speech_projector = build_speech_projector(config) + # breakpoint() + if hasattr(config, "vision_config"): + self.vision_tower = build_vision_tower(config.vision_config, delay_load=True) + self.vision_resampler = build_vision_resampler(config.vision_config, vision_tower=self.vision_tower) + self.mm_projector = build_vision_projector(config, vision_cfg=self.vision_tower.config) + # breakpoint() + + def get_speech_encoder(self): + speech_encoder = getattr(self, 'speech_encoder', None) + if type(speech_encoder) is list: + speech_encoder = speech_encoder[0] + return speech_encoder + + def get_vision_projector(self): + vision_projector = getattr(self, 'mm_projector', None) + if type(vision_projector) is list: + vision_projector = vision_projector[0] + return vision_projector + + def get_vision_tower(self): + vision_tower = getattr(self, 'vision_tower', None) + if type(vision_tower) is list: + vision_tower = vision_tower[0] + return vision_tower + + def initialize_speech_modules(self, model_args, fsdp=None): + self.config.speech_encoder = getattr(model_args, "speech_encoder", None) + self.config.speech_encoder_type = getattr(model_args, "speech_encoder_type", None) + self.config.speech_projector_type = getattr(model_args, 'speech_projector_type', 'linear') + self.config.speech_encoder_ds_rate = getattr(model_args, 'speech_encoder_ds_rate', 5) + self.config.speech_encoder_hidden_size = getattr(model_args, 'speech_encoder_hidden_size', 1280) + self.config.music_encoder = getattr(model_args, 'music_encoder', None) + + if self.get_speech_encoder() is None: + speech_encoder = build_speech_encoder(self.config) + if fsdp is not None and len(fsdp) > 0: + self.speech_encoder = [speech_encoder] + else: + self.speech_encoder = speech_encoder + + if getattr(self, 'speech_projector', None) is None: + self.speech_projector = build_speech_projector(self.config) + else: + # In case it is frozen by LoRA + for p in self.speech_projector.parameters(): + p.requires_grad = True + + if model_args.pretrain_speech_projector is not None: + pretrain_speech_projector_weights = torch.load(model_args.pretrain_speech_projector, map_location='cpu') + def get_w(weights, keyword): + return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} + print('Loading pretrain speech projector weights') + + msg = self.speech_projector.load_state_dict(get_w(pretrain_speech_projector_weights, 'speech_projector'), strict=False, assign=True) + print(msg) + + def initialize_vision_modules(self, model_args, fsdp=None): + vision_tower = model_args.vision_tower + mm_vision_select_layer = model_args.mm_vision_select_layer + mm_vision_select_feature = model_args.mm_vision_select_feature + pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter + + self.config.mm_vision_tower = vision_tower + + if self.get_vision_tower() is None: + vision_tower = build_vision_tower(model_args) + vision_resampler = build_vision_resampler(model_args, vision_tower=vision_tower) + ## Get the mm_spatial_pool_mode and mm_spatial_pool_stride + for k, v in vision_resampler.config.items(): + setattr(self.config, k, v) + + if fsdp is not None and len(fsdp) > 0: + self.vision_tower = [vision_tower] + self.vision_resampler = [vision_resampler] + else: + self.vision_tower = vision_tower + self.vision_resampler = vision_resampler + else: + if fsdp is not None and len(fsdp) > 0: + vision_resampler = self.vision_resampler[0] + vision_tower = self.vision_tower[0] + else: + vision_resampler = self.vision_resampler + vision_tower = self.vision_tower + vision_tower.load_model() + + # In case it is frozen by LoRA + for p in self.vision_resampler.parameters(): + p.requires_grad = True + + self.config.use_mm_proj = True + self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear') + self.config.mm_hidden_size = getattr(vision_resampler, 'hidden_size', vision_tower.hidden_size) + + self.config.mm_vision_select_layer = mm_vision_select_layer + self.config.mm_vision_select_feature = mm_vision_select_feature + + if getattr(self, 'mm_projector', None) is None: + self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config) + else: + for p in self.mm_projector.parameters(): + p.requires_grad = True + + if pretrain_mm_mlp_adapter is not None: + mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') + def get_w(weights, keyword): + return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} + + self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector')) + print('Loading pretrain mm projector weights') + incompatible_keys = self.vision_resampler.load_state_dict(get_w(mm_projector_weights, 'vision_resampler'), strict=False, assign=True) + print(incompatible_keys) + +class OlaMetaForCausalLM(ABC): + + @abstractmethod + def get_model(self): + pass + + def get_speech_encoder(self): + return self.get_model().get_speech_encoder() + + def get_vision_tower(self): + return self.get_model().get_vision_tower() + + def get_vision_projector(self): + return self.get_model().get_vision_projector() + + def get_speech_projector(self): + return self.get_model().speech_projector + + def encode_speech(self, speech, speech_lengths, speech_wav): + # import pdb; pdb.set_trace() + print(f"Debug - Input speech range: {speech.min().item()} to {speech.max().item()}") + print(f"Debug - Input speech has nan: {torch.isnan(speech).any().item()}") + print(f"Debug - Input speech has inf: {torch.isinf(speech).any().item()}") + print(f"Debug - Input speech shape: {speech.shape}") + + print(f"Debug - Input speech_wav range: {speech_wav.min().item()} to {speech_wav.max().item()}") + print(f"Debug - Input speech_wav has nan: {torch.isnan(speech_wav).any().item()}") + print(f"Debug - Input speech_wav has inf: {torch.isinf(speech_wav).any().item()}") + print(f"Debug - Input speech_wav shape: {speech_wav.shape}") + + speech_encoder_type = self.config.speech_encoder_type + speech_encoder = self.get_speech_encoder() + if "whisper" in speech_encoder_type.lower(): + encoder_outs = speech_encoder(speech.permute(0, 2, 1)) + speech_lengths = (speech_lengths + 1) // 2 + else: + encoder_outs = speech_encoder(speech.permute(0, 2, 1), raw_wav=speech_wav) + speech_lengths = (speech_lengths + 1) // 2 + + print(f"Debug - After speech encoder range: {encoder_outs.min().item()} to {encoder_outs.max().item()}") + print(f"Debug - After speech encoder has nan: {torch.isnan(encoder_outs).any().item()}") + print(f"Debug - After speech encoder has inf: {torch.isinf(encoder_outs).any().item()}") + print(f"Debug - After speech encoder shape: {encoder_outs.shape}") + + speech_projector_type = self.config.speech_projector_type + speech_projector = self.get_speech_projector() + if speech_projector_type == "linear": + encoder_outs = speech_projector(encoder_outs) + speech_lengths = speech_lengths // speech_projector.k + else: + raise ValueError(f'Unknown speech projector: {speech_projector_type}') + + print(f"Debug - After speech projector range: {encoder_outs.min().item()} to {encoder_outs.max().item()}") + print(f"Debug - After speech projector has nan: {torch.isnan(encoder_outs).any().item()}") + print(f"Debug - After speech projector has inf: {torch.isinf(encoder_outs).any().item()}") + print(f"Debug - After speech projector shape: {encoder_outs.shape}") + + # speech_features = [encoder_outs[i, :speech_lengths[i]] for i in range(len(encoder_outs))] + return encoder_outs + + def prepare_inputs_labels_for_speech_text_for_ola( + self, input_ids, position_ids, attention_mask, past_key_values, labels, + speech, speech_lengths, speech_chunks, speech_wav, images, modalities, image_sizes=None, images_highres=None + ): + # breakpoint() + speech_encoder = self.get_speech_encoder() + vision_tower = self.get_vision_tower() + + if speech_encoder is None or input_ids.shape[1] == 1: + return input_ids, position_ids, attention_mask, past_key_values, None, labels + + if vision_tower is None or input_ids.shape[1] == 1: + return input_ids, position_ids, attention_mask, past_key_values, None, labels + # encode speech + if not isinstance(speech, list): + speech = torch.split(speech, speech_chunks.tolist(), dim=0) + speech_lengths = torch.split(speech_lengths, speech_chunks.tolist(), dim=0) + speech_wav = torch.split(speech_wav, speech_chunks.tolist(), dim=0) + speech_features = [] + for idx in range(len(speech)): + speech_features.append(self.encode_speech(speech[idx], speech_lengths[idx], speech_wav[idx])) + + # encode vision + if isinstance(modalities, str): + modalities = [modalities] + + video_idx_in_batch = [] + for modal in range(len(modalities)): + if 'video' in modalities[modal]: + video_idx_in_batch.append(modal) + + aimg = images[-1] + lowres_img = [] + for idx, img_feat in enumerate(images): + if idx in video_idx_in_batch: + img_feat = aimg.new(1, 3, 128, 128).fill_(0) + lowres_img.append(img_feat) + + lowres_img_features, lowres_img_sizes = self.get_model().get_vision_tower()(lowres_img) + highres_img_features = [] + highres_img_sizes = [] + for idx, img_feat in enumerate(images_highres): + if img_feat.ndim == 5: + img_feat = img_feat.squeeze(1) + highres_img_feature, highres_img_size = self.get_model().get_vision_tower()(img_feat) + highres_img_features.append(highres_img_feature) + highres_img_sizes.append(highres_img_size) + image_features = [] + for idx in range(len(modalities)): + img_feat = self.get_model().mm_projector(lowres_img_features[idx], + lowres_img_sizes[idx], + highres_img_features[idx], + highres_img_sizes[idx], + modalities[idx]) + image_features.append(img_feat.flatten(0, 1)) + + _labels = labels + _position_ids = position_ids + _attention_mask = attention_mask + if attention_mask is None: + attention_mask = torch.ones_like(input_ids, dtype=torch.bool) + else: + attention_mask = attention_mask.bool() + if position_ids is None: + position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) + if labels is None: + labels = torch.full_like(input_ids, IGNORE_INDEX) + + # remove the padding using attention_mask -- FIXME + _input_ids = input_ids + input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] + labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] + + new_input_embeds = [] + new_labels = [] + cur_speech_idx = 0 + cur_image_idx = 0 + for batch_idx, cur_input_ids in enumerate(input_ids): + + num_speech = (cur_input_ids == SPEECH_TOKEN_INDEX).sum() + num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() + + num_speech_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() + (cur_input_ids == SPEECH_TOKEN_INDEX).sum() + + if num_speech_images == 0: + cur_speech_features = speech_features[cur_speech_idx] + cur_images_features = image_features[cur_image_idx] + cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) + cur_input_embeds = torch.cat([cur_input_embeds_1, cur_speech_features[0:0], cur_images_features[0:0]], dim=0) + new_input_embeds.append(cur_input_embeds) + new_labels.append(labels[batch_idx]) + cur_speech_idx += 1 + cur_image_idx += 1 + continue + speech_image_token_indices = [-1] + torch.where((cur_input_ids == SPEECH_TOKEN_INDEX) | (cur_input_ids == IMAGE_TOKEN_INDEX))[0].tolist() + [cur_input_ids.shape[0]] + + cur_input_ids_nospeech_image = [] + cur_labels = labels[batch_idx] + cur_labels_nospeech_image = [] + for i in range(len(speech_image_token_indices) - 1): + cur_input_ids_nospeech_image.append(cur_input_ids[speech_image_token_indices[i]+1:speech_image_token_indices[i+1]]) + cur_labels_nospeech_image.append(cur_labels[speech_image_token_indices[i]+1:speech_image_token_indices[i+1]]) + split_sizes = [x.shape[0] for x in cur_labels_nospeech_image] + cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_nospeech_image)) + cur_input_embeds_no_speech_image = torch.split(cur_input_embeds, split_sizes, dim=0) + cur_new_input_embeds = [] + cur_new_labels = [] + + for i in range(num_speech_images + 1): + cur_new_input_embeds.append(cur_input_embeds_no_speech_image[i]) + cur_new_labels.append(cur_labels_nospeech_image[i]) + if i < num_speech_images: + if i < num_images: + cur_images_features = image_features[cur_image_idx] + cur_image_idx += 1 + cur_new_input_embeds.append(cur_images_features) + cur_new_labels.append(torch.full((cur_images_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) + else: + cur_speech_features = speech_features[cur_speech_idx] + cur_speech_idx += 1 + cur_new_input_embeds.append(cur_speech_features) + cur_new_labels.append(torch.full((cur_speech_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) + + cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] + + cur_new_input_embeds = torch.cat(cur_new_input_embeds) + cur_new_labels = torch.cat(cur_new_labels) + + if num_images == 0: + cur_new_input_embeds = torch.cat([cur_new_input_embeds, image_features[cur_image_idx][0:0]], dim=0) + cur_image_idx += 1 + + if num_speech == 0: + cur_new_input_embeds = torch.cat([cur_new_input_embeds, speech_features[cur_speech_idx][0:0]], dim=0) + cur_speech_idx += 1 + + new_input_embeds.append(cur_new_input_embeds) + new_labels.append(cur_new_labels) + + # Truncate sequences to max length as speech features can make the sequence longer + tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None) + if tokenizer_model_max_length is not None: + new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] + new_labels = [x[:tokenizer_model_max_length] for x in new_labels] + + # Combine them + max_len = max(x.shape[0] for x in new_input_embeds) + batch_size = len(new_input_embeds) + + new_input_embeds_padded = [] + new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) + attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) + position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) + + for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): + cur_len = cur_new_embed.shape[0] + if getattr(self.config, 'tokenizer_padding_side', 'right') == "left": + new_input_embeds_padded.append(torch.cat(( + torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), + cur_new_embed + ), dim=0)) + if cur_len > 0: + new_labels_padded[i, -cur_len:] = cur_new_labels + attention_mask[i, -cur_len:] = True + position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) + else: + new_input_embeds_padded.append(torch.cat(( + cur_new_embed, + torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device) + ), dim=0)) + if cur_len > 0: + new_labels_padded[i, :cur_len] = cur_new_labels + attention_mask[i, :cur_len] = True + position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) + + new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) + + if _labels is None: + new_labels = None + else: + new_labels = new_labels_padded + + if _attention_mask is None: + attention_mask = None + else: + attention_mask = attention_mask.to(dtype=_attention_mask.dtype) + + if _position_ids is None: + position_ids = None + + return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels + + + def prepare_inputs_labels_for_speech_text_for_internvl( + self, input_ids, position_ids, attention_mask, past_key_values, labels, + speech, speech_lengths, speech_chunks, speech_wav, modalities + ): + # encode speech + # breakpoint() + if input_ids.shape[1] == 1: + return input_ids, position_ids, attention_mask, past_key_values, None, labels + + if not isinstance(speech, list): + speech = torch.split(speech, speech_chunks.tolist(), dim=0) + speech_lengths = torch.split(speech_lengths, speech_chunks.tolist(), dim=0) + speech_wav = torch.split(speech_wav, speech_chunks.tolist(), dim=0) + speech_features = [] + for idx in range(len(speech)): + speech_features.append(self.encode_speech(speech[idx], speech_lengths[idx], speech_wav[idx])) + _labels = labels + _position_ids = position_ids + _attention_mask = attention_mask + if attention_mask is None: + attention_mask = torch.ones_like(input_ids, dtype=torch.bool) + else: + attention_mask = attention_mask.bool() + if position_ids is None: + position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) + if labels is None: + labels = torch.full_like(input_ids, IGNORE_INDEX) + # remove the padding using attention_mask -- FIXME + _input_ids = input_ids + input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] + labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] + + new_input_embeds = [] + new_labels = [] + cur_speech_idx = 0 + for batch_idx, cur_input_ids in enumerate(input_ids): + num_speech = (cur_input_ids == SPEECH_TOKEN_INDEX).sum() + num_speech_images = num_speech + + if num_speech_images == 0: + cur_speech_features = speech_features[cur_speech_idx] + cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) + cur_input_embeds = torch.cat([cur_input_embeds_1, cur_speech_features[0:0]], dim=0) + new_input_embeds.append(cur_input_embeds) + new_labels.append(labels[batch_idx]) + cur_speech_idx += 1 + continue + + # Handle speech and image tokens + speech_image_token_indices = [-1] + torch.where((cur_input_ids == SPEECH_TOKEN_INDEX))[0].tolist() + [cur_input_ids.shape[0]] + cur_input_ids_nospeech_image = [] + cur_labels = labels[batch_idx] + cur_labels_nospeech_image = [] + + for i in range(len(speech_image_token_indices) - 1): + cur_input_ids_nospeech_image.append(cur_input_ids[speech_image_token_indices[i]+1:speech_image_token_indices[i+1]]) + cur_labels_nospeech_image.append(cur_labels[speech_image_token_indices[i]+1:speech_image_token_indices[i+1]]) + + split_sizes = [x.shape[0] for x in cur_labels_nospeech_image] + + # Debug: Check text embeddings + text_input_ids = torch.cat(cur_input_ids_nospeech_image) + print(f"Debug - Text input_ids range: {text_input_ids.min().item()} to {text_input_ids.max().item()}") + print(f"Debug - Text input_ids has nan: {torch.isnan(text_input_ids.float()).any().item()}") + + cur_input_embeds = self.get_model().get_input_embeddings()(text_input_ids) + print(f"Debug - Text embeddings range: {cur_input_embeds.min().item()} to {cur_input_embeds.max().item()}") + print(f"Debug - Text embeddings has nan: {torch.isnan(cur_input_embeds).any().item()}") + print(f"Debug - Text embeddings has inf: {torch.isinf(cur_input_embeds).any().item()}") + + cur_input_embeds_no_speech_image = torch.split(cur_input_embeds, split_sizes, dim=0) + cur_new_input_embeds = [] + cur_new_labels = [] + + # Process tokens in order, similar to OLA's approach + for i in range(num_speech_images + 1): + cur_new_input_embeds.append(cur_input_embeds_no_speech_image[i]) + cur_new_labels.append(cur_labels_nospeech_image[i]) + + if i < num_speech_images: + # Determine which token type comes next based on position + cur_speech_features = speech_features[cur_speech_idx] + print(f"Debug - Speech features range: {cur_speech_features.min().item()} to {cur_speech_features.max().item()}") + print(f"Debug - Speech features has nan: {torch.isnan(cur_speech_features).any().item()}") + print(f"Debug - Speech features has inf: {torch.isinf(cur_speech_features).any().item()}") + print(f"Debug - Speech features shape: {cur_speech_features.shape}") + + cur_speech_idx += 1 + cur_new_input_embeds.append(cur_speech_features) + cur_new_labels.append(torch.full((cur_speech_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) + + + cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] + cur_new_input_embeds = torch.cat(cur_new_input_embeds) + cur_new_labels = torch.cat(cur_new_labels) + + if num_speech == 0: + cur_new_input_embeds = torch.cat([cur_new_input_embeds, speech_features[cur_speech_idx][0:0]], dim=0) + cur_speech_idx += 1 + + new_input_embeds.append(cur_new_input_embeds) + new_labels.append(cur_new_labels) + + # Truncate sequences to max length + tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None) + if tokenizer_model_max_length is not None: + new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] + new_labels = [x[:tokenizer_model_max_length] for x in new_labels] + + # Combine and pad + max_len = max(x.shape[0] for x in new_input_embeds) + batch_size = len(new_input_embeds) + + new_input_embeds_padded = [] + new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) + attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) + position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) + + for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): + cur_len = cur_new_embed.shape[0] + if getattr(self.config, 'tokenizer_padding_side', 'right') == "left": + new_input_embeds_padded.append(torch.cat(( + torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), + cur_new_embed + ), dim=0)) + if cur_len > 0: + new_labels_padded[i, -cur_len:] = cur_new_labels + attention_mask[i, -cur_len:] = True + position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) + else: + new_input_embeds_padded.append(torch.cat(( + cur_new_embed, + torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device) + ), dim=0)) + if cur_len > 0: + new_labels_padded[i, :cur_len] = cur_new_labels + attention_mask[i, :cur_len] = True + position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) + + new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) + + if _labels is None: + new_labels = None + else: + new_labels = new_labels_padded + + if _attention_mask is None: + attention_mask = None + else: + attention_mask = attention_mask.to(dtype=_attention_mask.dtype) + + if _position_ids is None: + position_ids = None + + # Debug: Check the generated embeddings + print(f"Debug - Generated input_embeds range: {new_input_embeds.min().item()} to {new_input_embeds.max().item()}") + print(f"Debug - Generated input_embeds has nan: {torch.isnan(new_input_embeds).any().item()}") + print(f"Debug - Generated input_embeds has inf: {torch.isinf(new_input_embeds).any().item()}") + print(f"Debug - Generated input_embeds shape: {new_input_embeds.shape}") + + return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels + + + def initialize_vision_tokenizer(self, model_args, tokenizer): + if model_args.mm_use_im_patch_token: + tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) + self.resize_token_embeddings(len(tokenizer)) + + if model_args.mm_use_im_start_end: + num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) + self.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = self.get_input_embeddings().weight.data + output_embeddings = self.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + if model_args.tune_mm_mlp_adapter: + for p in self.get_input_embeddings().parameters(): + p.requires_grad = True + for p in self.get_output_embeddings().parameters(): + p.requires_grad = False + + if model_args.pretrain_mm_mlp_adapter: + mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu') + embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight'] + assert num_new_tokens == 2 + if input_embeddings.shape == embed_tokens_weight.shape: + input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] + elif embed_tokens_weight.shape[0] == num_new_tokens: + input_embeddings[-num_new_tokens:] = embed_tokens_weight + else: + raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") + elif model_args.mm_use_im_patch_token: + if model_args.tune_mm_mlp_adapter: + for p in self.get_input_embeddings().parameters(): + p.requires_grad = False + for p in self.get_output_embeddings().parameters(): + p.requires_grad = False \ No newline at end of file diff --git a/ola/model/speech_encoder/__pycache__/builder.cpython-312.pyc b/ola/model/speech_encoder/__pycache__/builder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f12b72df36b4512a7b4c366b5eeb0484b00524b8 Binary files /dev/null and b/ola/model/speech_encoder/__pycache__/builder.cpython-312.pyc differ diff --git a/ola/model/speech_encoder/__pycache__/speech_encoder.cpython-312.pyc b/ola/model/speech_encoder/__pycache__/speech_encoder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea4a3fb11ef059d3ac098b781a79cfe39684366b Binary files /dev/null and b/ola/model/speech_encoder/__pycache__/speech_encoder.cpython-312.pyc differ diff --git a/ola/model/speech_encoder/beats/BEATs.py b/ola/model/speech_encoder/beats/BEATs.py new file mode 100644 index 0000000000000000000000000000000000000000..6de98b079d519f407ec32078a2c90e70d87b091f --- /dev/null +++ b/ola/model/speech_encoder/beats/BEATs.py @@ -0,0 +1,204 @@ +# -------------------------------------------------------- +# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) +# Github source: https://github.com/microsoft/unilm/tree/master/beats +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + + +import torch +import torch.nn as nn +from torch.nn import LayerNorm +# import torchaudio.compliance.kaldi as ta_kaldi + +from .kaldi import fbank as kaldi_fbank + +from .backbone import ( + TransformerEncoder, +) + +import logging +from typing import Optional + +logger = logging.getLogger(__name__) + + +class BEATsConfig: + def __init__(self, cfg=None): + self.input_patch_size: int = -1 # path size of patch embedding + self.embed_dim: int = 512 # patch embedding dimension + self.conv_bias: bool = False # include bias in conv encoder + + self.encoder_layers: int = 12 # num encoder layers in the transformer + self.encoder_embed_dim: int = 768 # encoder embedding dimension + self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN + self.encoder_attention_heads: int = 12 # num encoder attention heads + self.activation_fn: str = "gelu" # activation function to use + + self.layer_wise_gradient_decay_ratio: float = 1.0 # ratio for layer-wise gradient decay + self.layer_norm_first: bool = False # apply layernorm first in the transformer + self.deep_norm: bool = False # apply deep_norm first in the transformer + + # dropouts + self.dropout: float = 0.1 # dropout probability for the transformer + self.attention_dropout: float = 0.1 # dropout probability for attention weights + self.activation_dropout: float = 0.0 # dropout probability after activation in FFN + self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer + self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr) + + # positional embeddings + self.conv_pos: int = 128 # number of filters for convolutional positional embeddings + self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding + + # relative position embedding + self.relative_position_embedding: bool = False # apply relative position embedding + self.num_buckets: int = 320 # number of buckets for relative position embedding + self.max_distance: int = 1280 # maximum distance for relative position embedding + self.gru_rel_pos: bool = False # apply gated relative position embedding + + # label predictor + self.finetuned_model: bool = False # whether the model is a fine-tuned model. + self.predictor_dropout: float = 0.1 # dropout probability for the predictor + self.predictor_class: int = 527 # target class number for the predictor + + if cfg is not None: + self.update(cfg) + + def update(self, cfg: dict): + self.__dict__.update(cfg) + + +class BEATs(nn.Module): + def __init__( + self, + cfg: BEATsConfig, + ) -> None: + super().__init__() + logger.info(f"BEATs Config: {cfg.__dict__}") + + self.cfg = cfg + + self.embed = cfg.embed_dim + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim + else None + ) + + self.input_patch_size = cfg.input_patch_size + self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size, + bias=cfg.conv_bias) + + self.dropout_input = nn.Dropout(cfg.dropout_input) + + assert not cfg.deep_norm or not cfg.layer_norm_first + self.encoder = TransformerEncoder(cfg) + self.layer_norm = LayerNorm(self.embed) + + if cfg.finetuned_model: + self.predictor_dropout = nn.Dropout(cfg.predictor_dropout) + self.predictor = nn.Linear(cfg.encoder_embed_dim, cfg.predictor_class) + else: + self.predictor = None + + def forward_padding_mask( + self, + features: torch.Tensor, + padding_mask: torch.Tensor, + ) -> torch.Tensor: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view( + padding_mask.size(0), features.size(1), -1 + ) + padding_mask = padding_mask.all(-1) + return padding_mask + + def preprocess( + self, + source: torch.Tensor, + fbank_mean: float = 15.41663, + fbank_std: float = 6.55582, + ) -> torch.Tensor: + fbanks = [] + for waveform in source: + waveform = waveform.unsqueeze(0) * 2 ** 15 + fbank = kaldi_fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10) + fbanks.append(fbank) + fbank = torch.stack(fbanks, dim=0) + fbank = (fbank - fbank_mean) / (2 * fbank_std) + return fbank + + def extract_features( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + fbank_mean: float = 15.41663, + fbank_std: float = 6.55582, + feature_only=False, + ): + print(f"Debug - BEATs input source range: {source.min().item()} to {source.max().item()}") + print(f"Debug - BEATs input source has nan: {torch.isnan(source).any().item()}") + + fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std).to(torch.float32) + print(f"Debug - BEATs fbank range: {fbank.min().item()} to {fbank.max().item()}") + print(f"Debug - BEATs fbank has nan: {torch.isnan(fbank).any().item()}") + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(fbank, padding_mask) + + fbank = fbank.unsqueeze(1) + print(f"Debug - BEATs fbank after unsqueeze range: {fbank.min().item()} to {fbank.max().item()}") + print(f"Debug - BEATs fbank after unsqueeze has nan: {torch.isnan(fbank).any().item()}") + + features = self.patch_embedding(fbank) + print(f"Debug - BEATs after patch_embedding range: {features.min().item()} to {features.max().item()}") + print(f"Debug - BEATs after patch_embedding has nan: {torch.isnan(features).any().item()}") + + features = features.reshape(features.shape[0], features.shape[1], -1) + features = features.transpose(1, 2) + print(f"Debug - BEATs after reshape/transpose range: {features.min().item()} to {features.max().item()}") + print(f"Debug - BEATs after reshape/transpose has nan: {torch.isnan(features).any().item()}") + + features = self.layer_norm(features) + print(f"Debug - BEATs after layer_norm range: {features.min().item()} to {features.max().item()}") + print(f"Debug - BEATs after layer_norm has nan: {torch.isnan(features).any().item()}") + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + print(f"Debug - BEATs after post_extract_proj range: {features.min().item()} to {features.max().item()}") + print(f"Debug - BEATs after post_extract_proj has nan: {torch.isnan(features).any().item()}") + + x = self.dropout_input(features) + print(f"Debug - BEATs after dropout_input range: {x.min().item()} to {x.max().item()}") + print(f"Debug - BEATs after dropout_input has nan: {torch.isnan(x).any().item()}") + + x, layer_results = self.encoder( + x, + padding_mask=padding_mask, + ) + print(f"Debug - BEATs after encoder range: {x.min().item()} to {x.max().item()}") + print(f"Debug - BEATs after encoder has nan: {torch.isnan(x).any().item()}") + + if not feature_only and self.predictor is not None: + x = self.predictor_dropout(x) + logits = self.predictor(x) + + if padding_mask is not None and padding_mask.any(): + logits[padding_mask] = 0 + logits = logits.sum(dim=1) + logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits) + else: + logits = logits.mean(dim=1) + + lprobs = torch.sigmoid(logits) + + return lprobs, padding_mask + else: + return x, padding_mask \ No newline at end of file diff --git a/ola/model/speech_encoder/beats/Tokenizers.py b/ola/model/speech_encoder/beats/Tokenizers.py new file mode 100644 index 0000000000000000000000000000000000000000..597c8902493b38689136b7153f114842c8fd66a3 --- /dev/null +++ b/ola/model/speech_encoder/beats/Tokenizers.py @@ -0,0 +1,174 @@ +# -------------------------------------------------------- +# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) +# Github source: https://github.com/microsoft/unilm/tree/master/beats +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + + +import torch +import torch.nn as nn +from torch.nn import LayerNorm +# import torchaudio.compliance.kaldi as ta_kaldi + +from .kaldi import fbank as kaldi_fbank + +from .backbone import ( + TransformerEncoder, +) +from .quantizer import ( + NormEMAVectorQuantizer, +) + +import logging +from typing import Optional + +logger = logging.getLogger(__name__) + + +class TokenizersConfig: + def __init__(self, cfg=None): + self.input_patch_size: int = -1 # path size of patch embedding + self.embed_dim: int = 512 # patch embedding dimension + self.conv_bias: bool = False # include bias in conv encoder + + self.encoder_layers: int = 12 # num encoder layers in the transformer + self.encoder_embed_dim: int = 768 # encoder embedding dimension + self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN + self.encoder_attention_heads: int = 12 # num encoder attention heads + self.activation_fn: str = "gelu" # activation function to use + + self.layer_norm_first: bool = False # apply layernorm first in the transformer + self.deep_norm: bool = False # apply deep_norm first in the transformer + + # dropouts + self.dropout: float = 0.1 # dropout probability for the transformer + self.attention_dropout: float = 0.1 # dropout probability for attention weights + self.activation_dropout: float = 0.0 # dropout probability after activation in FFN + self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer + self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr) + + # positional embeddings + self.conv_pos: int = 128 # number of filters for convolutional positional embeddings + self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding + + # relative position embedding + self.relative_position_embedding: bool = False # apply relative position embedding + self.num_buckets: int = 320 # number of buckets for relative position embedding + self.max_distance: int = 1280 # maximum distance for relative position embedding + self.gru_rel_pos: bool = False # apply gated relative position embedding + + # quantizer + self.quant_n: int = 1024 # codebook number in quantizer + self.quant_dim: int = 256 # codebook dimension in quantizer + + if cfg is not None: + self.update(cfg) + + def update(self, cfg: dict): + self.__dict__.update(cfg) + + +class Tokenizers(nn.Module): + def __init__( + self, + cfg: TokenizersConfig, + ) -> None: + super().__init__() + logger.info(f"Tokenizers Config: {cfg.__dict__}") + + self.cfg = cfg + + self.embed = cfg.embed_dim + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim + else None + ) + + self.input_patch_size = cfg.input_patch_size + self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size, + bias=cfg.conv_bias) + + self.dropout_input = nn.Dropout(cfg.dropout_input) + + assert not cfg.deep_norm or not cfg.layer_norm_first + self.encoder = TransformerEncoder(cfg) + self.layer_norm = LayerNorm(self.embed) + + self.quantize = NormEMAVectorQuantizer( + n_embed=cfg.quant_n, embedding_dim=cfg.quant_dim, beta=1.0, kmeans_init=True, decay=0.99, + ) + self.quant_n = cfg.quant_n + self.quantize_layer = nn.Sequential( + nn.Linear(cfg.encoder_embed_dim, cfg.encoder_embed_dim), + nn.Tanh(), + nn.Linear(cfg.encoder_embed_dim, cfg.quant_dim) # for quantize + ) + + def forward_padding_mask( + self, + features: torch.Tensor, + padding_mask: torch.Tensor, + ) -> torch.Tensor: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view( + padding_mask.size(0), features.size(1), -1 + ) + padding_mask = padding_mask.all(-1) + return padding_mask + + def preprocess( + self, + source: torch.Tensor, + fbank_mean: float = 15.41663, + fbank_std: float = 6.55582, + ) -> torch.Tensor: + fbanks = [] + for waveform in source: + waveform = waveform.unsqueeze(0) * 2 ** 15 + fbank = kaldi_fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10) + fbanks.append(fbank) + fbank = torch.stack(fbanks, dim=0) + fbank = (fbank - fbank_mean) / (2 * fbank_std) + return fbank + + def extract_labels( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + fbank_mean: float = 15.41663, + fbank_std: float = 6.55582, + ): + fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std) + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(fbank, padding_mask) + + fbank = fbank.unsqueeze(1) + features = self.patch_embedding(fbank) + features = features.reshape(features.shape[0], features.shape[1], -1) + features = features.transpose(1, 2) + features = self.layer_norm(features) + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + x = self.dropout_input(features) + + x, layer_results = self.encoder( + x, + padding_mask=padding_mask, + ) + + quantize_input = self.quantize_layer(x) + quantize_feature, embed_loss, embed_ind = self.quantize(quantize_input) + + return embed_ind diff --git a/ola/model/speech_encoder/beats/__init__.py b/ola/model/speech_encoder/beats/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ola/model/speech_encoder/beats/__pycache__/BEATs.cpython-310.pyc b/ola/model/speech_encoder/beats/__pycache__/BEATs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47f97c85ac24172414259a0b1b9bc04161df3b32 Binary files /dev/null and b/ola/model/speech_encoder/beats/__pycache__/BEATs.cpython-310.pyc differ diff --git a/ola/model/speech_encoder/beats/__pycache__/BEATs.cpython-312.pyc b/ola/model/speech_encoder/beats/__pycache__/BEATs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdda76d662cd64c74131fd9df236ed62a74f51a2 Binary files /dev/null and b/ola/model/speech_encoder/beats/__pycache__/BEATs.cpython-312.pyc differ diff --git a/ola/model/speech_encoder/beats/__pycache__/BEATs.cpython-38.pyc b/ola/model/speech_encoder/beats/__pycache__/BEATs.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d67c59eaf976f823216c63dce0cff30c1dfba007 Binary files /dev/null and b/ola/model/speech_encoder/beats/__pycache__/BEATs.cpython-38.pyc differ diff --git a/ola/model/speech_encoder/beats/__pycache__/__init__.cpython-310.pyc b/ola/model/speech_encoder/beats/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e3210d91220b69ba199ac36b55cd6213428fef2 Binary files /dev/null and b/ola/model/speech_encoder/beats/__pycache__/__init__.cpython-310.pyc differ diff --git a/ola/model/speech_encoder/beats/__pycache__/__init__.cpython-312.pyc b/ola/model/speech_encoder/beats/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73c83a5fce04803e2b3ba7caddc34e3e31a946d3 Binary files /dev/null and b/ola/model/speech_encoder/beats/__pycache__/__init__.cpython-312.pyc differ diff --git a/ola/model/speech_encoder/beats/__pycache__/__init__.cpython-38.pyc b/ola/model/speech_encoder/beats/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68353622f642405b85709154b35efaf31a95db2c Binary files /dev/null and b/ola/model/speech_encoder/beats/__pycache__/__init__.cpython-38.pyc differ diff --git a/ola/model/speech_encoder/beats/__pycache__/backbone.cpython-310.pyc b/ola/model/speech_encoder/beats/__pycache__/backbone.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce7a8103f591533bc5c5c8ca86ab6b4a5692354b Binary files /dev/null and b/ola/model/speech_encoder/beats/__pycache__/backbone.cpython-310.pyc differ diff --git a/ola/model/speech_encoder/beats/__pycache__/backbone.cpython-312.pyc b/ola/model/speech_encoder/beats/__pycache__/backbone.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..736504b7700f6cf74a265456684d60a427ebaff5 Binary files /dev/null and b/ola/model/speech_encoder/beats/__pycache__/backbone.cpython-312.pyc differ diff --git a/ola/model/speech_encoder/beats/__pycache__/backbone.cpython-38.pyc b/ola/model/speech_encoder/beats/__pycache__/backbone.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e9f96cae92d6bda7025478d7111317fdc5333f5 Binary files /dev/null and b/ola/model/speech_encoder/beats/__pycache__/backbone.cpython-38.pyc differ diff --git a/ola/model/speech_encoder/beats/__pycache__/kaldi.cpython-310.pyc b/ola/model/speech_encoder/beats/__pycache__/kaldi.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80b1dade1f45362a523241ce7b4483c67018a9f0 Binary files /dev/null and b/ola/model/speech_encoder/beats/__pycache__/kaldi.cpython-310.pyc differ diff --git a/ola/model/speech_encoder/beats/__pycache__/kaldi.cpython-312.pyc b/ola/model/speech_encoder/beats/__pycache__/kaldi.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..726f176330be5c9863512e58e88a6b57d538bcb0 Binary files /dev/null and b/ola/model/speech_encoder/beats/__pycache__/kaldi.cpython-312.pyc differ diff --git a/ola/model/speech_encoder/beats/__pycache__/kaldi.cpython-38.pyc b/ola/model/speech_encoder/beats/__pycache__/kaldi.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd57dbd423cc3d825bf7aabbd4aa4f14f3a220ed Binary files /dev/null and b/ola/model/speech_encoder/beats/__pycache__/kaldi.cpython-38.pyc differ diff --git a/ola/model/speech_encoder/beats/__pycache__/modules.cpython-310.pyc b/ola/model/speech_encoder/beats/__pycache__/modules.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c663dc7565996e5c4be111a1ce01d5ed18286efb Binary files /dev/null and b/ola/model/speech_encoder/beats/__pycache__/modules.cpython-310.pyc differ diff --git a/ola/model/speech_encoder/beats/__pycache__/modules.cpython-312.pyc b/ola/model/speech_encoder/beats/__pycache__/modules.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3066a8d38dd50f9f8af79b890e37934591be2a87 Binary files /dev/null and b/ola/model/speech_encoder/beats/__pycache__/modules.cpython-312.pyc differ diff --git a/ola/model/speech_encoder/beats/__pycache__/modules.cpython-38.pyc b/ola/model/speech_encoder/beats/__pycache__/modules.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74a0a28e7d2f7bf1661ff4bd8ed10cb83cd7adb2 Binary files /dev/null and b/ola/model/speech_encoder/beats/__pycache__/modules.cpython-38.pyc differ diff --git a/ola/model/speech_encoder/beats/backbone.py b/ola/model/speech_encoder/beats/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..103b5b69cc8555a4fcfa09ec520dfc150fada509 --- /dev/null +++ b/ola/model/speech_encoder/beats/backbone.py @@ -0,0 +1,786 @@ +# -------------------------------------------------------- +# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) +# Github source: https://github.com/microsoft/unilm/tree/master/beats +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import math +import numpy as np +from typing import Dict, Optional, Tuple +import torch +from torch import Tensor, nn +import torch.nn.functional as F +from torch.nn import LayerNorm, Parameter +from .modules import ( + GradMultiply, + SamePad, + get_activation_fn, + GLU_Linear, + quant_noise, +) + + +class TransformerEncoder(nn.Module): + def __init__(self, args): + super().__init__() + + self.dropout = args.dropout + self.embedding_dim = args.encoder_embed_dim + + self.pos_conv = nn.Conv1d( + self.embedding_dim, + self.embedding_dim, + kernel_size=args.conv_pos, + padding=args.conv_pos // 2, + groups=args.conv_pos_groups, + ) + dropout = 0 + std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) + nn.init.normal_(self.pos_conv.weight, mean=0, std=std) + nn.init.constant_(self.pos_conv.bias, 0) + + self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) + self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) + + if hasattr(args, "relative_position_embedding"): + self.relative_position_embedding = args.relative_position_embedding + self.num_buckets = args.num_buckets + self.max_distance = args.max_distance + else: + self.relative_position_embedding = False + self.num_buckets = 0 + self.max_distance = 0 + + self.layers = nn.ModuleList( + [ + TransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + deep_norm=args.deep_norm, + has_relative_attention_bias=self.relative_position_embedding, + num_buckets=self.num_buckets, + max_distance=self.max_distance, + gru_rel_pos=args.gru_rel_pos, + encoder_layers=args.encoder_layers, + ) + for i in range(args.encoder_layers) + ] + ) + if self.relative_position_embedding: + for i in range(1, args.encoder_layers): + del self.layers[i].self_attn.relative_attention_bias + self.layers[i].self_attn.relative_attention_bias = self.layers[0].self_attn.relative_attention_bias + + self.layer_norm_first = args.layer_norm_first + self.layer_norm = LayerNorm(self.embedding_dim) + self.layerdrop = args.encoder_layerdrop + + self.apply(init_bert_params) + + if args.deep_norm: + deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4) + for i in range(args.encoder_layers): + nn.init.xavier_normal_(self.layers[i].self_attn.k_proj.weight, gain=1) + nn.init.xavier_normal_(self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta) + nn.init.xavier_normal_(self.layers[i].self_attn.q_proj.weight, gain=1) + nn.init.xavier_normal_(self.layers[i].self_attn.out_proj.weight, gain=deep_norm_beta) + nn.init.xavier_normal_(self.layers[i].fc1.weight, gain=deep_norm_beta) + nn.init.xavier_normal_(self.layers[i].fc2.weight, gain=deep_norm_beta) + + self.layer_wise_gradient_decay_ratio = getattr(args, "layer_wise_gradient_decay_ratio", 1) + + def forward(self, x, padding_mask=None, layer=None): + x, layer_results = self.extract_features(x, padding_mask, layer) + + if self.layer_norm_first and layer is None: + x = self.layer_norm(x) + + return x, layer_results + + def extract_features(self, x, padding_mask=None, tgt_layer=None): + + if padding_mask is not None: + x[padding_mask] = 0 + x_conv = self.pos_conv(x.transpose(1, 2)) + x_conv = x_conv.transpose(1, 2) + x = x + x_conv + + if not self.layer_norm_first: + x = self.layer_norm(x) + + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + layer_results = [] + z = None + if tgt_layer is not None: + layer_results.append((x, z)) + r = None + pos_bias = None + for i, layer in enumerate(self.layers): + if self.layer_wise_gradient_decay_ratio != 1.0: + x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio) + dropout_probability = np.random.random() + if not self.training or (dropout_probability > self.layerdrop): + x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_bias) + if tgt_layer is not None: + layer_results.append((x, z)) + if i == tgt_layer: + r = x + break + + if r is not None: + x = r + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x, layer_results + + +class TransformerSentenceEncoderLayer(nn.Module): + def __init__( + self, + embedding_dim: float = 768, + ffn_embedding_dim: float = 3072, + num_attention_heads: float = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + layer_norm_first: bool = False, + deep_norm: bool = False, + has_relative_attention_bias: bool = False, + num_buckets: int = 0, + max_distance: int = 0, + rescale_init: bool = False, + gru_rel_pos: bool = False, + encoder_layers: int = 0, + ) -> None: + + super().__init__() + self.embedding_dim = embedding_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + + self.activation_name = activation_fn + self.activation_fn = get_activation_fn(activation_fn) + self.self_attn = MultiheadAttention( + self.embedding_dim, + num_attention_heads, + dropout=attention_dropout, + self_attention=True, + has_relative_attention_bias=has_relative_attention_bias, + num_buckets=num_buckets, + max_distance=max_distance, + rescale_init=rescale_init, + gru_rel_pos=gru_rel_pos, + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(self.activation_dropout) + self.dropout3 = nn.Dropout(dropout) + + self.layer_norm_first = layer_norm_first + + self.self_attn_layer_norm = LayerNorm(self.embedding_dim) + + if self.activation_name == "glu": + self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish") + else: + self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) + self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) + + self.final_layer_norm = LayerNorm(self.embedding_dim) + + self.deep_norm = deep_norm + if self.deep_norm: + self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4) + else: + self.deep_norm_alpha = 1 + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + need_weights: bool = False, + pos_bias=None + ): + residual = x + + if self.layer_norm_first: + x = self.self_attn_layer_norm(x) + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=False, + attn_mask=self_attn_mask, + position_bias=pos_bias + ) + x = self.dropout1(x) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + else: + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=need_weights, + attn_mask=self_attn_mask, + position_bias=pos_bias + ) + + x = self.dropout1(x) + x = residual * self.deep_norm_alpha + x + + x = self.self_attn_layer_norm(x) + + residual = x + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual * self.deep_norm_alpha + x + x = self.final_layer_norm(x) + + return x, attn, pos_bias + + +class MultiheadAttention(nn.Module): + """Multi-headed attention. + + See "Attention Is All You Need" for more details. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + has_relative_attention_bias=False, + num_buckets=32, + max_distance=128, + gru_rel_pos=False, + rescale_init=False, + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout_module = nn.Dropout(dropout) + + self.has_relative_attention_bias = has_relative_attention_bias + self.num_buckets = num_buckets + self.max_distance = max_distance + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(num_buckets, num_heads) + + self.head_dim = embed_dim // num_heads + self.q_head_dim = self.head_dim + self.k_head_dim = self.head_dim + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, ( + "Self-attention requires query, key and " "value to be of the same size" + ) + + k_bias = True + if rescale_init: + k_bias = False + + k_embed_dim = embed_dim + q_embed_dim = embed_dim + + self.k_proj = quant_noise( + nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size + ) + self.v_proj = quant_noise( + nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size + ) + self.q_proj = quant_noise( + nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size + ) + + self.out_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size + ) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.gru_rel_pos = gru_rel_pos + if self.gru_rel_pos: + self.grep_linear = nn.Linear(self.q_head_dim, 8) + self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1)) + + self.reset_parameters() + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.q_proj.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + if self.has_relative_attention_bias: + nn.init.xavier_normal_(self.relative_attention_bias.weight) + + def _relative_positions_bucket(self, relative_positions, bidirectional=True): + num_buckets = self.num_buckets + max_distance = self.max_distance + relative_buckets = 0 + + if bidirectional: + num_buckets = num_buckets // 2 + relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets + relative_positions = torch.abs(relative_positions) + else: + relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions)) + + max_exact = num_buckets // 2 + is_small = relative_positions < max_exact + + relative_postion_if_large = max_exact + ( + torch.log(relative_positions.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length): + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position + relative_position_bucket = self._relative_positions_bucket( + relative_position, + bidirectional=True + ) + relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) + values = self.relative_attention_bias(relative_position_bucket) + values = values.permute([2, 0, 1]) + return values + + def forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + position_bias: Optional[Tensor] = None + ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + is_tpu = query.device.type == "xla" + + tgt_len, bsz, embed_dim = query.size() + src_len = tgt_len + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + if key is not None: + src_len, key_bsz, _ = key.size() + if not torch.jit.is_scripting(): + assert key_bsz == bsz + assert value is not None + assert src_len, bsz == value.shape[:2] + + if self.has_relative_attention_bias and position_bias is None: + position_bias = self.compute_bias(tgt_len, src_len) + position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len) + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + alpha = 32 + q *= 1 / alpha + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros(key_padding_mask.size(0), 1), + ], + dim=1, + ) + + q = ( + q.contiguous() + .view(tgt_len, bsz * self.num_heads, self.q_head_dim) + .transpose(0, 1) + ) + if k is not None: + k = ( + k.contiguous() + .view(-1, bsz * self.num_heads, self.k_head_dim) + .transpose(0, 1) + ) + if v is not None: + v = ( + v.contiguous() + .view(-1, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + src_len = k.size(1) + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + + saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + assert k.size(1) == src_len + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) + v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + torch.zeros(key_padding_mask.size(0), 1).type_as( + key_padding_mask + ), + ], + dim=1, + ) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = (attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]) * alpha + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + if not is_tpu: + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf"), + ) + else: + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v, position_bias + + if position_bias is not None: + attn_mask_rel_pos = position_bias + if self.gru_rel_pos == 1: + query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) * alpha / self.scaling + _B, _H, _L, __ = query_layer.size() + gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( + _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, tgt_len, 1) * position_bias + + attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size()) + + attn_weights = attn_weights + attn_mask_rel_pos + + attn_weights_float = F.softmax( + attn_weights, dim=-1 + ) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = self.dropout_module(attn_weights) + + assert v is not None + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + attn_weights: Optional[Tensor] = None + if need_weights: + attn_weights = attn_weights_float.view( + bsz, self.num_heads, tgt_len, src_len + ).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights, position_bias + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[Tensor]: + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 + ) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + if src_len > prev_key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - prev_key_padding_mask.size(1)), + device=prev_key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), filler.float()], dim=1 + ) + else: + new_key_padding_mask = prev_key_padding_mask.float() + elif key_padding_mask is not None: + if src_len > key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - key_padding_mask.size(1)), + device=key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [filler.float(), key_padding_mask.float()], dim=1 + ) + else: + new_key_padding_mask = key_padding_mask.float() + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ) -> Dict[str, Optional[Tensor]]: + result = self.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + else: + empty_result: Dict[str, Optional[Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + buffer: Dict[str, Optional[Tensor]], + ): + return self.set_incremental_state(incremental_state, "attn_state", buffer) + + def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): + return attn_weights + + +def init_bert_params(module): + """ + Initialize the weights specific to the BERT Model. + This overrides the default initializations depending on the specified arguments. + 1. If normal_init_linear_weights is set then weights of linear + layer will be initialized using the normal distribution and + bais will be set to the specified value. + 2. If normal_init_embed_weights is set then weights of embedding + layer will be initialized using the normal distribution. + 3. If normal_init_proj_weights is set then weights of + in_project_weight for MultiHeadAttention initialized using + the normal distribution (to be validated). + """ + + def normal_(data): + # with FSDP, module params will be on CUDA, so we cast them back to CPU + # so that the RNG is consistent with and without FSDP + # Check if data is a meta tensor + if data.is_meta: + # Skip initialization for meta tensors + return + data.copy_( + data.cpu().normal_(mean=0.0, std=0.02).to(data.device) + ) + + if isinstance(module, nn.Linear): + normal_(module.weight.data) + if module.bias is not None and not module.bias.data.is_meta: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + normal_(module.weight.data) + if module.padding_idx is not None and not module.weight.data.is_meta: + module.weight.data[module.padding_idx].zero_() + if isinstance(module, MultiheadAttention): + normal_(module.q_proj.weight.data) + normal_(module.k_proj.weight.data) + normal_(module.v_proj.weight.data) \ No newline at end of file diff --git a/ola/model/speech_encoder/beats/kaldi.py b/ola/model/speech_encoder/beats/kaldi.py new file mode 100644 index 0000000000000000000000000000000000000000..f97fa85308e785af571bfdc912d6f12bb092e10e --- /dev/null +++ b/ola/model/speech_encoder/beats/kaldi.py @@ -0,0 +1,813 @@ +import math +from typing import Tuple + +import torch +# import torchaudio +from torch import Tensor + +__all__ = [ + "get_mel_banks", + "inverse_mel_scale", + "inverse_mel_scale_scalar", + "mel_scale", + "mel_scale_scalar", + "spectrogram", + "fbank", + "mfcc", + "vtln_warp_freq", + "vtln_warp_mel_freq", +] + +# numeric_limits::epsilon() 1.1920928955078125e-07 +EPSILON = torch.tensor(torch.finfo(torch.float).eps) +# 1 milliseconds = 0.001 seconds +MILLISECONDS_TO_SECONDS = 0.001 + +# window types +HAMMING = "hamming" +HANNING = "hanning" +POVEY = "povey" +RECTANGULAR = "rectangular" +BLACKMAN = "blackman" +WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN] + + +def _get_epsilon(device, dtype): + return EPSILON.to(device=device, dtype=dtype) + + +def _next_power_of_2(x: int) -> int: + r"""Returns the smallest power of 2 that is greater than x""" + return 1 if x == 0 else 2 ** (x - 1).bit_length() + + +def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edges: bool) -> Tensor: + r"""Given a waveform (1D tensor of size ``num_samples``), it returns a 2D tensor (m, ``window_size``) + representing how the window is shifted along the waveform. Each row is a frame. + + Args: + waveform (Tensor): Tensor of size ``num_samples`` + window_size (int): Frame length + window_shift (int): Frame shift + snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit + in the file, and the number of frames depends on the frame_length. If False, the number of frames + depends only on the frame_shift, and we reflect the data at the ends. + + Returns: + Tensor: 2D tensor of size (m, ``window_size``) where each row is a frame + """ + assert waveform.dim() == 1 + num_samples = waveform.size(0) + strides = (window_shift * waveform.stride(0), waveform.stride(0)) + + if snip_edges: + if num_samples < window_size: + return torch.empty((0, 0), dtype=waveform.dtype, device=waveform.device) + else: + m = 1 + (num_samples - window_size) // window_shift + else: + reversed_waveform = torch.flip(waveform, [0]) + m = (num_samples + (window_shift // 2)) // window_shift + pad = window_size // 2 - window_shift // 2 + pad_right = reversed_waveform + if pad > 0: + # torch.nn.functional.pad returns [2,1,0,1,2] for 'reflect' + # but we want [2, 1, 0, 0, 1, 2] + pad_left = reversed_waveform[-pad:] + waveform = torch.cat((pad_left, waveform, pad_right), dim=0) + else: + # pad is negative so we want to trim the waveform at the front + waveform = torch.cat((waveform[-pad:], pad_right), dim=0) + + sizes = (m, window_size) + return waveform.as_strided(sizes, strides) + + +def _feature_window_function( + window_type: str, + window_size: int, + blackman_coeff: float, + device: torch.device, + dtype: int, +) -> Tensor: + r"""Returns a window function with the given type and size""" + if window_type == HANNING: + return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype) + elif window_type == HAMMING: + return torch.hamming_window(window_size, periodic=False, alpha=0.54, beta=0.46, device=device, dtype=dtype) + elif window_type == POVEY: + # like hanning but goes to zero at edges + return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85) + elif window_type == RECTANGULAR: + return torch.ones(window_size, device=device, dtype=dtype) + elif window_type == BLACKMAN: + a = 2 * math.pi / (window_size - 1) + window_function = torch.arange(window_size, device=device, dtype=dtype) + # can't use torch.blackman_window as they use different coefficients + return ( + blackman_coeff + - 0.5 * torch.cos(a * window_function) + + (0.5 - blackman_coeff) * torch.cos(2 * a * window_function) + ).to(device=device, dtype=dtype) + else: + raise Exception("Invalid window type " + window_type) + + +def _get_log_energy(strided_input: Tensor, epsilon: Tensor, energy_floor: float) -> Tensor: + r"""Returns the log energy of size (m) for a strided_input (m,*)""" + device, dtype = strided_input.device, strided_input.dtype + log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m) + if energy_floor == 0.0: + return log_energy + return torch.max(log_energy, torch.tensor(math.log(energy_floor), device=device, dtype=dtype)) + + +def _get_waveform_and_window_properties( + waveform: Tensor, + channel: int, + sample_frequency: float, + frame_shift: float, + frame_length: float, + round_to_power_of_two: bool, + preemphasis_coefficient: float, +) -> Tuple[Tensor, int, int, int]: + r"""Gets the waveform and window properties""" + channel = max(channel, 0) + assert channel < waveform.size(0), "Invalid channel {} for size {}".format(channel, waveform.size(0)) + waveform = waveform[channel, :] # size (n) + window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS) + window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS) + padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size + + assert 2 <= window_size <= len(waveform), "choose a window size {} that is [2, {}]".format( + window_size, len(waveform) + ) + assert 0 < window_shift, "`window_shift` must be greater than 0" + assert padded_window_size % 2 == 0, ( + "the padded `window_size` must be divisible by two." " use `round_to_power_of_two` or change `frame_length`" + ) + assert 0.0 <= preemphasis_coefficient <= 1.0, "`preemphasis_coefficient` must be between [0,1]" + assert sample_frequency > 0, "`sample_frequency` must be greater than zero" + return waveform, window_shift, window_size, padded_window_size + + +def _get_window( + waveform: Tensor, + padded_window_size: int, + window_size: int, + window_shift: int, + window_type: str, + blackman_coeff: float, + snip_edges: bool, + raw_energy: bool, + energy_floor: float, + dither: float, + remove_dc_offset: bool, + preemphasis_coefficient: float, +) -> Tuple[Tensor, Tensor]: + r"""Gets a window and its log energy + + Returns: + (Tensor, Tensor): strided_input of size (m, ``padded_window_size``) and signal_log_energy of size (m) + """ + device, dtype = waveform.device, waveform.dtype + epsilon = _get_epsilon(device, dtype) + + # size (m, window_size) + strided_input = _get_strided(waveform, window_size, window_shift, snip_edges) + + if dither != 0.0: + rand_gauss = torch.randn(strided_input.shape, device=device, dtype=dtype) + strided_input = strided_input + rand_gauss * dither + + if remove_dc_offset: + # Subtract each row/frame by its mean + row_means = torch.mean(strided_input, dim=1).unsqueeze(1) # size (m, 1) + strided_input = strided_input - row_means + + if raw_energy: + # Compute the log energy of each row/frame before applying preemphasis and + # window function + signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m) + + if preemphasis_coefficient != 0.0: + # strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j + offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze( + 0 + ) # size (m, window_size + 1) + strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :-1] + + # Apply window_function to each row/frame + window_function = _feature_window_function(window_type, window_size, blackman_coeff, device, dtype).unsqueeze( + 0 + ) # size (1, window_size) + strided_input = strided_input * window_function # size (m, window_size) + + # Pad columns with zero until we reach size (m, padded_window_size) + if padded_window_size != window_size: + padding_right = padded_window_size - window_size + strided_input = torch.nn.functional.pad( + strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0 + ).squeeze(0) + + # Compute energy after window function (not the raw one) + if not raw_energy: + signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m) + + return strided_input, signal_log_energy + + +def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor: + # subtracts the column mean of the tensor size (m, n) if subtract_mean=True + # it returns size (m, n) + if subtract_mean: + col_means = torch.mean(tensor, dim=0).unsqueeze(0) + tensor = tensor - col_means + return tensor + + +def spectrogram( + waveform: Tensor, + blackman_coeff: float = 0.42, + channel: int = -1, + dither: float = 0.0, + energy_floor: float = 1.0, + frame_length: float = 25.0, + frame_shift: float = 10.0, + min_duration: float = 0.0, + preemphasis_coefficient: float = 0.97, + raw_energy: bool = True, + remove_dc_offset: bool = True, + round_to_power_of_two: bool = True, + sample_frequency: float = 16000.0, + snip_edges: bool = True, + subtract_mean: bool = False, + window_type: str = POVEY, +) -> Tensor: + r"""Create a spectrogram from a raw audio signal. This matches the input/output of Kaldi's + compute-spectrogram-feats. + + Args: + waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2) + blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``) + channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``) + dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set + the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``) + energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution: + this floor is applied to the zeroth component, representing the total signal energy. The floor on the + individual spectrogram elements is fixed at std::numeric_limits::epsilon(). (Default: ``1.0``) + frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``) + frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``) + min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``) + preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``) + raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``) + remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``) + round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input + to FFT. (Default: ``True``) + sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if + specified there) (Default: ``16000.0``) + snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit + in the file, and the number of frames depends on the frame_length. If False, the number of frames + depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``) + subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do + it this way. (Default: ``False``) + window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman') + (Default: ``'povey'``) + + Returns: + Tensor: A spectrogram identical to what Kaldi would output. The shape is + (m, ``padded_window_size // 2 + 1``) where m is calculated in _get_strided + """ + device, dtype = waveform.device, waveform.dtype + epsilon = _get_epsilon(device, dtype) + + waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties( + waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient + ) + + if len(waveform) < min_duration * sample_frequency: + # signal is too short + return torch.empty(0) + + strided_input, signal_log_energy = _get_window( + waveform, + padded_window_size, + window_size, + window_shift, + window_type, + blackman_coeff, + snip_edges, + raw_energy, + energy_floor, + dither, + remove_dc_offset, + preemphasis_coefficient, + ) + + # size (m, padded_window_size // 2 + 1, 2) + fft = torch.fft.rfft(strided_input) + + # Convert the FFT into a power spectrum + power_spectrum = torch.max(fft.abs().pow(2.0), epsilon).log() # size (m, padded_window_size // 2 + 1) + power_spectrum[:, 0] = signal_log_energy + + power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean) + return power_spectrum + + +def inverse_mel_scale_scalar(mel_freq: float) -> float: + return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0) + + +def inverse_mel_scale(mel_freq: Tensor) -> Tensor: + return 700.0 * ((mel_freq / 1127.0).exp() - 1.0) + + +def mel_scale_scalar(freq: float) -> float: + return 1127.0 * math.log(1.0 + freq / 700.0) + + +def mel_scale(freq: Tensor) -> Tensor: + return 1127.0 * (1.0 + freq / 700.0).log() + + +def vtln_warp_freq( + vtln_low_cutoff: float, + vtln_high_cutoff: float, + low_freq: float, + high_freq: float, + vtln_warp_factor: float, + freq: Tensor, +) -> Tensor: + r"""This computes a VTLN warping function that is not the same as HTK's one, + but has similar inputs (this function has the advantage of never producing + empty bins). + + This function computes a warp function F(freq), defined between low_freq + and high_freq inclusive, with the following properties: + F(low_freq) == low_freq + F(high_freq) == high_freq + The function is continuous and piecewise linear with two inflection + points. + The lower inflection point (measured in terms of the unwarped + frequency) is at frequency l, determined as described below. + The higher inflection point is at a frequency h, determined as + described below. + If l <= f <= h, then F(f) = f/vtln_warp_factor. + If the higher inflection point (measured in terms of the unwarped + frequency) is at h, then max(h, F(h)) == vtln_high_cutoff. + Since (by the last point) F(h) == h/vtln_warp_factor, then + max(h, h/vtln_warp_factor) == vtln_high_cutoff, so + h = vtln_high_cutoff / max(1, 1/vtln_warp_factor). + = vtln_high_cutoff * min(1, vtln_warp_factor). + If the lower inflection point (measured in terms of the unwarped + frequency) is at l, then min(l, F(l)) == vtln_low_cutoff + This implies that l = vtln_low_cutoff / min(1, 1/vtln_warp_factor) + = vtln_low_cutoff * max(1, vtln_warp_factor) + Args: + vtln_low_cutoff (float): Lower frequency cutoffs for VTLN + vtln_high_cutoff (float): Upper frequency cutoffs for VTLN + low_freq (float): Lower frequency cutoffs in mel computation + high_freq (float): Upper frequency cutoffs in mel computation + vtln_warp_factor (float): Vtln warp factor + freq (Tensor): given frequency in Hz + + Returns: + Tensor: Freq after vtln warp + """ + assert vtln_low_cutoff > low_freq, "be sure to set the vtln_low option higher than low_freq" + assert vtln_high_cutoff < high_freq, "be sure to set the vtln_high option lower than high_freq [or negative]" + l = vtln_low_cutoff * max(1.0, vtln_warp_factor) + h = vtln_high_cutoff * min(1.0, vtln_warp_factor) + scale = 1.0 / vtln_warp_factor + Fl = scale * l # F(l) + Fh = scale * h # F(h) + assert l > low_freq and h < high_freq + # slope of left part of the 3-piece linear function + scale_left = (Fl - low_freq) / (l - low_freq) + # [slope of center part is just "scale"] + + # slope of right part of the 3-piece linear function + scale_right = (high_freq - Fh) / (high_freq - h) + + res = torch.empty_like(freq) + + outside_low_high_freq = torch.lt(freq, low_freq) | torch.gt(freq, high_freq) # freq < low_freq || freq > high_freq + before_l = torch.lt(freq, l) # freq < l + before_h = torch.lt(freq, h) # freq < h + after_h = torch.ge(freq, h) # freq >= h + + # order of operations matter here (since there is overlapping frequency regions) + res[after_h] = high_freq + scale_right * (freq[after_h] - high_freq) + res[before_h] = scale * freq[before_h] + res[before_l] = low_freq + scale_left * (freq[before_l] - low_freq) + res[outside_low_high_freq] = freq[outside_low_high_freq] + + return res + + +def vtln_warp_mel_freq( + vtln_low_cutoff: float, + vtln_high_cutoff: float, + low_freq, + high_freq: float, + vtln_warp_factor: float, + mel_freq: Tensor, +) -> Tensor: + r""" + Args: + vtln_low_cutoff (float): Lower frequency cutoffs for VTLN + vtln_high_cutoff (float): Upper frequency cutoffs for VTLN + low_freq (float): Lower frequency cutoffs in mel computation + high_freq (float): Upper frequency cutoffs in mel computation + vtln_warp_factor (float): Vtln warp factor + mel_freq (Tensor): Given frequency in Mel + + Returns: + Tensor: ``mel_freq`` after vtln warp + """ + return mel_scale( + vtln_warp_freq( + vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, vtln_warp_factor, inverse_mel_scale(mel_freq) + ) + ) + + +def get_mel_banks( + num_bins: int, + window_length_padded: int, + sample_freq: float, + low_freq: float, + high_freq: float, + vtln_low: float, + vtln_high: float, + vtln_warp_factor: float, +) -> Tuple[Tensor, Tensor]: + """ + Returns: + (Tensor, Tensor): The tuple consists of ``bins`` (which is + melbank of size (``num_bins``, ``num_fft_bins``)) and ``center_freqs`` (which is + center frequencies of bins of size (``num_bins``)). + """ + assert num_bins > 3, "Must have at least 3 mel bins" + assert window_length_padded % 2 == 0 + num_fft_bins = window_length_padded / 2 + nyquist = 0.5 * sample_freq + + if high_freq <= 0.0: + high_freq += nyquist + + assert ( + (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq) + ), "Bad values in options: low-freq {} and high-freq {} vs. nyquist {}".format(low_freq, high_freq, nyquist) + + # fft-bin width [think of it as Nyquist-freq / half-window-length] + fft_bin_width = sample_freq / window_length_padded + mel_low_freq = mel_scale_scalar(low_freq) + mel_high_freq = mel_scale_scalar(high_freq) + + # divide by num_bins+1 in next line because of end-effects where the bins + # spread out to the sides. + mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1) + + if vtln_high < 0.0: + vtln_high += nyquist + + assert vtln_warp_factor == 1.0 or ( + (low_freq < vtln_low < high_freq) and (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high) + ), "Bad values in options: vtln-low {} and vtln-high {}, versus " "low-freq {} and high-freq {}".format( + vtln_low, vtln_high, low_freq, high_freq + ) + + bin = torch.arange(num_bins).unsqueeze(1) + left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1) + center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # size(num_bins, 1) + right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # size(num_bins, 1) + + if vtln_warp_factor != 1.0: + left_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, left_mel) + center_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, center_mel) + right_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, right_mel) + + center_freqs = inverse_mel_scale(center_mel) # size (num_bins) + # size(1, num_fft_bins) + mel = mel_scale(fft_bin_width * torch.arange(num_fft_bins)).unsqueeze(0) + + # size (num_bins, num_fft_bins) + up_slope = (mel - left_mel) / (center_mel - left_mel) + down_slope = (right_mel - mel) / (right_mel - center_mel) + + if vtln_warp_factor == 1.0: + # left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values + bins = torch.max(torch.zeros(1), torch.min(up_slope, down_slope)) + else: + # warping can move the order of left_mel, center_mel, right_mel anywhere + bins = torch.zeros_like(up_slope) + up_idx = torch.gt(mel, left_mel) & torch.le(mel, center_mel) # left_mel < mel <= center_mel + down_idx = torch.gt(mel, center_mel) & torch.lt(mel, right_mel) # center_mel < mel < right_mel + bins[up_idx] = up_slope[up_idx] + bins[down_idx] = down_slope[down_idx] + + return bins, center_freqs + + +def fbank( + waveform: Tensor, + blackman_coeff: float = 0.42, + channel: int = -1, + dither: float = 0.0, + energy_floor: float = 1.0, + frame_length: float = 25.0, + frame_shift: float = 10.0, + high_freq: float = 0.0, + htk_compat: bool = False, + low_freq: float = 20.0, + min_duration: float = 0.0, + num_mel_bins: int = 23, + preemphasis_coefficient: float = 0.97, + raw_energy: bool = True, + remove_dc_offset: bool = True, + round_to_power_of_two: bool = True, + sample_frequency: float = 16000.0, + snip_edges: bool = True, + subtract_mean: bool = False, + use_energy: bool = False, + use_log_fbank: bool = True, + use_power: bool = True, + vtln_high: float = -500.0, + vtln_low: float = 100.0, + vtln_warp: float = 1.0, + window_type: str = POVEY, +) -> Tensor: + r"""Create a fbank from a raw audio signal. This matches the input/output of Kaldi's + compute-fbank-feats. + + Args: + waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2) + blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``) + channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``) + dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set + the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``) + energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution: + this floor is applied to the zeroth component, representing the total signal energy. The floor on the + individual spectrogram elements is fixed at std::numeric_limits::epsilon(). (Default: ``1.0``) + frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``) + frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``) + high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist) + (Default: ``0.0``) + htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible features + (need to change other parameters). (Default: ``False``) + low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``) + min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``) + num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``) + preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``) + raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``) + remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``) + round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input + to FFT. (Default: ``True``) + sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if + specified there) (Default: ``16000.0``) + snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit + in the file, and the number of frames depends on the frame_length. If False, the number of frames + depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``) + subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do + it this way. (Default: ``False``) + use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``) + use_log_fbank (bool, optional):If true, produce log-filterbank, else produce linear. (Default: ``True``) + use_power (bool, optional): If true, use power, else use magnitude. (Default: ``True``) + vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if + negative, offset from high-mel-freq (Default: ``-500.0``) + vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``) + vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``) + window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman') + (Default: ``'povey'``) + + Returns: + Tensor: A fbank identical to what Kaldi would output. The shape is (m, ``num_mel_bins + use_energy``) + where m is calculated in _get_strided + """ + device, dtype = waveform.device, waveform.dtype + + waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties( + waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient + ) + + if len(waveform) < min_duration * sample_frequency: + # signal is too short + return torch.empty(0, device=device, dtype=dtype) + + # strided_input, size (m, padded_window_size) and signal_log_energy, size (m) + strided_input, signal_log_energy = _get_window( + waveform, + padded_window_size, + window_size, + window_shift, + window_type, + blackman_coeff, + snip_edges, + raw_energy, + energy_floor, + dither, + remove_dc_offset, + preemphasis_coefficient, + ) + + # size (m, padded_window_size // 2 + 1) + spectrum = torch.fft.rfft(strided_input).abs() + if use_power: + spectrum = spectrum.pow(2.0) + + # size (num_mel_bins, padded_window_size // 2) + mel_energies, _ = get_mel_banks( + num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp + ) + mel_energies = mel_energies.to(device=device, dtype=dtype) + + # pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1) + mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0) + + # sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins) + mel_energies = torch.mm(spectrum, mel_energies.T) + if use_log_fbank: + # avoid log of zero (which should be prevented anyway by dithering) + mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log() + + # if use_energy then add it as the last column for htk_compat == true else first column + if use_energy: + signal_log_energy = signal_log_energy.unsqueeze(1) # size (m, 1) + # returns size (m, num_mel_bins + 1) + if htk_compat: + mel_energies = torch.cat((mel_energies, signal_log_energy), dim=1) + else: + mel_energies = torch.cat((signal_log_energy, mel_energies), dim=1) + + mel_energies = _subtract_column_mean(mel_energies, subtract_mean) + return mel_energies + + +def _get_dct_matrix(num_ceps: int, num_mel_bins: int) -> Tensor: + # returns a dct matrix of size (num_mel_bins, num_ceps) + # size (num_mel_bins, num_mel_bins) + dct_matrix = torchaudio.functional.create_dct(num_mel_bins, num_mel_bins, "ortho") + # kaldi expects the first cepstral to be weighted sum of factor sqrt(1/num_mel_bins) + # this would be the first column in the dct_matrix for torchaudio as it expects a + # right multiply (which would be the first column of the kaldi's dct_matrix as kaldi + # expects a left multiply e.g. dct_matrix * vector). + dct_matrix[:, 0] = math.sqrt(1 / float(num_mel_bins)) + dct_matrix = dct_matrix[:, :num_ceps] + return dct_matrix + + +def _get_lifter_coeffs(num_ceps: int, cepstral_lifter: float) -> Tensor: + # returns size (num_ceps) + # Compute liftering coefficients (scaling on cepstral coeffs) + # coeffs are numbered slightly differently from HTK: the zeroth index is C0, which is not affected. + i = torch.arange(num_ceps) + return 1.0 + 0.5 * cepstral_lifter * torch.sin(math.pi * i / cepstral_lifter) + + +def mfcc( + waveform: Tensor, + blackman_coeff: float = 0.42, + cepstral_lifter: float = 22.0, + channel: int = -1, + dither: float = 0.0, + energy_floor: float = 1.0, + frame_length: float = 25.0, + frame_shift: float = 10.0, + high_freq: float = 0.0, + htk_compat: bool = False, + low_freq: float = 20.0, + num_ceps: int = 13, + min_duration: float = 0.0, + num_mel_bins: int = 23, + preemphasis_coefficient: float = 0.97, + raw_energy: bool = True, + remove_dc_offset: bool = True, + round_to_power_of_two: bool = True, + sample_frequency: float = 16000.0, + snip_edges: bool = True, + subtract_mean: bool = False, + use_energy: bool = False, + vtln_high: float = -500.0, + vtln_low: float = 100.0, + vtln_warp: float = 1.0, + window_type: str = POVEY, +) -> Tensor: + r"""Create a mfcc from a raw audio signal. This matches the input/output of Kaldi's + compute-mfcc-feats. + + Args: + waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2) + blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``) + cepstral_lifter (float, optional): Constant that controls scaling of MFCCs (Default: ``22.0``) + channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``) + dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set + the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``) + energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution: + this floor is applied to the zeroth component, representing the total signal energy. The floor on the + individual spectrogram elements is fixed at std::numeric_limits::epsilon(). (Default: ``1.0``) + frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``) + frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``) + high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist) + (Default: ``0.0``) + htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible + features (need to change other parameters). (Default: ``False``) + low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``) + num_ceps (int, optional): Number of cepstra in MFCC computation (including C0) (Default: ``13``) + min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``) + num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``) + preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``) + raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``) + remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``) + round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input + to FFT. (Default: ``True``) + sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if + specified there) (Default: ``16000.0``) + snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit + in the file, and the number of frames depends on the frame_length. If False, the number of frames + depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``) + subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do + it this way. (Default: ``False``) + use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``) + vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if + negative, offset from high-mel-freq (Default: ``-500.0``) + vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``) + vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``) + window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman') + (Default: ``"povey"``) + + Returns: + Tensor: A mfcc identical to what Kaldi would output. The shape is (m, ``num_ceps``) + where m is calculated in _get_strided + """ + assert num_ceps <= num_mel_bins, "num_ceps cannot be larger than num_mel_bins: %d vs %d" % (num_ceps, num_mel_bins) + + device, dtype = waveform.device, waveform.dtype + + # The mel_energies should not be squared (use_power=True), not have mean subtracted + # (subtract_mean=False), and use log (use_log_fbank=True). + # size (m, num_mel_bins + use_energy) + feature = fbank( + waveform=waveform, + blackman_coeff=blackman_coeff, + channel=channel, + dither=dither, + energy_floor=energy_floor, + frame_length=frame_length, + frame_shift=frame_shift, + high_freq=high_freq, + htk_compat=htk_compat, + low_freq=low_freq, + min_duration=min_duration, + num_mel_bins=num_mel_bins, + preemphasis_coefficient=preemphasis_coefficient, + raw_energy=raw_energy, + remove_dc_offset=remove_dc_offset, + round_to_power_of_two=round_to_power_of_two, + sample_frequency=sample_frequency, + snip_edges=snip_edges, + subtract_mean=False, + use_energy=use_energy, + use_log_fbank=True, + use_power=True, + vtln_high=vtln_high, + vtln_low=vtln_low, + vtln_warp=vtln_warp, + window_type=window_type, + ) + + if use_energy: + # size (m) + signal_log_energy = feature[:, num_mel_bins if htk_compat else 0] + # offset is 0 if htk_compat==True else 1 + mel_offset = int(not htk_compat) + feature = feature[:, mel_offset : (num_mel_bins + mel_offset)] + + # size (num_mel_bins, num_ceps) + dct_matrix = _get_dct_matrix(num_ceps, num_mel_bins).to(dtype=dtype, device=device) + + # size (m, num_ceps) + feature = feature.matmul(dct_matrix) + + if cepstral_lifter != 0.0: + # size (1, num_ceps) + lifter_coeffs = _get_lifter_coeffs(num_ceps, cepstral_lifter).unsqueeze(0) + feature *= lifter_coeffs.to(device=device, dtype=dtype) + + # if use_energy then replace the last column for htk_compat == true else first column + if use_energy: + feature[:, 0] = signal_log_energy + + if htk_compat: + energy = feature[:, 0].unsqueeze(1) # size (m, 1) + feature = feature[:, 1:] # size (m, num_ceps - 1) + if not use_energy: + # scale on C0 (actually removing a scale we previously added that's + # part of one common definition of the cosine transform.) + energy *= math.sqrt(2) + + feature = torch.cat((feature, energy), dim=1) + + feature = _subtract_column_mean(feature, subtract_mean) + return feature diff --git a/ola/model/speech_encoder/beats/modules.py b/ola/model/speech_encoder/beats/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..18e2d2066b93139acc9427f0edcdd96b12769f25 --- /dev/null +++ b/ola/model/speech_encoder/beats/modules.py @@ -0,0 +1,218 @@ +# -------------------------------------------------------- +# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) +# Github source: https://github.com/microsoft/unilm/tree/master/beats +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import math +import warnings +import torch +from torch import Tensor, nn +import torch.nn.functional as F + + +class GradMultiply(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale): + ctx.scale = scale + res = x.new(x) + return res + + @staticmethod + def backward(ctx, grad): + return grad * ctx.scale, None + + +class SamePad(nn.Module): + def __init__(self, kernel_size, causal=False): + super().__init__() + if causal: + self.remove = kernel_size - 1 + else: + self.remove = 1 if kernel_size % 2 == 0 else 0 + + def forward(self, x): + if self.remove > 0: + x = x[:, :, : -self.remove] + return x + + +class Swish(nn.Module): + def __init__(self): + super(Swish, self).__init__() + self.act = torch.nn.Sigmoid() + + def forward(self, x): + return x * self.act(x) + + +class GLU_Linear(nn.Module): + def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True): + super(GLU_Linear, self).__init__() + + self.glu_type = glu_type + self.output_dim = output_dim + + if glu_type == "sigmoid": + self.glu_act = torch.nn.Sigmoid() + elif glu_type == "swish": + self.glu_act = Swish() + elif glu_type == "relu": + self.glu_act = torch.nn.ReLU() + elif glu_type == "gelu": + self.glu_act = torch.nn.GELU() + + if bias_in_glu: + self.linear = nn.Linear(input_dim, output_dim * 2, True) + else: + self.linear = nn.Linear(input_dim, output_dim * 2, False) + + def forward(self, x): + # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case + x = self.linear(x) + + if self.glu_type == "bilinear": + x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2]) + else: + x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2])) + + return x + + +def gelu_accurate(x): + if not hasattr(gelu_accurate, "_a"): + gelu_accurate._a = math.sqrt(2 / math.pi) + return ( + 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) + ) + + +def gelu(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x.float()).type_as(x) + + +def get_activation_fn(activation: str): + """Returns the activation function corresponding to `activation`""" + + if activation == "relu": + return F.relu + elif activation == "gelu": + return gelu + elif activation == "gelu_fast": + warnings.warn( + "--activation-fn=gelu_fast has been renamed to gelu_accurate" + ) + return gelu_accurate + elif activation == "gelu_accurate": + return gelu_accurate + elif activation == "tanh": + return torch.tanh + elif activation == "linear": + return lambda x: x + elif activation == "glu": + return lambda x: x + else: + raise RuntimeError("--activation-fn {} not supported".format(activation)) + + +def quant_noise(module, p, block_size): + """ + Wraps modules and applies quantization noise to the weights for + subsequent quantization with Iterative Product Quantization as + described in "Training with Quantization Noise for Extreme Model Compression" + + Args: + - module: nn.Module + - p: amount of Quantization Noise + - block_size: size of the blocks for subsequent quantization with iPQ + + Remarks: + - Module weights must have the right sizes wrt the block size + - Only Linear, Embedding and Conv2d modules are supported for the moment + - For more detail on how to quantize by blocks with convolutional weights, + see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" + - We implement the simplest form of noise here as stated in the paper + which consists in randomly dropping blocks + """ + + # if no quantization noise, don't register hook + if p <= 0: + return module + + # supported modules + assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) + + # test whether module.weight has the right sizes wrt block_size + is_conv = module.weight.ndim == 4 + + # 2D matrix + if not is_conv: + assert ( + module.weight.size(1) % block_size == 0 + ), "Input features must be a multiple of block sizes" + + # 4D matrix + else: + # 1x1 convolutions + if module.kernel_size == (1, 1): + assert ( + module.in_channels % block_size == 0 + ), "Input channels must be a multiple of block sizes" + # regular convolutions + else: + k = module.kernel_size[0] * module.kernel_size[1] + assert k % block_size == 0, "Kernel size must be a multiple of block size" + + def _forward_pre_hook(mod, input): + # no noise for evaluation + if mod.training: + if not is_conv: + # gather weight and sizes + weight = mod.weight + in_features = weight.size(1) + out_features = weight.size(0) + + # split weight matrix into blocks and randomly drop selected blocks + mask = torch.zeros( + in_features // block_size * out_features, device=weight.device + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) + + else: + # gather weight and sizes + weight = mod.weight + in_channels = mod.in_channels + out_channels = mod.out_channels + + # split weight matrix into blocks and randomly drop selected blocks + if mod.kernel_size == (1, 1): + mask = torch.zeros( + int(in_channels // block_size * out_channels), + device=weight.device, + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) + else: + mask = torch.zeros( + weight.size(0), weight.size(1), device=weight.device + ) + mask.bernoulli_(p) + mask = ( + mask.unsqueeze(2) + .unsqueeze(3) + .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) + ) + + # scale weights and apply mask + mask = mask.to( + torch.bool + ) # x.bool() is not currently supported in TorchScript + s = 1 / (1 - p) + mod.weight.data = s * weight.masked_fill(mask, 0) + + module.register_forward_pre_hook(_forward_pre_hook) + return module diff --git a/ola/model/speech_encoder/beats/quantizer.py b/ola/model/speech_encoder/beats/quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..704be4c357bce7ee425ea2b6737b536333a5a63c --- /dev/null +++ b/ola/model/speech_encoder/beats/quantizer.py @@ -0,0 +1,215 @@ +# -------------------------------------------------------- +# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) +# Github source: https://github.com/microsoft/unilm/tree/master/beats +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on VQGAN code bases +# https://github.com/CompVis/taming-transformers +# --------------------------------------------------------' + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as distributed + +try: + from einops import rearrange, repeat +except ImportError: + pass + + +def l2norm(t): + return F.normalize(t, p=2, dim=-1) + + +def ema_inplace(moving_avg, new, decay): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def sample_vectors(samples, num): + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device=device)[:num] + else: + indices = torch.randint(0, num_samples, (num,), device=device) + + return samples[indices] + + +def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False): + dim, dtype, device = samples.shape[-1], samples.dtype, samples.device + + means = sample_vectors(samples, num_clusters) + + for _ in range(num_iters): + if use_cosine_sim: + dists = samples @ means.t() + else: + diffs = rearrange(samples, 'n d -> n () d') \ + - rearrange(means, 'c d -> () c d') + dists = -(diffs ** 2).sum(dim=-1) + + buckets = dists.max(dim=-1).indices + bins = torch.bincount(buckets, minlength=num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) + new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples) + new_means = new_means / bins_min_clamped[..., None] + + if use_cosine_sim: + new_means = l2norm(new_means) + + means = torch.where(zero_mask[..., None], means, new_means) + + return means, bins + + +class EmbeddingEMA(nn.Module): + def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5, kmeans_init=True, codebook_init_path=''): + super().__init__() + self.num_tokens = num_tokens + self.codebook_dim = codebook_dim + self.decay = decay + self.eps = eps + if codebook_init_path == '': + if not kmeans_init: + weight = torch.randn(num_tokens, codebook_dim) + weight = l2norm(weight) + else: + weight = torch.zeros(num_tokens, codebook_dim) + self.register_buffer('initted', torch.Tensor([not kmeans_init])) + else: + print(f"load init codebook weight from {codebook_init_path}") + codebook_ckpt_weight = torch.load(codebook_init_path, map_location='cpu') + weight = codebook_ckpt_weight.clone() + self.register_buffer('initted', torch.Tensor([True])) + + self.weight = nn.Parameter(weight, requires_grad=False) + self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False) + self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False) + # self.register_buffer('initted', torch.Tensor([not kmeans_init])) + self.update = True + + @torch.jit.ignore + def init_embed_(self, data): + if self.initted: + return + print("Performing Kemans init for codebook") + embed, cluster_size = kmeans(data, self.num_tokens, 10, use_cosine_sim=True) + self.weight.data.copy_(embed) + self.cluster_size.data.copy_(cluster_size) + self.initted.data.copy_(torch.Tensor([True])) + + def forward(self, embed_id): + return F.embedding(embed_id, self.weight) + + def cluster_size_ema_update(self, new_cluster_size): + self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay) + + def embed_avg_ema_update(self, new_embed_avg): + self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay) + + def weight_update(self, num_tokens): + n = self.cluster_size.sum() + smoothed_cluster_size = ( + (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n + ) + # normalize embedding average with smoothed cluster size + embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) + # embed_normalized = l2norm(self.embed_avg / smoothed_cluster_size.unsqueeze(1)) + self.weight.data.copy_(embed_normalized) + + +def norm_ema_inplace(moving_avg, new, decay): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + moving_avg.data.copy_(l2norm(moving_avg.data)) + + +class NormEMAVectorQuantizer(nn.Module): + def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5, + statistic_code_usage=True, kmeans_init=False, codebook_init_path=''): + super().__init__() + self.codebook_dim = embedding_dim + self.num_tokens = n_embed + self.beta = beta + self.decay = decay + + # learnable = True if orthogonal_reg_weight > 0 else False + self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps, kmeans_init, codebook_init_path) + + self.statistic_code_usage = statistic_code_usage + if statistic_code_usage: + self.register_buffer('cluster_size', torch.zeros(n_embed)) + if distributed.is_available() and distributed.is_initialized(): + print("ddp is enable, so use ddp_reduce to sync the statistic_code_usage for each gpu!") + self.all_reduce_fn = distributed.all_reduce + else: + self.all_reduce_fn = nn.Identity() + + def reset_cluster_size(self, device): + if self.statistic_code_usage: + self.register_buffer('cluster_size', torch.zeros(self.num_tokens)) + self.cluster_size = self.cluster_size.to(device) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + # z, 'b c h w -> b h w c' + # z = rearrange(z, 'b c h w -> b h w c') + # z = z.transpose(1, 2) + z = l2norm(z) + z_flattened = z.reshape(-1, self.codebook_dim) + + self.embedding.init_embed_(z_flattened) + + d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \ + self.embedding.weight.pow(2).sum(dim=1) - 2 * \ + torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n' + + encoding_indices = torch.argmin(d, dim=1) + + z_q = self.embedding(encoding_indices).view(z.shape) + + encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) + + if not self.training: + with torch.no_grad(): + cluster_size = encodings.sum(0) + self.all_reduce_fn(cluster_size) + ema_inplace(self.cluster_size, cluster_size, self.decay) + + if self.training and self.embedding.update: + # EMA cluster size + + bins = encodings.sum(0) + self.all_reduce_fn(bins) + + # self.embedding.cluster_size_ema_update(bins) + ema_inplace(self.cluster_size, bins, self.decay) + + zero_mask = (bins == 0) + bins = bins.masked_fill(zero_mask, 1.) + + embed_sum = z_flattened.t() @ encodings + self.all_reduce_fn(embed_sum) + + embed_normalized = (embed_sum / bins.unsqueeze(0)).t() + embed_normalized = l2norm(embed_normalized) + + embed_normalized = torch.where(zero_mask[..., None], self.embedding.weight, + embed_normalized) + norm_ema_inplace(self.embedding.weight, embed_normalized, self.decay) + + # compute loss for embedding + loss = self.beta * F.mse_loss(z_q.detach(), z) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + # z_q, 'b h w c -> b c h w' + # z_q = rearrange(z_q, 'b h w c -> b c h w') + # z_q = z_q.transpose(1, 2) + return z_q, loss, encoding_indices \ No newline at end of file diff --git a/ola/model/speech_encoder/builder.py b/ola/model/speech_encoder/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..75595dcade66ab6bf6be9be4f54e31ab9a91bbb1 --- /dev/null +++ b/ola/model/speech_encoder/builder.py @@ -0,0 +1,15 @@ +from .speech_encoder import WhisperWrappedEncoder, DualWrappedEncoder +import torch.nn as nn + +def build_speech_encoder(config): + speech_encoder_type = getattr(config, 'speech_encoder_type', None) + + print(f"Building speech encoder: {speech_encoder_type}") + if "whisper" in speech_encoder_type.lower(): + return WhisperWrappedEncoder.load(config) + elif "dual" in speech_encoder_type.lower(): + return DualWrappedEncoder(config) + elif "none" in speech_encoder_type.lower(): + return None + + raise ValueError(f'Unknown speech encoder: {speech_encoder_type}') diff --git a/ola/model/speech_encoder/speech_encoder.py b/ola/model/speech_encoder/speech_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..d6886fb38393ba839748109de03c71da7d6c5bbf --- /dev/null +++ b/ola/model/speech_encoder/speech_encoder.py @@ -0,0 +1,295 @@ +import types +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import WhisperFeatureExtractor +import whisper +import torch +try: + torch.set_default_device("cpu") +except Exception: + pass +import accelerate +from ola.model.speech_encoder.beats.BEATs import BEATsConfig, BEATs + +class WhisperWrappedEncoder: + + @classmethod + def load(cls, model_config): + + def replace_layer_norm(module): + from whisper.model import LayerNorm + for name, child in module.named_children(): + if isinstance(child, LayerNorm): + # Check if any parameter is a meta tensor + has_meta = any(p.is_meta for p in child.parameters()) + if has_meta: + # For meta tensors, create new layer norm with same shape + new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine) + else: + old_params = child.state_dict() + new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine) + new_layer_norm.load_state_dict(old_params) + setattr(module, name, new_layer_norm) + else: + replace_layer_norm(child) + + # Load whisper model, handling both file paths and model names + speech_encoder_path = model_config.speech_encoder + + # First try loading directly (works for both file paths and model names) + try: + encoder = whisper.load_model(name=speech_encoder_path, device='cpu').encoder + except (NotImplementedError, RuntimeError) as e: + if "meta tensor" in str(e): + # Meta tensor issue - load model without device specification + print(f"Detected meta tensor issue, using alternative loading approach...") + + # Load checkpoint directly to avoid device issues + import os + if os.path.isfile(speech_encoder_path): + # Load from file + checkpoint = torch.load(speech_encoder_path, map_location='cpu') + + # Create model from checkpoint + from whisper.model import ModelDimensions, Whisper + dims = ModelDimensions(**checkpoint["dims"]) + model = Whisper(dims) + + # Load state dict without moving to device + model.load_state_dict(checkpoint["model_state_dict"]) + + # Get encoder without device movement + encoder = model.encoder + else: + # Try loading as model name without device + import whisper.model as whisper_model + # This is a fallback - may need adjustment based on actual model + raise RuntimeError(f"Cannot load model {speech_encoder_path} due to meta tensor issues") + else: + raise e + + replace_layer_norm(encoder) + return encoder + +class DualWrappedEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.whisper_model = self.load_whisper(config) + self.beats_model = self.load_beats(config) + + def load_whisper(self, model_config): + + def replace_layer_norm(module): + from whisper.model import LayerNorm + for name, child in module.named_children(): + if isinstance(child, LayerNorm): + # Check if any parameter is a meta tensor + has_meta = any(p.is_meta for p in child.parameters()) + if has_meta: + # For meta tensors, create new layer norm with same shape + new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine) + else: + old_params = child.state_dict() + new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine) + new_layer_norm.load_state_dict(old_params) + setattr(module, name, new_layer_norm) + else: + replace_layer_norm(child) + + # Load whisper model, handling both file paths and model names + speech_encoder_path = model_config.speech_encoder + + # First try loading directly (works for both file paths and model names) + # try: + # breakpoint() + import torch + from whisper.model import Whisper, ModelDimensions + + # 1) Load checkpoint to CPU (weights are real tensors here) + ckpt = torch.load("/data1/cxy/model/THUdyh/Ola-7b/large-v3.pt", map_location="cpu") + dims = ModelDimensions(**ckpt["dims"]) + + # 2) Build the module skeleton, then MATERIALIZE tensors on CPU + model = Whisper(dims) + model.to_empty(device="cpu") # <-- crucial when meta is involved + + # 3) Load weights + missing, unexpected = model.load_state_dict(ckpt["model_state_dict"], strict=True) + model.eval() + + encoder = model.encoder + print("missing:", missing) + print("unexpected:", unexpected) + # with accelerate.init_empty_weights(): + # encoder = whisper.load_model(name=speech_encoder_path, device='cpu').encoder + # state = torch.load("/data1/cxy/model/THUdyh/Ola-7b/large-v3.pt", map_location='cpu')['model_state_dict']['encoder.positional_embedding'] + # breakpoint() + # except (NotImplementedError, RuntimeError) as e: + # if "meta tensor" in str(e): + # # Meta tensor issue - load model without device specification + # print(f"Detected meta tensor issue, using alternative loading approach...") + + # # Load checkpoint directly to avoid device issues + # import os + # if os.path.isfile(speech_encoder_path): + # # Load from file + # checkpoint = torch.load(speech_encoder_path, map_location='cpu') + + # # Create model from checkpoint + # # breakpoint() + # from whisper.model import ModelDimensions, Whisper + # dims = ModelDimensions(**checkpoint["dims"]) + # model = Whisper(dims) + + # # Load state dict without moving to device + # model.load_state_dict(checkpoint["model_state_dict"]) + + # # Get encoder without device movement + # encoder = model.encoder + # else: + # # Try loading as model name without device + # import whisper.model as whisper_model + # # This is a fallback - may need adjustment based on actual model + # raise RuntimeError(f"Cannot load model {speech_encoder_path} due to meta tensor issues") + # else: + # raise e + + replace_layer_norm(encoder) + return encoder + + def load_beats(self, model_config): + beats_path = model_config.music_encoder + print("Loading BEATs Model") + beats_ckpt = torch.load(beats_path, map_location='cpu') + beats_cfg = BEATsConfig(beats_ckpt['cfg']) + beats = BEATs(beats_cfg) + beats = beats.to_empty(device='cpu') + # Load state dict + beats.load_state_dict(beats_ckpt['model'], strict=True) + # breakpoint() + # 检查BEATs模型权重是否有问题 + print("Checking BEATs model weights for NaN/Inf values...") + nan_count = 0 + inf_count = 0 + for name, param in beats.named_parameters(): + if torch.isnan(param).any(): + print(f"ERROR - BEATs parameter {name} contains NaN values!") + print(f"Debug - Parameter shape: {param.shape}") + print(f"Debug - Parameter dtype: {param.dtype}") + print(f"Debug - Parameter device: {param.device}") + print(f"Debug - NaN count: {torch.isnan(param).sum().item()}") + nan_count += 1 + if torch.isinf(param).any(): + print(f"ERROR - BEATs parameter {name} contains Inf values!") + print(f"Debug - Parameter shape: {param.shape}") + print(f"Debug - Inf count: {torch.isinf(param).sum().item()}") + inf_count += 1 + + if nan_count > 0 or inf_count > 0: + print(f"ERROR - Found NaN values in {nan_count} parameters and Inf values in {inf_count} parameters") + print("This indicates the BEATs model weights are corrupted!") + raise ValueError(f"BEATs model weights are corrupted: {nan_count} NaN parameters, {inf_count} Inf parameters") + else: + print("BEATs model weights are clean (no NaN or Inf values)") + + return beats + + def forward(self, x, raw_wav=None, audio_padding_mask=None): + with torch.no_grad(): + self.beats_model = self.beats_model.float() + + # Debug: Check input data + print(f"Debug - Speech encoder input x range: {x.min().item()} to {x.max().item()}") + print(f"Debug - Speech encoder input x has nan: {torch.isnan(x).any().item()}") + print(f"Debug - Speech encoder input raw_wav range: {raw_wav.min().item()} to {raw_wav.max().item()}") + print(f"Debug - Speech encoder input raw_wav has nan: {torch.isnan(raw_wav).any().item()}") + + # Check Whisper model + print(f"Debug - Whisper model device: {next(self.whisper_model.parameters()).device}") + print(f"Debug - Input x device: {x.device}") + + speech_embeds = self.whisper_model(x) + print(f"Debug - Whisper output range: {speech_embeds.min().item()} to {speech_embeds.max().item()}") + print(f"Debug - Whisper output has nan: {torch.isnan(speech_embeds).any().item()}") + + # Check BEATs model + print(f"Debug - BEATs model device: {next(self.beats_model.parameters()).device}") + print(f"Debug - Input raw_wav device: {raw_wav.device}") + + # Check if BEATs model has nan weights (should be fixed now) + has_nan_weights = False + for name, param in self.beats_model.named_parameters(): + if torch.isnan(param).any(): + print(f"WARNING - BEATs parameter {name} still has nan values after fix!") + has_nan_weights = True + if not has_nan_weights: + print("Debug - BEATs model weights are clean (no nan)") + + try: + # 详细检查BEATs模型输入 + raw_wav_float = raw_wav.float() + print(f"Debug - BEATs input raw_wav_float range: {raw_wav_float.min().item()} to {raw_wav_float.max().item()}") + print(f"Debug - BEATs input raw_wav_float shape: {raw_wav_float.shape}") + print(f"Debug - BEATs input raw_wav_float has nan: {torch.isnan(raw_wav_float).any().item()}") + print(f"Debug - BEATs input raw_wav_float has inf: {torch.isinf(raw_wav_float).any().item()}") + print(f"Debug - BEATs input raw_wav_float dtype: {raw_wav_float.dtype}") + print(f"Debug - BEATs input raw_wav_float device: {raw_wav_float.device}") + + # 检查输入是否在BEATs期望的范围内 [-1, 1] + if raw_wav_float.min().item() < -1.0 or raw_wav_float.max().item() > 1.0: + print(f"WARNING - BEATs input out of expected range [-1, 1]! Clipping to valid range.") + raw_wav_float = torch.clamp(raw_wav_float, -1.0, 1.0) + print(f"Debug - After clipping range: {raw_wav_float.min().item()} to {raw_wav_float.max().item()}") + else: + print("Debug - BEATs input is within expected range [-1, 1]") + + if audio_padding_mask is not None: + print(f"Debug - BEATs input padding_mask range: {audio_padding_mask.min().item()} to {audio_padding_mask.max().item()}") + print(f"Debug - BEATs input padding_mask shape: {audio_padding_mask.shape}") + print(f"Debug - BEATs input padding_mask has nan: {torch.isnan(audio_padding_mask).any().item()}") + print(f"Debug - BEATs input padding_mask dtype: {audio_padding_mask.dtype}") + else: + print("Debug - BEATs input padding_mask is None") + + # 在调用BEATs之前,让我们检查模型状态 + print("Debug - BEATs model training mode:", self.beats_model.training) + print("Debug - BEATs model device:", next(self.beats_model.parameters()).device) + + # 让我们逐步调试BEATs的内部处理 + print("Debug - Calling BEATs extract_features...") + audio_embeds, _ = self.beats_model.extract_features(raw_wav_float, padding_mask=audio_padding_mask, feature_only=True) + print(f"Debug - BEATs output range: {audio_embeds.min().item()} to {audio_embeds.max().item()}") + print(f"Debug - BEATs output has nan: {torch.isnan(audio_embeds).any().item()}") + print(f"Debug - BEATs output shape: {audio_embeds.shape}") + print(f"Debug - BEATs output dtype: {audio_embeds.dtype}") + + # 检查BEATs输出是否有NaN值 + if torch.isnan(audio_embeds).any(): + print("ERROR - BEATs output contains NaN values!") + print(f"Debug - NaN positions: {torch.isnan(audio_embeds).sum().item()} out of {audio_embeds.numel()}") + print(f"Debug - NaN ratio: {torch.isnan(audio_embeds).float().mean().item():.4f}") + # 不替换,直接抛出异常来找出根本原因 + raise ValueError("BEATs model produced NaN values - this indicates a bug in the model or input data") + except Exception as e: + print(f"ERROR - BEATs model failed: {e}") + print("Falling back to Whisper-only mode") + # Create zero audio embeddings with the same shape as expected + audio_embeds = torch.zeros(speech_embeds.shape[0], speech_embeds.shape[1], 1024, device=speech_embeds.device, dtype=speech_embeds.dtype) + + if audio_embeds.size(1) < speech_embeds.size(1): + audio_embeds = F.pad(audio_embeds, (0, 0, 0, speech_embeds.size(1) - audio_embeds.size(1))) + elif audio_embeds.size(1) > speech_embeds.size(1): + speech_embeds = F.pad(speech_embeds, (0, 0, 0, audio_embeds.size(1) - speech_embeds.size(1))) + speech_embeds = torch.cat((speech_embeds, audio_embeds), dim=-1) + speech_embeds = speech_embeds.to(torch.bfloat16) + + # 最终检查是否有NaN值 + if torch.isnan(speech_embeds).any(): + print("ERROR - Final speech embeddings contain NaN values!") + print(f"Debug - NaN positions: {torch.isnan(speech_embeds).sum().item()} out of {speech_embeds.numel()}") + print(f"Debug - NaN ratio: {torch.isnan(speech_embeds).float().mean().item():.4f}") + raise ValueError("Final speech embeddings contain NaN values - this indicates a bug in the speech encoder") + + return speech_embeds \ No newline at end of file diff --git a/ola/model/speech_projector/__pycache__/builder.cpython-312.pyc b/ola/model/speech_projector/__pycache__/builder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70b1bff2aab095a98263a63ea8ff02e56eaae878 Binary files /dev/null and b/ola/model/speech_projector/__pycache__/builder.cpython-312.pyc differ diff --git a/ola/model/speech_projector/__pycache__/speech_projector.cpython-312.pyc b/ola/model/speech_projector/__pycache__/speech_projector.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4d2787b19503c09caf787bce5f875650b99a6b9 Binary files /dev/null and b/ola/model/speech_projector/__pycache__/speech_projector.cpython-312.pyc differ diff --git a/ola/model/speech_projector/builder.py b/ola/model/speech_projector/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..bf55a32d6839766945f05f40ed03fa655a579ccb --- /dev/null +++ b/ola/model/speech_projector/builder.py @@ -0,0 +1,11 @@ +from .speech_projector import EncoderProjectorConcat + + +def build_speech_projector(config): + projector_type = getattr(config, 'speech_projector_type', 'linear') + if projector_type == 'linear': + return EncoderProjectorConcat(config) + elif projector_type == 'none': + return None + + raise ValueError(f'Unknown projector type: {projector_type}') diff --git a/ola/model/speech_projector/speech_projector.py b/ola/model/speech_projector/speech_projector.py new file mode 100644 index 0000000000000000000000000000000000000000..1e0156065bc238807c57918aeeed48df575ff533 --- /dev/null +++ b/ola/model/speech_projector/speech_projector.py @@ -0,0 +1,47 @@ +import torch +import torch.nn as nn +import math + +class EncoderProjectorConcat(nn.Module): + def __init__(self, config): + super().__init__() + self.k = config.speech_encoder_ds_rate + self.encoder_dim = config.speech_encoder_hidden_size + self.llm_dim = config.hidden_size + self.linear1 = nn.Linear(self.encoder_dim * self.k, 2048) + self.relu = nn.ReLU() + self.linear2 = nn.Linear(2048, config.hidden_size) + + embed_std = 1 / math.sqrt(config.hidden_size) + self.speech_newline = nn.Parameter( + torch.randn(config.hidden_size) * embed_std + ) + self.speech_begin = nn.Parameter( + torch.randn(config.hidden_size) * embed_std + ) + self.speech_end = nn.Parameter( + torch.randn(config.hidden_size) * embed_std + ) + + def forward(self, x): + batch_size, seq_len, dim = x.size() + num_frames_to_discard = seq_len % self.k + if num_frames_to_discard > 0: + x = x[:, :-num_frames_to_discard, :] + seq_len = x.size(1) + + x = x.contiguous() + x = x.view(batch_size, seq_len // self.k, dim * self.k) + x = self.linear1(x) + x = self.relu(x) + x = self.linear2(x) + x = torch.cat([ + x, + self.speech_newline.reshape(1, 1, -1).expand(batch_size, 1, -1).to(x.dtype) + ], dim=1) + begin = self.speech_begin.reshape(1, -1).to(x.dtype) + end = self.speech_end.reshape(1, -1).to(x.dtype) + x = x.flatten(0, 1) + x = torch.cat([begin, x, end], dim=0) + # x = x.flatten(0, 1) + return x \ No newline at end of file diff --git a/ola/serve/__init__.py b/ola/serve/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ola/serve/controller.py b/ola/serve/controller.py new file mode 100644 index 0000000000000000000000000000000000000000..0e58de52fd71cff951fc7ca2cde75fded460e3f7 --- /dev/null +++ b/ola/serve/controller.py @@ -0,0 +1,298 @@ +""" +A controller manages distributed workers. +It sends worker addresses to clients. +""" +import argparse +import asyncio +import dataclasses +from enum import Enum, auto +import json +import logging +import time +from typing import List, Union +import threading + +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse +import numpy as np +import requests +import uvicorn + +from omni_speech.constants import CONTROLLER_HEART_BEAT_EXPIRATION +from omni_speech.utils import build_logger, server_error_msg + + +logger = build_logger("controller", "controller.log") + + +class DispatchMethod(Enum): + LOTTERY = auto() + SHORTEST_QUEUE = auto() + + @classmethod + def from_str(cls, name): + if name == "lottery": + return cls.LOTTERY + elif name == "shortest_queue": + return cls.SHORTEST_QUEUE + else: + raise ValueError(f"Invalid dispatch method") + + +@dataclasses.dataclass +class WorkerInfo: + model_names: List[str] + speed: int + queue_length: int + check_heart_beat: bool + last_heart_beat: str + + +def heart_beat_controller(controller): + while True: + time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION) + controller.remove_stable_workers_by_expiration() + + +class Controller: + def __init__(self, dispatch_method: str): + # Dict[str -> WorkerInfo] + self.worker_info = {} + self.dispatch_method = DispatchMethod.from_str(dispatch_method) + + self.heart_beat_thread = threading.Thread( + target=heart_beat_controller, args=(self,), daemon=True) + self.heart_beat_thread.start() + + logger.info("Init controller") + + def register_worker(self, worker_name: str, check_heart_beat: bool, + worker_status: dict): + if worker_name not in self.worker_info: + logger.info(f"Register a new worker: {worker_name}") + else: + logger.info(f"Register an existing worker: {worker_name}") + + if not worker_status: + worker_status = self.get_worker_status(worker_name) + if not worker_status: + return False + + self.worker_info[worker_name] = WorkerInfo( + worker_status["model_names"], worker_status["speed"], worker_status["queue_length"], + check_heart_beat, time.time()) + + logger.info(f"Register done: {worker_name}, {worker_status}") + return True + + def get_worker_status(self, worker_name: str): + try: + r = requests.post(worker_name + "/worker_get_status", timeout=5) + except requests.exceptions.RequestException as e: + logger.error(f"Get status fails: {worker_name}, {e}") + return None + + if r.status_code != 200: + logger.error(f"Get status fails: {worker_name}, {r}") + return None + + return r.json() + + def remove_worker(self, worker_name: str): + del self.worker_info[worker_name] + + def refresh_all_workers(self): + old_info = dict(self.worker_info) + self.worker_info = {} + + for w_name, w_info in old_info.items(): + if not self.register_worker(w_name, w_info.check_heart_beat, None): + logger.info(f"Remove stale worker: {w_name}") + + def list_models(self): + model_names = set() + + for w_name, w_info in self.worker_info.items(): + model_names.update(w_info.model_names) + + return list(model_names) + + def get_worker_address(self, model_name: str): + if self.dispatch_method == DispatchMethod.LOTTERY: + worker_names = [] + worker_speeds = [] + for w_name, w_info in self.worker_info.items(): + if model_name in w_info.model_names: + worker_names.append(w_name) + worker_speeds.append(w_info.speed) + worker_speeds = np.array(worker_speeds, dtype=np.float32) + norm = np.sum(worker_speeds) + if norm < 1e-4: + return "" + worker_speeds = worker_speeds / norm + if True: # Directly return address + pt = np.random.choice(np.arange(len(worker_names)), + p=worker_speeds) + worker_name = worker_names[pt] + return worker_name + + # Check status before returning + while True: + pt = np.random.choice(np.arange(len(worker_names)), + p=worker_speeds) + worker_name = worker_names[pt] + + if self.get_worker_status(worker_name): + break + else: + self.remove_worker(worker_name) + worker_speeds[pt] = 0 + norm = np.sum(worker_speeds) + if norm < 1e-4: + return "" + worker_speeds = worker_speeds / norm + continue + return worker_name + elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE: + worker_names = [] + worker_qlen = [] + for w_name, w_info in self.worker_info.items(): + if model_name in w_info.model_names: + worker_names.append(w_name) + worker_qlen.append(w_info.queue_length / w_info.speed) + if len(worker_names) == 0: + return "" + min_index = np.argmin(worker_qlen) + w_name = worker_names[min_index] + self.worker_info[w_name].queue_length += 1 + logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}") + return w_name + else: + raise ValueError(f"Invalid dispatch method: {self.dispatch_method}") + + def receive_heart_beat(self, worker_name: str, queue_length: int): + if worker_name not in self.worker_info: + logger.info(f"Receive unknown heart beat. {worker_name}") + return False + + self.worker_info[worker_name].queue_length = queue_length + self.worker_info[worker_name].last_heart_beat = time.time() + logger.info(f"Receive heart beat. {worker_name}") + return True + + def remove_stable_workers_by_expiration(self): + expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION + to_delete = [] + for worker_name, w_info in self.worker_info.items(): + if w_info.check_heart_beat and w_info.last_heart_beat < expire: + to_delete.append(worker_name) + + for worker_name in to_delete: + self.remove_worker(worker_name) + + def worker_api_generate_stream(self, params): + worker_addr = self.get_worker_address(params["model"]) + if not worker_addr: + logger.info(f"no worker: {params['model']}") + ret = { + "text": server_error_msg, + "error_code": 2, + } + yield json.dumps(ret).encode() + b"\0" + + try: + response = requests.post(worker_addr + "/worker_generate_stream", + json=params, stream=True, timeout=5) + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + yield chunk + b"\0" + except requests.exceptions.RequestException as e: + logger.info(f"worker timeout: {worker_addr}") + ret = { + "text": server_error_msg, + "error_code": 3, + } + yield json.dumps(ret).encode() + b"\0" + + + # Let the controller act as a worker to achieve hierarchical + # management. This can be used to connect isolated sub networks. + def worker_api_get_status(self): + model_names = set() + speed = 0 + queue_length = 0 + + for w_name in self.worker_info: + worker_status = self.get_worker_status(w_name) + if worker_status is not None: + model_names.update(worker_status["model_names"]) + speed += worker_status["speed"] + queue_length += worker_status["queue_length"] + + return { + "model_names": list(model_names), + "speed": speed, + "queue_length": queue_length, + } + + +app = FastAPI() + + +@app.post("/register_worker") +async def register_worker(request: Request): + data = await request.json() + controller.register_worker( + data["worker_name"], data["check_heart_beat"], + data.get("worker_status", None)) + + +@app.post("/refresh_all_workers") +async def refresh_all_workers(): + models = controller.refresh_all_workers() + + +@app.post("/list_models") +async def list_models(): + models = controller.list_models() + return {"models": models} + + +@app.post("/get_worker_address") +async def get_worker_address(request: Request): + data = await request.json() + addr = controller.get_worker_address(data["model"]) + return {"address": addr} + + +@app.post("/receive_heart_beat") +async def receive_heart_beat(request: Request): + data = await request.json() + exist = controller.receive_heart_beat( + data["worker_name"], data["queue_length"]) + return {"exist": exist} + + +@app.post("/worker_generate_stream") +async def worker_api_generate_stream(request: Request): + params = await request.json() + generator = controller.worker_api_generate_stream(params) + return StreamingResponse(generator) + + +@app.post("/worker_get_status") +async def worker_api_get_status(request: Request): + return controller.worker_api_get_status() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21001) + parser.add_argument("--dispatch-method", type=str, choices=[ + "lottery", "shortest_queue"], default="shortest_queue") + args = parser.parse_args() + logger.info(f"args: {args}") + + controller = Controller(args.dispatch_method) + uvicorn.run(app, host=args.host, port=args.port, log_level="info") \ No newline at end of file diff --git a/ola/serve/gradio_web_server.py b/ola/serve/gradio_web_server.py new file mode 100644 index 0000000000000000000000000000000000000000..19015a12e5ba09112a2cf49c8e234adca7c994de --- /dev/null +++ b/ola/serve/gradio_web_server.py @@ -0,0 +1,348 @@ +import argparse +import datetime +import json +import os +import time +import torch +import torchaudio + +import gradio as gr +import numpy as np +import requests +import soundfile as sf + +from omni_speech.conversation import default_conversation, conv_templates +from omni_speech.constants import LOGDIR +from omni_speech.utils import build_logger, server_error_msg +from fairseq.models.text_to_speech.vocoder import CodeHiFiGANVocoder + + +logger = build_logger("gradio_web_server", "gradio_web_server.log") + +vocoder = None + +headers = {"User-Agent": "LLaMA-Omni Client"} + +no_change_btn = gr.Button() +enable_btn = gr.Button(interactive=True) +disable_btn = gr.Button(interactive=False) + + +def get_conv_log_filename(): + t = datetime.datetime.now() + name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") + return name + + +def get_model_list(): + ret = requests.post(args.controller_url + "/refresh_all_workers") + assert ret.status_code == 200 + ret = requests.post(args.controller_url + "/list_models") + models = ret.json()["models"] + logger.info(f"Models: {models}") + return models + + +get_window_url_params = """ +function() { + const params = new URLSearchParams(window.location.search); + url_params = Object.fromEntries(params); + console.log(url_params); + return url_params; + } +""" + + +def load_demo(url_params, request: gr.Request): + logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") + + dropdown_update = gr.Dropdown(visible=True) + if "model" in url_params: + model = url_params["model"] + if model in models: + dropdown_update = gr.Dropdown(value=model, visible=True) + + state = default_conversation.copy() + return state, dropdown_update + + +def load_demo_refresh_model_list(request: gr.Request): + logger.info(f"load_demo. ip: {request.client.host}") + models = get_model_list() + state = default_conversation.copy() + dropdown_update = gr.Dropdown( + choices=models, + value=models[0] if len(models) > 0 else "" + ) + return state, dropdown_update + + +def clear_history(request: gr.Request): + logger.info(f"clear_history. ip: {request.client.host}") + state = default_conversation.copy() + return (state, None, "", "", None) + + +def add_speech(state, speech, request: gr.Request): + text = "Please directly answer the questions in the user's speech." + text = '\n' + text + text = (text, speech) + state = default_conversation.copy() + state.append_message(state.roles[0], text) + state.append_message(state.roles[1], None) + state.skip_next = False + return (state) + + +def http_bot(state, model_selector, temperature, top_p, max_new_tokens, chunk_size, request: gr.Request): + logger.info(f"http_bot. ip: {request.client.host}") + start_tstamp = time.time() + model_name = model_selector + + if state.skip_next: + # This generate call is skipped due to invalid inputs + yield (state, "", "", None) + return + + if len(state.messages) == state.offset + 2: + # First round of conversation + template_name = "llama_3" + new_state = conv_templates[template_name].copy() + new_state.append_message(new_state.roles[0], state.messages[-2][1]) + new_state.append_message(new_state.roles[1], None) + state = new_state + + # Query worker address + controller_url = args.controller_url + ret = requests.post(controller_url + "/get_worker_address", + json={"model": model_name}) + worker_addr = ret.json()["address"] + logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}") + + # No available worker + if worker_addr == "": + state.messages[-1][-1] = server_error_msg + yield (state, "", "", None) + return + + # Construct prompt + prompt = state.get_prompt() + + sr, audio = state.messages[0][1][1] + resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000) + audio = torch.tensor(audio.astype(np.float32)).unsqueeze(0) + audio = resampler(audio).squeeze(0).numpy() + audio /= 32768.0 + audio = audio.tolist() + # Make requests + pload = { + "model": model_name, + "prompt": prompt, + "temperature": float(temperature), + "top_p": float(top_p), + "max_new_tokens": min(int(max_new_tokens), 1500), + "stop": state.sep2, + "audio": audio, + } + + yield (state, "", "", None) + + cur_dir = os.path.dirname(os.path.abspath(__file__)) + + try: + # Stream output + response = requests.post(worker_addr + "/worker_generate_stream", + headers=headers, json=pload, stream=True, timeout=10) + num_generated_units = 0 + wav_list = [] + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode()) + if data["error_code"] == 0: + output = data["text"][len(prompt):].strip() + output_unit = list(map(int, data["unit"].strip().split())) + state.messages[-1][-1] = (output, data["unit"].strip()) + + # vocoder + new_units = output_unit[num_generated_units:] + if len(new_units) >= chunk_size: + num_generated_units = len(output_unit) + x = {"code": torch.LongTensor(new_units).view(1, -1).cuda()} + wav = vocoder(x, True) + wav_list.append(wav.detach().cpu().numpy()) + + if len(wav_list) > 0: + wav_full = np.concatenate(wav_list) + return_value = (16000, wav_full) + else: + return_value = None + + yield (state, state.messages[-1][-1][0], state.messages[-1][-1][1], return_value) + else: + output = data["text"] + f" (error_code: {data['error_code']})" + state.messages[-1][-1] = output + yield (state, "", "", None) + return + time.sleep(0.03) + except requests.exceptions.RequestException as e: + state.messages[-1][-1] = server_error_msg + yield (state, "", "", None) + return + + if num_generated_units < len(output_unit): + new_units = output_unit[num_generated_units:] + num_generated_units = len(output_unit) + x = { + "code": torch.LongTensor(new_units).view(1, -1).cuda() + } + wav = vocoder(x, True) + wav_list.append(wav.detach().cpu().numpy()) + + if len(wav_list) > 0: + wav_full = np.concatenate(wav_list) + return_value = (16000, wav_full) + else: + return_value = None + + yield (state, state.messages[-1][-1][0], state.messages[-1][-1][1], return_value) + + finish_tstamp = time.time() + logger.info(f"{output}") + logger.info(f"{output_unit}") + + +title_markdown = (""" +# 🎧 LLaMA-Omni: Seamless Speech Interaction with Large Language Models +""") + +block_css = """ + +#buttons button { + min-width: min(120px,100%); +} + +""" + +def build_demo(embed_mode, vocoder, cur_dir=None, concurrency_count=10): + with gr.Blocks(title="LLaMA-Omni Speech Chatbot", theme=gr.themes.Default(), css=block_css) as demo: + state = gr.State() + + if not embed_mode: + gr.Markdown(title_markdown) + + with gr.Row(elem_id="model_selector_row"): + model_selector = gr.Dropdown( + choices=models, + value=models[0] if len(models) > 0 else "", + interactive=True, + show_label=False, + container=False) + + with gr.Row(): + audio_input_box = gr.Audio(sources=["upload", "microphone"], label="Speech Input") + with gr.Accordion("Parameters", open=True) as parameter_row: + temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True, label="Temperature",) + top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",) + max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max Output Tokens",) + chunk_size = gr.Slider(minimum=10, maximum=500, value=40, step=10, interactive=True, label="Chunk Size",) + + if cur_dir is None: + cur_dir = os.path.dirname(os.path.abspath(__file__)) + gr.Examples(examples=[ + [f"{cur_dir}/examples/vicuna_1.wav"], + [f"{cur_dir}/examples/vicuna_2.wav"], + [f"{cur_dir}/examples/vicuna_3.wav"], + [f"{cur_dir}/examples/vicuna_4.wav"], + [f"{cur_dir}/examples/vicuna_5.wav"], + [f"{cur_dir}/examples/helpful_base_1.wav"], + [f"{cur_dir}/examples/helpful_base_2.wav"], + [f"{cur_dir}/examples/helpful_base_3.wav"], + [f"{cur_dir}/examples/helpful_base_4.wav"], + [f"{cur_dir}/examples/helpful_base_5.wav"], + ], inputs=[audio_input_box]) + + with gr.Row(): + submit_btn = gr.Button(value="Send", variant="primary") + clear_btn = gr.Button(value="Clear") + + text_output_box = gr.Textbox(label="Text Output", type="text") + unit_output_box = gr.Textbox(label="Unit Output", type="text") + audio_output_box = gr.Audio(label="Speech Output") + + url_params = gr.JSON(visible=False) + + submit_btn.click( + add_speech, + [state, audio_input_box], + [state] + ).then( + http_bot, + [state, model_selector, temperature, top_p, max_output_tokens, chunk_size], + [state, text_output_box, unit_output_box, audio_output_box], + concurrency_limit=concurrency_count + ) + + clear_btn.click( + clear_history, + None, + [state, audio_input_box, text_output_box, unit_output_box, audio_output_box], + queue=False + ) + + if args.model_list_mode == "once": + demo.load( + load_demo, + [url_params], + [state, model_selector], + js=get_window_url_params + ) + elif args.model_list_mode == "reload": + demo.load( + load_demo_refresh_model_list, + None, + [state, model_selector], + queue=False + ) + else: + raise ValueError(f"Unknown model list mode: {args.model_list_mode}") + + return demo + + +def build_vocoder(args): + global vocoder + if args.vocoder is None: + return None + with open(args.vocoder_cfg) as f: + vocoder_cfg = json.load(f) + vocoder = CodeHiFiGANVocoder(args.vocoder, vocoder_cfg).cuda() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int) + parser.add_argument("--controller-url", type=str, default="http://localhost:21001") + parser.add_argument("--concurrency-count", type=int, default=16) + parser.add_argument("--model-list-mode", type=str, default="once", + choices=["once", "reload"]) + parser.add_argument("--share", action="store_true") + parser.add_argument("--moderate", action="store_true") + parser.add_argument("--embed", action="store_true") + parser.add_argument("--vocoder", type=str) + parser.add_argument("--vocoder-cfg", type=str) + args = parser.parse_args() + logger.info(f"args: {args}") + + models = get_model_list() + build_vocoder(args) + + logger.info(args) + demo = build_demo(args.embed, vocoder, concurrency_count=args.concurrency_count) + demo.queue( + api_open=False + ).launch( + server_name=args.host, + server_port=args.port, + share=args.share + ) \ No newline at end of file diff --git a/ola/serve/model_worker.py b/ola/serve/model_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..e41f98b9e6400cc59846516da1e159aa599eadf2 --- /dev/null +++ b/ola/serve/model_worker.py @@ -0,0 +1,292 @@ +""" +A model worker executes the model. +""" +import argparse +import asyncio +import json +import time +import threading +import uuid + +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse +import requests +import torch +import uvicorn +import whisper +import numpy as np +from functools import partial + +from transformers import PreTrainedTokenizer + +from omni_speech.constants import WORKER_HEART_BEAT_INTERVAL +from omni_speech.utils import (build_logger, server_error_msg, + pretty_print_semaphore) +from omni_speech.model.builder import load_pretrained_model +from omni_speech.constants import SPEECH_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN +from omni_speech.datasets.preprocess import tokenizer_speech_token +from transformers import TextIteratorStreamer +from threading import Thread + + +GB = 1 << 30 + +worker_id = str(uuid.uuid4())[:6] +logger = build_logger("model_worker", f"model_worker_{worker_id}.log") +global_counter = 0 + +model_semaphore = None + + +def heart_beat_worker(controller): + + while True: + time.sleep(WORKER_HEART_BEAT_INTERVAL) + controller.send_heart_beat() + + +def load_speech(audio, input_type, mel_size, speech_normalize): + speech = np.array(audio, dtype=np.float32) + if input_type == "raw": + speech = torch.from_numpy(speech) + if speech_normalize: + speech = torch.nn.functional.layer_norm(speech, speech.shape) + elif input_type == "mel": + speech = whisper.pad_or_trim(speech) + speech = whisper.log_mel_spectrogram(speech, n_mels=mel_size).permute(1, 0) + return speech + + +def build_unit_tokenizer(vocab_size): + import os + from transformers import BertTokenizer + with open("unit_vocab.txt", "w") as f: + for i in range(vocab_size + 1): + f.write(str(i) + "\n") + tokenizer = BertTokenizer(vocab_file="unit_vocab.txt") + os.remove("unit_vocab.txt") + return tokenizer + + +class ModelWorker: + def __init__(self, controller_addr, worker_addr, + worker_id, no_register, + model_path, model_base, model_name, + load_8bit, load_4bit, device, input_type, mel_size, s2s, is_lora, use_flash_attn=False): + self.controller_addr = controller_addr + self.worker_addr = worker_addr + self.worker_id = worker_id + self.device = device + self.model_name = model_name + self.input_type = input_type + self.mel_size = mel_size + self.tokenizer, self.model, self.context_len = load_pretrained_model( + model_path, model_base, is_lora=is_lora, s2s=s2s, load_8bit=load_8bit, load_4bit=load_4bit, device=self.device, use_flash_attn=use_flash_attn) + self.unit_tokenizer = build_unit_tokenizer(self.model.config.unit_vocab_size) + + if not no_register: + self.register_to_controller() + self.heart_beat_thread = threading.Thread( + target=heart_beat_worker, args=(self,), daemon=True) + self.heart_beat_thread.start() + + def register_to_controller(self): + logger.info("Register to controller") + + url = self.controller_addr + "/register_worker" + data = { + "worker_name": self.worker_addr, + "check_heart_beat": True, + "worker_status": self.get_status() + } + r = requests.post(url, json=data) + assert r.status_code == 200 + + def send_heart_beat(self): + logger.info(f"Send heart beat. Models: {[self.model_name]}. " + f"Semaphore: {pretty_print_semaphore(model_semaphore)}. " + f"global_counter: {global_counter}") + + url = self.controller_addr + "/receive_heart_beat" + + while True: + try: + ret = requests.post(url, json={ + "worker_name": self.worker_addr, + "queue_length": self.get_queue_length()}, timeout=5) + exist = ret.json()["exist"] + break + except requests.exceptions.RequestException as e: + logger.error(f"heart beat error: {e}") + time.sleep(5) + + if not exist: + self.register_to_controller() + + def get_queue_length(self): + if model_semaphore is None: + return 0 + else: + return args.limit_model_concurrency - model_semaphore._value + (len( + model_semaphore._waiters) if model_semaphore._waiters is not None else 0) + + def get_status(self): + return { + "model_names": [self.model_name], + "speed": 1, + "queue_length": self.get_queue_length(), + } + + @torch.inference_mode() + def generate_stream(self, params): + tokenizer, model = self.tokenizer, self.model + + prompt = params["prompt"] + ori_prompt = prompt + audio = params.get("audio", None) + if audio is not None and len(audio) > 0: + speech = load_speech(audio, self.input_type, self.mel_size, self.model.config.speech_normalize) + speech_length = torch.LongTensor([speech.shape[0]]).unsqueeze(0).to(self.device) + speech_tensor = speech.unsqueeze(0).to(self.device, dtype=torch.float16) + speech_args = {"speech": speech_tensor, "speech_lengths": speech_length} + else: + speech = None + speech_args = {} + + temperature = float(params.get("temperature", 1.0)) + top_p = float(params.get("top_p", 1.0)) + max_context_length = getattr(model.config, 'max_position_embeddings', 2048) + max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024) + stop_str = params.get("stop", None) + do_sample = True if temperature > 0.001 else False + + input_ids = tokenizer_speech_token(prompt, tokenizer, return_tensors='pt').unsqueeze(0).to(self.device) + streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15) + streamer_unit = TextIteratorStreamer(self.unit_tokenizer, skip_prompt=False, skip_special_tokens=True, timeout=15) + + # max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens) + + if max_new_tokens < 1: + yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0" + return + + thread = Thread(target=model.generate, kwargs=dict( + inputs=input_ids, + do_sample=do_sample, + temperature=temperature, + top_p=top_p, + max_new_tokens=max_new_tokens, + streamer=streamer, + streamer_unit=streamer_unit, + streaming_unit_gen=True, + use_cache=True, + **speech_args + )) + thread.start() + + generated_text = ori_prompt + for new_text in streamer: + generated_text += new_text + generated_unit = " ".join(map(str, streamer_unit.token_cache)) + if generated_text.endswith(stop_str): + generated_text = generated_text[:-len(stop_str)] + yield json.dumps({"text": generated_text, "unit": generated_unit, "error_code": 0}).encode() + b"\0" + + def generate_stream_gate(self, params): + try: + for x in self.generate_stream(params): + yield x + except ValueError as e: + print("Caught ValueError:", e) + ret = { + "text": server_error_msg, + "error_code": 1, + } + yield json.dumps(ret).encode() + b"\0" + except torch.cuda.CudaError as e: + print("Caught torch.cuda.CudaError:", e) + ret = { + "text": server_error_msg, + "error_code": 1, + } + yield json.dumps(ret).encode() + b"\0" + except Exception as e: + print("Caught Unknown Error", e) + ret = { + "text": server_error_msg, + "error_code": 1, + } + yield json.dumps(ret).encode() + b"\0" + + +app = FastAPI() + + +def release_model_semaphore(fn=None): + model_semaphore.release() + if fn is not None: + fn() + + +@app.post("/worker_generate_stream") +async def generate_stream(request: Request): + global model_semaphore, global_counter + global_counter += 1 + params = await request.json() + + if model_semaphore is None: + model_semaphore = asyncio.Semaphore(args.limit_model_concurrency) + await model_semaphore.acquire() + worker.send_heart_beat() + generator = worker.generate_stream_gate(params) + background_tasks = BackgroundTasks() + background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat)) + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_get_status") +async def get_status(request: Request): + return worker.get_status() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, + default="http://localhost:21002") + parser.add_argument("--controller-address", type=str, + default="http://localhost:21001") + parser.add_argument("--model-path", type=str, default="facebook/opt-350m") + parser.add_argument("--model-base", type=str, default=None) + parser.add_argument("--model-name", type=str) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--limit-model-concurrency", type=int, default=5) + parser.add_argument("--stream-interval", type=int, default=1) + parser.add_argument("--no-register", action="store_true") + parser.add_argument("--load-8bit", action="store_true") + parser.add_argument("--load-4bit", action="store_true") + parser.add_argument("--use-flash-attn", action="store_true") + parser.add_argument("--input-type", type=str, default="mel") + parser.add_argument("--mel-size", type=int, default=128) + parser.add_argument("--s2s", action="store_true", default=False) + parser.add_argument("--is-lora", action="store_true", default=False) + args = parser.parse_args() + logger.info(f"args: {args}") + + worker = ModelWorker(args.controller_address, + args.worker_address, + worker_id, + args.no_register, + args.model_path, + args.model_base, + args.model_name, + args.load_8bit, + args.load_4bit, + args.device, + args.input_type, + args.mel_size, + args.s2s, + args.is_lora, + use_flash_attn=args.use_flash_attn) + uvicorn.run(app, host=args.host, port=args.port, log_level="info") \ No newline at end of file diff --git a/ola/train/__pycache__/ola_trainer.cpython-312.pyc b/ola/train/__pycache__/ola_trainer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48601e938ea53280f3295ab70a365ee216e81c78 Binary files /dev/null and b/ola/train/__pycache__/ola_trainer.cpython-312.pyc differ diff --git a/ola/train/internvl_finetune.py b/ola/train/internvl_finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..2fc6b0b1ea0e1162aee5e4900612020bddeefca1 --- /dev/null +++ b/ola/train/internvl_finetune.py @@ -0,0 +1,1072 @@ +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2024 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +import logging +import math +import os +import random +import sys +import traceback +import warnings +from copy import deepcopy +from dataclasses import dataclass, field +from functools import partial +from typing import Dict, Literal, Optional + +import numpy as np + +try: + import orjson as json +except: + import json + +import torch +import torch.distributed as dist +import transformers +from internvl.dist_utils import init_dist +from internvl.model.internlm2.modeling_internlm2 import InternLM2ForCausalLM +from internvl.model.internvl_chat import (InternVisionConfig, + InternVisionModel, + InternVLChatConfig, + InternVLChatModel) +from internvl.patch import (concat_pad_data_collator, + replace_internlm2_attention_class, + replace_llama_attention_class, + replace_llama_rmsnorm_with_fused_rmsnorm, + replace_phi3_attention_class, + replace_qwen2_attention_class, + replace_train_dataloader, replace_train_sampler) +from internvl.train.constants import (BOX_END_TOKEN, BOX_START_TOKEN, + IMG_CONTEXT_TOKEN, IMG_END_TOKEN, + IMG_START_TOKEN, QUAD_END_TOKEN, + QUAD_START_TOKEN, REF_END_TOKEN, + REF_START_TOKEN) +from internvl.train.dataset import (ConcatDataset, TCSLoader, + WeightedConcatDataset, build_transform, + check_conversations_repetition, + dynamic_preprocess, preprocess, + preprocess_internlm, + preprocess_internvl2_5, preprocess_mpt, + preprocess_phi3) +from internvl.train.dataset_packed import PackedDataset, packed_collate_fn +from PIL import Image, ImageFile, PngImagePlugin, UnidentifiedImageError +from torch.utils.data import Dataset +from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, + HfArgumentParser, Trainer, TrainingArguments, + set_seed) +from transformers.trainer_utils import get_last_checkpoint +from transformers.utils.logging import (enable_default_handler, + enable_explicit_format, set_verbosity) + +# Try to import petrel_client for image loading, fallback to PIL if unavailable +try: + from petrel_client.client import Client + from petrel_client.common.config import Config + has_tcs_loader = True +except ImportError as E: + print('petrel_client is not installed. Using PIL to load images.') + has_tcs_loader = False + +# Set constants for image processing and logging +IGNORE_INDEX = -100 +Image.MAX_IMAGE_PIXELS = None +ImageFile.LOAD_TRUNCATED_IMAGES = True +MaximumDecompressedSize = 1024 +MegaByte = 2 ** 20 +PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte + +warnings.filterwarnings('ignore') +logger = logging.getLogger(__name__) + +os.environ['TOKENIZERS_PARALLELISM'] = 'true' + + +@dataclass +class ModelArguments: + """ + Arguments for specifying model, tokenizer, and configurations. + """ + model_name_or_path: Optional[str] = field( + default=None, + metadata={'help': 'Path to a pretrained model (local or from huggingface.co/models).'} + ) + vision_path: Optional[str] = field( + default=None, + metadata={'help': 'Path to a pretrained model (local or from huggingface.co/models).'} + ) + llm_path: Optional[str] = field( + default=None, + metadata={'help': 'Path to a pretrained model (local or from huggingface.co/models).'} + ) + mlp_path: Optional[str] = field( + default=None, + metadata={'help': 'Path to a pretrained model (local or from huggingface.co/models).'} + ) + freeze_llm: bool = field( + default=False, + metadata={'help': 'Set to True to freeze the LLM. Default is False.'}, + ) + freeze_backbone: bool = field( + default=False, + metadata={'help': 'Set to True to freeze the ViT. Default is False.'}, + ) + freeze_mlp: bool = field( + default=False, + metadata={'help': 'Set to True to freeze the MLP. Default is False.'}, + ) + unfreeze_vit_layers: int = field( + default=0, + metadata={'help': 'Specify the number of ViT layers to unfreeze. Default is 0.'}, + ) + vision_select_layer: int = field( + default=-1, + metadata={'help': 'Specify the layer of ViT feature map to use. Default is -1 for the last layer.'}, + ) + use_backbone_lora: int = field( + default=0, + metadata={'help': 'Set the LoRA adapter rank for the ViT. Default is 0.'} + ) + use_llm_lora: int = field( + default=0, + metadata={'help': 'Set the LoRA adapter rank for the LLM. Default is 0.'} + ) + unfreeze_lm_head: bool = field( + default=False, + metadata={'help': 'Set to True to unfreeze the head of LLM. Default is False.'}, + ) + grad_checkpoint: bool = field( + default=True, + metadata={'help': 'Set to True to use gradient checkpointing. Default is True.'}, + ) + drop_path_rate: float = field( + default=0.0, + metadata={'help': 'Set the drop path rate for the ViT. Default is 0.'}, + ) + ps_version: Literal['v1', 'v2'] = field( + default='v2', + metadata={'help': 'Specify the version of pixel shuffle implementation. Default is v2.'} + ) + use_fast_tokenizer: bool = field( + default=False, + metadata={'help': 'Set to True to use the fast mode of the tokenizer.'} + ) + use_liger: bool = field( + default=False, + metadata={'help': 'Set to True to use the liger kernel.'} + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments for specifying data input for training and evaluation. + """ + max_seq_length: int = field( + default=8192, + metadata={ + 'help': ( + 'The maximum total input sequence length after tokenization. Sequences longer ' + 'than this will be truncated, sequences shorter will be padded.' + ) + }, + ) + force_image_size: int = field( + default=448, + metadata={'help': 'Set the desired size for the image. Default is 448.'}, + ) + down_sample_ratio: float = field( + default=0.5, + metadata={'help': 'Set the desired down-sampling ratio for the image. Default is 0.5.'}, + ) + pad2square: bool = field( + default=False, + metadata={'help': 'Pad the image to a square shape if set to True. Default is False.'}, + ) + conv_style: str = field( + default='internlm2-chat', metadata={'help': 'Prompt style for a conversation.'} + ) + meta_path: str = field( + default=None, + metadata={'help': 'The path of the meta file of datasets.'}, + ) + use_data_resampling: bool = field( + default=False, + metadata={'help': 'Set to True to use data resampling. Default is False.'}, + ) + dynamic_image_size: bool = field( + default=False, + metadata={'help': 'Set to True to use dynamic high resolution strategy. Default is False.'}, + ) + use_thumbnail: bool = field( + default=False, + metadata={'help': 'Set to True to add a thumbnail image. Default is False.'}, + ) + min_dynamic_patch: int = field( + default=1, + metadata={'help': 'The minimum number of dynamic patches. Default is 1.'}, + ) + max_dynamic_patch: int = field( + default=12, + metadata={'help': 'The maximum number of dynamic patches. Default is 12.'}, + ) + min_num_frame: int = field( + default=8, + metadata={'help': 'The minimum number of frames for video data. Default is 8.'}, + ) + max_num_frame: int = field( + default=32, + metadata={'help': 'The maximum number of frames for video data. Default is 32.'}, + ) + normalize_type: Literal['imagenet', 'clip', 'siglip'] = field( + default='imagenet', + metadata={'help': 'The normalization type for the image. Default is imagenet.'}, + ) + use_packed_ds: bool = field( + default=False, + metadata={'help': 'Whether to use packed dataset for efficient training. Default is False.'}, + ) + num_images_expected: int = field( + default=40, + metadata={'help': 'The maximum number of images per packed sample. Default is 40.'}, + ) + max_packed_tokens: int = field( + default=8192, + metadata={'help': 'The required token length of per packed sample. Default is 8192.'}, + ) + max_buffer_size: int = field( + default=20, + metadata={'help': 'The buffer size of the packed dataset. Default is 20.'}, + ) + log_freq: int = field( + default=1000, + metadata={'help': 'The log frequency of the packed dataset. Default is 1000.'}, + ) + strict_mode: bool = field( + default=True, + metadata={'help': 'Whether to pad the number of images to satisfy num_images_expected. Default is True.'}, + ) + replacement: bool = field( + default=False, + metadata={'help': 'Whether to restart the dataset after it is exhausted. Default is False.'}, + ) + allow_overflow: bool = field( + default=False, + metadata={'help': 'Whether to drop the sample over the specified max_packed_tokens. Default is False.'}, + ) + loss_reduction: str = field( + default='token', + metadata={'help': 'Loss reduction method. Default is token.'}, + ) + loss_reduction_all_gather: bool = field( + default=False, + metadata={'help': 'Whether to gather all during loss reduction. Default is False.'}, + ) + + +class LazySupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__( + self, + template_name, + meta, + tokenizer, + tcs_loader, + ds_name, + num_image_token, + image_size=448, + is_train=True, + pad2square=False, + group_by_length=False, + dynamic_image_size=False, + use_thumbnail=False, + min_dynamic_patch=1, + max_dynamic_patch=12, + min_num_frame=8, # for video data + max_num_frame=32, # for video data + sampling_method='rand', # for video data + repeat_time=1, + normalize_type='imagenet', + # hyperparameters for packed training + use_packed_ds=False, + data_rank=0, + data_world_size=1, + distributed_mode=False, + force_shuffle=False, + random_seed=0, + ): + super(LazySupervisedDataset, self).__init__() + self.ds_name = ds_name + self.tokenizer = tokenizer + self.template_name = template_name + self.num_image_token = num_image_token + logger.info(f'[Dataset] num_image_token: {num_image_token}') + logger.info(f'[Dataset] dynamic_image_size: {dynamic_image_size}') + logger.info(f'[Dataset] use_thumbnail: {use_thumbnail}') + logger.info(f'[Dataset] min_dynamic_patch: {min_dynamic_patch}, max_dynamic_patch: {max_dynamic_patch}') + + self.image_size = image_size + self.is_train = is_train + self.pad2square = pad2square + self.max_num_frame = max_num_frame + self.min_num_frame = min_num_frame + self.sampling_method = sampling_method + + # hyperparameters for distributed training + self.use_packed_ds = use_packed_ds + self.data_rank = data_rank + self.data_world_size = data_world_size + self.worker_id = None + self.worker_state_key = None + self.worker_distributed = False + self.distributed_mode = distributed_mode + # hyperparameters for packed dataset + self.dataset_type = 'pair' + self.max_num_images = 1 + self.max_tokens = tokenizer.model_max_length + self.force_shuffle = force_shuffle + # TODO: quick resume + self._state_dict = {} + + logger.info('Formatting inputs...Skip in lazy mode') + assert meta['annotation'].endswith('jsonl'), f'annotation must be jsonl, but got {meta["annotation"]}' + + with open(meta['annotation'], 'r') as f: + self.raw_data = f.readlines() + if repeat_time < 1: + # If repeat_time is less than 1, select a portion of the data + self.raw_data = self.raw_data[:int(len(self.raw_data) * repeat_time)] + if repeat_time > 1: + assert isinstance(repeat_time, int) + # Repeat the list if repeat_time is greater than 1 + self.raw_data = self.raw_data * repeat_time + + self.rng = np.random.default_rng(seed=random_seed) + if self.force_shuffle: + self.rng.shuffle(self.raw_data) + + self.root = meta['root'] + self.cached_data_dict = {} + self.tcs_loader = tcs_loader + self.group_by_length = group_by_length + self.dynamic_image_size = dynamic_image_size + self.use_thumbnail = use_thumbnail + self.min_dynamic_patch = min_dynamic_patch + self.max_dynamic_patch = max_dynamic_patch + self.normalize_type = normalize_type + + # If the precomputed length does not exist, roughly estimate the length of + # each sample to improve the efficiency of group_by_length. + if self.group_by_length: + self.conv2length = {} # Using a dictionary to speed up token length calculation + self.length = [] + for data_item in self.raw_data: + data_item = json.loads(data_item) + if 'length' in data_item: + token_length = data_item['length'] # Use precomputed length if available + else: + # Compute token length using the tokenizer + conversations = '\n'.join([temp['value'] for temp in data_item['conversations']]) + str_length = len(conversations) + if str_length not in self.conv2length: + token_length = tokenizer( + conversations, return_tensors='pt', padding=False, truncation=False, + ).input_ids.size(1) + self.conv2length[str_length] = token_length + num_image_token * ( + max_dynamic_patch + use_thumbnail) + else: + token_length = self.conv2length[str_length] + self.length.append(token_length) + + def __len__(self): + return len(self.raw_data) + + def get_preprocess_function(self): + # Select the appropriate preprocessing function based on the template name + if self.template_name == 'Hermes-2': + preprocess_function = preprocess_mpt + elif self.template_name == 'internlm2-chat': + preprocess_function = preprocess_internlm + elif self.template_name == 'phi3-chat': + preprocess_function = preprocess_phi3 + elif self.template_name == 'internvl2_5': + preprocess_function = preprocess_internvl2_5 + else: + preprocess_function = preprocess + return preprocess_function + + def load_image(self, image_path): + # Load the image using tcs_loader if available, otherwise use PIL + if self.tcs_loader is not None and 's3://' in image_path: + return self.tcs_loader(image_path) + return Image.open(image_path).convert('RGB') + + def get_image_path(self, image_path): + if image_path.startswith('s3://'): # for ceph + image_path = self.root + image_path + else: # for local image + image_path = os.path.join(self.root, image_path) + return image_path + + def get_transform(self): + # Build transformation function + transform = build_transform(is_train=self.is_train, input_size=self.image_size, + pad2square=self.pad2square, normalize_type=self.normalize_type) + return transform + + def multi_modal_get_item(self, data_item): + # Build transformation function + transform = self.get_transform() + + # Ensure the first conversation contains an image placeholder + if '' not in data_item['conversations'][0]['value']: + data_item['conversations'][0]['value'] = '\n' + data_item['conversations'][0]['value'] + + # Merge the image path + image_path = self.get_image_path(data_item['image']) + + # Load the image using tcs_loader if available, otherwise use PIL + image = self.load_image(image_path) + + if self.dynamic_image_size: # If dynamic image size is enabled, preprocess the image dynamically + images = dynamic_preprocess(image, min_num=self.min_dynamic_patch, max_num=self.max_dynamic_patch, + image_size=self.image_size, use_thumbnail=self.use_thumbnail) + else: # Otherwise, use the original image as a single patch + images = [image] + + # Apply the transformation to each image and stack the results into a tensor + pixel_values = [transform(image) for image in images] + pixel_values = torch.stack(pixel_values) + + # Ensure that there is only one patch if dynamic image size is not enabled + num_patches = pixel_values.size(0) + if not self.dynamic_image_size: + assert num_patches == 1, f'The number of patches should be 1, but got {num_patches}.' + + # Select the appropriate preprocessing function based on the template name + preprocess_function = self.get_preprocess_function() + + # Preprocess the conversations and generate the return dictionary + ret = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])], + self.tokenizer, [self.num_image_token * num_patches], + group_by_length=self.group_by_length, + use_packed_ds=self.use_packed_ds, ds_name=self.ds_name) + + # Calculate position_ids for packed dataset + position_ids = ret['attention_mask'].long().cumsum(-1) - 1 + position_ids.masked_fill_(ret['attention_mask'] == 0, 1) + image_end_token_id = self.tokenizer.convert_tokens_to_ids(IMG_END_TOKEN) + assert (ret['input_ids'][0] == image_end_token_id).sum() == 1, f'image tokens are truncated, this dataset is {self.ds_name}' + + # Create the final return dictionary + ret = dict( + input_ids=ret['input_ids'][0], + labels=ret['labels'][0], + attention_mask=ret['attention_mask'][0], + position_ids=position_ids[0], + pixel_values=pixel_values, + image_flags=torch.tensor([1] * num_patches, dtype=torch.long) + ) + return ret + + def multi_modal_multi_image_get_item(self, data_item): + # Build transformation function + transform = self.get_transform() + + images, num_tiles = [], [] + num_image = len(data_item['image']) + for image_path in data_item['image']: + # Merge the image path + image_path = self.get_image_path(image_path) + # Load the image using tcs_loader if available, otherwise use PIL + image = self.load_image(image_path) + if self.dynamic_image_size: # If dynamic image size is enabled, preprocess the image dynamically + image = dynamic_preprocess(image, min_num=self.min_dynamic_patch, + max_num=max(1, self.max_dynamic_patch // num_image), + image_size=self.image_size, use_thumbnail=self.use_thumbnail) + images += image + num_tiles.append(len(image)) + else: # Otherwise, use the original image as a single patch + images.append(image) + num_tiles.append(1) + pixel_values = [transform(image) for image in images] + pixel_values = torch.stack(pixel_values) + num_patches = pixel_values.size(0) + + # Select the appropriate preprocessing function based on the template name + preprocess_function = self.get_preprocess_function() + + # Preprocess the conversations and generate the return dictionary + num_image_tokens = [self.num_image_token * num_tile for num_tile in num_tiles] + ret = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])], + self.tokenizer, num_image_tokens, group_by_length=self.group_by_length, + use_packed_ds=self.use_packed_ds, ds_name=self.ds_name, num_image=num_image) + + # Calculate position_ids for packed dataset + position_ids = ret['attention_mask'].long().cumsum(-1) - 1 + position_ids.masked_fill_(ret['attention_mask'] == 0, 1) + image_end_token_id = self.tokenizer.convert_tokens_to_ids(IMG_END_TOKEN) + assert (ret['input_ids'][0] == image_end_token_id).sum() == num_image, f'image tokens are truncated, this dataset is {self.ds_name}' + + # Create the final return dictionary + ret = dict( + input_ids=ret['input_ids'][0], + labels=ret['labels'][0], + attention_mask=ret['attention_mask'][0], + position_ids=position_ids[0], + pixel_values=pixel_values, + image_flags=torch.tensor([1] * num_patches, dtype=torch.long) + ) + return ret + + def video_get_item(self, data_item): + # Build transformation function + transform = self.get_transform() + + # Ensure the first conversation contains a video placeholder + if '