internvl_ola / intern_ola_load.py
jjw0126's picture
Upload files
62d115a verified
import os
import sys
from pathlib import Path
import math
import numpy as np
import torch
import torchvision.transforms as T
from decord import VideoReader, cpu # 暂时注释掉,专注于语音功能测试
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoModel, AutoTokenizer
from contextlib import redirect_stdout
import io
import librosa
import whisper
import moviepy as mp
import torch
from transformers import AutoTokenizer, AutoConfig, AutoModel
# pure text
# image + text
# video + text
# audio + text
# video + audio + text
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
# Add the parent directory of 'internvl-ola' to the Python path
sys.path.append(str(Path(__file__).parent.parent.resolve()))
# 完全禁用 HF 缓存,直接从本地目录加载
# os.environ["HF_HOME"] = "/data1/cxy/plm-v/modeling/cache"
os.environ["TRANSFORMERS_OFFLINE"] = "1" # 强制离线模式
os.environ["HF_DATASETS_OFFLINE"] = "1" # 禁用数据集缓存
# os.environ["WHISPER_CACHE_DIR"] = "/data1/cxy/plm-v/modeling/cache"
# 添加本地路径到 sys.path 以确保导入本地模块
# path = '/data1/cxy/plm-v/modeling/internvl-ola'
# if path not in sys.path:
# sys.path.insert(0, path)
# 完全禁用 HF 缓存,直接从本地目录加载
# os.environ["HF_HOME"] = "/data1/cxy/plm-v/modeling/cache"
os.environ['TRANSFORMERS_OFFLINE'] = '1'
def build_transform(input_size):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
return transform
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
def load_image(image_file, input_size=448, max_num=12):
image = Image.open(image_file).convert('RGB')
transform = build_transform(input_size=input_size)
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values
# 获取当前脚本文件的目录路径
script_dir = os.path.dirname(os.path.abspath(__file__))
tokenizer_path = script_dir
# 指定模型路径
model_path = script_dir
# 加载配置
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
# 如果只有部分权重,可以这样加载:
with torch.device("cuda"):
model = AutoModel.from_config(config, trust_remote_code=True) # 初始化模型,权重随机
# 对于 safetensors 格式的权重文件
from safetensors.torch import load_file
partial_state_dict = load_file(f"{script_dir}/model.safetensors") # 替换为你的部分权重路径
model.load_state_dict(partial_state_dict, strict=False) # strict=False 允许部分加载
# 确保所有模型组件都在CUDA上
model = model.cuda()
model.eval()
# 确保speech_encoder的所有组件都在CUDA上
if hasattr(model, 'speech_encoder') and model.speech_encoder is not None:
model.speech_encoder = model.speech_encoder.cuda()
print("✅ Speech encoder moved to CUDA")
# 确保speech_projector在CUDA上
if hasattr(model, 'speech_projector') and model.speech_projector is not None:
model.speech_projector = model.speech_projector.cuda()
print("✅ Speech projector moved to CUDA")
def print_model_structure(model, max_depth=2, current_depth=0, prefix=""):
"""
打印模型结构,只显示到指定深度
max_depth: 最大深度 (0=只显示根, 1=主要组件, 2=次级组件)
"""
if current_depth == 0:
print(f"{model.__class__.__name__}(")
if current_depth < max_depth:
for name, module in model.named_children():
indent = " " * (current_depth + 1)
print(f"{indent}({name}): {module.__class__.__name__}", end="")
# 如果模块有子模块且没有达到最大深度,递归打印
child_modules = list(module.named_children())
if child_modules and current_depth + 1 < max_depth:
print("(")
print_model_structure(module, max_depth, current_depth + 1, prefix + " ")
print(f"{indent})")
else:
# 如果有子模块但达到最大深度,显示省略号
if child_modules:
print("(...)")
else:
print("()")
if current_depth == 0:
print(")")
def compare_common_modules(model1, model2, model1_name="model", model2_name="internvl_model"):
"""
比较两个模型的共同模块参数是否一致
"""
print(f"\n=== 比较 {model1_name}{model2_name} 的共同模块参数 ===")
# 获取两个模型的模块字典
model1_modules = dict(model1.named_modules())
model2_modules = dict(model2.named_modules())
# 找到共同的模块名
common_module_names = set(model1_modules.keys()) & set(model2_modules.keys())
# 排除根模块(空字符串)
common_module_names.discard('')
print(f"共同模块数量: {len(common_module_names)}")
# 比较参数
identical_params = 0
different_params = 0
missing_params = 0
for module_name in sorted(common_module_names):
module1 = model1_modules[module_name]
module2 = model2_modules[module_name]
# 获取模块的参数
params1 = dict(module1.named_parameters(recurse=False))
params2 = dict(module2.named_parameters(recurse=False))
# 如果模块没有直接参数,跳过
if not params1 and not params2:
continue
print(f"\n模块: {module_name}")
print(f" 类型: {module1.__class__.__name__}")
# 比较共同的参数
common_param_names = set(params1.keys()) & set(params2.keys())
for param_name in common_param_names:
param1 = params1[param_name]
param2 = params2[param_name]
if param1.shape != param2.shape:
print(f" 参数 {param_name}: 形状不同 {param1.shape} vs {param2.shape}")
different_params += 1
elif torch.allclose(param1, param2, rtol=1e-5, atol=1e-8):
print(f" 参数 {param_name}: ✓ 相同")
identical_params += 1
else:
print(f" 参数 {param_name}: ✗ 数值不同")
different_params += 1
# 显示一些统计信息
diff = (param1 - param2).abs()
print(f" 最大差异: {diff.max().item():.6e}")
print(f" 平均差异: {diff.mean().item():.6e}")
# 检查缺失的参数
only_in_model1 = set(params1.keys()) - set(params2.keys())
only_in_model2 = set(params2.keys()) - set(params1.keys())
if only_in_model1:
print(f" 只在 {model1_name} 中: {list(only_in_model1)}")
missing_params += len(only_in_model1)
if only_in_model2:
print(f" 只在 {model2_name} 中: {list(only_in_model2)}")
missing_params += len(only_in_model2)
print(f"\n=== 参数比较总结 ===")
print(f"相同参数: {identical_params}")
print(f"不同参数: {different_params}")
print(f"缺失参数: {missing_params}")
return identical_params, different_params, missing_params
def compare_module_structure(model1, model2, model1_name="model", model2_name="internvl_model"):
"""
比较两个模型的模块结构差异
"""
print(f"\n=== 比较 {model1_name}{model2_name} 的模块结构差异 ===")
# 获取主要组件
model1_children = dict(model1.named_children())
model2_children = dict(model2.named_children())
print(f"{model1_name} 的主要组件: {list(model1_children.keys())}")
print(f"{model2_name} 的主要组件: {list(model2_children.keys())}")
# 找到差异
only_in_model1 = set(model1_children.keys()) - set(model2_children.keys())
only_in_model2 = set(model2_children.keys()) - set(model1_children.keys())
common_modules = set(model1_children.keys()) & set(model2_children.keys())
if only_in_model1:
print(f"\n只在 {model1_name} 中的模块:")
for module_name in only_in_model1:
print(f" - {module_name}: {model1_children[module_name].__class__.__name__}")
if only_in_model2:
print(f"\n只在 {model2_name} 中的模块:")
for module_name in only_in_model2:
print(f" - {module_name}: {model2_children[module_name].__class__.__name__}")
print(f"\n共同模块: {list(common_modules)}")
return common_modules, only_in_model1, only_in_model2
def calculate_model_size(model, model_name="model"):
"""
计算模型的参数数量和大小
"""
print(f"\n=== {model_name} 模型大小分析 ===")
# 计算总参数数量
total_params = 0
trainable_params = 0
# 按模块统计参数
module_stats = {}
for name, module in model.named_children():
module_params = sum(p.numel() for p in module.parameters())
module_trainable = sum(p.numel() for p in module.parameters() if p.requires_grad)
module_stats[name] = {
'total': module_params,
'trainable': module_trainable,
'type': module.__class__.__name__
}
total_params += module_params
trainable_params += module_trainable
# 计算内存占用(近似值)
# 假设大部分参数是 float32 (4 bytes) 或 bfloat16 (2 bytes)
# 这里简化为平均 2.5 bytes per parameter
memory_mb = total_params * 2.5 / (1024 * 1024)
memory_gb = memory_mb / 1024
print(f"总参数数量: {total_params:,}")
print(f"可训练参数: {trainable_params:,}")
print(f"不可训练参数: {total_params - trainable_params:,}")
print(f"估计内存占用: {memory_mb:.2f} MB ({memory_gb:.3f} GB)")
print(f"\n各模块参数统计:")
print("-" * 60)
print(f"{'模块名':<20} {'类型':<25} {'参数数量':<15} {'占比':<10}")
print("-" * 60)
for name, stats in sorted(module_stats.items(), key=lambda x: x[1]['total'], reverse=True):
percentage = (stats['total'] / total_params) * 100 if total_params > 0 else 0
print(f"{name:<20} {stats['type']:<25} {stats['total']:<15,} {percentage:<10.2f}%")
return {
'total_params': total_params,
'trainable_params': trainable_params,
'memory_mb': memory_mb,
'memory_gb': memory_gb,
'module_stats': module_stats
}
def compare_model_sizes(model1, model2, model1_name="model", model2_name="internvl_model"):
"""
比较两个模型的大小
"""
print(f"\n{'='*60}")
print(f"模型大小对比分析")
print(f"{'='*60}")
# 计算两个模型的大小
stats1 = calculate_model_size(model1, model1_name)
stats2 = calculate_model_size(model2, model2_name)
print(f"\n=== 模型大小对比总结 ===")
print("-" * 50)
# 参数数量对比
param_diff = stats1['total_params'] - stats2['total_params']
param_ratio = stats1['total_params'] / stats2['total_params'] if stats2['total_params'] > 0 else 0
print(f"{model1_name} 总参数: {stats1['total_params']:,}")
print(f"{model2_name} 总参数: {stats2['total_params']:,}")
print(f"参数差异: {param_diff:+,} ({param_ratio:.2f}x)")
# 内存占用对比
memory_diff = stats1['memory_gb'] - stats2['memory_gb']
memory_ratio = stats1['memory_gb'] / stats2['memory_gb'] if stats2['memory_gb'] > 0 else 0
print(f"{model1_name} 内存占用: {stats1['memory_gb']:.3f} GB")
print(f"{model2_name} 内存占用: {stats2['memory_gb']:.3f} GB")
print(f"内存差异: {memory_diff:+.3f} GB ({memory_ratio:.2f}x)")
# 分析差异来源
print(f"\n=== 差异来源分析 ===")
if param_diff > 0:
print(f"{model1_name}{model2_name}{param_diff:,} 个参数")
# 找出只在 model1 中的模块
model1_modules = set(stats1['module_stats'].keys())
model2_modules = set(stats2['module_stats'].keys())
only_in_model1 = model1_modules - model2_modules
if only_in_model1:
print(f"额外模块贡献的参数:")
for module_name in only_in_model1:
module_params = stats1['module_stats'][module_name]['total']
percentage = (module_params / param_diff) * 100 if param_diff > 0 else 0
print(f" - {module_name}: {module_params:,} ({percentage:.1f}%)")
return stats1, stats2
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True, use_fast=False)
resource_path = "/data1/cxy/plm-v/modeling/example/"
# set the max number of tiles in `max_num`
pixel_values = load_image(f'{resource_path}image1.jpg', max_num=12).to(torch.bfloat16).cuda()
# breakpoint()
generation_config = dict(max_new_tokens=1024, do_sample=True)
# breakpoint()
# question = 'Hello, who are you?'
# response, history = model.chat(tokenizer, None, question, generation_config, history=None, return_history=True)
# print(f'User: {question}\nAssistant: {response}')
# 多模态推理测试
print("\n" + "="*80)
print("🧪 开始多模态推理测试")
print("="*80)
def test_inference(test_name, question, pixel_values_input=None, speech_input=None, speech_lengths_input=None, num_patches_list=None):
"""统一的推理测试函数"""
print(f"\n{'='*60}")
print(f"🧪 测试: {test_name}")
print(f"📝 问题: {question}")
print(f"{'='*60}")
try:
# 准备参数
chat_kwargs = {
'tokenizer': tokenizer,
'pixel_values': pixel_values_input,
'question': question,
'generation_config': generation_config,
'verbose': True
}
# 如果有视频数据,添加num_patches_list参数
if num_patches_list is not None:
chat_kwargs['num_patches_list'] = num_patches_list
# 如果有speech数据,添加speech参数
if speech_input is not None:
chat_kwargs.update({
'speech': speech_input, # mel 谱图,用于 Whisper
'speech_lengths': speech_lengths_input,
'speech_wav': speech_wavs, # 原始音频波形,用于 BEATs
})
# 执行推理
# breakpoint()
response = model.chat(**chat_kwargs)
print(f"✅ 推理成功!")
print(f"🤖 回复: {response}")
return True, response
except Exception as e:
print(f"❌ 推理失败: {str(e)}")
import traceback
traceback.print_exc()
return False, str(e)
# 测试1: Pure Text (应该正常,使用训练好的InternVL)
success1, response1 = test_inference(
test_name="Pure Text",
question="Hello, who are you? Please introduce yourself briefly.",
pixel_values_input=None,
speech_input=None,
speech_lengths_input=None
)
# 测试2: Text & Image - Visual only (应该正常,使用训练好的InternVL)
success2, response2 = test_inference(
test_name="Text & Image (Visual only)",
question="<image>\nPlease describe this image in detail.",
pixel_values_input=pixel_values,
speech_input=None,
speech_lengths_input=None
)
print("\n" + "="*60)
print("🔄 准备Speech相关测试 (可能输出乱码,因为speech部分未训练)")
print("="*60)
def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
if bound:
start, end = bound[0], bound[1]
else:
start, end = -100000, 100000
start_idx = max(first_idx, round(start * fps))
end_idx = min(round(end * fps), max_frame)
seg_size = float(end_idx - start_idx) / num_segments
frame_indices = np.array([
int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
for idx in range(num_segments)
])
return frame_indices
def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
max_frame = len(vr) - 1
fps = float(vr.get_avg_fps())
pixel_values_list, num_patches_list = [], []
transform = build_transform(input_size=input_size)
frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments)
for frame_index in frame_indices:
img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB')
img = dynamic_preprocess(img, image_size=input_size, use_thumbnail=True, max_num=max_num)
pixel_values = [transform(tile) for tile in img]
pixel_values = torch.stack(pixel_values)
num_patches_list.append(pixel_values.shape[0])
pixel_values_list.append(pixel_values)
pixel_values = torch.cat(pixel_values_list)
return pixel_values, num_patches_list
def load_audio(audio_file_name):
"""
加载音频文件,使用Ola风格的mel谱图预处理
这与原始的Ola load_audio函数保持一致
"""
speech_wav, samplerate = librosa.load(audio_file_name, sr=16000)
if len(speech_wav.shape) > 1:
speech_wav = speech_wav[:, 0]
speech_wav = speech_wav.astype(np.float32)
CHUNK_LIM = 480000
SAMPLE_RATE = 16000
speechs = []
speech_wavs = []
if len(speech_wav) <= CHUNK_LIM:
speech = whisper.pad_or_trim(speech_wav)
speech_wav_chunk = whisper.pad_or_trim(speech_wav)
speechs.append(speech)
speech_wavs.append(torch.from_numpy(speech_wav_chunk).unsqueeze(0))
else:
for i in range(0, len(speech_wav), CHUNK_LIM):
chunk = speech_wav[i : i + CHUNK_LIM]
if len(chunk) < CHUNK_LIM:
chunk = whisper.pad_or_trim(chunk)
speechs.append(chunk)
speech_wavs.append(torch.from_numpy(chunk).unsqueeze(0))
# 生成mel谱图
mels = []
for chunk in speechs:
chunk = whisper.log_mel_spectrogram(chunk, n_mels=128).permute(1, 0).unsqueeze(0)
mels.append(chunk)
mels = torch.cat(mels, dim=0)
speech_wavs = torch.cat(speech_wavs, dim=0)
if mels.shape[0] > 25:
mels = mels[:25]
speech_wavs = speech_wavs[:25]
speech_length = torch.LongTensor([mels.shape[1]] * mels.shape[0])
speech_chunks = torch.LongTensor([mels.shape[0]])
return mels, speech_length, speech_chunks, speech_wavs
def extract_audio(videos_file_path):
my_clip = mp.VideoFileClip(videos_file_path)
return my_clip.audio
# 加载视频数据用于视频测试
print("\n📥 加载视频数据...")
try:
video_path = f'{resource_path}red-panda.mp4'
if os.path.exists(video_path):
video_pixel_values, video_num_patches_list = load_video(video_path, num_segments=8, max_num=1)
video_pixel_values = video_pixel_values.to(torch.bfloat16).cuda()
video_loaded = True
print(f"✅ 视频加载成功:")
print(f" - 视频帧数: {len(video_num_patches_list)}")
print(f" - 视频像素值形状: {video_pixel_values.shape}")
print(f" - 每帧patch数: {video_num_patches_list}")
else:
print(f"⚠️ 视频文件不存在: {video_path}")
video_loaded = False
video_pixel_values = None
video_num_patches_list = None
except Exception as e:
print(f"❌ 视频加载失败: {e}")
video_loaded = False
video_pixel_values = None
video_num_patches_list = None
audio_path = f'/data1/cxy/dataset/english.mp3'
# 加载音频数据用于后续测试
print("\n📥 加载音频数据...")
try:
# 加载音频文件 - 使用Ola风格的mel谱图预处理
mels, speech_lengths, speech_chunks, speech_wavs = load_audio(audio_path)
print(f"✅ 音频加载成功:")
print(f" - mel谱图形状: {mels.shape}")
print(f" - 音频长度: {speech_lengths}")
print(f" - 音频块数: {speech_chunks}")
print(f" - 原始音频波形形状: {speech_wavs.shape}")
# 将音频数据转换为适当的格式并移到GPU
mels = mels.to(torch.bfloat16).cuda()
speech_lengths = speech_lengths.cuda()
speech_chunks = speech_chunks.cuda()
speech_wavs = speech_wavs.cuda()
audio_loaded = True
except Exception as e:
print(f"❌ 音频加载失败: {e}")
audio_loaded = False
mels = None
speech_lengths = None
# 测试3: Audio only (可能乱码,speech部分未训练)
if audio_loaded:
success3, response3 = test_inference(
test_name="Audio only (预期乱码)",
question="<speech>\nPlease transcribe and summarize what you heard in the audio.",
pixel_values_input=None,
speech_input=mels,
speech_lengths_input=speech_lengths
)
else:
print("⚠️ 跳过Audio only测试 (音频加载失败)")
success3 = False
# 测试4: Audio + Image (可能乱码,speech部分未训练)
if audio_loaded:
success4, response4 = test_inference(
test_name="Audio + Image (预期乱码)",
question="<image>\nUser's question in speech: <speech>\n",
pixel_values_input=pixel_values,
speech_input=mels,
speech_lengths_input=speech_lengths
)
else:
print("⚠️ 跳过Audio + Image测试 (音频加载失败)")
success4 = False
# 测试5: Video + Text (应该正常,使用训练好的InternVL)
if video_loaded:
# 构建视频帧前缀
video_prefix = ''.join([f'Frame{i+1}: <image>\n' for i in range(len(video_num_patches_list))])
video_question = video_prefix + 'What is the red panda doing in this video? Please describe the actions and movements you observe.'
success5, response5 = test_inference(
test_name="Video + Text",
question=video_question,
pixel_values_input=video_pixel_values,
speech_input=None,
speech_lengths_input=None,
num_patches_list=video_num_patches_list
)
else:
print("⚠️ 跳过Video + Text测试 (视频加载失败)")
success5 = False
# # 测试5: Video + Audio (可能乱码,speech部分未训练)
# if audio_loaded:
# success5, response5 = test_inference(
# test_name="Video + Audio (预期乱码)",
# question="<speech><image>\nDescribe what you hear and see in this content.",
# pixel_values_input=pixel_values,
# speech_input=mels,
# speech_lengths_input=speech_lengths
# )
# else:
# print("⚠️ 跳过Video + Audio测试 (音频加载失败)")
# success5 = False
# 测试总结
print("\n" + "="*80)
print("📊 多模态推理测试总结")
print("="*80)
test_results = [
("Pure Text", success1, "PASS", "应该正常 (训练好的InternVL)"),
("Text & Image", success2, "PASS", "应该正常 (训练好的InternVL)"),
("Video + Text", success5 if video_loaded else False, "PASS", "应该正常 (训练好的InternVL)"),
("Audio only", success3 if audio_loaded else False, "GARBLED", "可能乱码 (speech未训练)"),
("Audio + Image", success4 if audio_loaded else False, "GARBLED", "可能乱码 (speech未训练)"),
]
for test_name, success, expected, note in test_results:
status = "✅ PASS" if success else "❌ FAIL"
print(f"{status} {test_name:<15} (预期: {expected:<8}) - {note}")
passed = sum(1 for _, success, _, _ in test_results if success)
total = len(test_results)
print(f"\n📈 测试统计: {passed}/{total} 通过")
if passed >= 2: # 至少pure text、text&image、video+text中的2个应该通过
print("🎉 基础功能正常,Speech集成架构成功!")
print("💡 Speech相关测试如果输出乱码是正常的,因为speech部分还未训练")
if passed >= 3:
print("🌟 所有基础模态测试都通过了!")
else:
print("⚠️ 基础功能可能存在问题,需要进一步检查")
print("\n=== 多模态推理测试完成 ===")