|
|
import requests |
|
|
import re, ujson, os, sys, fire, glob, random, time, json |
|
|
import numpy as np |
|
|
import io |
|
|
import torch |
|
|
from torch.utils.data import default_collate |
|
|
import torchaudio |
|
|
from typing import * |
|
|
from dataclasses import dataclass, field |
|
|
import transformers |
|
|
from transformers.modeling_outputs import ModelOutput |
|
|
from transformers.audio_utils import mel_filter_bank, spectrogram, window_function |
|
|
from functools import lru_cache |
|
|
from io import BytesIO |
|
|
from PIL import Image |
|
|
from qcloud_cos import CosConfig |
|
|
from qcloud_cos import CosS3Client |
|
|
import tos |
|
|
import concurrent.futures as cf |
|
|
from transformers.image_transforms import resize, center_crop, get_resize_output_image_size |
|
|
from transformers.image_utils import PILImageResampling |
|
|
from PIL import Image, ImageOps |
|
|
from PIL import ImageFile |
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
import base64 |
|
|
from decord import VideoReader, cpu |
|
|
import cv2 |
|
|
import av |
|
|
import imagesize |
|
|
import math |
|
|
|
|
|
|
|
|
def smart_resize( |
|
|
height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280 |
|
|
): |
|
|
"""Rescales the image so that the following conditions are met: |
|
|
|
|
|
1. Both dimensions (height and width) are divisible by 'factor'. |
|
|
|
|
|
2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. |
|
|
|
|
|
3. The aspect ratio of the image is maintained as closely as possible. |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
if max(height, width) / min(height, width) > 200: |
|
|
raise ValueError( |
|
|
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}" |
|
|
) |
|
|
h_bar = round(height / factor) * factor if height > factor else factor |
|
|
w_bar = round(width / factor) * factor if width > factor else factor |
|
|
if h_bar * w_bar > max_pixels: |
|
|
beta = math.sqrt((height * width) / max_pixels) |
|
|
h_bar = math.floor(height / beta / factor) * factor |
|
|
w_bar = math.floor(width / beta / factor) * factor |
|
|
elif h_bar * w_bar < min_pixels: |
|
|
beta = math.sqrt(min_pixels / (height * width)) |
|
|
h_bar = math.ceil(height * beta / factor) * factor |
|
|
w_bar = math.ceil(width * beta / factor) * factor |
|
|
return h_bar, w_bar |
|
|
|
|
|
|
|
|
def select_best_resolution(image_size, candidate_resolutions): |
|
|
'''找到最佳的resolution 对于原图进行放缩 |
|
|
image_size 通常为ori_size e.g. (8*336, 16*336) |
|
|
candidate_resolutions 为备选分辨率 e.g. (1*336, 4*336) |
|
|
''' |
|
|
try: |
|
|
original_width, original_height = image_size |
|
|
except: |
|
|
pass |
|
|
best_fit = None |
|
|
max_effective_resolution = 0 |
|
|
min_wasted_resolution = float("inf") |
|
|
|
|
|
for width, height in candidate_resolutions: |
|
|
|
|
|
scale = min(width / original_width, height / original_height) |
|
|
|
|
|
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) |
|
|
|
|
|
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) |
|
|
|
|
|
wasted_resolution = (width * height) - effective_resolution |
|
|
|
|
|
|
|
|
if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution): |
|
|
max_effective_resolution = effective_resolution |
|
|
min_wasted_resolution = wasted_resolution |
|
|
best_fit = (width, height) |
|
|
|
|
|
return best_fit |
|
|
|
|
|
|
|
|
def read_video(image_path, max_frame_number, decode_way): |
|
|
if decode_way=='1fps': |
|
|
try: |
|
|
vr = VideoReader(image_path, ctx=cpu(0)) |
|
|
total_frame_num = len(vr) |
|
|
fps = round(vr.get_avg_fps()) |
|
|
frame_idx = [i for i in range(0, len(vr), fps)] |
|
|
frames = vr.get_batch(frame_idx).asnumpy() |
|
|
frames = [i for i in frames] |
|
|
cnt = len(frames) |
|
|
except Exception as e: |
|
|
print(image_path) |
|
|
print('error is', e) |
|
|
return None |
|
|
elif decode_way=='key': |
|
|
try: |
|
|
with av.open(image_path) as container: |
|
|
stream = container.streams.video[0] |
|
|
stream.codec_context.skip_frame = 'NONKEY' |
|
|
frames = [] |
|
|
fps = int(stream.average_rate) |
|
|
cnt = 0 |
|
|
for frame in container.decode(stream): |
|
|
image = frame.to_image() |
|
|
frames.append(image) |
|
|
cnt += 1 |
|
|
except Exception as e: |
|
|
print('error is', e) |
|
|
return None |
|
|
if frames is None or len(frames)==0: |
|
|
return None |
|
|
if len(frames)>max_frame_number and max_frame_number>0: |
|
|
|
|
|
indices = np.linspace(0, len(frames) - 1, max_frame_number, dtype=int) |
|
|
|
|
|
sampled_elements = [frames[idx] for idx in indices] |
|
|
frames = sampled_elements |
|
|
return frames |
|
|
|
|
|
class OceanImageProcessor: |
|
|
def __init__(self, config, **kwargs): |
|
|
self.config = config |
|
|
self.min_pixels = self.config.min_pixels if hasattr(self.config, 'min_pixels') else 56 * 56 |
|
|
self.max_pixels = self.config.max_pixels if hasattr(self.config, 'max_pixels') else 28 * 28 * 1280 |
|
|
self.patch_size = self.config.patch_size if hasattr(self.config, 'patch_size') else 14 |
|
|
self.temporal_patch_size = self.config.temporal_patch_size if hasattr(self.config, 'temporal_patch_size') else 2 |
|
|
self.merge_size = self.config.merge_size if hasattr(self.config, 'merge_size') else 2 |
|
|
self.spatial_merge_size = self.config.spatial_merge_size if hasattr(self.config, 'spatial_merge_size') else 2 |
|
|
|
|
|
def image_transform(self, strseq, return_mm_data = True): |
|
|
image = None |
|
|
if isinstance(strseq, str): |
|
|
if return_mm_data: |
|
|
image = Image.open(strseq).convert("RGB") |
|
|
else: |
|
|
image = Image.open(BytesIO(strseq)).convert("RGB") |
|
|
|
|
|
image = np.array(image.convert("RGB")) |
|
|
image_org_size = image.shape[:2] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
resized_height, resized_width = smart_resize( |
|
|
image_org_size[0], image_org_size[1], |
|
|
factor=self.patch_size * self.spatial_merge_size, |
|
|
min_pixels=self.min_pixels, |
|
|
max_pixels=self.max_pixels, |
|
|
) |
|
|
output_size = (resized_height, resized_width) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image = resize(image, output_size, PILImageResampling.BICUBIC) |
|
|
|
|
|
|
|
|
img = image.transpose(2, 0, 1) |
|
|
|
|
|
image = (img / 255.0 - np.array(self.config.image_mean)[:, np.newaxis, np.newaxis]) / np.array(self.config.image_std)[:,np.newaxis,np.newaxis] |
|
|
|
|
|
patches = image[np.newaxis, :] |
|
|
if patches.shape[0] == 1: |
|
|
patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1)) |
|
|
channel = patches.shape[1] |
|
|
grid_t = patches.shape[0] // self.temporal_patch_size |
|
|
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size |
|
|
patches = patches.reshape( |
|
|
grid_t, |
|
|
self.temporal_patch_size, |
|
|
channel, |
|
|
grid_h // self.spatial_merge_size, |
|
|
self.spatial_merge_size, |
|
|
self.patch_size, |
|
|
grid_w // self.spatial_merge_size, |
|
|
self.spatial_merge_size, |
|
|
self.patch_size, |
|
|
) |
|
|
patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8) |
|
|
flatten_patches = patches.reshape( |
|
|
grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size |
|
|
) |
|
|
|
|
|
return flatten_patches, image_org_size, (grid_t, grid_h, grid_w) |
|
|
|
|
|
|
|
|
class OceanAudioProcessor: |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config, |
|
|
**kwargs |
|
|
): |
|
|
|
|
|
assert(len(torchaudio.list_audio_backends()) > 0) |
|
|
self.config = config |
|
|
self.mel_filters = mel_filter_bank( |
|
|
num_frequency_bins=1 + self.config.n_fft // 2, |
|
|
num_mel_filters=self.config.num_mel_bins, |
|
|
min_frequency=0.0, |
|
|
max_frequency=self.config.sampling_rate / 2.0, |
|
|
sampling_rate=self.config.sampling_rate, |
|
|
norm="slaney", |
|
|
mel_scale="slaney", |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def zero_mean_unit_var_norm(x): |
|
|
return (x - x.mean()) / torch.sqrt(x.var() + 1e-8) |
|
|
|
|
|
def load_audio_waveform(self, uri, return_tensors=True, do_normalize=False): |
|
|
metadata = torchaudio.info(uri) |
|
|
assert(metadata.num_channels <= 2), "acoustic file with {} channels.".format(metadata.num_channels) |
|
|
waveform_tensor, _ = torchaudio.load(uri, normalize=True) |
|
|
if self.config.sampling_rate != metadata.sample_rate: |
|
|
waveform_tensor = torchaudio.functional.resample(waveform_tensor, metadata.sample_rate, self.config.sampling_rate) |
|
|
|
|
|
|
|
|
if metadata.num_channels > 1: |
|
|
waveform_tensor = torch.mean(waveform_tensor, dim=0, keepdim=True) |
|
|
|
|
|
|
|
|
if do_normalize: |
|
|
waveform_tensor = self.zero_mean_unit_var_norm(waveform_tensor) |
|
|
|
|
|
if return_tensors: |
|
|
return waveform_tensor |
|
|
else: |
|
|
return waveform_tensor.numpy() |
|
|
|
|
|
def split_with_overlap(self, waveform): |
|
|
channels, wave_samples = waveform.shape |
|
|
max_audio_samples = self.config.max_audio_seconds * self.config.sampling_rate |
|
|
if wave_samples <= max_audio_samples or self.config.split_overlap < 0: |
|
|
return [waveform] |
|
|
|
|
|
split_waveform, start = [], 0 |
|
|
while start < wave_samples: |
|
|
if start > int(self.config.sampling_rate * self.config.split_overlap): |
|
|
start -= int(self.config.sampling_rate * self.config.split_overlap) |
|
|
end = min(start + max_audio_samples, wave_samples) |
|
|
split_waveform.append(waveform[:, start:end]) |
|
|
start = end |
|
|
return split_waveform |
|
|
|
|
|
@classmethod |
|
|
def inference_output_length(cls, config, input_length): |
|
|
|
|
|
kernel_size = config.kernel_size |
|
|
stride_size = config.stride_size |
|
|
avg_pooler = config.avg_pooler |
|
|
encoder_length = (input_length + 2 * (kernel_size // 2) - kernel_size) // 1 + 1 |
|
|
encoder_length = (encoder_length + 2 * (kernel_size // 2) - kernel_size) // stride_size + 1 |
|
|
if avg_pooler > 1: |
|
|
bridge_length = encoder_length // avg_pooler |
|
|
return encoder_length, bridge_length |
|
|
|
|
|
def extract_fbank_features(self, waveform): |
|
|
|
|
|
channels, wave_samples = waveform.shape |
|
|
assert(wave_samples >= self.config.n_fft) |
|
|
valid_frame_nums = min(self.config.max_audio_seconds * self.config.sampling_rate // self.config.hop_length, wave_samples // self.config.hop_length + 1) |
|
|
if wave_samples < self.config.max_audio_seconds * self.config.sampling_rate: |
|
|
waveform = torch.nn.functional.pad(waveform, (0, self.config.max_audio_seconds * self.config.sampling_rate - wave_samples), "constant", 0) |
|
|
else: |
|
|
waveform = waveform[:, :self.config.max_audio_seconds * self.config.sampling_rate] |
|
|
|
|
|
window = torch.hann_window(self.config.n_fft) |
|
|
stft = torch.stft(waveform, self.config.n_fft, self.config.hop_length, window=window, return_complex=True) |
|
|
magnitudes = stft[..., :-1].abs() ** 2 |
|
|
|
|
|
mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32) |
|
|
mel_spec = mel_filters.T @ magnitudes |
|
|
log_spec = torch.clamp(mel_spec, min=1e-10).log10() |
|
|
|
|
|
if waveform.dim() == 2: |
|
|
max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0] |
|
|
log_spec = torch.maximum(log_spec, max_val - 8.0) |
|
|
else: |
|
|
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) |
|
|
log_spec = (log_spec + 4.0) / 4.0 |
|
|
|
|
|
log_spec = log_spec[0].numpy() |
|
|
log_spec[:, valid_frame_nums:] = 0.0 |
|
|
|
|
|
return log_spec, valid_frame_nums |
|
|
|
|
|
def data_augment(self, feature: np.array, input_length, training=True): |
|
|
|
|
|
|
|
|
def mask_start_indices(input_length, mask_length, min_masks, mask_prob): |
|
|
|
|
|
num_masked_span = int(mask_prob * input_length / mask_length + random.random()) |
|
|
num_masked_span = max(num_masked_span, min_masks) |
|
|
start_indices = list(range(input_length - mask_length)) |
|
|
random.shuffle(start_indices) |
|
|
start_indices = start_indices[:num_masked_span] |
|
|
return start_indices |
|
|
|
|
|
if not training or (self.config.mask_time_prob <= 0 and self.config.mask_feature_prob <= 0): |
|
|
return feature |
|
|
if input_length < self.config.mask_time_length * self.config.mask_time_min_masks + 1: |
|
|
return feature |
|
|
if self.config.num_mel_bins < self.config.mask_feature_length * self.config.mask_feature_min_masks + 1: |
|
|
return feature |
|
|
|
|
|
if self.config.mask_time_prob > 0: |
|
|
start_indices = mask_start_indices(input_length, self.config.mask_time_length, self.config.mask_time_min_masks, self.config.mask_time_prob) |
|
|
for start_idx in start_indices: |
|
|
feature[:, start_idx: start_idx + self.config.mask_time_length] = 0.0 |
|
|
if self.config.mask_feature_prob > 0: |
|
|
start_indices = mask_start_indices(self.config.num_mel_bins, self.config.mask_feature_length, self.config.mask_feature_min_masks, self.config.mask_feature_prob) |
|
|
for start_idx in start_indices: |
|
|
feature[start_idx: start_idx + self.config.mask_feature_length, :] = 0.0 |
|
|
|
|
|
return feature |
|
|
|
|
|
class CosClient(): |
|
|
def __init__(self, bucket_name='crawl-pic-1317568651', |
|
|
max_retries=2): |
|
|
self.config = CosConfig( |
|
|
Endpoint="cos.ap-guangzhou.myqcloud.com", |
|
|
|
|
|
SecretId='AKIDnRpxoOghgVs0tkU3Mfv20jAMI0SRDj02', |
|
|
SecretKey='td9tRlqiPvEJ8i27wXwBIDiy5ye6JGyS', |
|
|
Token=None, Scheme='https', Timeout=300) |
|
|
self.client = CosS3Client(self.config) |
|
|
self.max_retries = max_retries |
|
|
self.bucket_name = bucket_name |
|
|
|
|
|
def __call__(self, relative_path, bucket_name=None): |
|
|
if bucket_name is None or len(bucket_name) <= 0: |
|
|
bucket_name = self.bucket_name |
|
|
multimodal_bytes = None |
|
|
for _ in range(self.max_retries): |
|
|
try: |
|
|
response = self.client.get_object(Bucket=bucket_name, Key=relative_path) |
|
|
fp = response['Body'].get_raw_stream() |
|
|
multimodal_bytes = fp.read() |
|
|
break |
|
|
except Exception as e: |
|
|
time.sleep(0.01) |
|
|
continue |
|
|
return multimodal_bytes |
|
|
|
|
|
class TosClient(object): |
|
|
def __init__(self): |
|
|
ak = "AKLTYTM3MWY5MTFhNDgyNDk4YjhmYTE0ZTE3YTk5ZmU1MjU" |
|
|
sk = "TVRRM1pUZGtaVEJqWTJJd05HSTNPR0ppWVdKa1lqYzVORFUwTlRobU1UVQ==" |
|
|
endpoint = "tos-cn-beijing.ivolces.com" |
|
|
region = "cn-beijing" |
|
|
self.bucket_name = "audio-dataset" |
|
|
self.client = tos.TosClientV2(ak, sk, endpoint, region) |
|
|
|
|
|
def __call__(self, path, bucket_name=None): |
|
|
if bucket_name is None: |
|
|
bucket_name = self.bucket_name |
|
|
for _ in range(2): |
|
|
try: |
|
|
object_stream = self.client.get_object(bucket_name, path) |
|
|
return object_stream.read() |
|
|
except Exception as e: |
|
|
time.sleep(0.01) |
|
|
continue |
|
|
return None |
|
|
|
|
|
@dataclass |
|
|
class OceanProcessorOutput(ModelOutput): |
|
|
input_ids: Optional["List|torch.Tensor"] = None |
|
|
labels: Optional["List|torch.Tensor"] = None |
|
|
attention_mask: Optional["List|torch.Tensor"] = None |
|
|
position_ids: Optional["List|torch.Tensor"] = None |
|
|
seqlens: Optional["List|torch.Tensor"] = None |
|
|
|
|
|
audios: Optional["List|torch.Tensor"] = None |
|
|
encoder_length: Optional["List|torch.Tensor"] = None |
|
|
bridge_length: Optional["List|torch.Tensor"] = None |
|
|
|
|
|
images: Optional["List|torch.Tensor"] = None |
|
|
patch_nums: Optional["List|torch.Tensor"] = None |
|
|
images_size: Optional["List|torch.Tensor"] = None |
|
|
crop_size: Optional["List|torch.Tensor"] = None |
|
|
images_grid: Optional["List|torch.Tensor"] = None |
|
|
|
|
|
videos: Optional["List|torch.Tensor"] = None |
|
|
videos_patch_nums: Optional["List|torch.Tensor"] = None |
|
|
videos_size: Optional["List|torch.Tensor"] = None |
|
|
videos_crop_size: Optional["List|torch.Tensor"] = None |
|
|
videos_grid: Optional["List|torch.Tensor"] = None |
|
|
|
|
|
raw_text: Optional[str] = None |
|
|
index: Optional[int] = None |
|
|
|
|
|
def concatenate(self, other): |
|
|
def concat_one(a, b): |
|
|
if a is None and b is None: |
|
|
return None |
|
|
elif a is None and b is not None: |
|
|
return b |
|
|
elif a is not None and b is None: |
|
|
return a |
|
|
else: |
|
|
return a + b |
|
|
return OceanProcessorOutput( |
|
|
input_ids=concat_one(self.input_ids, other.input_ids), |
|
|
labels=concat_one(self.labels, other.labels), |
|
|
audios=concat_one(self.audios, other.audios), |
|
|
encoder_length=concat_one(self.encoder_length, other.encoder_length), |
|
|
bridge_length=concat_one(self.bridge_length, other.bridge_length), |
|
|
images=concat_one(self.images, other.images), |
|
|
images_grid=concat_one(self.images_grid, other.images_grid), |
|
|
patch_nums=concat_one(self.patch_nums, other.patch_nums), |
|
|
|
|
|
videos=concat_one(self.videos, other.videos), |
|
|
videos_grid=concat_one(self.videos_grid, other.videos_grid), |
|
|
videos_patch_nums=concat_one(self.videos_patch_nums, other.videos_patch_nums), |
|
|
|
|
|
position_ids=concat_one(self.position_ids, other.position_ids), |
|
|
seqlens=concat_one(self.seqlens, other.seqlens), |
|
|
images_size=concat_one(self.images_size, other.images_size) |
|
|
) |
|
|
|
|
|
class OceanMMProcessor(object): |
|
|
def __init__(self, |
|
|
tokenizer: transformers.PreTrainedTokenizer, |
|
|
config, |
|
|
training, |
|
|
relative_path=None, |
|
|
**kwargs, |
|
|
): |
|
|
self.tokenizer = tokenizer |
|
|
self.config = config |
|
|
self.audio_processor = None |
|
|
if hasattr(config, "audio_config"): |
|
|
self.audio_processor = OceanAudioProcessor(config.audio_config) |
|
|
self.visual_processor = None |
|
|
if hasattr(config, "visual_config"): |
|
|
self.visual_processor = OceanImageProcessor(config.visual_config) |
|
|
self.video_processor = None |
|
|
if hasattr(config, "video_config"): |
|
|
self.video_processor = OceanImageProcessor(config.video_config) |
|
|
self.training = training |
|
|
self.relative_path = relative_path |
|
|
self.cos_client = CosClient() |
|
|
self.tos_client = TosClient() |
|
|
|
|
|
self.audio_start_tag = None |
|
|
self.audio_end_tag = None |
|
|
self.audio_pad_tag = None |
|
|
self.audio_delim_tag = None |
|
|
if hasattr(self.config, "audio_config"): |
|
|
self.audio_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_start_token_id) |
|
|
self.audio_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_end_token_id) |
|
|
self.audio_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_pad_token_id) |
|
|
self.audio_delim_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_delim_token_id) |
|
|
|
|
|
self.image_start_tag = None |
|
|
self.image_end_tag = None |
|
|
self.image_pad_tag = None |
|
|
self.video_start_tag = None |
|
|
self.video_end_tag = None |
|
|
if hasattr(self.config, "visual_config"): |
|
|
|
|
|
self.image_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_start_token_id) |
|
|
|
|
|
self.image_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_end_token_id) |
|
|
|
|
|
self.image_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_pad_token_id) |
|
|
self.image_line_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_line_token_id) |
|
|
self.image_delimiter_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_delimiter_token_id) |
|
|
if hasattr(self.config, "video_config"): |
|
|
self.video_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_start_token_id) |
|
|
self.video_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_end_token_id) |
|
|
self.image_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_start_token_id) |
|
|
self.image_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_end_token_id) |
|
|
self.image_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_pad_token_id) |
|
|
self.video_place_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_place_token_id) |
|
|
|
|
|
|
|
|
|
|
|
def _get_audio(self, audio_info, return_mm_data = True): |
|
|
try: |
|
|
audio_info = ujson.loads(audio_info) |
|
|
audio_uri = None |
|
|
if 'path' in audio_info.keys(): |
|
|
|
|
|
if self.relative_path is not None: |
|
|
audio_uri = os.path.join(self.relative_path, audio_info['path']) |
|
|
if not os.path.exists(audio_uri): |
|
|
audio_uri = None |
|
|
if audio_uri is None: |
|
|
if audio_info.get('server', 'cos') == 'tos': |
|
|
audio_uri = self.tos_client(audio_info['path'], 'audio-dataset') |
|
|
else: |
|
|
audio_uri = self.cos_client(audio_info['path'], 'audio-data-1317568651') |
|
|
|
|
|
elif 'local' in audio_info.keys(): |
|
|
audio_uri = audio_info['local'] |
|
|
if not os.path.exists(audio_uri): |
|
|
audio_uri = None |
|
|
return OceanProcessorOutput() |
|
|
else: |
|
|
raise ValueError("can not find path or local in audio_info") |
|
|
|
|
|
waveforms = self.audio_processor.load_audio_waveform(audio_uri, True) |
|
|
waveforms = self.audio_processor.split_with_overlap(waveforms) |
|
|
ret = OceanProcessorOutput() |
|
|
for waveform in waveforms: |
|
|
audio, input_length = self.audio_processor.extract_fbank_features(waveform) |
|
|
audio = self.audio_processor.data_augment(audio, input_length, self.training) |
|
|
encoder_length, bridge_length = self.audio_processor.inference_output_length(self.config.audio_config, input_length) |
|
|
if bridge_length <= 0: |
|
|
continue |
|
|
current_ret = OceanProcessorOutput( |
|
|
audios=[audio], |
|
|
encoder_length=[encoder_length], |
|
|
bridge_length=[bridge_length]) |
|
|
if ret.audios is None: |
|
|
ret = current_ret |
|
|
else: |
|
|
ret = ret.concatenate(current_ret) |
|
|
if not return_mm_data: |
|
|
ret.audios = [None] |
|
|
return ret |
|
|
except Exception as e: |
|
|
print("**** get audio error: {}, info: {} *****".format(str(e), str(audio_info))) |
|
|
return OceanProcessorOutput() |
|
|
|
|
|
|
|
|
def _get_image(self, image_info, return_mm_data = True): |
|
|
try: |
|
|
try: |
|
|
image_info = ujson.loads(image_info) |
|
|
except: |
|
|
|
|
|
image_info = re.sub(r"(?<!\\)'", '"', image_info) |
|
|
image_info = ujson.loads(image_info) |
|
|
if 'base64' in image_info.keys(): |
|
|
image_data = base64.b64decode(image_info['base64']) |
|
|
image_feat, org_size, image_list = self.visual_processor.image_transform(image_data) |
|
|
elif 'local' in image_info.keys(): |
|
|
image_feat, org_size, image_list = self.visual_processor.image_transform(image_info['local'],return_mm_data = return_mm_data) |
|
|
elif 'path' in image_info.keys(): |
|
|
if "tos_bucket" in image_info.keys(): |
|
|
tos_bucket = image_info['tos_bucket'] |
|
|
image_bytes = self.tos_client(image_info['path'], tos_bucket) |
|
|
else: |
|
|
cos_bucket = None |
|
|
if "cos_bucket" in image_info.keys(): |
|
|
cos_bucket = image_info['cos_bucket'] |
|
|
if "bucket_name" in image_info.keys(): |
|
|
cos_bucket = image_info['bucket_name'] |
|
|
image_bytes = self.cos_client(image_info['path'], cos_bucket) |
|
|
|
|
|
image_feat, org_size, image_list = self.visual_processor.image_transform(image_bytes) |
|
|
else: |
|
|
raise ValueError("can not find any path in image_info") |
|
|
|
|
|
merge_length = self.visual_processor.merge_size**2 |
|
|
patch_nums = np.array(image_list).prod() // merge_length |
|
|
|
|
|
if org_size[0] * org_size[1] > 16**2: |
|
|
return OceanProcessorOutput( |
|
|
images=[image_feat], |
|
|
patch_nums=[patch_nums], |
|
|
crop_size=[image_list], |
|
|
images_size= [org_size], |
|
|
images_grid=[image_list] |
|
|
) |
|
|
else: |
|
|
print("**** image too small: {}, info: {} *****".format(str(org_size), str(image_info))) |
|
|
return OceanProcessorOutput() |
|
|
|
|
|
except Exception as e: |
|
|
print("**** get image error: {}, info: {} *****".format(str(e), str(image_info))) |
|
|
return OceanProcessorOutput() |
|
|
|
|
|
|
|
|
def _get_video_frame(self, video_frame_info, return_mm_data = True): |
|
|
try: |
|
|
pattern = r'\{.*?\}' |
|
|
matches = re.findall(pattern, video_frame_info) |
|
|
ret = OceanProcessorOutput() |
|
|
|
|
|
for match in matches: |
|
|
video_frame_info = ujson.loads(match) |
|
|
if 'local' in video_frame_info.keys(): |
|
|
image_feat, org_size, image_list = self.video_processor.image_transform(video_frame_info['local'],return_mm_data = return_mm_data) |
|
|
else: |
|
|
raise ValueError("can not find any path in image_info") |
|
|
|
|
|
merge_length = self.video_processor.merge_size**2 |
|
|
patch_nums = np.array(image_list).prod() // merge_length |
|
|
|
|
|
if org_size[0] * org_size[1] > 16**2: |
|
|
ret = ret.concatenate( |
|
|
OceanProcessorOutput( |
|
|
videos=[image_feat], |
|
|
videos_patch_nums=[patch_nums], |
|
|
videos_crop_size=[image_list], |
|
|
videos_size= [org_size], |
|
|
videos_grid=[image_list] |
|
|
) |
|
|
) |
|
|
else: |
|
|
print("**** video too small: {}, info: {} *****".format(str(org_size), str(video_frame_info))) |
|
|
return ret |
|
|
|
|
|
except Exception as e: |
|
|
print("**** get video error: {}, info: {} *****".format(str(e), str(video_frame_info))) |
|
|
return OceanProcessorOutput() |
|
|
|
|
|
|
|
|
def _get_video_obj_byte(self, source, path, video_obj_json): |
|
|
video_obj_byte = None |
|
|
if source == "cos": |
|
|
start_time = time.time() |
|
|
video_obj_byte = self.cos_client(path, bucket_name=video_obj_json.get("cos_bucket", None)) |
|
|
if (time.time() - start_time) > 1.0: |
|
|
self.reflash_cos_client() |
|
|
if source == "local": |
|
|
if os.path.exists(path): |
|
|
video_obj_byte = open(path, "rb").read() |
|
|
else: |
|
|
video_obj_byte = None |
|
|
if source == "base64": |
|
|
video_obj_byte = base64.b64decode(path) |
|
|
if source == "url": |
|
|
video_obj_byte = requests.get(url=path).content |
|
|
return video_obj_byte |
|
|
|
|
|
|
|
|
def _split_video_to_frames(self, video_info, max_frame_number=-1, decode_way="1fps"): |
|
|
video_path = video_info['local'] |
|
|
|
|
|
frame_path = video_path.split('.')[0] + '_frames' |
|
|
if not os.path.exists(frame_path) or len(os.listdir(frame_path))==0: |
|
|
|
|
|
os.makedirs(frame_path, exist_ok=True) |
|
|
mm_obj_byte = self._get_video_obj_byte('local', video_path, video_info) |
|
|
if mm_obj_byte is None: |
|
|
return "" |
|
|
frames = read_video(io.BytesIO(mm_obj_byte), max_frame_number=max_frame_number, decode_way=decode_way) |
|
|
for frame_idx, frame in enumerate(frames): |
|
|
output_filename = os.path.join(frame_path, f"{frame_idx}.jpg") |
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) |
|
|
cv2.imwrite(output_filename, frame) |
|
|
|
|
|
|
|
|
frame_number = len([filename for filename in os.listdir(frame_path) if filename.endswith('.jpg')]) |
|
|
if frame_number>max_frame_number: |
|
|
indices = np.linspace(0, frame_number - 1, max_frame_number, dtype=int) |
|
|
else: |
|
|
indices = np.linspace(0, frame_number - 1, frame_number, dtype=int) |
|
|
|
|
|
replace_str = "" |
|
|
for idx in indices: |
|
|
frame_str = f"{self.image_start_tag}{os.path.join(frame_path, f'{idx}.jpg')}{self.image_end_tag}" |
|
|
replace_str += frame_str |
|
|
return replace_str |
|
|
|
|
|
def _get_video_frame_str(self, video_info, return_mm_data = True ): |
|
|
try: |
|
|
video_info = ujson.loads(video_info) |
|
|
if 'local' in video_info.keys(): |
|
|
|
|
|
frames_str = self._split_video_to_frames(video_info, max_frame_number=self.config.video_config.max_frame_num, decode_way=self.config.video_config.decode_way) |
|
|
if frames_str != "": |
|
|
parts = frames_str.split(self.image_end_tag) |
|
|
result = [] |
|
|
for part in parts: |
|
|
if self.image_start_tag in part: |
|
|
before_path, path = part.split(self.image_start_tag) |
|
|
new_path = f'{self.image_start_tag}{{"local": "{path}"}}{self.image_end_tag}' |
|
|
result.append(before_path + new_path) |
|
|
else: |
|
|
result.append(part) |
|
|
return ''.join(result) |
|
|
else: |
|
|
raise ValueError('can not find localpath in video_info') |
|
|
except Exception as e: |
|
|
print("**** get video error: {}, info: {} *****".format(str(e), str(video_info))) |
|
|
return "" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _replace_audio(self, audio_text, mminfo_ret_dict): |
|
|
audio_info = re.sub(re.compile(self.audio_start_tag + "|" + self.audio_end_tag), '', audio_text) |
|
|
|
|
|
ret = mminfo_ret_dict.get(audio_info, OceanProcessorOutput()) |
|
|
if ret.bridge_length is not None: |
|
|
replaced_text = [self.audio_pad_tag * l for l in ret.bridge_length] |
|
|
replaced_text = self.audio_delim_tag.join(replaced_text) |
|
|
return self.audio_start_tag + replaced_text + self.audio_end_tag |
|
|
return '' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _replace_image(self, image_text, mminfo_ret_dict): |
|
|
image_info = re.sub(re.compile(self.image_start_tag + "|" + self.image_end_tag), '', image_text) |
|
|
|
|
|
ret = mminfo_ret_dict.get(image_info, OceanProcessorOutput()) |
|
|
if ret.patch_nums is None: |
|
|
return '' |
|
|
return self.image_start_tag + self.image_pad_tag * ret.patch_nums[0] + self.image_end_tag |
|
|
return '' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _replace_video_frame(self, video_frame_text, mminfo_ret_dict): |
|
|
video_frame_info = re.sub(re.compile(self.video_start_tag + '|' + self.video_end_tag), '', video_frame_text) |
|
|
video_frame_info = re.sub(re.compile(self.image_start_tag + "|" + self.image_end_tag), '', video_frame_info) |
|
|
|
|
|
ret = mminfo_ret_dict.get(video_frame_info, OceanProcessorOutput()) |
|
|
if ret.videos_patch_nums is None: |
|
|
return '' |
|
|
video_frame_str = [self.image_start_tag + self.video_place_tag * ret.videos_patch_nums[i] + self.image_end_tag for i in range(len(ret.videos_patch_nums))] |
|
|
return ''.join(video_frame_str) |
|
|
|
|
|
def extract_replace_multimodal(self, text, mtype='audio', return_mm_data = True): |
|
|
|
|
|
if (self.audio_start_tag != None) and (mtype == 'audio'): |
|
|
match_regex = re.compile(self.audio_start_tag + '.*?' + self.audio_end_tag) |
|
|
drop_regex = re.compile(self.audio_start_tag + "|" + self.audio_end_tag) |
|
|
extract_func = self._get_audio |
|
|
replace_func = self._replace_audio |
|
|
elif (self.image_start_tag != None) and (mtype == 'image'): |
|
|
match_regex = re.compile(self.image_start_tag + '.*?' + self.image_end_tag) |
|
|
drop_regex = re.compile(self.image_start_tag + "|" + self.image_end_tag) |
|
|
extract_func = self._get_image |
|
|
replace_func = self._replace_image |
|
|
elif (self.video_start_tag != None) and (mtype == 'video'): |
|
|
video_match_regex = re.compile(self.video_start_tag + '.*?' + self.video_end_tag) |
|
|
video_drop_regex = re.compile(self.video_start_tag + "|" + self.video_end_tag) |
|
|
|
|
|
mm_info_list = re.findall(video_match_regex, text) |
|
|
for mm_info in mm_info_list: |
|
|
frame_str = self._get_video_frame_str(re.sub(video_drop_regex, '', mm_info)) |
|
|
|
|
|
text = re.sub(mm_info, self.video_start_tag + frame_str + self.video_end_tag, text) |
|
|
|
|
|
match_regex = re.compile(self.video_start_tag+r'(.*?)'+self.video_end_tag) |
|
|
drop_regex = re.compile(self.image_start_tag + "|" + self.image_end_tag) |
|
|
extract_func = self._get_video_frame |
|
|
replace_func = self._replace_video_frame |
|
|
else: |
|
|
raise ValueError("mtype not supportted!") |
|
|
|
|
|
mm_info_list = re.findall(match_regex, text) |
|
|
mm_info_list = [re.sub(drop_regex, '', mm_info) for mm_info in mm_info_list] |
|
|
|
|
|
mminfo_ret_dict = {} |
|
|
ret = OceanProcessorOutput() |
|
|
for mm_info in mm_info_list: |
|
|
mm_ret = extract_func(mm_info, return_mm_data = return_mm_data) |
|
|
mminfo_ret_dict[mm_info] = mm_ret |
|
|
if mm_ret.audios is None and mm_ret.images is None and mm_ret.videos is None: |
|
|
return ret |
|
|
ret = ret.concatenate(mm_ret) |
|
|
|
|
|
ret.raw_text = re.sub(match_regex, lambda x: replace_func(x.group(), mminfo_ret_dict), text) |
|
|
return ret |
|
|
|
|
|
def process_one(self, text, index=0, raw_only=False, return_mm_data = True): |
|
|
ret = OceanProcessorOutput(index=index) |
|
|
for mtype in self.config.multimodal: |
|
|
mret = self.extract_replace_multimodal(text, mtype, return_mm_data = return_mm_data) |
|
|
if mret.raw_text is None: |
|
|
return OceanProcessorOutput(index=index) |
|
|
ret = ret.concatenate(mret) |
|
|
text = mret.raw_text |
|
|
ret.raw_text = text |
|
|
if raw_only: |
|
|
return ret |
|
|
|
|
|
|
|
|
input_ids, labels = [], [] |
|
|
trainable_sep = re.findall(r'<trainable_start>|<trainable_end>', ret.raw_text.replace('\n', '<LF>')) |
|
|
if len(trainable_sep) <= 0: |
|
|
input_ids = self.tokenizer(ret.raw_text, padding='do_not_pad', truncation=True, return_tensors="np")['input_ids'][0].tolist() |
|
|
labels = [True for _ in input_ids] |
|
|
else: |
|
|
split_content = re.split(r'<trainable_start>|<trainable_end>', ret.raw_text) |
|
|
for i, sc in enumerate(split_content): |
|
|
if len(sc.strip()) == 0: |
|
|
continue |
|
|
sc_ids = self.tokenizer(sc, padding='do_not_pad', truncation=True, return_tensors="np")['input_ids'][0].tolist() |
|
|
input_ids.extend(sc_ids) |
|
|
if i == 0 or trainable_sep[i - 1] == '<trainable_end>': |
|
|
labels.extend([False] * len(sc_ids)) |
|
|
else: |
|
|
labels.extend([True] * len(sc_ids)) |
|
|
|
|
|
|
|
|
ret.labels = [input_ids[j] if (l and input_ids[j] not in self.config.multimodal_special_token_no_loss_list) else -100 for j, l in enumerate(labels)] |
|
|
ret.input_ids = input_ids |
|
|
ret.index = index |
|
|
return ret |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__(self, example, parallel=8): |
|
|
|
|
|
if isinstance(example, Dict): |
|
|
pass |
|
|
elif isinstance(example, str): |
|
|
return self.process_one(example) |
|
|
elif isinstance(example, List): |
|
|
with cf.ThreadPoolExecutor(min(parallel, len(example))) as executor: |
|
|
future_list = [executor.submit(self.process_one, di, idx) for idx, di in enumerate(example)] |
|
|
batch_data = [key.result() for key in cf.as_completed(future_list)] |
|
|
valid_num = sum([1 if x.input_ids is not None else 0 for x in batch_data]) |
|
|
assert(valid_num == len(batch_data)) |
|
|
batch_data = sorted(batch_data, key=lambda x: x.index) |
|
|
|
|
|
ret = OceanProcessorOutput() |
|
|
for i in range(len(batch_data)): |
|
|
ret = ret.concatenate(batch_data[i]) |
|
|
self.tokenizer.padding_side = "left" |
|
|
padding_result = self.tokenizer.pad({"input_ids": [r.input_ids for r in batch_data]}, return_tensors='pt') |
|
|
ret.input_ids = padding_result["input_ids"] |
|
|
ret.attention_mask = padding_result["attention_mask"] |
|
|
padding_result = self.tokenizer.pad({"input_ids": [r.labels for r in batch_data]}, return_tensors='pt') |
|
|
ret.labels = padding_result["input_ids"] |
|
|
|
|
|
if ret.audios is not None: |
|
|
ret.audios = default_collate(ret.audios) |
|
|
ret.encoder_length = default_collate(ret.encoder_length) |
|
|
ret.bridge_length = default_collate(ret.bridge_length) |
|
|
|
|
|
if ret.images is not None: |
|
|
ret.images = [torch.from_numpy(np.asarray(image, dtype=np.float32)) for image in ret.images] |
|
|
|
|
|
|
|
|
|
|
|
if ret.videos is not None: |
|
|
ret.images = [torch.from_numpy(np.asarray(image, dtype=np.float32)) for image in ret.videos] |
|
|
|
|
|
return ret |
|
|
|
|
|
else: |
|
|
raise ValueError("example format supported yet") |
|
|
|
|
|
@torch.no_grad() |
|
|
def pack_batch_pretrain(self, raw_batch, max_sequence_length=None, parallel=8): |
|
|
if max_sequence_length is None: |
|
|
max_sequence_length = self.tokenizer.model_max_length |
|
|
|
|
|
assert isinstance(raw_batch, List) |
|
|
start_ts = time.time() |
|
|
if parallel > 1: |
|
|
with cf.ThreadPoolExecutor(max_workers=parallel) as executor: |
|
|
future_list = [] |
|
|
for idx, json_text in enumerate(raw_batch): |
|
|
try: |
|
|
json_obj = ujson.loads(json_text.strip()) |
|
|
except: |
|
|
try: |
|
|
json_obj = ast.literal_eval(json_text.strip()) |
|
|
except: |
|
|
print("parse json obj faild: {}....".format(json_text[:300])) |
|
|
continue |
|
|
try: |
|
|
if isinstance(json_obj, list): |
|
|
content = json_obj[1] |
|
|
elif 'raw' in json_obj.keys(): |
|
|
content = (json_obj["title"] if "title" in json_obj.keys() else "") + json_obj["raw"] |
|
|
else: |
|
|
content = (json_obj["title"] if "title" in json_obj.keys() else "") + json_obj["content"] |
|
|
except: |
|
|
print("parse json raw/content error: {}....".format(json_text[:300])) |
|
|
continue |
|
|
|
|
|
future_list.append(executor.submit(self.process_one, content, idx)) |
|
|
|
|
|
batch_data = [key.result() for key in cf.as_completed(future_list)] |
|
|
else: |
|
|
|
|
|
batch_data = [] |
|
|
for json_text in raw_batch: |
|
|
data = ujson.loads(json_text.strip()) |
|
|
if 'raw' in data.keys(): |
|
|
batch_data.append(self.process_one(data['raw'], 0)) |
|
|
else: |
|
|
batch_data.append(self.process_one(data['content'], 0)) |
|
|
|
|
|
if (time.time() - start_ts) / (len(batch_data) + 1e-3) > 1.0: |
|
|
print('[WARNING] processing each data cost more than 1.0s') |
|
|
|
|
|
|
|
|
current_length, packed_output, output = 0, OceanProcessorOutput(position_ids=[], seqlens=[]), [] |
|
|
empty_data = OceanProcessorOutput(input_ids=[], labels=[]) |
|
|
for idx, bd in enumerate(batch_data + [empty_data]): |
|
|
if bd.input_ids is None and idx < len(batch_data): |
|
|
continue |
|
|
if (len(bd.input_ids) <= 0 or len(bd.input_ids) + 1 > max_sequence_length) and idx < len(batch_data): |
|
|
continue |
|
|
if current_length + len(bd.input_ids) + 1 > max_sequence_length or idx == len(batch_data): |
|
|
pad_nums = max_sequence_length - current_length |
|
|
if packed_output.input_ids is None or packed_output.labels is None: |
|
|
packed_output.input_ids = [self.tokenizer.pad_token_id] * pad_nums |
|
|
packed_output.labels = [-100] * pad_nums |
|
|
packed_output.position_ids += [0] * (pad_nums+1) |
|
|
else: |
|
|
packed_output.input_ids += [self.tokenizer.pad_token_id] * pad_nums |
|
|
packed_output.labels += [-100] * pad_nums |
|
|
packed_output.position_ids += [0] * pad_nums |
|
|
packed_output.attention_mask = [1] * current_length + [0] * pad_nums |
|
|
packed_output.seqlens += [0] * (max_sequence_length - len(packed_output.seqlens)) |
|
|
output.append(packed_output) |
|
|
packed_output = OceanProcessorOutput(position_ids=[], seqlens=[]) |
|
|
packed_output = packed_output.concatenate(bd) |
|
|
packed_output.input_ids.append(self.tokenizer.eos_token_id) |
|
|
packed_output.labels.append(self.tokenizer.eos_token_id) |
|
|
|
|
|
packed_output.position_ids.extend(list(range(len(bd.input_ids) + 1))) |
|
|
packed_output.seqlens.append(len(bd.input_ids) + 1) |
|
|
|
|
|
current_length = len(packed_output.input_ids) |
|
|
return output |
|
|
|
|
|
@torch.no_grad() |
|
|
def collect_batch_pretrain(self, batch_data): |
|
|
ret = OceanProcessorOutput() |
|
|
for i in range(len(batch_data)): |
|
|
ret = ret.concatenate(batch_data[i]) |
|
|
ret.input_ids = default_collate([np.asarray(x.input_ids, dtype=np.int64) for x in batch_data]).cuda(non_blocking=True) |
|
|
ret.labels = default_collate([np.asarray(x.labels, dtype=np.int64) for x in batch_data]).cuda(non_blocking=True) |
|
|
ret.attention_mask = default_collate([np.asarray(x.attention_mask, dtype=np.float32) for x in batch_data]).cuda(non_blocking=True) |
|
|
ret.position_ids = default_collate([np.asarray(x.position_ids, dtype=np.int64) for x in batch_data]).cuda(non_blocking=True) |
|
|
ret.seqlens = default_collate([np.asarray(x.seqlens, dtype=np.int64) for x in batch_data]).cuda(non_blocking=True) |
|
|
|
|
|
ret.raw_text = None |
|
|
if ret.audios is not None: |
|
|
ret.audios = default_collate(np.asarray(ret.audios, dtype=np.float32)).cuda(non_blocking=True) |
|
|
ret.encoder_length = default_collate(np.asarray(ret.encoder_length, dtype=np.int32)).cuda(non_blocking=True) |
|
|
ret.bridge_length = default_collate(np.asarray(ret.bridge_length, dtype=np.int32)).cuda(non_blocking=True) |
|
|
if ret.images is not None: |
|
|
ret.images = [torch.from_numpy(np.asarray(image, dtype=np.float32)).cuda(non_blocking=True) for image in ret.images] |
|
|
ret.patch_nums = default_collate(np.asarray(ret.patch_nums, dtype=np.int32)).cuda(non_blocking=True) |
|
|
if ret.videos is not None: |
|
|
ret.videos = [torch.from_numpy(np.asarray(video, dtype=np.float32)).cuda(non_blocking=True) for video in ret.videos] |
|
|
ret.videos_patch_nums = default_collate(np.asarray(ret.videos_patch_nums, dtype=np.int32)).cuda(non_blocking=True) |
|
|
|
|
|
return ret |
|
|
|
|
|
@torch.no_grad() |
|
|
def collect_batch_sft(self, batch_data): |
|
|
|
|
|
batch_data = [OceanProcessorOutput(**bd) for bd in batch_data] |
|
|
ret = OceanProcessorOutput() |
|
|
for i in range(len(batch_data)): |
|
|
ret = ret.concatenate(batch_data[i]) |
|
|
ret.input_ids = default_collate([np.asarray(x.input_ids, dtype=np.int64) for x in batch_data]) |
|
|
ret.labels = default_collate([np.asarray(x.labels, dtype=np.int64) for x in batch_data]) |
|
|
ret.position_ids = default_collate([np.asarray(x.position_ids, dtype=np.int64) for x in batch_data]) |
|
|
ret.seqlens = default_collate([np.asarray(x.seqlens, dtype=np.int64) for x in batch_data]) |
|
|
|
|
|
ret.raw_text = None |
|
|
if ret.audios is not None: |
|
|
ret.audios = default_collate(np.asarray(ret.audios, dtype=np.float32)) |
|
|
ret.encoder_length = default_collate(np.asarray(ret.encoder_length, dtype=np.int32)) |
|
|
ret.bridge_length = default_collate(np.asarray(ret.bridge_length, dtype=np.int32)) |
|
|
if ret.images is not None: |
|
|
|
|
|
ret.images = [torch.from_numpy(np.asarray(image, dtype=np.float32)) for image in ret.images] |
|
|
if ret.videos is not None: |
|
|
ret.videos = [torch.from_numpy(np.asarray(video, dtype=np.float32)) for video in ret.videos] |
|
|
|
|
|
|
|
|
|
|
|
ret = ret.__dict__ |
|
|
del ret['patch_nums'] |
|
|
del ret['images_size'] |
|
|
del ret['crop_size'] |
|
|
del ret['raw_text'] |
|
|
del ret['index'] |
|
|
del ret['attention_mask'] |
|
|
del ret['videos_patch_nums'] |
|
|
del ret['videos_size'] |
|
|
del ret['videos_crop_size'] |
|
|
return ret |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_img_processor(): |
|
|
from transformers import AutoConfig |
|
|
from transformers.models.clip import CLIPImageProcessor |
|
|
config = AutoConfig.from_pretrained("./", trust_remote_code=True) |
|
|
processor = OceanImageProcessor(config.visual_config) |
|
|
offical_processor = CLIPImageProcessor(size=config.visual_config.crop_size, crop_size=config.visual_config.crop_size, |
|
|
image_mean=config.visual_config.image_mean, image_std=config.visual_config.image_std, |
|
|
do_convert_rgb=True) |
|
|
img_files = ['sogou/7a2c8ffc1bc61146b32805c3390f42e2', 'wukong/77c1db1c0e4200d12b478c33ba3a412d', 'wukong/62e9a5c8eb8b0ea8858a34ba3f1a999f', 'wukong/fb9ab4d7c3fe9f54289948fd6a57fc30'] |
|
|
cos_client = CosClient() |
|
|
for img_file in img_files: |
|
|
img_bytes = cos_client(img_file) |
|
|
img_rbg = Image.open(io.BytesIO(img_bytes)) |
|
|
image, org_size = processor.image_transform(img_bytes) |
|
|
offical_image = offical_processor.preprocess([img_rbg], |
|
|
do_resize=True, do_center_crop=True, do_rescale=True, do_normalize=True, |
|
|
return_tensors='np').data['pixel_values'][0] |
|
|
print('-'*60) |
|
|
print(np.array(img_rbg).shape) |
|
|
print(image.shape) |
|
|
print(offical_image.shape) |
|
|
print(image - offical_image) |
|
|
|
|
|
def test_audio_processor(): |
|
|
from transformers.models.whisper import WhisperFeatureExtractor |
|
|
from transformers import AutoConfig |
|
|
config = AutoConfig.from_pretrained("./", trust_remote_code=True) |
|
|
offical_processor = WhisperFeatureExtractor(feature_size=128) |
|
|
processor = OceanAudioProcessor(config.audio_config) |
|
|
|
|
|
wave_files = ['/home/nfs_bc_alignment/sunhaoze/sounds/audioset_full/7ZY0U5tfKyQ.flac', '/home/nfs_bc_alignment/sunhaoze/sounds/audioset_full/Osly4Shchs4.flac'] |
|
|
for wave_file in wave_files: |
|
|
wave = processor.load_audio_waveform(wave_file, True, False) |
|
|
offical_features = offical_processor(wave[0].numpy(), do_normalize=False) |
|
|
feat = offical_features['input_features'][0] |
|
|
wave, frame_nums = processor.extract_fbank_features(wave) |
|
|
print("="*60) |
|
|
print(feat.shape) |
|
|
print(wave.shape, frame_nums) |
|
|
print('the difference between offical extractor and our implementation: {}'.format(wave_file)) |
|
|
print(wave[:, :frame_nums] - feat[:, :frame_nums]) |
|
|
print(wave) |
|
|
|
|
|
|
|
|
zeros_before = np.sum(wave == 0) |
|
|
aug = processor.data_augment(wave, frame_nums) |
|
|
zeros_after = np.sum(aug == 0) |
|
|
print(zeros_before, zeros_after) |
|
|
|
|
|
|
|
|
def test_audio_long(): |
|
|
from transformers import AutoConfig, AutoTokenizer |
|
|
config = AutoConfig.from_pretrained("./", trust_remote_code=True) |
|
|
config.audio_config.split_overlap = 1 |
|
|
tokenizer = AutoTokenizer.from_pretrained("./", model_max_length=4096) |
|
|
processor = OceanMMProcessor(tokenizer, config, True) |
|
|
examples = ["<audio_start_ocean>{\"path\": \"panda\/testdata\/podcast_demo_30s\/easy_chat_xianliaohuier_30s\/easy_chat_xianliaohuier-133.mp3\"}<audio_end_ocean>What is the level of noise from the speech?\n<trainable_start>The speech energy\n is medium.<trainable_end>", |
|
|
"what's the sound's energy? \n sound1 <audio_start_ocean>{\"path\": \"panda\/testdata\/podcast_demo_30s\/btrt_talk_heihua_30s\/btrt_talk_heihua-116.mp3\"}<audio_end_ocean> \n sound2 <audio_start_ocean>{\"path\": \"panda\/testdata\/podcast_demo_30s\/btrt_talk_heihua_30s\/btrt_talk_heihua-221.mp3\"}<audio_end_ocean>The speech energy is medium.", |
|
|
] |
|
|
ret = processor(examples) |
|
|
print(ret) |
|
|
print(torch.sum(ret.input_ids == 151659)) |
|
|
print(torch.sum(ret.input_ids == 151674)) |
|
|
|
|
|
def test_processor(): |
|
|
from transformers import AutoConfig, AutoTokenizer |
|
|
config = AutoConfig.from_pretrained("./", trust_remote_code=True) |
|
|
tokenizer = AutoTokenizer.from_pretrained("./", model_max_length=4096) |
|
|
processor = OceanMMProcessor(tokenizer, config, True, '/home/nfs_bc_alignment/sunhaoze/sounds') |
|
|
examples = ["<audio_start_ocean>{\"path\": \"vggsound\/7DH5fqj8j6Q.flac\"}<audio_end_ocean>What is the level of noise from the speech?\n<trainable_start>The speech energy\n is medium.<trainable_end>", |
|
|
"hello, ocean 你好 百川智能。", |
|
|
"what's the sound's energy? \n <audio_start_ocean>{\"path\": \"iemocap\/Ses01F_script01_3_F022.wav\"}<audio_end_ocean>The speech energy is medium.", |
|
|
"sound1: <audio_start_ocean>{\"path\": \"audioset_full\/9B53NVDNT8U.flac\"}<audio_end_ocean>\n sound2: \n<audio_start_ocean>{\"path\": \"audioset_full\/a2dgzb9GDSQ.flac\"}<audio_end_ocean>How is the speech speed related to the estimated speaker age?\n<trainable_start>The slow speech speed suggests a more deliberate and thoughtful approach often seen in mature individuals.<trainable_end>", |
|
|
"<img_start_ocean>{\"path\": \"sogou\/7351ae4f3fbe58ff0e4cc165cfabb3ed\"}<img_end_ocean>新和记潮汕牛肉火锅的牛肉丸好不好吃 用户评价口味怎么样 常州美食牛肉丸实拍图片 大众点评", |
|
|
"这两个图片有什么关系?图片1<img_start_ocean>{\"path\": \"sogou\/ac91d57ab68335913ed41aa283e76356\"}<img_end_ocean>图片2\n<img_start_ocean>{\"path\": \"sogou\/6ad5e632b74265d9ef689e45936ab1aa\"}<img_end_ocean>", |
|
|
"根据图片和语音给出描述\n图片<img_start_ocean>{\"path\": \"sogou\/32274c1ab28d11f8c490cf7ae15b36f1\"}<img_end_ocean>语音<audio_start_ocean>{\"path\": \"voxceleb2\/id06726_s2lysJWkjus_00169.m4a\"}<audio_end_ocean><trainable_start>这是一只猫<trainable_end>", |
|
|
"这些图片和音频不存在<img_start_ocean>{\"path\": \"soogou\/32274c1ab28d11f8c490cf7ae15b36f1\"}<img_end_ocean>语音<audio_start_ocean>{\"path\": \"voxceleb_1\/id06726_s2lysJWkjus_00169.m4a\"}<audio_end_ocean><trainable_start>这是一只猫<trainable_end>" |
|
|
] |
|
|
ret = processor(examples[4:-1]) |
|
|
print(ret) |
|
|
print(torch.sum(ret.input_ids == 151659)) |
|
|
print(torch.sum(ret.input_ids == 151662)) |
|
|
try: |
|
|
print(ret.bridge_length) |
|
|
print(ret.patch_nums) |
|
|
except: |
|
|
pass |
|
|
print(torch.sum(ret.attention_mask, dim=1)) |
|
|
|
|
|
|
|
|
def test_grounding(): |
|
|
from transformers import AutoConfig, AutoTokenizer |
|
|
config = AutoConfig.from_pretrained("./", trust_remote_code=True) |
|
|
tokenizer = AutoTokenizer.from_pretrained("./", model_max_length=4096) |
|
|
processor = OceanMMProcessor(tokenizer, config, True, '/home/nfs_bc_alignment/sunhaoze/sounds') |
|
|
examples = ["<img_start_ocean>{\"path\": \"grit\/663423bf2f0884c034bf75279bce9694\"}<img_end_ocean>\nWhere is \"A woman\" ? Answer: <trainable_start>The bounding box is <box_start_ocean>(0.58,0.8),(0.71,1.0)<box_end_ocean><trainable_end>", |
|
|
"hello, ocean 你好 百川智能。", |
|
|
"<img_start_ocean>{\"path\": \"grit\/0e6e3952c584cbac7235940a22514656\"}<img_end_ocean> Generate the caption with grounding: <trainable_start>Photo pour Portrait of <ref_start_ocean>young Asian muslim woman wearing hijab<ref_end_ocean><box_start_ocean>(0.09,0.01),(0.77,1.0)<box_end_ocean> shows regret gesture, hand on her forehead, forget something important, against red background - image libre de droit<trainable_end>", |
|
|
"Recognize the object in the outlined section <img_start_ocean>{\"path\": \"grit\/045823cf6f819670f27aee20af7ae0e6\"}<img_end_ocean> of the picture.<box_start_ocean>(0.07,0.2),(0.91,0.96)<box_end_ocean>\n<trainable_start>Inflatable water trampolines<trainable_end>" |
|
|
] |
|
|
ret = processor(examples) |
|
|
print(ret) |
|
|
for i, input_ids in enumerate(ret.input_ids): |
|
|
print("="*60) |
|
|
print(ret.labels[i]) |
|
|
|
|
|
def test_pack(): |
|
|
from transformers import AutoConfig, AutoTokenizer |
|
|
config = AutoConfig.from_pretrained("./", trust_remote_code=True) |
|
|
tokenizer = AutoTokenizer.from_pretrained("./", model_max_length=2048) |
|
|
processor = OceanMMProcessor(tokenizer, config, True, '/home/nfs_bc_alignment/sunhaoze/sounds') |
|
|
examples = open('/cpfs/29f69eb5e2e60f26/user/sunhaoze/pretrain-v6/sogou/part-00000').readlines()[:5] |
|
|
examples += open('/home/nfs_bc_alignment/sunhaoze/text/openaqa-as-stage2-v1/part-00000').readlines()[:5] |
|
|
random.shuffle(examples) |
|
|
batch_output = processor.pack_batch_pretrain(examples) |
|
|
for i, b in enumerate(batch_output): |
|
|
print('='*60) |
|
|
try: |
|
|
print(b.input_ids, len(b.input_ids)) |
|
|
print(b.labels, len(b.labels)) |
|
|
print(b.attention_mask, len(b.attention_mask)) |
|
|
print(b.position_ids, len(b.position_ids)) |
|
|
print(b.seqlens, len(b.seqlens)) |
|
|
print(b.audios) |
|
|
print(b.bridge_length) |
|
|
except: |
|
|
continue |
|
|
|
|
|
batch_for_model = processor.collect_batch_pretrain(batch_output) |
|
|
print(batch_for_model.input_ids.shape) |
|
|
print(batch_for_model.labels.shape) |
|
|
print(batch_for_model.audios.shape) |
|
|
print(batch_for_model["bridge_length"]) |
|
|
print(batch_for_model.images.shape) |
|
|
print(batch_for_model["patch_nums"]) |
|
|
print(batch_for_model["position_ids"]) |
|
|
print(batch_for_model["seqlens"]) |
|
|
|
|
|
def test_cos_audio(): |
|
|
cos_client = CosClient() |
|
|
audio_bytes = cos_client('panda/data/common_voice/cv-corpus-18.0-2024-06-14/zh-CN/clips/common_voice_zh-CN_19428637.mp3', 'audio-data-1317568651') |
|
|
wave, sr = torchaudio.load(audio_bytes, normalize=False) |
|
|
print(wave.shape, sr) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
fire.Fire() |