Ocean-OCR / processor_ocean.py
guoxy25's picture
Upload 56 files
2abe772 verified
import requests
import re, ujson, os, sys, fire, glob, random, time, json
import numpy as np
import io
import torch
from torch.utils.data import default_collate
import torchaudio
from typing import *
from dataclasses import dataclass, field
import transformers
from transformers.modeling_outputs import ModelOutput
from transformers.audio_utils import mel_filter_bank, spectrogram, window_function
from functools import lru_cache
from io import BytesIO
from PIL import Image
from qcloud_cos import CosConfig
from qcloud_cos import CosS3Client
import tos
import concurrent.futures as cf
from transformers.image_transforms import resize, center_crop, get_resize_output_image_size
from transformers.image_utils import PILImageResampling
from PIL import Image, ImageOps
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import base64
from decord import VideoReader, cpu
import cv2
import av
import imagesize
import math
def smart_resize(
height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280
):
"""Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
# if height < factor or width < factor:
# raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
if max(height, width) / min(height, width) > 200:
raise ValueError(
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
)
h_bar = round(height / factor) * factor if height > factor else factor
w_bar = round(width / factor) * factor if width > factor else factor
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = math.floor(height / beta / factor) * factor
w_bar = math.floor(width / beta / factor) * factor
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = math.ceil(height * beta / factor) * factor
w_bar = math.ceil(width * beta / factor) * factor
return h_bar, w_bar
def select_best_resolution(image_size, candidate_resolutions):
'''找到最佳的resolution 对于原图进行放缩
image_size 通常为ori_size e.g. (8*336, 16*336)
candidate_resolutions 为备选分辨率 e.g. (1*336, 4*336)
'''
try:
original_width, original_height = image_size
except:
pass
best_fit = None
max_effective_resolution = 0
min_wasted_resolution = float("inf")
# 从candidate_resolutions 中遍历宽和高
for width, height in candidate_resolutions:
# width / original_width 和 height / original_height 中最小的那个作为scale
scale = min(width / original_width, height / original_height) # e.g. scale =min (1/8, 1/4) = 1/8
# 放缩 original_width 和 original_height
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) # e.g. 1*336, 2*336
# effective_resolution 为 放缩之后的分辨率 s^2 * w * h
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) # e.g. min(1*336 * 2*336, 8*336 * 16*336)
# wasted_resolution 为 放缩前后分辨率的差值
wasted_resolution = (width * height) - effective_resolution
# 若 (1) 放缩之后的分辨率 比当前的max_effective_resolution更大;
# (2)
if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
max_effective_resolution = effective_resolution # 更新max_effective_resolution
min_wasted_resolution = wasted_resolution # min_wasted_resolution
best_fit = (width, height)
return best_fit
def read_video(image_path, max_frame_number, decode_way):
if decode_way=='1fps':
try:
vr = VideoReader(image_path, ctx=cpu(0))
total_frame_num = len(vr)
fps = round(vr.get_avg_fps())
frame_idx = [i for i in range(0, len(vr), fps)]
frames = vr.get_batch(frame_idx).asnumpy()
frames = [i for i in frames]
cnt = len(frames)
except Exception as e:
print(image_path)
print('error is', e)
return None
elif decode_way=='key':
try:
with av.open(image_path) as container:
stream = container.streams.video[0]
stream.codec_context.skip_frame = 'NONKEY'
frames = []
fps = int(stream.average_rate)
cnt = 0
for frame in container.decode(stream): # 关键帧存成image patch
image = frame.to_image()
frames.append(image)
cnt += 1
except Exception as e:
print('error is', e)
return None
if frames is None or len(frames)==0:
return None
if len(frames)>max_frame_number and max_frame_number>0:
# 生成均匀间隔的索引
indices = np.linspace(0, len(frames) - 1, max_frame_number, dtype=int)
# 根据索引获取对应元素
sampled_elements = [frames[idx] for idx in indices]
frames = sampled_elements
return frames
class OceanImageProcessor:
def __init__(self, config, **kwargs):
self.config = config # visual_config
self.min_pixels = self.config.min_pixels if hasattr(self.config, 'min_pixels') else 56 * 56
self.max_pixels = self.config.max_pixels if hasattr(self.config, 'max_pixels') else 28 * 28 * 1280
self.patch_size = self.config.patch_size if hasattr(self.config, 'patch_size') else 14
self.temporal_patch_size = self.config.temporal_patch_size if hasattr(self.config, 'temporal_patch_size') else 2
self.merge_size = self.config.merge_size if hasattr(self.config, 'merge_size') else 2
self.spatial_merge_size = self.config.spatial_merge_size if hasattr(self.config, 'spatial_merge_size') else 2
def image_transform(self, strseq, return_mm_data = True):
image = None
if isinstance(strseq, str):
if return_mm_data:
image = Image.open(strseq).convert("RGB")
else:
image = Image.open(BytesIO(strseq)).convert("RGB")
image = np.array(image.convert("RGB")) # 这一步首先将图像转换为 RGB 格式,确保图像有三个通道(R、G、B)。然后使用 np.array() 将其转换为 NumPy 数组,方便后续处理。
image_org_size = image.shape[:2] # 这里保存了图像的原始大小(高度和宽度),image.shape 返回图像的形状 (高度, 宽度, 通道数),而 image.shape[:2] 提取了前两个值,即原始的高度和宽度。这个信息可以用于后续的对比或其他处理。
# resize, crop, scale, normalize
# 接受目标尺寸作为输入参数,通常是目标尺寸的短边或长边长度。例如,如果指定目标短边为 336 像素,函数会自动计算出对应的长边大小,以保持图像的宽高比。
# 输出一个新的尺寸,这个尺寸通常是 (宽度, 高度) 格式,用于后续的图像调整操作,如缩放或裁剪。
resized_height, resized_width = smart_resize(
image_org_size[0], image_org_size[1],
factor=self.patch_size * self.spatial_merge_size,
min_pixels=self.min_pixels,
max_pixels=self.max_pixels,
)
output_size = (resized_height, resized_width)
# output_size = get_resize_output_image_size(image, self.config.crop_size, False) # 短边resize到336
# 使用 resize 函数将图像调整到 output_size 大小。PILImageResampling.BICUBIC 指定使用双三次插值法来进行图像缩放,这种方法通常能够提供较好的图像质量。
# image: 输入的图像数据,可以是 NumPy 数组或 PIL 图像对象;output_size: 目标大小,通常是一个二元组 (宽度, 高度)。这个尺寸可以是图像的绝对大小,也可以是相对于原始图像的比例;
# resample: 可选的重采样方法,通常用于确定如何插值像素。例如,PILImageResampling.BICUBIC 表示使用双三次插值法,这是一种平滑的插值方法,常用于图像缩放。
image = resize(image, output_size, PILImageResampling.BICUBIC)
# 从图像中心裁剪出一个指定大小的区域,这里是一个正方形区域 self.config.crop_size x self.config.crop_size。center_crop 函数的参数 return_numpy=True 表示返回一个 NumPy 数组形式的裁剪图像。
# image = center_crop(image, (self.config.crop_size, self.config.crop_size), return_numpy=True)
img = image.transpose(2, 0, 1)
# 对图像进行归一化和标准化处理
image = (img / 255.0 - np.array(self.config.image_mean)[:, np.newaxis, np.newaxis]) / np.array(self.config.image_std)[:,np.newaxis,np.newaxis]
# 处理成patch
patches = image[np.newaxis, :]
if patches.shape[0] == 1:
patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1))
channel = patches.shape[1]
grid_t = patches.shape[0] // self.temporal_patch_size
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
patches = patches.reshape(
grid_t,
self.temporal_patch_size,
channel,
grid_h // self.spatial_merge_size,
self.spatial_merge_size,
self.patch_size,
grid_w // self.spatial_merge_size,
self.spatial_merge_size,
self.patch_size,
)
patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
flatten_patches = patches.reshape(
grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size
)
return flatten_patches, image_org_size, (grid_t, grid_h, grid_w)
class OceanAudioProcessor:
# 包含基本的音频特征抽取模块 + 输入数据解析模块 + cos请求/缓存模块
def __init__(
self,
config, # audio processor config
**kwargs
):
# make sure you have install 'conda install -c conda-forge 'ffmpeg<7'' for torchaudio
assert(len(torchaudio.list_audio_backends()) > 0)
self.config = config
self.mel_filters = mel_filter_bank(
num_frequency_bins=1 + self.config.n_fft // 2,
num_mel_filters=self.config.num_mel_bins,
min_frequency=0.0,
max_frequency=self.config.sampling_rate / 2.0,
sampling_rate=self.config.sampling_rate,
norm="slaney",
mel_scale="slaney",
)
@staticmethod
def zero_mean_unit_var_norm(x):
return (x - x.mean()) / torch.sqrt(x.var() + 1e-8)
def load_audio_waveform(self, uri, return_tensors=True, do_normalize=False):
metadata = torchaudio.info(uri) # sample_rate, num_frames, num_channels, bits_per_sample, encoding=PCM_S
assert(metadata.num_channels <= 2), "acoustic file with {} channels.".format(metadata.num_channels) # whisper only accept mono channel audio
waveform_tensor, _ = torchaudio.load(uri, normalize=True)
if self.config.sampling_rate != metadata.sample_rate:
waveform_tensor = torchaudio.functional.resample(waveform_tensor, metadata.sample_rate, self.config.sampling_rate)
# downmix to mono channel https://trac.ffmpeg.org/wiki/AudioChannelManipulation
if metadata.num_channels > 1:
waveform_tensor = torch.mean(waveform_tensor, dim=0, keepdim=True)
# normalized to zero mean (Qwen Audio没有处理 但Whisper官方实现)
if do_normalize:
waveform_tensor = self.zero_mean_unit_var_norm(waveform_tensor)
if return_tensors: # (channels, samples)
return waveform_tensor
else:
return waveform_tensor.numpy()
def split_with_overlap(self, waveform): # 如果长度超过最大长度限制 分割为带overlap的多段
channels, wave_samples = waveform.shape
max_audio_samples = self.config.max_audio_seconds * self.config.sampling_rate
if wave_samples <= max_audio_samples or self.config.split_overlap < 0:
return [waveform] # 没有超出最大长度or截断逻辑 统一返回list
split_waveform, start = [], 0
while start < wave_samples: # 20240724修改 统一按秒数对齐overlap 保证不同sampling rate/n_fft/hop length配置下采到的数据是一致的
if start > int(self.config.sampling_rate * self.config.split_overlap):
start -= int(self.config.sampling_rate * self.config.split_overlap) # 0表示没有overlap,>0 overlap对应秒数
end = min(start + max_audio_samples, wave_samples)
split_waveform.append(waveform[:, start:end]) # 注意这里可能会切割出特别短的片段 需要在预处理判断并丢弃
start = end
return split_waveform
@classmethod
def inference_output_length(cls, config, input_length):
# for whisper + bridge
kernel_size = config.kernel_size
stride_size = config.stride_size
avg_pooler = config.avg_pooler
encoder_length = (input_length + 2 * (kernel_size // 2) - kernel_size) // 1 + 1 # conv layer1 with pad=1
encoder_length = (encoder_length + 2 * (kernel_size // 2) - kernel_size) // stride_size + 1 # conv layer2 with pad=1
if avg_pooler > 1:
bridge_length = encoder_length // avg_pooler
return encoder_length, bridge_length
def extract_fbank_features(self, waveform):
# ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py
channels, wave_samples = waveform.shape
assert(wave_samples >= self.config.n_fft)
valid_frame_nums = min(self.config.max_audio_seconds * self.config.sampling_rate // self.config.hop_length, wave_samples // self.config.hop_length + 1)
if wave_samples < self.config.max_audio_seconds * self.config.sampling_rate:
waveform = torch.nn.functional.pad(waveform, (0, self.config.max_audio_seconds * self.config.sampling_rate - wave_samples), "constant", 0)
else:
waveform = waveform[:, :self.config.max_audio_seconds * self.config.sampling_rate]
window = torch.hann_window(self.config.n_fft)
stft = torch.stft(waveform, self.config.n_fft, self.config.hop_length, window=window, return_complex=True) # fft, len(wave) // n_fft // 2 + 1
magnitudes = stft[..., :-1].abs() ** 2
mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
mel_spec = mel_filters.T @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
if waveform.dim() == 2:
max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0]
log_spec = torch.maximum(log_spec, max_val - 8.0)
else:
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
log_spec = log_spec[0].numpy() # (channel, filters, samples) -> (filters, samples)
log_spec[:, valid_frame_nums:] = 0.0 # pad0 在collect时取batch内最大长度
return log_spec, valid_frame_nums
def data_augment(self, feature: np.array, input_length, training=True):
# reference https://arxiv.org/pdf/1904.08779
# run only on cpu
def mask_start_indices(input_length, mask_length, min_masks, mask_prob):
# 计算总共需要mask的span数 之后随机筛选span开始下标
num_masked_span = int(mask_prob * input_length / mask_length + random.random())
num_masked_span = max(num_masked_span, min_masks)
start_indices = list(range(input_length - mask_length))
random.shuffle(start_indices)
start_indices = start_indices[:num_masked_span]
return start_indices
if not training or (self.config.mask_time_prob <= 0 and self.config.mask_feature_prob <= 0):
return feature
if input_length < self.config.mask_time_length * self.config.mask_time_min_masks + 1:
return feature
if self.config.num_mel_bins < self.config.mask_feature_length * self.config.mask_feature_min_masks + 1:
return feature
if self.config.mask_time_prob > 0:
start_indices = mask_start_indices(input_length, self.config.mask_time_length, self.config.mask_time_min_masks, self.config.mask_time_prob)
for start_idx in start_indices:
feature[:, start_idx: start_idx + self.config.mask_time_length] = 0.0
if self.config.mask_feature_prob > 0:
start_indices = mask_start_indices(self.config.num_mel_bins, self.config.mask_feature_length, self.config.mask_feature_min_masks, self.config.mask_feature_prob)
for start_idx in start_indices:
feature[start_idx: start_idx + self.config.mask_feature_length, :] = 0.0
return feature
class CosClient():
def __init__(self, bucket_name='crawl-pic-1317568651',
max_retries=2):
self.config = CosConfig(
Endpoint="cos.ap-guangzhou.myqcloud.com",
# Region='ap-guangzhou',
SecretId='AKIDnRpxoOghgVs0tkU3Mfv20jAMI0SRDj02',
SecretKey='td9tRlqiPvEJ8i27wXwBIDiy5ye6JGyS',
Token=None, Scheme='https', Timeout=300)
self.client = CosS3Client(self.config)
self.max_retries = max_retries
self.bucket_name = bucket_name
def __call__(self, relative_path, bucket_name=None):
if bucket_name is None or len(bucket_name) <= 0:
bucket_name = self.bucket_name
multimodal_bytes = None
for _ in range(self.max_retries):
try:
response = self.client.get_object(Bucket=bucket_name, Key=relative_path)
fp = response['Body'].get_raw_stream()
multimodal_bytes = fp.read()
break
except Exception as e:
time.sleep(0.01)
continue
return multimodal_bytes
class TosClient(object):
def __init__(self):
ak = "AKLTYTM3MWY5MTFhNDgyNDk4YjhmYTE0ZTE3YTk5ZmU1MjU"
sk = "TVRRM1pUZGtaVEJqWTJJd05HSTNPR0ppWVdKa1lqYzVORFUwTlRobU1UVQ=="
endpoint = "tos-cn-beijing.ivolces.com" # "tos-cn-beijing.ivolces.com"
region = "cn-beijing"
self.bucket_name = "audio-dataset"
self.client = tos.TosClientV2(ak, sk, endpoint, region)
def __call__(self, path, bucket_name=None):
if bucket_name is None:
bucket_name = self.bucket_name
for _ in range(2):
try:
object_stream = self.client.get_object(bucket_name, path)
return object_stream.read()
except Exception as e:
time.sleep(0.01)
continue
return None
@dataclass
class OceanProcessorOutput(ModelOutput):
input_ids: Optional["List|torch.Tensor"] = None
labels: Optional["List|torch.Tensor"] = None
attention_mask: Optional["List|torch.Tensor"] = None
position_ids: Optional["List|torch.Tensor"] = None
seqlens: Optional["List|torch.Tensor"] = None # 需要配合Ocean Modeling使用
# audio fields
audios: Optional["List|torch.Tensor"] = None
encoder_length: Optional["List|torch.Tensor"] = None
bridge_length: Optional["List|torch.Tensor"] = None
# image fields
images: Optional["List|torch.Tensor"] = None
patch_nums: Optional["List|torch.Tensor"] = None
images_size: Optional["List|torch.Tensor"] = None
crop_size: Optional["List|torch.Tensor"] = None
images_grid: Optional["List|torch.Tensor"] = None
# video fields
videos: Optional["List|torch.Tensor"] = None
videos_patch_nums: Optional["List|torch.Tensor"] = None
videos_size: Optional["List|torch.Tensor"] = None
videos_crop_size: Optional["List|torch.Tensor"] = None
videos_grid: Optional["List|torch.Tensor"] = None
# processor fields
raw_text: Optional[str] = None
index: Optional[int] = None
def concatenate(self, other): # 仅限list使用
def concat_one(a, b):
if a is None and b is None:
return None
elif a is None and b is not None:
return b
elif a is not None and b is None:
return a
else:
return a + b
return OceanProcessorOutput(
input_ids=concat_one(self.input_ids, other.input_ids),
labels=concat_one(self.labels, other.labels),
audios=concat_one(self.audios, other.audios),
encoder_length=concat_one(self.encoder_length, other.encoder_length),
bridge_length=concat_one(self.bridge_length, other.bridge_length),
images=concat_one(self.images, other.images),
images_grid=concat_one(self.images_grid, other.images_grid),
patch_nums=concat_one(self.patch_nums, other.patch_nums),
videos=concat_one(self.videos, other.videos),
videos_grid=concat_one(self.videos_grid, other.videos_grid),
videos_patch_nums=concat_one(self.videos_patch_nums, other.videos_patch_nums),
position_ids=concat_one(self.position_ids, other.position_ids),
seqlens=concat_one(self.seqlens, other.seqlens),
images_size=concat_one(self.images_size, other.images_size)
)
class OceanMMProcessor(object):
def __init__(self,
tokenizer: transformers.PreTrainedTokenizer,
config,
training,
relative_path=None,
**kwargs,
):
self.tokenizer = tokenizer
self.config = config
self.audio_processor = None
if hasattr(config, "audio_config"):
self.audio_processor = OceanAudioProcessor(config.audio_config)
self.visual_processor = None
if hasattr(config, "visual_config"):
self.visual_processor = OceanImageProcessor(config.visual_config)
self.video_processor = None
if hasattr(config, "video_config"):
self.video_processor = OceanImageProcessor(config.video_config)
self.training = training
self.relative_path = relative_path
self.cos_client = CosClient()
self.tos_client = TosClient()
# audio tag
self.audio_start_tag = None
self.audio_end_tag = None
self.audio_pad_tag = None
self.audio_delim_tag = None
if hasattr(self.config, "audio_config"):
self.audio_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_start_token_id)
self.audio_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_end_token_id)
self.audio_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_pad_token_id)
self.audio_delim_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_delim_token_id)
# image tag
self.image_start_tag = None
self.image_end_tag = None
self.image_pad_tag = None
self.video_start_tag = None
self.video_end_tag = None
if hasattr(self.config, "visual_config"):
# special token for start_tag
self.image_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_start_token_id)
# special token for end_tag
self.image_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_end_token_id)
# special token for pad_tag
self.image_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_pad_token_id)
self.image_line_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_line_token_id)
self.image_delimiter_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_delimiter_token_id)
if hasattr(self.config, "video_config"):
self.video_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_start_token_id)
self.video_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_end_token_id)
self.image_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_start_token_id)
self.image_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_end_token_id)
self.image_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_pad_token_id)
self.video_place_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_place_token_id)
# @lru_cache(maxsize=1024)
def _get_audio(self, audio_info, return_mm_data = True):
try:
audio_info = ujson.loads(audio_info)
audio_uri = None
if 'path' in audio_info.keys():
if self.relative_path is not None: # 优先匹配本地路径
audio_uri = os.path.join(self.relative_path, audio_info['path'])
if not os.path.exists(audio_uri):
audio_uri = None
if audio_uri is None: # 本地没有尝试取cos/tos
if audio_info.get('server', 'cos') == 'tos':
audio_uri = self.tos_client(audio_info['path'], 'audio-dataset')
else:
audio_uri = self.cos_client(audio_info['path'], 'audio-data-1317568651')
elif 'local' in audio_info.keys():
audio_uri = audio_info['local']
if not os.path.exists(audio_uri):
audio_uri = None
return OceanProcessorOutput()
else:
raise ValueError("can not find path or local in audio_info")
waveforms = self.audio_processor.load_audio_waveform(audio_uri, True)
waveforms = self.audio_processor.split_with_overlap(waveforms) # 分割逻辑
ret = OceanProcessorOutput() # 默认初始化 audios字段为None
for waveform in waveforms:
audio, input_length = self.audio_processor.extract_fbank_features(waveform)
audio = self.audio_processor.data_augment(audio, input_length, self.training)
encoder_length, bridge_length = self.audio_processor.inference_output_length(self.config.audio_config, input_length)
if bridge_length <= 0: # 过滤极端短数据 1. 如果len(waveforms)==1 ret=None; 2. len(waveforms)>1 则说明最后一段太短被抛弃
continue
current_ret = OceanProcessorOutput(
audios=[audio],
encoder_length=[encoder_length],
bridge_length=[bridge_length])
if ret.audios is None:
ret = current_ret
else:
ret = ret.concatenate(current_ret) # 拼接多个切片
if not return_mm_data:
ret.audios = [None]
return ret
except Exception as e:
print("**** get audio error: {}, info: {} *****".format(str(e), str(audio_info)))
return OceanProcessorOutput()
# @lru_cache(maxsize=1024)
def _get_image(self, image_info, return_mm_data = True):
try:
try: # chensong
image_info = ujson.loads(image_info)
except:
#image_info = image_info.replace("'", '"')
image_info = re.sub(r"(?<!\\)'", '"', image_info)
image_info = ujson.loads(image_info)
if 'base64' in image_info.keys():
image_data = base64.b64decode(image_info['base64'])
image_feat, org_size, image_list = self.visual_processor.image_transform(image_data)
elif 'local' in image_info.keys():
image_feat, org_size, image_list = self.visual_processor.image_transform(image_info['local'],return_mm_data = return_mm_data)
elif 'path' in image_info.keys():
if "tos_bucket" in image_info.keys(): # tos上的每个item,一定要写明tos的桶以及tos_bucket这个key
tos_bucket = image_info['tos_bucket']
image_bytes = self.tos_client(image_info['path'], tos_bucket) # 从cos_client 获得 image
else:
cos_bucket = None
if "cos_bucket" in image_info.keys():
cos_bucket = image_info['cos_bucket']
if "bucket_name" in image_info.keys():
cos_bucket = image_info['bucket_name']
image_bytes = self.cos_client(image_info['path'], cos_bucket) # 从cos_client 获得 image
# 获得image_feat(image patches), org_size(image最初的size), image_list
image_feat, org_size, image_list = self.visual_processor.image_transform(image_bytes)
else:
raise ValueError("can not find any path in image_info")
merge_length = self.visual_processor.merge_size**2
patch_nums = np.array(image_list).prod() // merge_length
if org_size[0] * org_size[1] > 16**2: # 极端小的图过滤
return OceanProcessorOutput(
images=[image_feat],
patch_nums=[patch_nums],
crop_size=[image_list],
images_size= [org_size],
images_grid=[image_list]
)
else:
print("**** image too small: {}, info: {} *****".format(str(org_size), str(image_info)))
return OceanProcessorOutput()
except Exception as e:
print("**** get image error: {}, info: {} *****".format(str(e), str(image_info)))
return OceanProcessorOutput()
# @lru_cache(maxsize=1024)
def _get_video_frame(self, video_frame_info, return_mm_data = True):
try:
pattern = r'\{.*?\}'
matches = re.findall(pattern, video_frame_info)
ret = OceanProcessorOutput()
# 逐个解析
for match in matches:
video_frame_info = ujson.loads(match)
if 'local' in video_frame_info.keys():
image_feat, org_size, image_list = self.video_processor.image_transform(video_frame_info['local'],return_mm_data = return_mm_data)
else:
raise ValueError("can not find any path in image_info")
merge_length = self.video_processor.merge_size**2
patch_nums = np.array(image_list).prod() // merge_length
if org_size[0] * org_size[1] > 16**2: # 极端小的图过滤
ret = ret.concatenate(
OceanProcessorOutput(
videos=[image_feat],
videos_patch_nums=[patch_nums],
videos_crop_size=[image_list],
videos_size= [org_size],
videos_grid=[image_list]
)
)
else:
print("**** video too small: {}, info: {} *****".format(str(org_size), str(video_frame_info)))
return ret
except Exception as e:
print("**** get video error: {}, info: {} *****".format(str(e), str(video_frame_info)))
return OceanProcessorOutput()
# 读取视频
def _get_video_obj_byte(self, source, path, video_obj_json):
video_obj_byte = None
if source == "cos":
start_time = time.time()
video_obj_byte = self.cos_client(path, bucket_name=video_obj_json.get("cos_bucket", None))
if (time.time() - start_time) > 1.0:
self.reflash_cos_client()
if source == "local":
if os.path.exists(path):
video_obj_byte = open(path, "rb").read()
else:
video_obj_byte = None
if source == "base64":
video_obj_byte = base64.b64decode(path)
if source == "url":
video_obj_byte = requests.get(url=path).content
return video_obj_byte
# 将视频切分为帧,保存至子目录中
def _split_video_to_frames(self, video_info, max_frame_number=-1, decode_way="1fps"):
video_path = video_info['local']
# 帧保存本地路径
frame_path = video_path.split('.')[0] + '_frames'
if not os.path.exists(frame_path) or len(os.listdir(frame_path))==0:
# 保存帧
os.makedirs(frame_path, exist_ok=True)
mm_obj_byte = self._get_video_obj_byte('local', video_path, video_info)
if mm_obj_byte is None: # 未读取到视频文件
return ""
frames = read_video(io.BytesIO(mm_obj_byte), max_frame_number=max_frame_number, decode_way=decode_way) #读取全部帧
for frame_idx, frame in enumerate(frames):
output_filename = os.path.join(frame_path, f"{frame_idx}.jpg")
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
cv2.imwrite(output_filename, frame)
# 选取帧
frame_number = len([filename for filename in os.listdir(frame_path) if filename.endswith('.jpg')])
if frame_number>max_frame_number:
indices = np.linspace(0, frame_number - 1, max_frame_number, dtype=int)
else:
indices = np.linspace(0, frame_number - 1, frame_number, dtype=int)
# 拼接模式
replace_str = ""
for idx in indices:
frame_str = f"{self.image_start_tag}{os.path.join(frame_path, f'{idx}.jpg')}{self.image_end_tag}"
replace_str += frame_str
return replace_str
def _get_video_frame_str(self, video_info, return_mm_data = True ):
try:
video_info = ujson.loads(video_info)
if 'local' in video_info.keys():
# 获取包含多帧图像路径的字符串,最大帧数量max_frame_number
frames_str = self._split_video_to_frames(video_info, max_frame_number=self.config.video_config.max_frame_num, decode_way=self.config.video_config.decode_way)
if frames_str != "":
parts = frames_str.split(self.image_end_tag)
result = []
for part in parts:
if self.image_start_tag in part:
before_path, path = part.split(self.image_start_tag)
new_path = f'{self.image_start_tag}{{"local": "{path}"}}{self.image_end_tag}'
result.append(before_path + new_path)
else:
result.append(part)
return ''.join(result)
else:
raise ValueError('can not find localpath in video_info')
except Exception as e:
print("**** get video error: {}, info: {} *****".format(str(e), str(video_info)))
return ""
# def _replace_audio(self, audio_text, return_mm_data = True):
# audio_info = re.sub(re.compile(self.audio_start_tag + "|" + self.audio_end_tag), '', audio_text)
# ret = self._get_audio(audio_info, return_mm_data) # 重复取结果 cached result
def _replace_audio(self, audio_text, mminfo_ret_dict):
audio_info = re.sub(re.compile(self.audio_start_tag + "|" + self.audio_end_tag), '', audio_text)
# ret = self._get_audio(audio_info) # 重复取结果 cached result
ret = mminfo_ret_dict.get(audio_info, OceanProcessorOutput()) # 直接从字典取
if ret.bridge_length is not None: # TODO 如果pad token很多 tokenizer效率会很低
replaced_text = [self.audio_pad_tag * l for l in ret.bridge_length]
replaced_text = self.audio_delim_tag.join(replaced_text)
return self.audio_start_tag + replaced_text + self.audio_end_tag
return ''
# def _replace_image(self, image_text, return_mm_data = True):
# image_info = re.sub(re.compile(self.image_start_tag + "|" + self.image_end_tag), '', image_text)
# ret = self._get_image(image_info, return_mm_data) # 重复取结果 cached result
def _replace_image(self, image_text, mminfo_ret_dict):
image_info = re.sub(re.compile(self.image_start_tag + "|" + self.image_end_tag), '', image_text)
# ret = self._get_image(image_info) # 重复取结果 cached result
ret = mminfo_ret_dict.get(image_info, OceanProcessorOutput()) # 直接从字典取
if ret.patch_nums is None:
return ''
return self.image_start_tag + self.image_pad_tag * ret.patch_nums[0] + self.image_end_tag
return ''
# def _replace_video_frame(self, video_frame_text, return_mm_data = True):
# video_frame_info = re.sub(re.compile(self.image_start_tag + "|" + self.image_end_tag), '', video_frame_text)
# ret = self._get_video_frame(video_frame_info, return_mm_data) # 重复取结果 cached result
def _replace_video_frame(self, video_frame_text, mminfo_ret_dict):
video_frame_info = re.sub(re.compile(self.video_start_tag + '|' + self.video_end_tag), '', video_frame_text)
video_frame_info = re.sub(re.compile(self.image_start_tag + "|" + self.image_end_tag), '', video_frame_info)
# ret = self._get_video_frame(video_frame_info) # 重复取结果 cached result
ret = mminfo_ret_dict.get(video_frame_info, OceanProcessorOutput())
if ret.videos_patch_nums is None:
return ''
video_frame_str = [self.image_start_tag + self.video_place_tag * ret.videos_patch_nums[i] + self.image_end_tag for i in range(len(ret.videos_patch_nums))]
return ''.join(video_frame_str)
def extract_replace_multimodal(self, text, mtype='audio', return_mm_data = True):
# 抽取text中的json格式音频/图像信息,读取并转化为特征,同时估计encoder token数,填入对应数量的pad token
if (self.audio_start_tag != None) and (mtype == 'audio'):
match_regex = re.compile(self.audio_start_tag + '.*?' + self.audio_end_tag)
drop_regex = re.compile(self.audio_start_tag + "|" + self.audio_end_tag)
extract_func = self._get_audio
replace_func = self._replace_audio
elif (self.image_start_tag != None) and (mtype == 'image'):
match_regex = re.compile(self.image_start_tag + '.*?' + self.image_end_tag)
drop_regex = re.compile(self.image_start_tag + "|" + self.image_end_tag)
extract_func = self._get_image
replace_func = self._replace_image
elif (self.video_start_tag != None) and (mtype == 'video'):
video_match_regex = re.compile(self.video_start_tag + '.*?' + self.video_end_tag)
video_drop_regex = re.compile(self.video_start_tag + "|" + self.video_end_tag)
# 处理视频,将视频路径转换为多帧图像路径
mm_info_list = re.findall(video_match_regex, text)
for mm_info in mm_info_list:
frame_str = self._get_video_frame_str(re.sub(video_drop_regex, '', mm_info))
# 替换路径;如果视频不存在,路径替换为空字符串
text = re.sub(mm_info, self.video_start_tag + frame_str + self.video_end_tag, text)
# 采用多图像处理方式
match_regex = re.compile(self.video_start_tag+r'(.*?)'+self.video_end_tag)
drop_regex = re.compile(self.image_start_tag + "|" + self.image_end_tag)
extract_func = self._get_video_frame
replace_func = self._replace_video_frame
else:
raise ValueError("mtype not supportted!")
mm_info_list = re.findall(match_regex, text)
mm_info_list = [re.sub(drop_regex, '', mm_info) for mm_info in mm_info_list]
mminfo_ret_dict = {}
ret = OceanProcessorOutput()
for mm_info in mm_info_list: # 如果没有匹配到对应的模态 直接返回raw_text=text 结果不会是None
mm_ret = extract_func(mm_info, return_mm_data = return_mm_data)
mminfo_ret_dict[mm_info] = mm_ret
if mm_ret.audios is None and mm_ret.images is None and mm_ret.videos is None: # 数据包含音频/图像/视频但抽取失败 整条数据无效(ret的raw_text为None
return ret
ret = ret.concatenate(mm_ret) # 可能有多条结果,初步collect
# ret.raw_text = re.sub(match_regex, lambda x: replace_func(x.group()), text)
ret.raw_text = re.sub(match_regex, lambda x: replace_func(x.group(), mminfo_ret_dict), text)
return ret
def process_one(self, text, index=0, raw_only=False, return_mm_data = True):
ret = OceanProcessorOutput(index=index)
for mtype in self.config.multimodal: # 循环获取音频 图像结果 并更新raw_text字段
mret = self.extract_replace_multimodal(text, mtype, return_mm_data = return_mm_data) # 增加获取视频结果
if mret.raw_text is None: # 数据包含音频但音频获取失败
return OceanProcessorOutput(index=index)
ret = ret.concatenate(mret)
text = mret.raw_text
ret.raw_text = text
if raw_only:
return ret # 兼容SFT等自定义tokenizer逻辑的代码
# 处理预训练中的trainable部分
input_ids, labels = [], []
trainable_sep = re.findall(r'<trainable_start>|<trainable_end>', ret.raw_text.replace('\n', '<LF>'))
if len(trainable_sep) <= 0:
input_ids = self.tokenizer(ret.raw_text, padding='do_not_pad', truncation=True, return_tensors="np")['input_ids'][0].tolist()
labels = [True for _ in input_ids]
else:
split_content = re.split(r'<trainable_start>|<trainable_end>', ret.raw_text)
for i, sc in enumerate(split_content):
if len(sc.strip()) == 0:
continue # 把多余的空格干掉
sc_ids = self.tokenizer(sc, padding='do_not_pad', truncation=True, return_tensors="np")['input_ids'][0].tolist()
input_ids.extend(sc_ids)
if i == 0 or trainable_sep[i - 1] == '<trainable_end>': # stop gradient
labels.extend([False] * len(sc_ids))
else:
labels.extend([True] * len(sc_ids))
# input_ids += [self.tokenizer.eos_token_id]
# labels += [True]
ret.labels = [input_ids[j] if (l and input_ids[j] not in self.config.multimodal_special_token_no_loss_list) else -100 for j, l in enumerate(labels)]
ret.input_ids = input_ids
ret.index = index
return ret
@torch.no_grad()
def __call__(self, example, parallel=8):
# 最终入口 支持预训练数据string,sft数据message, 以及 batch推理数据listofstring 3种形式
if isinstance(example, Dict):
pass
elif isinstance(example, str):
return self.process_one(example)
elif isinstance(example, List): # batch推理 异步多线程处理
with cf.ThreadPoolExecutor(min(parallel, len(example))) as executor:
future_list = [executor.submit(self.process_one, di, idx) for idx, di in enumerate(example)]
batch_data = [key.result() for key in cf.as_completed(future_list)]
valid_num = sum([1 if x.input_ids is not None else 0 for x in batch_data])
assert(valid_num == len(batch_data)) # 推理数据严格要求数量对齐
batch_data = sorted(batch_data, key=lambda x: x.index) # 保证顺序不变
ret = OceanProcessorOutput()
for i in range(len(batch_data)):
ret = ret.concatenate(batch_data[i])
self.tokenizer.padding_side = "left"
padding_result = self.tokenizer.pad({"input_ids": [r.input_ids for r in batch_data]}, return_tensors='pt')
ret.input_ids = padding_result["input_ids"]
ret.attention_mask = padding_result["attention_mask"] # batch推理不pack 不需要seqlens
padding_result = self.tokenizer.pad({"input_ids": [r.labels for r in batch_data]}, return_tensors='pt')
ret.labels = padding_result["input_ids"]
if ret.audios is not None:
ret.audios = default_collate(ret.audios)
ret.encoder_length = default_collate(ret.encoder_length)
ret.bridge_length = default_collate(ret.bridge_length)
if ret.images is not None:
ret.images = [torch.from_numpy(np.asarray(image, dtype=np.float32)) for image in ret.images]
# else:ret.images = default_collate(ret.images)
# ret.patch_nums = default_collate(ret.patch_nums)
if ret.videos is not None:
ret.images = [torch.from_numpy(np.asarray(image, dtype=np.float32)) for image in ret.videos]
return ret
else:
raise ValueError("example format supported yet")
@torch.no_grad()
def pack_batch_pretrain(self, raw_batch, max_sequence_length=None, parallel=8):
if max_sequence_length is None:
max_sequence_length = self.tokenizer.model_max_length
# 将N条数据pack为M条 max_sequence_length长度的数据, 每条数据包含所属的多模态输入
assert isinstance(raw_batch, List)
start_ts = time.time()
if parallel > 1:
with cf.ThreadPoolExecutor(max_workers=parallel) as executor:
future_list = []
for idx, json_text in enumerate(raw_batch):
try: # 读取json
json_obj = ujson.loads(json_text.strip())
except:
try:
json_obj = ast.literal_eval(json_text.strip())
except:
print("parse json obj faild: {}....".format(json_text[:300]))
continue
try: # chensong
if isinstance(json_obj, list):
content = json_obj[1]
elif 'raw' in json_obj.keys():
content = (json_obj["title"] if "title" in json_obj.keys() else "") + json_obj["raw"]
else:
content = (json_obj["title"] if "title" in json_obj.keys() else "") + json_obj["content"]
except:
print("parse json raw/content error: {}....".format(json_text[:300]))
continue
future_list.append(executor.submit(self.process_one, content, idx))
# 获取结果 乱序
batch_data = [key.result() for key in cf.as_completed(future_list)]
else: # debug only
batch_data = []
for json_text in raw_batch:
data = ujson.loads(json_text.strip())
if 'raw' in data.keys():
batch_data.append(self.process_one(data['raw'], 0))
else:
batch_data.append(self.process_one(data['content'], 0))
if (time.time() - start_ts) / (len(batch_data) + 1e-3) > 1.0:
print('[WARNING] processing each data cost more than 1.0s')
# packing 文本部分的输入,不做任何截断
current_length, packed_output, output = 0, OceanProcessorOutput(position_ids=[], seqlens=[]), []
empty_data = OceanProcessorOutput(input_ids=[], labels=[])
for idx, bd in enumerate(batch_data + [empty_data]): # 加空数据方便appedn最后一个数据到output,防止遗漏
if bd.input_ids is None and idx < len(batch_data):
continue # 数据没取到 并且不是最后一个
if (len(bd.input_ids) <= 0 or len(bd.input_ids) + 1 > max_sequence_length) and idx < len(batch_data):
continue # 太长的直接不要 并且不是最后一个
if current_length + len(bd.input_ids) + 1 > max_sequence_length or idx == len(batch_data):
pad_nums = max_sequence_length - current_length # right padding
if packed_output.input_ids is None or packed_output.labels is None:
packed_output.input_ids = [self.tokenizer.pad_token_id] * pad_nums
packed_output.labels = [-100] * pad_nums
packed_output.position_ids += [0] * (pad_nums+1)
else:
packed_output.input_ids += [self.tokenizer.pad_token_id] * pad_nums
packed_output.labels += [-100] * pad_nums
packed_output.position_ids += [0] * pad_nums
packed_output.attention_mask = [1] * current_length + [0] * pad_nums
packed_output.seqlens += [0] * (max_sequence_length - len(packed_output.seqlens))
output.append(packed_output)
packed_output = OceanProcessorOutput(position_ids=[], seqlens=[]) # reset empty
packed_output = packed_output.concatenate(bd)
packed_output.input_ids.append(self.tokenizer.eos_token_id) # </s>需要单独加
packed_output.labels.append(self.tokenizer.eos_token_id)
packed_output.position_ids.extend(list(range(len(bd.input_ids) + 1)))
packed_output.seqlens.append(len(bd.input_ids) + 1)
current_length = len(packed_output.input_ids)
return output
@torch.no_grad()
def collect_batch_pretrain(self, batch_data):
ret = OceanProcessorOutput()
for i in range(len(batch_data)):
ret = ret.concatenate(batch_data[i])
ret.input_ids = default_collate([np.asarray(x.input_ids, dtype=np.int64) for x in batch_data]).cuda(non_blocking=True)
ret.labels = default_collate([np.asarray(x.labels, dtype=np.int64) for x in batch_data]).cuda(non_blocking=True)
ret.attention_mask = default_collate([np.asarray(x.attention_mask, dtype=np.float32) for x in batch_data]).cuda(non_blocking=True)
ret.position_ids = default_collate([np.asarray(x.position_ids, dtype=np.int64) for x in batch_data]).cuda(non_blocking=True)
ret.seqlens = default_collate([np.asarray(x.seqlens, dtype=np.int64) for x in batch_data]).cuda(non_blocking=True)
ret.raw_text = None
if ret.audios is not None:
ret.audios = default_collate(np.asarray(ret.audios, dtype=np.float32)).cuda(non_blocking=True)
ret.encoder_length = default_collate(np.asarray(ret.encoder_length, dtype=np.int32)).cuda(non_blocking=True)
ret.bridge_length = default_collate(np.asarray(ret.bridge_length, dtype=np.int32)).cuda(non_blocking=True)
if ret.images is not None:
ret.images = [torch.from_numpy(np.asarray(image, dtype=np.float32)).cuda(non_blocking=True) for image in ret.images]#default_collate(np.asarray(ret.images, dtype=np.float32)).cuda(non_blocking=True)
ret.patch_nums = default_collate(np.asarray(ret.patch_nums, dtype=np.int32)).cuda(non_blocking=True)
if ret.videos is not None:
ret.videos = [torch.from_numpy(np.asarray(video, dtype=np.float32)).cuda(non_blocking=True) for video in ret.videos]#default_collate(np.asarray(ret.images, dtype=np.float32)).cuda(non_blocking=True)
ret.videos_patch_nums = default_collate(np.asarray(ret.videos_patch_nums, dtype=np.int32)).cuda(non_blocking=True)
return ret
@torch.no_grad()
def collect_batch_sft(self, batch_data):
# list of dict to dataclass
batch_data = [OceanProcessorOutput(**bd) for bd in batch_data]
ret = OceanProcessorOutput()
for i in range(len(batch_data)):
ret = ret.concatenate(batch_data[i])
ret.input_ids = default_collate([np.asarray(x.input_ids, dtype=np.int64) for x in batch_data])
ret.labels = default_collate([np.asarray(x.labels, dtype=np.int64) for x in batch_data])
ret.position_ids = default_collate([np.asarray(x.position_ids, dtype=np.int64) for x in batch_data])
ret.seqlens = default_collate([np.asarray(x.seqlens, dtype=np.int64) for x in batch_data])
ret.raw_text = None
if ret.audios is not None:
ret.audios = default_collate(np.asarray(ret.audios, dtype=np.float32))
ret.encoder_length = default_collate(np.asarray(ret.encoder_length, dtype=np.int32))
ret.bridge_length = default_collate(np.asarray(ret.bridge_length, dtype=np.int32))
if ret.images is not None:
# 转换 每个image 为torch tensor
ret.images = [torch.from_numpy(np.asarray(image, dtype=np.float32)) for image in ret.images]#default_collate(np.asarray(ret.images, dtype=np.float32)).cuda(non_blocking=True)
if ret.videos is not None:
ret.videos = [torch.from_numpy(np.asarray(video, dtype=np.float32)) for video in ret.videos]#default_collate(np.asarray(ret.images, dtype=np.float32)).cuda(non_blocking=True)
# ret.patch_nums = default_collate(np.asarray(ret.patch_nums, dtype=np.int32)).cuda(non_blocking=True)
ret = ret.__dict__
del ret['patch_nums']
del ret['images_size']
del ret['crop_size']
del ret['raw_text']
del ret['index']
del ret['attention_mask']
del ret['videos_patch_nums']
del ret['videos_size']
del ret['videos_crop_size']
return ret
#######################################################
## Unit Test Functions, usage
## python processor_ocean.py test
#######################################################
def test_img_processor():
from transformers import AutoConfig
from transformers.models.clip import CLIPImageProcessor
config = AutoConfig.from_pretrained("./", trust_remote_code=True)
processor = OceanImageProcessor(config.visual_config)
offical_processor = CLIPImageProcessor(size=config.visual_config.crop_size, crop_size=config.visual_config.crop_size,
image_mean=config.visual_config.image_mean, image_std=config.visual_config.image_std,
do_convert_rgb=True)
img_files = ['sogou/7a2c8ffc1bc61146b32805c3390f42e2', 'wukong/77c1db1c0e4200d12b478c33ba3a412d', 'wukong/62e9a5c8eb8b0ea8858a34ba3f1a999f', 'wukong/fb9ab4d7c3fe9f54289948fd6a57fc30']
cos_client = CosClient()
for img_file in img_files:
img_bytes = cos_client(img_file)
img_rbg = Image.open(io.BytesIO(img_bytes))
image, org_size = processor.image_transform(img_bytes)
offical_image = offical_processor.preprocess([img_rbg],
do_resize=True, do_center_crop=True, do_rescale=True, do_normalize=True,
return_tensors='np').data['pixel_values'][0]
print('-'*60)
print(np.array(img_rbg).shape)
print(image.shape)
print(offical_image.shape)
print(image - offical_image)
def test_audio_processor():
from transformers.models.whisper import WhisperFeatureExtractor
from transformers import AutoConfig
config = AutoConfig.from_pretrained("./", trust_remote_code=True)
offical_processor = WhisperFeatureExtractor(feature_size=128)
processor = OceanAudioProcessor(config.audio_config)
# wave_files = glob.glob('/home/nfs_bc_alignment/sunhaoze/audio-data/openaqa/openaqa-as/audio/*')
wave_files = ['/home/nfs_bc_alignment/sunhaoze/sounds/audioset_full/7ZY0U5tfKyQ.flac', '/home/nfs_bc_alignment/sunhaoze/sounds/audioset_full/Osly4Shchs4.flac']
for wave_file in wave_files:
wave = processor.load_audio_waveform(wave_file, True, False)
offical_features = offical_processor(wave[0].numpy(), do_normalize=False)
feat = offical_features['input_features'][0]
wave, frame_nums = processor.extract_fbank_features(wave)
print("="*60)
print(feat.shape)
print(wave.shape, frame_nums)
print('the difference between offical extractor and our implementation: {}'.format(wave_file))
print(wave[:, :frame_nums] - feat[:, :frame_nums])
print(wave)
# print(wave[120:-1, :])
# print(feat[120:-1, :wave.shape[1]])
zeros_before = np.sum(wave == 0)
aug = processor.data_augment(wave, frame_nums)
zeros_after = np.sum(aug == 0)
print(zeros_before, zeros_after)
def test_audio_long(): # 测试超过30秒音频的截断策略
from transformers import AutoConfig, AutoTokenizer
config = AutoConfig.from_pretrained("./", trust_remote_code=True)
config.audio_config.split_overlap = 1
tokenizer = AutoTokenizer.from_pretrained("./", model_max_length=4096)
processor = OceanMMProcessor(tokenizer, config, True)
examples = ["<audio_start_ocean>{\"path\": \"panda\/testdata\/podcast_demo_30s\/easy_chat_xianliaohuier_30s\/easy_chat_xianliaohuier-133.mp3\"}<audio_end_ocean>What is the level of noise from the speech?\n<trainable_start>The speech energy\n is medium.<trainable_end>",
"what's the sound's energy? \n sound1 <audio_start_ocean>{\"path\": \"panda\/testdata\/podcast_demo_30s\/btrt_talk_heihua_30s\/btrt_talk_heihua-116.mp3\"}<audio_end_ocean> \n sound2 <audio_start_ocean>{\"path\": \"panda\/testdata\/podcast_demo_30s\/btrt_talk_heihua_30s\/btrt_talk_heihua-221.mp3\"}<audio_end_ocean>The speech energy is medium.",
]
ret = processor(examples)
print(ret)
print(torch.sum(ret.input_ids == 151659))
print(torch.sum(ret.input_ids == 151674))
def test_processor():
from transformers import AutoConfig, AutoTokenizer
config = AutoConfig.from_pretrained("./", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("./", model_max_length=4096)
processor = OceanMMProcessor(tokenizer, config, True, '/home/nfs_bc_alignment/sunhaoze/sounds')
examples = ["<audio_start_ocean>{\"path\": \"vggsound\/7DH5fqj8j6Q.flac\"}<audio_end_ocean>What is the level of noise from the speech?\n<trainable_start>The speech energy\n is medium.<trainable_end>",
"hello, ocean 你好 百川智能。",
"what's the sound's energy? \n <audio_start_ocean>{\"path\": \"iemocap\/Ses01F_script01_3_F022.wav\"}<audio_end_ocean>The speech energy is medium.",
"sound1: <audio_start_ocean>{\"path\": \"audioset_full\/9B53NVDNT8U.flac\"}<audio_end_ocean>\n sound2: \n<audio_start_ocean>{\"path\": \"audioset_full\/a2dgzb9GDSQ.flac\"}<audio_end_ocean>How is the speech speed related to the estimated speaker age?\n<trainable_start>The slow speech speed suggests a more deliberate and thoughtful approach often seen in mature individuals.<trainable_end>",
"<img_start_ocean>{\"path\": \"sogou\/7351ae4f3fbe58ff0e4cc165cfabb3ed\"}<img_end_ocean>新和记潮汕牛肉火锅的牛肉丸好不好吃 用户评价口味怎么样 常州美食牛肉丸实拍图片 大众点评",
"这两个图片有什么关系?图片1<img_start_ocean>{\"path\": \"sogou\/ac91d57ab68335913ed41aa283e76356\"}<img_end_ocean>图片2\n<img_start_ocean>{\"path\": \"sogou\/6ad5e632b74265d9ef689e45936ab1aa\"}<img_end_ocean>",
"根据图片和语音给出描述\n图片<img_start_ocean>{\"path\": \"sogou\/32274c1ab28d11f8c490cf7ae15b36f1\"}<img_end_ocean>语音<audio_start_ocean>{\"path\": \"voxceleb2\/id06726_s2lysJWkjus_00169.m4a\"}<audio_end_ocean><trainable_start>这是一只猫<trainable_end>",
"这些图片和音频不存在<img_start_ocean>{\"path\": \"soogou\/32274c1ab28d11f8c490cf7ae15b36f1\"}<img_end_ocean>语音<audio_start_ocean>{\"path\": \"voxceleb_1\/id06726_s2lysJWkjus_00169.m4a\"}<audio_end_ocean><trainable_start>这是一只猫<trainable_end>"
]
ret = processor(examples[4:-1])
print(ret)
print(torch.sum(ret.input_ids == 151659))
print(torch.sum(ret.input_ids == 151662))
try:
print(ret.bridge_length)
print(ret.patch_nums)
except:
pass
print(torch.sum(ret.attention_mask, dim=1))
def test_grounding():
from transformers import AutoConfig, AutoTokenizer
config = AutoConfig.from_pretrained("./", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("./", model_max_length=4096)
processor = OceanMMProcessor(tokenizer, config, True, '/home/nfs_bc_alignment/sunhaoze/sounds')
examples = ["<img_start_ocean>{\"path\": \"grit\/663423bf2f0884c034bf75279bce9694\"}<img_end_ocean>\nWhere is \"A woman\" ? Answer: <trainable_start>The bounding box is <box_start_ocean>(0.58,0.8),(0.71,1.0)<box_end_ocean><trainable_end>",
"hello, ocean 你好 百川智能。",
"<img_start_ocean>{\"path\": \"grit\/0e6e3952c584cbac7235940a22514656\"}<img_end_ocean> Generate the caption with grounding: <trainable_start>Photo pour Portrait of <ref_start_ocean>young Asian muslim woman wearing hijab<ref_end_ocean><box_start_ocean>(0.09,0.01),(0.77,1.0)<box_end_ocean> shows regret gesture, hand on her forehead, forget something important, against red background - image libre de droit<trainable_end>",
"Recognize the object in the outlined section <img_start_ocean>{\"path\": \"grit\/045823cf6f819670f27aee20af7ae0e6\"}<img_end_ocean> of the picture.<box_start_ocean>(0.07,0.2),(0.91,0.96)<box_end_ocean>\n<trainable_start>Inflatable water trampolines<trainable_end>"
]
ret = processor(examples)
print(ret)
for i, input_ids in enumerate(ret.input_ids):
print("="*60)
print(ret.labels[i])
def test_pack():
from transformers import AutoConfig, AutoTokenizer
config = AutoConfig.from_pretrained("./", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("./", model_max_length=2048)
processor = OceanMMProcessor(tokenizer, config, True, '/home/nfs_bc_alignment/sunhaoze/sounds')
examples = open('/cpfs/29f69eb5e2e60f26/user/sunhaoze/pretrain-v6/sogou/part-00000').readlines()[:5]
examples += open('/home/nfs_bc_alignment/sunhaoze/text/openaqa-as-stage2-v1/part-00000').readlines()[:5]
random.shuffle(examples)
batch_output = processor.pack_batch_pretrain(examples)
for i, b in enumerate(batch_output):
print('='*60)
try:
print(b.input_ids, len(b.input_ids))
print(b.labels, len(b.labels))
print(b.attention_mask, len(b.attention_mask))
print(b.position_ids, len(b.position_ids))
print(b.seqlens, len(b.seqlens))
print(b.audios)
print(b.bridge_length)
except:
continue
batch_for_model = processor.collect_batch_pretrain(batch_output)
print(batch_for_model.input_ids.shape)
print(batch_for_model.labels.shape)
print(batch_for_model.audios.shape)
print(batch_for_model["bridge_length"])
print(batch_for_model.images.shape)
print(batch_for_model["patch_nums"])
print(batch_for_model["position_ids"])
print(batch_for_model["seqlens"])
def test_cos_audio():
cos_client = CosClient()
audio_bytes = cos_client('panda/data/common_voice/cv-corpus-18.0-2024-06-14/zh-CN/clips/common_voice_zh-CN_19428637.mp3', 'audio-data-1317568651')
wave, sr = torchaudio.load(audio_bytes, normalize=False)
print(wave.shape, sr)
# torchaudio.save('tmp.flac', wave, sr)
if __name__ == '__main__':
fire.Fire()