|
|
import requests |
|
|
import re, ujson, os, sys, fire, glob, random, time, json |
|
|
import numpy as np |
|
|
import io |
|
|
import torch |
|
|
from torch.utils.data import default_collate |
|
|
import torchaudio |
|
|
from typing import * |
|
|
from dataclasses import dataclass, field |
|
|
import transformers |
|
|
from transformers.modeling_outputs import ModelOutput |
|
|
from transformers.audio_utils import mel_filter_bank, spectrogram, window_function |
|
|
from functools import lru_cache |
|
|
from io import BytesIO |
|
|
from PIL import Image |
|
|
from qcloud_cos import CosConfig |
|
|
from qcloud_cos import CosS3Client |
|
|
import tos |
|
|
import oss2 |
|
|
import concurrent.futures as cf |
|
|
from transformers.image_transforms import resize, center_crop, get_resize_output_image_size |
|
|
from transformers.image_utils import PILImageResampling |
|
|
from PIL import Image, ImageOps |
|
|
from PIL import ImageFile |
|
|
torch.set_num_threads(1) |
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
import base64 |
|
|
from decord import VideoReader, cpu |
|
|
import cv2 |
|
|
import av |
|
|
import imagesize |
|
|
from ks3.connection import Connection, SubdomainCallingFormat, PathCallingFormat |
|
|
from ks3.prefix import Prefix |
|
|
from ks3.key import Key |
|
|
import ks3.exception |
|
|
from qcloud_cos import CosServiceError |
|
|
|
|
|
|
|
|
def select_best_resolution(image_size, candidate_resolutions): |
|
|
'''找到最佳的resolution 对于原图进行放缩 |
|
|
image_size 通常为ori_size e.g. (8*336, 16*336) |
|
|
candidate_resolutions 为备选分辨率 e.g. (1*336, 4*336) |
|
|
''' |
|
|
try: |
|
|
original_width, original_height = image_size |
|
|
except: |
|
|
pass |
|
|
best_fit = None |
|
|
max_effective_resolution = 0 |
|
|
min_wasted_resolution = float("inf") |
|
|
|
|
|
for width, height in candidate_resolutions: |
|
|
|
|
|
scale = min(width / original_width, height / original_height) |
|
|
|
|
|
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) |
|
|
|
|
|
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) |
|
|
|
|
|
wasted_resolution = (width * height) - effective_resolution |
|
|
|
|
|
|
|
|
if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution): |
|
|
max_effective_resolution = effective_resolution |
|
|
min_wasted_resolution = wasted_resolution |
|
|
best_fit = (width, height) |
|
|
|
|
|
return best_fit |
|
|
|
|
|
|
|
|
def read_video(image_path, max_frame_number, decode_way): |
|
|
if decode_way=='1fps': |
|
|
try: |
|
|
vr = VideoReader(image_path, ctx=cpu(0)) |
|
|
total_frame_num = len(vr) |
|
|
fps = round(vr.get_avg_fps()) |
|
|
frame_idx = [i for i in range(0, len(vr), fps)] |
|
|
frames = vr.get_batch(frame_idx).asnumpy() |
|
|
frames = [i for i in frames] |
|
|
cnt = len(frames) |
|
|
del vr |
|
|
except Exception as e: |
|
|
print(image_path) |
|
|
print('error is', e) |
|
|
return None |
|
|
elif decode_way=='key': |
|
|
try: |
|
|
with av.open(image_path) as container: |
|
|
stream = container.streams.video[0] |
|
|
stream.codec_context.skip_frame = 'NONKEY' |
|
|
frames = [] |
|
|
fps = int(stream.average_rate) |
|
|
cnt = 0 |
|
|
for frame in container.decode(stream): |
|
|
image = frame.to_image() |
|
|
frames.append(image) |
|
|
cnt += 1 |
|
|
del frame |
|
|
except Exception as e: |
|
|
print('error is', e) |
|
|
return None |
|
|
if frames is None or len(frames)==0: |
|
|
return None |
|
|
if len(frames)>max_frame_number and max_frame_number>0: |
|
|
|
|
|
indices = np.linspace(0, len(frames) - 1, max_frame_number, dtype=int) |
|
|
|
|
|
sampled_elements = [frames[idx] for idx in indices] |
|
|
frames = sampled_elements |
|
|
return frames |
|
|
|
|
|
class BaichuanImageProcessor: |
|
|
def __init__(self, config, **kwargs): |
|
|
self.config = config |
|
|
self.image_size = self.config.image_size |
|
|
print("#"*40) |
|
|
print("using anyres?", self.config.feature_mode) |
|
|
print("#"*40) |
|
|
if isinstance(self.config.grid_size, int): |
|
|
self.image_grid_pinpoints = [] |
|
|
for i in range(1,self.config.grid_size + 1): |
|
|
for j in range(1, self.config.grid_size + 1): |
|
|
self.image_grid_pinpoints.append([i * self.image_size, self.image_size * j]) |
|
|
grid_size_max = self.config.__dict__.get('grid_size_max', self.config.grid_size) |
|
|
for i in range(self.config.grid_size + 1, grid_size_max + 1): |
|
|
self.image_grid_pinpoints.append([i * self.image_size, self.image_size]) |
|
|
self.image_grid_pinpoints.append([self.image_size, i * self.image_size]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
self.image_grid_pinpoints = [[it[0]*self.image_size, it[1]*self.image_size]for it in self.config.grid_size] |
|
|
|
|
|
|
|
|
print("image_grid_pinpoints",self.image_grid_pinpoints) |
|
|
self.kernel = self.config.pooling_kernel |
|
|
@property |
|
|
def feature_size(self): |
|
|
|
|
|
return ((self.config.crop_size // self.config.patch_size) //self.kernel) ** 2 |
|
|
|
|
|
def fill_to_square(self, img, fill_color=(0, 0, 0)): |
|
|
|
|
|
|
|
|
width, height = img.size |
|
|
|
|
|
|
|
|
new_size = max(width, height) |
|
|
|
|
|
|
|
|
new_img = Image.new('RGB', (new_size, new_size), fill_color) |
|
|
|
|
|
|
|
|
new_img.paste(img) |
|
|
return new_img |
|
|
|
|
|
def image_transform(self, strseq, return_mm_data = True): |
|
|
image = None |
|
|
if isinstance(strseq, str): |
|
|
if return_mm_data: |
|
|
image = Image.open(strseq).convert("RGB") |
|
|
else: |
|
|
image = Image.open(BytesIO(strseq)).convert("RGB") |
|
|
|
|
|
if self.config.feature_mode == "anyres": |
|
|
img_list = [] |
|
|
img_size_list = [] |
|
|
if return_mm_data: |
|
|
ori_size = list(image.size) |
|
|
new_img = self.fill_to_square(image) |
|
|
|
|
|
image_patches = [np.array(new_img.resize((self.image_size, self.image_size)))] |
|
|
else: |
|
|
w,h = imagesize.get(strseq) |
|
|
ori_size = (w,h) |
|
|
image_patches = [None] |
|
|
|
|
|
img_size_list.append([self.image_size, self.image_size]) |
|
|
|
|
|
(best_width, best_height) = select_best_resolution(ori_size, self.image_grid_pinpoints) |
|
|
|
|
|
|
|
|
if return_mm_data: |
|
|
image_padded = ImageOps.pad(image, (best_width, best_height)) |
|
|
for i in range(0, best_height, self.config.crop_size): |
|
|
|
|
|
|
|
|
for j in range(0, best_width, self.config.crop_size): |
|
|
|
|
|
|
|
|
if return_mm_data: |
|
|
|
|
|
image_patch = np.array(image_padded.crop((j, i, j + self.config.crop_size, i + self.config.crop_size))) |
|
|
image_patches.append(image_patch) |
|
|
else: |
|
|
image_patches.append(None) |
|
|
img_size_list.append([self.config.crop_size, self.config.crop_size]) |
|
|
if return_mm_data: |
|
|
for image_patch in image_patches: |
|
|
|
|
|
image_patch = image_patch.transpose(2, 0, 1) |
|
|
|
|
|
image_patch = (image_patch / 255.0 - np.array(self.config.image_mean)[:, np.newaxis, np.newaxis]) / np.array(self.config.image_std)[:,np.newaxis,np.newaxis] |
|
|
img_list.append(image_patch) |
|
|
else: |
|
|
img_list += image_patches |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return np.stack(img_list, axis=0), ori_size, img_size_list |
|
|
else: |
|
|
image = np.array(image.convert("RGB")) |
|
|
image_org_size = image.shape[:2] |
|
|
output_size = get_resize_output_image_size(image, self.config.crop_size, False) |
|
|
|
|
|
image = resize(image, output_size, PILImageResampling.BICUBIC) |
|
|
image = center_crop(image, (self.config.crop_size, self.config.crop_size), return_numpy=True) |
|
|
img = image.transpose(2, 0, 1) |
|
|
image = (img / 255.0 - np.array(self.config.image_mean)[:, np.newaxis, np.newaxis]) / np.array(self.config.image_std)[:,np.newaxis,np.newaxis] |
|
|
return image, image_org_size, [image.shape[1:]] |
|
|
|
|
|
@classmethod |
|
|
def inference_output_length(cls, config, image_size): |
|
|
|
|
|
|
|
|
return ((config.crop_size // config.patch_size) //config.pooling_kernel) ** 2 |
|
|
|
|
|
|
|
|
def get_anyres_image_grid_shape(self, image_size): |
|
|
""" |
|
|
Calculate the shape of the image patch grid after the preprocessing for images of any resolution. |
|
|
|
|
|
Returns: |
|
|
tuple: The shape of the image patch grid in the format (width, height). |
|
|
""" |
|
|
|
|
|
height, width = select_best_resolution(image_size, self.image_grid_pinpoints) |
|
|
return height // self.config.image_size, width // self.config.image_size |
|
|
|
|
|
class BaichuanVideoProcessor: |
|
|
def __init__(self, config, **kwargs): |
|
|
self.config = config |
|
|
self.image_size = self.config.image_size |
|
|
print("#"*40) |
|
|
print("using anyres?", self.config.feature_mode) |
|
|
print("#"*40) |
|
|
if isinstance(self.config.grid_size, int): |
|
|
self.image_grid_pinpoints = [] |
|
|
for i in range(1,self.config.grid_size + 1): |
|
|
for j in range(1, self.config.grid_size + 1): |
|
|
self.image_grid_pinpoints.append([i * self.image_size, self.image_size * j]) |
|
|
grid_size_max = self.config.__dict__.get('grid_size_max', self.config.grid_size) |
|
|
for i in range(self.config.grid_size + 1, grid_size_max + 1): |
|
|
self.image_grid_pinpoints.append([i * self.image_size, self.image_size]) |
|
|
self.image_grid_pinpoints.append([self.image_size, i * self.image_size]) |
|
|
else: |
|
|
self.image_grid_pinpoints = [[it[0]*self.image_size, it[1]*self.image_size]for it in self.config.grid_size] |
|
|
print("image_grid_pinpoints",self.image_grid_pinpoints) |
|
|
self.kernel = self.config.pooling_kernel |
|
|
@property |
|
|
def feature_size(self): |
|
|
return ((self.config.crop_size // self.config.patch_size + 1) // self.kernel) ** 2 |
|
|
|
|
|
def fill_to_square(self, img, fill_color=(0, 0, 0)): |
|
|
|
|
|
|
|
|
width, height = img.size |
|
|
|
|
|
|
|
|
new_size = max(width, height) |
|
|
|
|
|
|
|
|
new_img = Image.new('RGB', (new_size, new_size), fill_color) |
|
|
|
|
|
|
|
|
new_img.paste(img) |
|
|
return new_img |
|
|
|
|
|
def image_transform(self, strseq, return_mm_data = True): |
|
|
image = None |
|
|
if isinstance(strseq, str): |
|
|
if return_mm_data: |
|
|
image = Image.open(strseq).convert("RGB") |
|
|
|
|
|
else: |
|
|
image = Image.open(BytesIO(strseq)).convert("RGB") |
|
|
|
|
|
if self.config.feature_mode == "anyres": |
|
|
img_list = [] |
|
|
img_size_list = [] |
|
|
if return_mm_data: |
|
|
ori_size = list(image.size) |
|
|
new_img = self.fill_to_square(image) |
|
|
|
|
|
image_patches = [np.array(new_img.resize((self.image_size, self.image_size)))] |
|
|
else: |
|
|
w,h = imagesize.get(strseq) |
|
|
ori_size = (w,h) |
|
|
image_patches = [None] |
|
|
|
|
|
img_size_list.append([self.image_size, self.image_size]) |
|
|
|
|
|
(best_width, best_height) = select_best_resolution(ori_size, self.image_grid_pinpoints) |
|
|
|
|
|
|
|
|
if return_mm_data: |
|
|
image_padded = ImageOps.pad(image, (best_width, best_height)) |
|
|
for i in range(0, best_height, self.config.crop_size): |
|
|
|
|
|
|
|
|
for j in range(0, best_width, self.config.crop_size): |
|
|
|
|
|
|
|
|
if return_mm_data: |
|
|
|
|
|
image_patch = np.array(image_padded.crop((j, i, j + self.config.crop_size, i + self.config.crop_size))) |
|
|
image_patches.append(image_patch) |
|
|
else: |
|
|
image_patches.append(None) |
|
|
img_size_list.append([self.config.crop_size, self.config.crop_size]) |
|
|
if return_mm_data: |
|
|
for image_patch in image_patches: |
|
|
|
|
|
image_patch = image_patch.transpose(2, 0, 1) |
|
|
|
|
|
image_patch = (image_patch / 255.0 - np.array(self.config.image_mean)[:, np.newaxis, np.newaxis]) / np.array(self.config.image_std)[:,np.newaxis,np.newaxis] |
|
|
img_list.append(image_patch) |
|
|
else: |
|
|
img_list += image_patches |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return np.stack(img_list, axis=0), ori_size, img_size_list |
|
|
else: |
|
|
image = np.array(image.convert("RGB")) |
|
|
image_org_size = image.shape[:2] |
|
|
|
|
|
|
|
|
|
|
|
output_size = get_resize_output_image_size(image, self.config.crop_size, False) |
|
|
|
|
|
|
|
|
|
|
|
image = resize(image, output_size, PILImageResampling.BICUBIC) |
|
|
|
|
|
image = center_crop(image, (self.config.crop_size, self.config.crop_size), return_numpy=True) |
|
|
img = image.transpose(2, 0, 1) |
|
|
|
|
|
image = (img / 255.0 - np.array(self.config.image_mean)[:, np.newaxis, np.newaxis]) / np.array(self.config.image_std)[:,np.newaxis,np.newaxis] |
|
|
return image, image_org_size, [image.shape[1:]] |
|
|
|
|
|
@classmethod |
|
|
def inference_output_length(cls, config, image_size): |
|
|
|
|
|
return ((config.crop_size // config.patch_size) // config.pooling_kernel) ** 2 |
|
|
|
|
|
|
|
|
def get_anyres_image_grid_shape(self, image_size): |
|
|
""" |
|
|
Calculate the shape of the image patch grid after the preprocessing for images of any resolution. |
|
|
|
|
|
Returns: |
|
|
tuple: The shape of the image patch grid in the format (width, height). |
|
|
""" |
|
|
|
|
|
height, width = select_best_resolution(image_size, self.image_grid_pinpoints) |
|
|
return height // self.config.image_size, width // self.config.image_size |
|
|
|
|
|
class BaichuanAudioProcessor: |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config, |
|
|
**kwargs |
|
|
): |
|
|
|
|
|
assert(len(torchaudio.list_audio_backends()) > 0) |
|
|
self.config = config |
|
|
self.mel_filters = mel_filter_bank( |
|
|
num_frequency_bins=1 + self.config.n_fft // 2, |
|
|
num_mel_filters=self.config.num_mel_bins, |
|
|
min_frequency=0.0, |
|
|
max_frequency=self.config.sampling_rate / 2.0, |
|
|
sampling_rate=self.config.sampling_rate, |
|
|
norm="slaney", |
|
|
mel_scale="slaney", |
|
|
) |
|
|
self.window = torch.hann_window(self.config.n_fft) |
|
|
|
|
|
@staticmethod |
|
|
def zero_mean_unit_var_norm(x): |
|
|
return (x - x.mean()) / torch.sqrt(x.var() + 1e-8) |
|
|
|
|
|
def load_audio_waveform(self, uri, return_tensors=True, do_normalize=False): |
|
|
metadata = torchaudio.info(uri) |
|
|
assert(metadata.num_channels <= 2), "acoustic file with {} channels.".format(metadata.num_channels) |
|
|
waveform_tensor, _ = torchaudio.load(uri, normalize=True) |
|
|
if self.config.sampling_rate != metadata.sample_rate: |
|
|
waveform_tensor = torchaudio.functional.resample(waveform_tensor, metadata.sample_rate, self.config.sampling_rate, lowpass_filter_width=128) |
|
|
|
|
|
|
|
|
if metadata.num_channels > 1: |
|
|
waveform_tensor = torch.mean(waveform_tensor, dim=0, keepdim=True) |
|
|
|
|
|
|
|
|
if do_normalize: |
|
|
waveform_tensor = self.zero_mean_unit_var_norm(waveform_tensor) |
|
|
|
|
|
if return_tensors: |
|
|
return waveform_tensor |
|
|
else: |
|
|
return waveform_tensor.numpy() |
|
|
|
|
|
def split_with_overlap(self, waveform): |
|
|
channels, wave_samples = waveform.shape |
|
|
max_audio_samples = self.config.max_audio_seconds * self.config.sampling_rate |
|
|
if wave_samples <= max_audio_samples or self.config.split_overlap < 0: |
|
|
return [waveform] |
|
|
|
|
|
split_waveform, start = [], 0 |
|
|
while start < wave_samples: |
|
|
if start > int(self.config.sampling_rate * self.config.split_overlap): |
|
|
start -= int(self.config.sampling_rate * self.config.split_overlap) |
|
|
end = min(start + max_audio_samples, wave_samples) |
|
|
split_waveform.append(waveform[:, start:end]) |
|
|
start = end |
|
|
return split_waveform |
|
|
|
|
|
@classmethod |
|
|
def inference_output_length(cls, config, input_length): |
|
|
|
|
|
kernel_size = config.kernel_size |
|
|
stride_size = config.stride_size |
|
|
avg_pooler = config.avg_pooler |
|
|
encoder_length = (input_length + 2 * (kernel_size // 2) - kernel_size) // 1 + 1 |
|
|
encoder_length = (encoder_length + 2 * (kernel_size // 2) - kernel_size) // stride_size + 1 |
|
|
if avg_pooler > 1: |
|
|
bridge_length = encoder_length // avg_pooler |
|
|
return encoder_length, bridge_length |
|
|
|
|
|
def extract_fbank_features(self, waveform): |
|
|
|
|
|
channels, wave_samples = waveform.shape |
|
|
assert(wave_samples >= self.config.n_fft) |
|
|
valid_frame_nums = min(self.config.max_audio_seconds * self.config.sampling_rate // self.config.hop_length, wave_samples // self.config.hop_length + 1) |
|
|
if wave_samples < self.config.max_audio_seconds * self.config.sampling_rate: |
|
|
waveform = torch.nn.functional.pad(waveform, (0, self.config.max_audio_seconds * self.config.sampling_rate - wave_samples), "constant", 0) |
|
|
else: |
|
|
waveform = waveform[:, :self.config.max_audio_seconds * self.config.sampling_rate] |
|
|
|
|
|
|
|
|
stft = torch.stft(waveform, self.config.n_fft, self.config.hop_length, window=self.window, return_complex=True) |
|
|
magnitudes = stft[..., :-1].abs() ** 2 |
|
|
|
|
|
mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32) |
|
|
mel_spec = mel_filters.T @ magnitudes |
|
|
log_spec = torch.clamp(mel_spec, min=1e-10).log10() |
|
|
if waveform.dim() == 2: |
|
|
max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0] |
|
|
log_spec = torch.maximum(log_spec, max_val - 8.0) |
|
|
else: |
|
|
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) |
|
|
log_spec = (log_spec + 4.0) / 4.0 |
|
|
|
|
|
log_spec = log_spec[0].numpy() |
|
|
log_spec[:, valid_frame_nums:] = 0.0 |
|
|
|
|
|
return log_spec, valid_frame_nums |
|
|
|
|
|
def data_augment(self, feature: np.array, input_length, training=True): |
|
|
|
|
|
|
|
|
def mask_start_indices(input_length, mask_length, min_masks, mask_prob): |
|
|
|
|
|
num_masked_span = int(mask_prob * input_length / mask_length + random.random()) |
|
|
num_masked_span = max(num_masked_span, min_masks) |
|
|
start_indices = list(range(input_length - mask_length)) |
|
|
random.shuffle(start_indices) |
|
|
start_indices = start_indices[:num_masked_span] |
|
|
return start_indices |
|
|
|
|
|
if not training or (self.config.mask_time_prob <= 0 and self.config.mask_feature_prob <= 0): |
|
|
return feature |
|
|
if input_length < self.config.mask_time_length * self.config.mask_time_min_masks + 1: |
|
|
return feature |
|
|
if self.config.num_mel_bins < self.config.mask_feature_length * self.config.mask_feature_min_masks + 1: |
|
|
return feature |
|
|
|
|
|
if self.config.mask_time_prob > 0: |
|
|
start_indices = mask_start_indices(input_length, self.config.mask_time_length, self.config.mask_time_min_masks, self.config.mask_time_prob) |
|
|
for start_idx in start_indices: |
|
|
feature[:, start_idx: start_idx + self.config.mask_time_length] = 0.0 |
|
|
if self.config.mask_feature_prob > 0: |
|
|
start_indices = mask_start_indices(self.config.num_mel_bins, self.config.mask_feature_length, self.config.mask_feature_min_masks, self.config.mask_feature_prob) |
|
|
for start_idx in start_indices: |
|
|
feature[start_idx: start_idx + self.config.mask_feature_length, :] = 0.0 |
|
|
|
|
|
return feature |
|
|
|
|
|
class CosClient(): |
|
|
def __init__(self, bucket_name='crawl-pic-bj-1317568651', |
|
|
image_cache_path=None, |
|
|
max_retries=5): |
|
|
self.config = CosConfig( |
|
|
Endpoint="cos.ap-beijing.myqcloud.com", |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SecretId='AKID2uZ4PWSBlBOOAxc4LQNFX0AD2d1B7pJb', |
|
|
SecretKey='A7ax85oJbfV8by75rPjpkSzSilaBcxAD', |
|
|
|
|
|
Token=None, Scheme='https', Timeout=30, PoolConnections=64, |
|
|
PoolMaxSize=64) |
|
|
self.client = CosS3Client(self.config) |
|
|
self.image_cache_path = image_cache_path |
|
|
self.max_retries = max_retries |
|
|
self.bucket_name = bucket_name |
|
|
|
|
|
def __call__(self, relative_path, bucket_name=None): |
|
|
if bucket_name is None or len(bucket_name) <= 0: |
|
|
bucket_name = self.bucket_name |
|
|
for _ in range(self.max_retries): |
|
|
try: |
|
|
response = self.client.get_object(Bucket=bucket_name, Key=relative_path) |
|
|
fp = response['Body'].get_raw_stream() |
|
|
multimodal_bytes = fp.read() |
|
|
return multimodal_bytes |
|
|
except Exception as e: |
|
|
time.sleep(0.01) |
|
|
continue |
|
|
return None |
|
|
|
|
|
def save(self, file_byte, relative_path, bucket_name=None): |
|
|
if bucket_name == None: |
|
|
bucket_name = self.bucket_name |
|
|
response = None |
|
|
for _ in range(self.max_retries): |
|
|
try: |
|
|
response = self.client.put_object(Bucket=bucket_name, Key=relative_path, Body=file_byte) |
|
|
break |
|
|
except Exception as e: |
|
|
time.sleep(0.01) |
|
|
continue |
|
|
return response |
|
|
|
|
|
def exists(self, relative_path, bucket_name=None): |
|
|
if bucket_name == None: |
|
|
bucket_name = self.bucket_name |
|
|
|
|
|
return self.client.object_exists(bucket_name, Key=relative_path) |
|
|
|
|
|
def listdir(self, folder, bucket_name=None, MaxKeys=1000): |
|
|
if bucket_name == None: |
|
|
bucket_name = self.bucket_name |
|
|
if folder[-1] != '/': |
|
|
folder += '/' |
|
|
if folder[0] == '/': |
|
|
folder = folder[1:] |
|
|
contents = self.client.list_objects(bucket_name, Prefix=folder, MaxKeys=MaxKeys)['Contents'] |
|
|
return [c['Key'] for c in contents] |
|
|
|
|
|
class TosClient(object): |
|
|
def __init__(self, bucket_name="audio-dataset", max_retries=5): |
|
|
ak = "AKLTYTM3MWY5MTFhNDgyNDk4YjhmYTE0ZTE3YTk5ZmU1MjU" |
|
|
sk = "TVRRM1pUZGtaVEJqWTJJd05HSTNPR0ppWVdKa1lqYzVORFUwTlRobU1UVQ==" |
|
|
|
|
|
|
|
|
endpoint = "tos-cn-beijing.ivolces.com" |
|
|
region = "cn-beijing" |
|
|
self.bucket_name = bucket_name |
|
|
self.max_retries = max_retries |
|
|
self.client = tos.TosClientV2(ak, sk, endpoint, region) |
|
|
|
|
|
def __call__(self, path, bucket_name=None): |
|
|
if bucket_name is None: |
|
|
bucket_name = self.bucket_name |
|
|
for _ in range(self.max_retries): |
|
|
try: |
|
|
object_stream = self.client.get_object(bucket_name, path) |
|
|
return object_stream.read() |
|
|
except Exception as e: |
|
|
time.sleep(0.01) |
|
|
continue |
|
|
return None |
|
|
|
|
|
def save(self, file_byte, relative_path, bucket_name=None): |
|
|
if bucket_name == None: |
|
|
bucket_name = self.bucket_name |
|
|
response = False |
|
|
for _ in range(self.max_retries): |
|
|
try: |
|
|
response = self.client.put_object(bucket_name, relative_path, content=file_byte) |
|
|
break |
|
|
except Exception as e: |
|
|
time.sleep(0.01) |
|
|
continue |
|
|
return response |
|
|
|
|
|
def exists(self, relative_path, bucket_name=None): |
|
|
if bucket_name == None: |
|
|
bucket_name = self.bucket_name |
|
|
exist = False |
|
|
for _ in range(self.max_retries): |
|
|
try: |
|
|
result = self.client.head_object(bucket_name, relative_path) |
|
|
exist = True |
|
|
except tos.exceptions.TosClientError as e: |
|
|
|
|
|
print('fail with client error, message:{}, cause: {}'.format(e.message, e.cause)) |
|
|
continue |
|
|
except tos.exceptions.TosServerError as e: |
|
|
if e.status_code == 404: |
|
|
print(f"{relative_path}: Object not found.") |
|
|
else: |
|
|
|
|
|
print('fail with server error, code: {}'.format(e.code)) |
|
|
|
|
|
print('error with request id: {}'.format(e.request_id)) |
|
|
print('error with message: {}'.format(e.message)) |
|
|
print('error with http code: {}'.format(e.status_code)) |
|
|
print('error with ec: {}'.format(e.ec)) |
|
|
print('error with request url: {}'.format(e.request_url)) |
|
|
continue |
|
|
except Exception as e: |
|
|
print('fail with unknown error: {}'.format(e)) |
|
|
continue |
|
|
return exist |
|
|
|
|
|
def listdir(self, folder, bucket_name=None, MaxKeys=1000): |
|
|
if bucket_name == None: |
|
|
bucket_name = self.bucket_name |
|
|
if folder[-1] != '/': |
|
|
folder += '/' |
|
|
if folder[0] == '/': |
|
|
folder = folder[1:] |
|
|
contents = self.client.list_objects_type2(bucket_name, prefix=folder, max_keys=MaxKeys).contents |
|
|
return [c.key for c in contents] |
|
|
|
|
|
class OssClient(object): |
|
|
def __init__(self): |
|
|
ak = "LTAI5tBtFsjti987FPHzr1gL" |
|
|
sk = "K8djwCTIJfiwo1J8PQS13unDNkXksg" |
|
|
self.endpoint = "oss-cn-wulanchabu-internal.aliyuncs.com" |
|
|
self.bucket_name = "bc-audio-data" |
|
|
self.auth = oss2.Auth(ak, sk) |
|
|
self.client = oss2.Bucket(self.auth, self.endpoint, self.bucket_name) |
|
|
|
|
|
def __call__(self, path, bucket_name=None): |
|
|
if bucket_name is None: |
|
|
bucket_name = self.bucket_name |
|
|
for _ in range(2): |
|
|
try: |
|
|
object_stream = self.client.get_object(path) |
|
|
return object_stream.read() |
|
|
except Exception as e: |
|
|
time.sleep(0.01) |
|
|
continue |
|
|
return None |
|
|
|
|
|
class Ks3Client(object): |
|
|
def __init__(self, ak = 'AKLTfVq6ZQojRfyX5tFMbrRX', |
|
|
sk = 'OI4elukIXEO2vnkq58jNKyJKZAG2HcOwXJ81HWa6', |
|
|
host='ks3-cn-beijing-internal.ksyuncs.com', |
|
|
bucket_name = 'crawl-mime-pic'): |
|
|
self.bucket_name = bucket_name |
|
|
c = Connection(ak, sk, host, is_secure=False, domain_mode=False) |
|
|
self.client = c.get_bucket(self.bucket_name) |
|
|
|
|
|
def __call__(self, object_name): |
|
|
try: |
|
|
ks3_res = self.client.get_key(object_name) |
|
|
if ks3_res is not None: |
|
|
return ks3_res.get_contents_as_string() |
|
|
except ks3.exception.S3ResponseError as e: |
|
|
if e.error_code != 'NoSuchKey': |
|
|
print(e) |
|
|
|
|
|
@dataclass |
|
|
class BaichuanProcessorOutput(ModelOutput): |
|
|
input_ids: Optional["List|torch.Tensor"] = None |
|
|
labels: Optional["List|torch.Tensor"] = None |
|
|
attention_mask: Optional["List|torch.Tensor"] = None |
|
|
position_ids: Optional["List|torch.Tensor"] = None |
|
|
seqlens: Optional["List|torch.Tensor"] = None |
|
|
|
|
|
audios: Optional["List|torch.Tensor"] = None |
|
|
encoder_length: Optional["List|torch.Tensor"] = None |
|
|
bridge_length: Optional["List|torch.Tensor"] = None |
|
|
|
|
|
images: Optional["List|torch.Tensor"] = None |
|
|
patch_nums: Optional["List|torch.Tensor"] = None |
|
|
images_size: Optional["List|torch.Tensor"] = None |
|
|
crop_size: Optional["List|torch.Tensor"] = None |
|
|
images_grid: Optional["List|torch.Tensor"] = None |
|
|
|
|
|
videos: Optional["List|torch.Tensor"] = None |
|
|
videos_patch_nums: Optional["List|torch.Tensor"] = None |
|
|
videos_size: Optional["List|torch.Tensor"] = None |
|
|
videos_crop_size: Optional["List|torch.Tensor"] = None |
|
|
videos_grid: Optional["List|torch.Tensor"] = None |
|
|
|
|
|
raw_text: Optional[str] = None |
|
|
index: Optional[int] = None |
|
|
|
|
|
def concatenate(self, other): |
|
|
def concat_one(a, b): |
|
|
if a is None and b is None: |
|
|
return None |
|
|
elif a is None and b is not None: |
|
|
return b |
|
|
elif a is not None and b is None: |
|
|
return a |
|
|
else: |
|
|
return a + b |
|
|
return BaichuanProcessorOutput( |
|
|
input_ids=concat_one(self.input_ids, other.input_ids), |
|
|
labels=concat_one(self.labels, other.labels), |
|
|
audios=concat_one(self.audios, other.audios), |
|
|
encoder_length=concat_one(self.encoder_length, other.encoder_length), |
|
|
bridge_length=concat_one(self.bridge_length, other.bridge_length), |
|
|
images=concat_one(self.images, other.images), |
|
|
images_grid=concat_one(self.images_grid, other.images_grid), |
|
|
patch_nums=concat_one(self.patch_nums, other.patch_nums), |
|
|
|
|
|
videos=concat_one(self.videos, other.videos), |
|
|
videos_grid=concat_one(self.videos_grid, other.videos_grid), |
|
|
videos_patch_nums=concat_one(self.videos_patch_nums, other.videos_patch_nums), |
|
|
|
|
|
position_ids=concat_one(self.position_ids, other.position_ids), |
|
|
seqlens=concat_one(self.seqlens, other.seqlens), |
|
|
images_size=concat_one(self.images_size, other.images_size), |
|
|
videos_size=concat_one(self.videos_size, other.videos_size) |
|
|
) |
|
|
|
|
|
class BaichuanMMProcessor(object): |
|
|
def __init__(self, |
|
|
tokenizer: transformers.PreTrainedTokenizer, |
|
|
config, |
|
|
training, |
|
|
relative_path=None, |
|
|
default_client='cos', |
|
|
parallel=None, |
|
|
**kwargs, |
|
|
): |
|
|
self.tokenizer = tokenizer |
|
|
self.config = config |
|
|
self.audio_processor = None |
|
|
if hasattr(config, "audio_config"): |
|
|
self.audio_processor = BaichuanAudioProcessor(config.audio_config) |
|
|
self.visual_processor = None |
|
|
if hasattr(config, "visual_config"): |
|
|
self.visual_processor = BaichuanImageProcessor(config.visual_config) |
|
|
if hasattr(config, "video_config"): |
|
|
self.video_processor = BaichuanVideoProcessor(config.video_config) |
|
|
self.training = training |
|
|
|
|
|
assert default_client in ['cos', 'tos', 'oss'] |
|
|
self.default_client = default_client |
|
|
self.relative_path = relative_path |
|
|
self.parallel = parallel |
|
|
self.cos_client = CosClient() |
|
|
self.tos_client = TosClient() |
|
|
self.oss_client = OssClient() |
|
|
self.ks3_client = Ks3Client() |
|
|
|
|
|
self.audio_start_tag = None |
|
|
self.audio_end_tag = None |
|
|
self.audio_pad_tag = None |
|
|
self.audio_delim_tag = None |
|
|
if hasattr(self.config, "audio_config"): |
|
|
self.audio_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_start_token_id) |
|
|
self.audio_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_end_token_id) |
|
|
self.audio_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_pad_token_id) |
|
|
self.audio_delim_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_delim_token_id) |
|
|
|
|
|
self.image_start_tag = None |
|
|
self.image_end_tag = None |
|
|
self.image_pad_tag = None |
|
|
self.video_start_tag = None |
|
|
self.video_end_tag = None |
|
|
|
|
|
self.videoframe_start_tag = '<videoframe_start_baichuan>' |
|
|
self.videoframe_end_tag = '<videoframe_end_baichuan>' |
|
|
if hasattr(self.config, "visual_config"): |
|
|
|
|
|
self.image_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_start_token_id) |
|
|
|
|
|
self.image_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_end_token_id) |
|
|
|
|
|
self.image_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_pad_token_id) |
|
|
self.image_line_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_line_token_id) |
|
|
self.image_delimiter_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_delimiter_token_id) |
|
|
if hasattr(self.config, "video_config"): |
|
|
self.video_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_start_token_id) |
|
|
self.video_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_end_token_id) |
|
|
self.image_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_start_token_id) |
|
|
self.image_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_end_token_id) |
|
|
self.image_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_pad_token_id) |
|
|
self.video_place_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_place_token_id) |
|
|
|
|
|
|
|
|
@lru_cache(maxsize=1024) |
|
|
def _get_audio(self, audio_info, return_mm_data = True): |
|
|
try: |
|
|
audio_info = ujson.loads(audio_info) |
|
|
audio_uri = None |
|
|
if 'path' in audio_info.keys(): |
|
|
audio_uri = None |
|
|
if self.relative_path is not None: |
|
|
audio_uri = os.path.join(self.relative_path, audio_info['path'].lstrip('/')) |
|
|
|
|
|
if audio_info.get('server', self.default_client) == 'tos': |
|
|
pattern = r"(apple_podcasts/.+\.flac)" |
|
|
match = re.search(pattern, audio_info['path']) |
|
|
audio_file_path = match.group(1) |
|
|
audio_uri = os.path.join('/data_train/mllm/dingbowen/audio-text-latest2', audio_file_path) |
|
|
|
|
|
if not os.path.exists(audio_uri): |
|
|
audio_uri = None |
|
|
if audio_uri is None: |
|
|
if audio_info.get('server', self.default_client) == 'tos': |
|
|
audio_uri = self.tos_client(audio_info['path'], 'audio-dataset') |
|
|
elif audio_info.get('server', self.default_client) == 'oss': |
|
|
audio_uri = self.oss_client(audio_info['path'], 'bc-audio-data') |
|
|
else: |
|
|
audio_uri = self.cos_client(audio_info['path'], 'audio-data-tmp-1317568651') |
|
|
|
|
|
waveforms = self.audio_processor.load_audio_waveform(audio_uri, True) |
|
|
waveforms = self.audio_processor.split_with_overlap(waveforms) |
|
|
ret = BaichuanProcessorOutput() |
|
|
for i, waveform in enumerate(waveforms): |
|
|
audio, input_length = self.audio_processor.extract_fbank_features(waveform) |
|
|
audio = self.audio_processor.data_augment(audio, input_length, self.training) |
|
|
encoder_length, bridge_length = self.audio_processor.inference_output_length(self.config.audio_config, input_length) |
|
|
if bridge_length <= 0: |
|
|
continue |
|
|
|
|
|
current_ret = BaichuanProcessorOutput( |
|
|
audios=[audio], |
|
|
encoder_length=[encoder_length], |
|
|
bridge_length=[bridge_length], |
|
|
|
|
|
|
|
|
) |
|
|
if ret.audios is None: |
|
|
ret = current_ret |
|
|
else: |
|
|
ret = ret.concatenate(current_ret) |
|
|
return ret |
|
|
else: |
|
|
raise ValueError("can not find path or local in audio_info") |
|
|
except Exception as e: |
|
|
print("**** get audio error: {}, info: {} *****".format(str(e), str(audio_info))) |
|
|
return BaichuanProcessorOutput() |
|
|
|
|
|
@lru_cache(maxsize=1024) |
|
|
def _get_image(self, image_info, return_mm_data = True): |
|
|
try: |
|
|
try: |
|
|
image_info = ujson.loads(image_info) |
|
|
except: |
|
|
image_info = re.sub(r"(?<!\\)'", '"', image_info) |
|
|
image_info = ujson.loads(image_info) |
|
|
if 'base64' in image_info.keys(): |
|
|
image_data = base64.b64decode(image_info['base64']) |
|
|
image_feat, org_size, image_list = self.visual_processor.image_transform(image_data) |
|
|
patch_nums = self.visual_processor.inference_output_length(self.config.visual_config, None)*len(image_list) |
|
|
elif 'local' in image_info.keys(): |
|
|
image_feat, org_size, image_list = self.visual_processor.image_transform(image_info['local'],return_mm_data = return_mm_data) |
|
|
patch_nums = self.visual_processor.inference_output_length(self.config.visual_config, None)*len(image_list) |
|
|
elif 'path' in image_info.keys(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "tos_bucket" in image_info.keys(): |
|
|
tos_bucket = image_info['tos_bucket'] |
|
|
image_bytes = self.tos_client(image_info['path'], tos_bucket) |
|
|
elif "ks3_bucket" in image_info.keys(): |
|
|
image_bytes = self.ks3_client(image_info['path']) |
|
|
else: |
|
|
cos_bucket = None |
|
|
if "cos_bucket" in image_info.keys(): |
|
|
cos_bucket = image_info['cos_bucket'] |
|
|
if "bucket_name" in image_info.keys(): |
|
|
cos_bucket = image_info['bucket_name'] |
|
|
if cos_bucket == 'crawl-mime-pic-1317568651': |
|
|
image_bytes = self.ks3_client(image_info['path']) |
|
|
if image_bytes is None: |
|
|
raise ValueError("#### image in crawl-mime-pic-1317568651 {} not found".format(image_info['path'])) |
|
|
elif cos_bucket == 'bc-train-archive-gz-1317568651': |
|
|
cos_bucket = 'bc-train-archive-gz-b-1317568651' |
|
|
image_bytes = self.cos_client(image_info['path'], cos_bucket) |
|
|
elif cos_bucket == 'crawl-pic-1317568651': |
|
|
cos_bucket = 'crawl-pic-bj-1317568651' |
|
|
image_bytes = self.cos_client(image_info['path'], cos_bucket) |
|
|
else: |
|
|
image_bytes = self.cos_client(image_info['path'], cos_bucket) |
|
|
|
|
|
|
|
|
image_feat, org_size, image_list = self.visual_processor.image_transform(image_bytes) |
|
|
patch_nums = self.visual_processor.inference_output_length(self.config.visual_config, None)*len(image_list) |
|
|
else: |
|
|
raise ValueError("can not find any path in image_info") |
|
|
|
|
|
if org_size[0] * org_size[1] > 16**2: |
|
|
return BaichuanProcessorOutput( |
|
|
images=[image_feat], |
|
|
patch_nums=[patch_nums], |
|
|
crop_size=[image_list], |
|
|
images_size= [list(org_size)], |
|
|
images_grid=[list(self.visual_processor.get_anyres_image_grid_shape(org_size))] |
|
|
) |
|
|
else: |
|
|
print("**** image too small: {}, info: {} *****".format(str(org_size), str(image_info))) |
|
|
return BaichuanProcessorOutput() |
|
|
|
|
|
except Exception as e: |
|
|
print("**** get image error: {}, info: {} *****".format(str(e), str(image_info))) |
|
|
|
|
|
return BaichuanProcessorOutput() |
|
|
|
|
|
@lru_cache(maxsize=1024) |
|
|
def _get_video_frame(self, video_frame_info, return_mm_data = True): |
|
|
try: |
|
|
pattern = r'\{.*?\}' |
|
|
matches = re.findall(pattern, video_frame_info) |
|
|
ret = BaichuanProcessorOutput() |
|
|
|
|
|
for match in matches: |
|
|
video_frame_info = ujson.loads(match) |
|
|
|
|
|
if 'local' in video_frame_info.keys(): |
|
|
image_feat, org_size, image_list = self.video_processor.image_transform(video_frame_info['local'],return_mm_data = return_mm_data) |
|
|
patch_nums = self.video_processor.inference_output_length(self.config.video_config, None)*len(image_list) |
|
|
else: |
|
|
raise ValueError("can not find any path in image_info") |
|
|
|
|
|
|
|
|
if org_size[0] * org_size[1] > 16**2: |
|
|
ret = ret.concatenate( |
|
|
BaichuanProcessorOutput( |
|
|
videos=[image_feat], |
|
|
videos_patch_nums=[patch_nums], |
|
|
videos_crop_size=[image_list], |
|
|
videos_size= [list(org_size)], |
|
|
videos_grid=[list(self.video_processor.get_anyres_image_grid_shape(org_size))] |
|
|
) |
|
|
) |
|
|
else: |
|
|
print("**** video too small: {}, info: {} *****".format(str(org_size), str(video_frame_info))) |
|
|
return ret |
|
|
|
|
|
except Exception as e: |
|
|
print("**** get video error: {}, info: {} *****".format(str(e), str(video_frame_info))) |
|
|
return BaichuanProcessorOutput() |
|
|
|
|
|
|
|
|
def _get_video_obj_byte(self, source, path, video_obj_json): |
|
|
video_obj_byte = None |
|
|
if source == "cos": |
|
|
start_time = time.time() |
|
|
video_obj_byte = self.cos_client(path, bucket_name=video_obj_json.get("cos_bucket", None)) |
|
|
if (time.time() - start_time) > 1.0: |
|
|
self.reflash_cos_client() |
|
|
if source == "local": |
|
|
if os.path.exists(path): |
|
|
video_obj_byte = open(path, "rb").read() |
|
|
else: |
|
|
video_obj_byte = None |
|
|
if source == "base64": |
|
|
video_obj_byte = base64.b64decode(path) |
|
|
if source == "url": |
|
|
video_obj_byte = requests.get(url=path).content |
|
|
return video_obj_byte |
|
|
|
|
|
|
|
|
def _split_video_to_frames(self, video_info, max_frame_number=-1, decode_way="1fps"): |
|
|
video_path = video_info['local'] |
|
|
|
|
|
frame_path = video_path.split('.')[0] + '_frames' |
|
|
if not os.path.exists(frame_path) or len(os.listdir(frame_path))==0: |
|
|
|
|
|
os.makedirs(frame_path, exist_ok=True) |
|
|
mm_obj_byte = self._get_video_obj_byte('local', video_path, video_info) |
|
|
if mm_obj_byte is None: |
|
|
return "" |
|
|
frames = read_video(io.BytesIO(mm_obj_byte), max_frame_number=max_frame_number, decode_way=decode_way) |
|
|
for frame_idx, frame in enumerate(frames): |
|
|
output_filename = os.path.join(frame_path, f"{frame_idx}.jpg") |
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) |
|
|
cv2.imwrite(output_filename, frame) |
|
|
del frames |
|
|
del mm_obj_byte |
|
|
|
|
|
|
|
|
frame_number = len([filename for filename in os.listdir(frame_path) if filename.endswith('.jpg')]) |
|
|
if frame_number>max_frame_number: |
|
|
indices = np.linspace(0, frame_number - 1, max_frame_number, dtype=int) |
|
|
else: |
|
|
indices = np.linspace(0, frame_number - 1, frame_number, dtype=int) |
|
|
|
|
|
replace_str = "" |
|
|
for idx in indices: |
|
|
frame_str = f"{self.image_start_tag}{os.path.join(frame_path, f'{idx}.jpg')}{self.image_end_tag}" |
|
|
replace_str += frame_str |
|
|
return replace_str |
|
|
|
|
|
def sample_frame(self,frames_str,max_frame = 32): |
|
|
def uniform_sample(lst, num_samples): |
|
|
if num_samples > len(lst): |
|
|
return lst |
|
|
interval = len(lst) / num_samples |
|
|
samples = [lst[int(i * interval)] for i in range(num_samples)] |
|
|
return samples |
|
|
p = rf'({self.image_start_tag}.*?{self.image_end_tag})' |
|
|
frames_str_split = re.split(p,frames_str) |
|
|
frame_idxs = [idx for idx in range(len(frames_str_split)) if self.image_start_tag in frames_str_split[idx]] |
|
|
sample_frame_idxs = set(uniform_sample(frame_idxs, max_frame)) |
|
|
return ''.join([item for idx,item in enumerate(frames_str_split) if idx in sample_frame_idxs or self.image_start_tag not in frames_str_split[idx]]) |
|
|
|
|
|
def _get_video_frame_str(self, video_info, return_mm_data = True ): |
|
|
try: |
|
|
if self.videoframe_start_tag in video_info: |
|
|
frames_str = video_info |
|
|
frames_str = frames_str.replace(self.videoframe_start_tag,self.image_start_tag).replace(self.videoframe_end_tag,self.image_end_tag) |
|
|
return self.sample_frame(frames_str, max_frame = self.config.video_config.max_frame_num) |
|
|
video_info = ujson.loads(video_info) |
|
|
if 'local' in video_info.keys(): |
|
|
|
|
|
frames_str = self._split_video_to_frames(video_info, max_frame_number=self.config.video_config.max_frame_num, decode_way=self.config.video_config.decode_way) |
|
|
if frames_str != "": |
|
|
parts = frames_str.split(self.image_end_tag) |
|
|
result = [] |
|
|
for part in parts: |
|
|
if self.image_start_tag in part: |
|
|
before_path, path = part.split(self.image_start_tag) |
|
|
new_path = f'{self.image_start_tag}{{"local": "{path}"}}{self.image_end_tag}' |
|
|
result.append(before_path + new_path) |
|
|
else: |
|
|
result.append(part) |
|
|
return ''.join(result) |
|
|
else: |
|
|
raise ValueError('can not find localpath in video_info') |
|
|
except Exception as e: |
|
|
print("**** get video error: {}, info: {} *****".format(str(e), str(video_info))) |
|
|
return "" |
|
|
|
|
|
def _replace_audio(self, audio_text, return_mm_data = True): |
|
|
audio_info = re.sub(re.compile(self.audio_start_tag + "|" + self.audio_end_tag), '', audio_text) |
|
|
ret = self._get_audio(audio_info) |
|
|
if ret.bridge_length is not None: |
|
|
replaced_text = [self.audio_pad_tag * l for l in ret.bridge_length] |
|
|
replaced_text = self.audio_delim_tag.join(replaced_text) |
|
|
return self.audio_start_tag + replaced_text + self.audio_end_tag |
|
|
return '' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _replace_image(self, image_text, return_mm_data = True): |
|
|
image_info = re.sub(re.compile(self.image_start_tag + "|" + self.image_end_tag), '', image_text) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
patch_nums = self.visual_processor.inference_output_length(self.config.visual_config, None) |
|
|
replace_str = self.image_start_tag + self.image_pad_tag * patch_nums + self.image_end_tag |
|
|
|
|
|
|
|
|
return replace_str |
|
|
|
|
|
def _replace_video_frame(self, video_frame_text, return_mm_data = True): |
|
|
video_frame_info = re.sub(re.compile(self.image_start_tag + "|" + self.image_end_tag), '', video_frame_text) |
|
|
ret = self._get_video_frame(video_frame_info, return_mm_data) |
|
|
if ret.videos_patch_nums is None: |
|
|
return '' |
|
|
video_frame_str = [] |
|
|
for i in range(len(ret.videos_patch_nums)): |
|
|
if hasattr(self.video_processor.config, 'feature_mode') and self.video_processor.config.feature_mode =='anyres': |
|
|
num_images = ret.videos_patch_nums[i]//self.video_processor.feature_size |
|
|
num_patch = ret.videos_patch_nums[i]//num_images |
|
|
num_patch_width, num_patch_height = self.video_processor.get_anyres_image_grid_shape(ret.videos_size[i]) |
|
|
|
|
|
slid_size = int(self.video_processor.feature_size**0.5) |
|
|
line = [self.video_place_tag * slid_size*num_patch_width]*num_patch_height*slid_size |
|
|
pad_tag = self.video_place_tag*num_patch + self.image_delimiter_tag + f"{self.image_line_tag}".join(line) |
|
|
video_frame_str.append(self.image_start_tag + pad_tag + self.image_end_tag) |
|
|
|
|
|
|
|
|
else: |
|
|
video_frame_str.append(self.image_start_tag + self.video_place_tag * ret.videos_patch_nums[i] + self.image_end_tag) |
|
|
|
|
|
del ret |
|
|
return ''.join(video_frame_str) |
|
|
|
|
|
def extract_replace_multimodal(self, text, mtype='audio', return_mm_data = True): |
|
|
|
|
|
if (self.audio_start_tag != None) and (mtype == 'audio'): |
|
|
match_regex = re.compile(self.audio_start_tag + '.*?' + self.audio_end_tag) |
|
|
drop_regex = re.compile(self.audio_start_tag + "|" + self.audio_end_tag) |
|
|
extract_func = self._get_audio |
|
|
replace_func = self._replace_audio |
|
|
elif (self.image_start_tag != None) and (mtype == 'image'): |
|
|
match_regex = re.compile(self.image_start_tag + '.*?' + self.image_end_tag) |
|
|
drop_regex = re.compile(self.image_start_tag + "|" + self.image_end_tag) |
|
|
extract_func = self._get_image |
|
|
replace_func = self._replace_image |
|
|
elif (self.video_start_tag != None) and (mtype == 'video'): |
|
|
video_match_regex = re.compile(self.video_start_tag + '.*?' + self.video_end_tag) |
|
|
video_drop_regex = re.compile(self.video_start_tag + "|" + self.video_end_tag) |
|
|
|
|
|
mm_info_list = re.findall(video_match_regex, text) |
|
|
for mm_info in mm_info_list: |
|
|
frame_str = self._get_video_frame_str(re.sub(video_drop_regex, '', mm_info)) |
|
|
|
|
|
text = re.sub(mm_info, self.video_start_tag + frame_str + self.video_end_tag, text) |
|
|
|
|
|
match_regex = re.compile(self.video_start_tag+r'(.*?)'+self.video_end_tag) |
|
|
drop_regex = re.compile(self.image_start_tag + "|" + self.image_end_tag) |
|
|
extract_func = self._get_video_frame |
|
|
replace_func = self._replace_video_frame |
|
|
else: |
|
|
raise ValueError("mtype not supportted!") |
|
|
|
|
|
mm_info_list = re.findall(match_regex, text) |
|
|
mm_info_list = [re.sub(drop_regex, '', mm_info) for mm_info in mm_info_list] |
|
|
|
|
|
ret = BaichuanProcessorOutput() |
|
|
for mm_info in mm_info_list: |
|
|
mm_ret = extract_func(mm_info, return_mm_data = return_mm_data) |
|
|
if mm_ret.audios is None and mm_ret.images is None and mm_ret.videos is None: |
|
|
return ret |
|
|
ret = ret.concatenate(mm_ret) |
|
|
ret.raw_text = re.sub(match_regex, lambda x: replace_func(x.group(), return_mm_data = return_mm_data), text) |
|
|
return ret |
|
|
|
|
|
def process_one(self, text, index=0, raw_only=False, return_mm_data = True): |
|
|
ret = BaichuanProcessorOutput(index=index) |
|
|
for mtype in self.config.multimodal: |
|
|
mret = self.extract_replace_multimodal(text, mtype, return_mm_data = return_mm_data) |
|
|
if mret.raw_text is None: |
|
|
return BaichuanProcessorOutput(index=index) |
|
|
ret = ret.concatenate(mret) |
|
|
text = mret.raw_text |
|
|
ret.raw_text = text |
|
|
if raw_only: |
|
|
return ret |
|
|
|
|
|
|
|
|
input_ids, labels = [], [] |
|
|
trainable_sep = re.findall(r'<trainable_start>|<trainable_end>', ret.raw_text.replace('\n', '<LF>')) |
|
|
if len(trainable_sep) <= 0: |
|
|
input_ids = self.tokenizer(ret.raw_text, padding='do_not_pad', truncation=True, return_tensors="np")['input_ids'][0].tolist() |
|
|
labels = [True if item != self.config.visual_config.image_pad_token_id else False for item in input_ids] |
|
|
|
|
|
else: |
|
|
split_content = re.split(r'<trainable_start>|<trainable_end>', ret.raw_text) |
|
|
for i, sc in enumerate(split_content): |
|
|
if len(sc.strip()) == 0: |
|
|
continue |
|
|
sc_ids = self.tokenizer(sc, padding='do_not_pad', truncation=True, return_tensors="np")['input_ids'][0].tolist() |
|
|
input_ids.extend(sc_ids) |
|
|
if i == 0 or trainable_sep[i - 1] == '<trainable_end>': |
|
|
labels.extend([False] * len(sc_ids)) |
|
|
else: |
|
|
labels.extend([True] * len(sc_ids)) |
|
|
|
|
|
|
|
|
ret.labels = [input_ids[j] if (l and (input_ids[j] not in self.config.multimodal_special_token_no_loss_list or input_ids[j] == self.config.visual_config.image_pad_token_id)) else -100 for j, l in enumerate(labels)] |
|
|
ret.input_ids = input_ids |
|
|
ret.index = index |
|
|
return ret |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__(self, example, parallel=128): |
|
|
|
|
|
if isinstance(example, Dict): |
|
|
pass |
|
|
elif isinstance(example, str): |
|
|
return self.process_one(example) |
|
|
elif isinstance(example, List): |
|
|
with cf.ThreadPoolExecutor(min(parallel, len(example))) as executor: |
|
|
future_list = [executor.submit(self.process_one, di, idx) for idx, di in enumerate(example)] |
|
|
batch_data = [key.result() for key in cf.as_completed(future_list)] |
|
|
valid_num = sum([1 if x.input_ids is not None else 0 for x in batch_data]) |
|
|
assert(valid_num == len(batch_data)) |
|
|
batch_data = sorted(batch_data, key=lambda x: x.index) |
|
|
|
|
|
ret = BaichuanProcessorOutput() |
|
|
for i in range(len(batch_data)): |
|
|
ret = ret.concatenate(batch_data[i]) |
|
|
self.tokenizer.padding_side = "left" |
|
|
padding_result = self.tokenizer.pad({"input_ids": [r.input_ids for r in batch_data]}, return_tensors='pt') |
|
|
ret.input_ids = padding_result["input_ids"] |
|
|
ret.attention_mask = padding_result["attention_mask"] |
|
|
padding_result = self.tokenizer.pad({"input_ids": [r.labels for r in batch_data]}, return_tensors='pt') |
|
|
ret.labels = padding_result["input_ids"] |
|
|
|
|
|
if ret.audios is not None: |
|
|
ret.audios = default_collate(ret.audios) |
|
|
ret.encoder_length = default_collate(ret.encoder_length) |
|
|
ret.bridge_length = default_collate(ret.bridge_length) |
|
|
|
|
|
if ret.images is not None: |
|
|
ret.images = [torch.from_numpy(np.asarray(image, dtype=np.float32)) for image in ret.images] |
|
|
|
|
|
|
|
|
|
|
|
if ret.videos is not None: |
|
|
ret.videos = [torch.from_numpy(np.asarray(image, dtype=np.float32)) for image in ret.videos] |
|
|
|
|
|
|
|
|
|
|
|
return ret |
|
|
|
|
|
else: |
|
|
raise ValueError("example format supported yet") |
|
|
|
|
|
@torch.no_grad() |
|
|
def pack_batch_pretrain(self, raw_batch, max_sequence_length=None, parallel=128): |
|
|
if self.parallel is not None: |
|
|
parallel = self.parallel |
|
|
if max_sequence_length is None: |
|
|
max_sequence_length = self.tokenizer.model_max_length |
|
|
|
|
|
assert isinstance(raw_batch, List) |
|
|
start_ts = time.time() |
|
|
if parallel > 1: |
|
|
with cf.ThreadPoolExecutor(max_workers=parallel) as executor: |
|
|
future_list = [] |
|
|
for idx, json_text in enumerate(raw_batch): |
|
|
try: |
|
|
json_obj = ujson.loads(json_text.strip()) |
|
|
except: |
|
|
try: |
|
|
json_obj = ast.literal_eval(json_text.strip()) |
|
|
except: |
|
|
print("parse json obj faild: {}....".format(json_text[:300])) |
|
|
continue |
|
|
if isinstance(json_obj, list): |
|
|
content = json_obj[1] |
|
|
elif 'raw' in json_obj.keys(): |
|
|
content = (json_obj["title"] if "title" in json_obj.keys() else "") + json_obj["raw"] |
|
|
elif 'content' in json_obj.keys(): |
|
|
content = (json_obj["title"] if "title" in json_obj.keys() else "") + json_obj["content"] |
|
|
else: |
|
|
continue |
|
|
|
|
|
future_list.append(executor.submit(self.process_one, content, idx)) |
|
|
|
|
|
batch_data = [key.result() for key in cf.as_completed(future_list)] |
|
|
else: |
|
|
|
|
|
batch_data = [] |
|
|
for json_text in raw_batch: |
|
|
data = ujson.loads(json_text.strip()) |
|
|
if 'raw' in data.keys(): |
|
|
batch_data.append(self.process_one(data['raw'], 0)) |
|
|
else: |
|
|
batch_data.append(self.process_one(data['content'], 0)) |
|
|
|
|
|
if (time.time() - start_ts) / (len(batch_data) + 1e-3) > 1.0: |
|
|
print('[WARNING] processing each data cost more than 1.0s') |
|
|
|
|
|
|
|
|
current_length, packed_output, output = 0, BaichuanProcessorOutput(position_ids=[], seqlens=[]), [] |
|
|
empty_data = BaichuanProcessorOutput(input_ids=[], labels=[]) |
|
|
for idx, bd in enumerate(batch_data + [empty_data]): |
|
|
if bd.input_ids is None and idx < len(batch_data): |
|
|
continue |
|
|
if (len(bd.input_ids) <= 0 or len(bd.input_ids) + 1 > max_sequence_length) and idx < len(batch_data): |
|
|
continue |
|
|
if current_length + len(bd.input_ids) + 1 > max_sequence_length or idx == len(batch_data): |
|
|
pad_nums = max_sequence_length - current_length |
|
|
if packed_output.input_ids is None or packed_output.labels is None: |
|
|
packed_output.input_ids = [self.tokenizer.pad_token_id] * pad_nums |
|
|
packed_output.labels = [-100] * pad_nums |
|
|
packed_output.position_ids += [0] * (pad_nums+1) |
|
|
else: |
|
|
packed_output.input_ids += [self.tokenizer.pad_token_id] * pad_nums |
|
|
packed_output.labels += [-100] * pad_nums |
|
|
packed_output.position_ids += [0] * pad_nums |
|
|
packed_output.attention_mask = [1] * current_length + [0] * pad_nums |
|
|
packed_output.seqlens += [0] * (max_sequence_length - len(packed_output.seqlens)) |
|
|
output.append(packed_output) |
|
|
packed_output = BaichuanProcessorOutput(position_ids=[], seqlens=[]) |
|
|
packed_output = packed_output.concatenate(bd) |
|
|
packed_output.input_ids.append(self.tokenizer.eos_token_id) |
|
|
packed_output.labels.append(self.tokenizer.eos_token_id) |
|
|
|
|
|
packed_output.position_ids.extend(list(range(len(bd.input_ids) + 1))) |
|
|
packed_output.seqlens.append(len(bd.input_ids) + 1) |
|
|
|
|
|
current_length = len(packed_output.input_ids) |
|
|
return output |
|
|
|
|
|
@torch.no_grad() |
|
|
def collect_batch_pretrain(self, batch_data): |
|
|
ret = BaichuanProcessorOutput() |
|
|
for i in range(len(batch_data)): |
|
|
ret = ret.concatenate(batch_data[i]) |
|
|
ret.input_ids = default_collate([np.asarray(x.input_ids, dtype=np.int64) for x in batch_data]).cuda(non_blocking=True) |
|
|
ret.labels = default_collate([np.asarray(x.labels, dtype=np.int64) for x in batch_data]).cuda(non_blocking=True) |
|
|
ret.attention_mask = default_collate([np.asarray(x.attention_mask, dtype=np.float32) for x in batch_data]).cuda(non_blocking=True) |
|
|
ret.position_ids = default_collate([np.asarray(x.position_ids, dtype=np.int64) for x in batch_data]).cuda(non_blocking=True) |
|
|
ret.seqlens = default_collate([np.asarray(x.seqlens, dtype=np.int64) for x in batch_data]).cuda(non_blocking=True) |
|
|
|
|
|
ret.raw_text = None |
|
|
if ret.audios is not None: |
|
|
ret.audios = default_collate(np.asarray(ret.audios, dtype=np.float32)).cuda(non_blocking=True) |
|
|
ret.encoder_length = default_collate(np.asarray(ret.encoder_length, dtype=np.int32)).cuda(non_blocking=True) |
|
|
ret.bridge_length = default_collate(np.asarray(ret.bridge_length, dtype=np.int32)).cuda(non_blocking=True) |
|
|
if ret.images is not None: |
|
|
ret.images = [torch.from_numpy(np.asarray(image, dtype=np.float32)).cuda(non_blocking=True) for image in ret.images] |
|
|
ret.patch_nums = default_collate(np.asarray(ret.patch_nums, dtype=np.int32)).cuda(non_blocking=True) |
|
|
if ret.videos is not None: |
|
|
ret.videos = [torch.from_numpy(np.asarray(video, dtype=np.float32)).cuda(non_blocking=True) for video in ret.videos] |
|
|
ret.videos_patch_nums = default_collate(np.asarray(ret.videos_patch_nums, dtype=np.int32)).cuda(non_blocking=True) |
|
|
|
|
|
return ret |
|
|
|
|
|
@torch.no_grad() |
|
|
def collect_batch_sft(self, batch_data): |
|
|
|
|
|
batch_data = [BaichuanProcessorOutput(**bd) for bd in batch_data] |
|
|
ret = BaichuanProcessorOutput() |
|
|
for i in range(len(batch_data)): |
|
|
ret = ret.concatenate(batch_data[i]) |
|
|
ret.input_ids = default_collate([np.asarray(x.input_ids, dtype=np.int64) for x in batch_data]) |
|
|
ret.labels = default_collate([np.asarray(x.labels, dtype=np.int64) for x in batch_data]) |
|
|
ret.position_ids = default_collate([np.asarray(x.position_ids, dtype=np.int64) for x in batch_data]) |
|
|
ret.seqlens = default_collate([np.asarray(x.seqlens, dtype=np.int64) for x in batch_data]) |
|
|
|
|
|
ret.raw_text = None |
|
|
if ret.audios is not None: |
|
|
ret.audios = default_collate(np.asarray(ret.audios, dtype=np.float32)) |
|
|
ret.encoder_length = default_collate(np.asarray(ret.encoder_length, dtype=np.int32)) |
|
|
ret.bridge_length = default_collate(np.asarray(ret.bridge_length, dtype=np.int32)) |
|
|
if ret.images is not None: |
|
|
|
|
|
ret.images = [torch.from_numpy(np.asarray(image, dtype=np.float32)) for image in ret.images] |
|
|
if ret.videos is not None: |
|
|
ret.videos = [torch.from_numpy(np.asarray(video, dtype=np.float32)) for video in ret.videos] |
|
|
|
|
|
|
|
|
|
|
|
ret = ret.__dict__ |
|
|
del ret['patch_nums'] |
|
|
del ret['images_size'] |
|
|
del ret['crop_size'] |
|
|
del ret['raw_text'] |
|
|
del ret['index'] |
|
|
del ret['attention_mask'] |
|
|
del ret['videos_patch_nums'] |
|
|
del ret['videos_size'] |
|
|
del ret['videos_crop_size'] |
|
|
return ret |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_img_processor(): |
|
|
from transformers import AutoConfig |
|
|
from transformers.models.clip import CLIPImageProcessor |
|
|
config = AutoConfig.from_pretrained("./", trust_remote_code=True) |
|
|
processor = BaichuanImageProcessor(config.visual_config) |
|
|
offical_processor = CLIPImageProcessor(size=config.visual_config.crop_size, crop_size=config.visual_config.crop_size, |
|
|
image_mean=config.visual_config.image_mean, image_std=config.visual_config.image_std, |
|
|
do_convert_rgb=True) |
|
|
img_files = ['coyo-700m/c612a5f97f905cd1f0d7567549bba2b0', "laion400m/cae522939353c1d047c2d0f19e0543e0", 'laion-coco/b44445525261273773c5b6aad1ee5b4e'] |
|
|
|
|
|
cos_client = CosClient() |
|
|
for img_file in img_files: |
|
|
img_bytes = cos_client(img_file) |
|
|
img_rbg = Image.open(io.BytesIO(img_bytes)) |
|
|
img_rbg.save("test.png", format="PNG") |
|
|
|
|
|
def test_audio_processor(): |
|
|
from transformers.models.whisper import WhisperFeatureExtractor |
|
|
from transformers import AutoConfig |
|
|
config = AutoConfig.from_pretrained("./", trust_remote_code=True) |
|
|
offical_processor = WhisperFeatureExtractor(feature_size=128) |
|
|
processor = BaichuanAudioProcessor(config.audio_config) |
|
|
|
|
|
wave_files = ['/home/nfs_bc_alignment/sunhaoze/sounds/audioset_full/7ZY0U5tfKyQ.flac', '/home/nfs_bc_alignment/sunhaoze/sounds/audioset_full/Osly4Shchs4.flac'] |
|
|
for wave_file in wave_files: |
|
|
wave = processor.load_audio_waveform(wave_file, True, False) |
|
|
offical_features = offical_processor(wave[0].numpy(), do_normalize=False) |
|
|
feat = offical_features['input_features'][0] |
|
|
wave, frame_nums = processor.extract_fbank_features(wave) |
|
|
print("="*60) |
|
|
print(feat.shape) |
|
|
print(wave.shape, frame_nums) |
|
|
print('the difference between offical extractor and our implementation: {}'.format(wave_file)) |
|
|
print(wave[:, :frame_nums] - feat[:, :frame_nums]) |
|
|
print(wave) |
|
|
|
|
|
|
|
|
zeros_before = np.sum(wave == 0) |
|
|
aug = processor.data_augment(wave, frame_nums) |
|
|
zeros_after = np.sum(aug == 0) |
|
|
print(zeros_before, zeros_after) |
|
|
|
|
|
|
|
|
def test_audio_long(): |
|
|
from transformers import AutoConfig, AutoTokenizer |
|
|
config = AutoConfig.from_pretrained("./", trust_remote_code=True) |
|
|
config.audio_config.split_overlap = 1 |
|
|
tokenizer = AutoTokenizer.from_pretrained("./", model_max_length=4096) |
|
|
processor = BaichuanMMProcessor(tokenizer, config, True) |
|
|
examples = ["<audio_start_baichuan>{\"path\": \"panda\/testdata\/podcast_demo_30s\/easy_chat_xianliaohuier_30s\/easy_chat_xianliaohuier-133.mp3\"}<audio_end_baichuan>What is the level of noise from the speech?\n<trainable_start>The speech energy\n is medium.<trainable_end>", |
|
|
"what's the sound's energy? \n sound1 <audio_start_baichuan>{\"path\": \"panda\/testdata\/podcast_demo_30s\/btrt_talk_heihua_30s\/btrt_talk_heihua-116.mp3\"}<audio_end_baichuan> \n sound2 <audio_start_baichuan>{\"path\": \"panda\/testdata\/podcast_demo_30s\/btrt_talk_heihua_30s\/btrt_talk_heihua-221.mp3\"}<audio_end_baichuan>The speech energy is medium.", |
|
|
] |
|
|
ret = processor(examples) |
|
|
print(ret) |
|
|
print(torch.sum(ret.input_ids == 151659)) |
|
|
print(torch.sum(ret.input_ids == 151674)) |
|
|
|
|
|
def test_processor(): |
|
|
from transformers import AutoConfig, AutoTokenizer |
|
|
config = AutoConfig.from_pretrained("./", trust_remote_code=True) |
|
|
tokenizer = AutoTokenizer.from_pretrained("./", model_max_length=4096) |
|
|
processor = BaichuanMMProcessor(tokenizer, config, True, '/home/nfs_bc_alignment/sunhaoze/sounds') |
|
|
examples = ["<audio_start_baichuan>{\"path\": \"vggsound\/7DH5fqj8j6Q.flac\"}<audio_end_baichuan>What is the level of noise from the speech?\n<trainable_start>The speech energy\n is medium.<trainable_end>", |
|
|
"hello, baichuan 你好 百川智能。", |
|
|
"what's the sound's energy? \n <audio_start_baichuan>{\"path\": \"iemocap\/Ses01F_script01_3_F022.wav\"}<audio_end_baichuan>The speech energy is medium.", |
|
|
"sound1: <audio_start_baichuan>{\"path\": \"audioset_full\/9B53NVDNT8U.flac\"}<audio_end_baichuan>\n sound2: \n<audio_start_baichuan>{\"path\": \"audioset_full\/a2dgzb9GDSQ.flac\"}<audio_end_baichuan>How is the speech speed related to the estimated speaker age?\n<trainable_start>The slow speech speed suggests a more deliberate and thoughtful approach often seen in mature individuals.<trainable_end>", |
|
|
"<img_start_baichuan>{\"path\": \"sogou\/7351ae4f3fbe58ff0e4cc165cfabb3ed\"}<img_end_baichuan>新和记潮汕牛肉火锅的牛肉丸好不好吃 用户评价口味怎么样 常州美食牛肉丸实拍图片 大众点评", |
|
|
"这两个图片有什么关系?图片1<img_start_baichuan>{\"path\": \"sogou\/ac91d57ab68335913ed41aa283e76356\"}<img_end_baichuan>图片2\n<img_start_baichuan>{\"path\": \"sogou\/6ad5e632b74265d9ef689e45936ab1aa\"}<img_end_baichuan>", |
|
|
"根据图片和语音给出描述\n图片<img_start_baichuan>{\"path\": \"sogou\/32274c1ab28d11f8c490cf7ae15b36f1\"}<img_end_baichuan>语音<audio_start_baichuan>{\"path\": \"voxceleb2\/id06726_s2lysJWkjus_00169.m4a\"}<audio_end_baichuan><trainable_start>这是一只猫<trainable_end>", |
|
|
"这些图片和音频不存在<img_start_baichuan>{\"path\": \"soogou\/32274c1ab28d11f8c490cf7ae15b36f1\"}<img_end_baichuan>语音<audio_start_baichuan>{\"path\": \"voxceleb_1\/id06726_s2lysJWkjus_00169.m4a\"}<audio_end_baichuan><trainable_start>这是一只猫<trainable_end>" |
|
|
] |
|
|
ret = processor(examples[4:-1]) |
|
|
print(ret) |
|
|
print(torch.sum(ret.input_ids == 151659)) |
|
|
print(torch.sum(ret.input_ids == 151662)) |
|
|
try: |
|
|
print(ret.bridge_length) |
|
|
print(ret.patch_nums) |
|
|
except: |
|
|
pass |
|
|
print(torch.sum(ret.attention_mask, dim=1)) |
|
|
|
|
|
|
|
|
def test_gen_processor(): |
|
|
from transformers import AutoConfig, AutoTokenizer |
|
|
config = AutoConfig.from_pretrained("/data_train2/mllm/wanfeng/code/mmtrain/modeling/baichuan3b2dot5+emu3+gen", trust_remote_code=True) |
|
|
tokenizer = AutoTokenizer.from_pretrained("/data_train2/mllm/wanfeng/code/mmtrain/modeling/baichuan3b2dot5+emu3+gen", model_max_length=4096) |
|
|
processor = BaichuanMMProcessor(tokenizer, config, True, '/home/nfs_bc_alignment/sunhaoze/sounds') |
|
|
examples = [ |
|
|
"The image is a collage of multiple figures, predominantly featuring a central female character with long, wavy hair, smiling and holding a microphone. She is wearing a sleeveless top and a skirt with a ruffled texture. Surrounding her are several other figures, each with distinct hairstyles and clothing, suggesting a variety of characters or poses. The background is a blurred image of a stage with spotlights, indicating a performance setting.\n\nOverlaying the image are graphic elements and text. The text reads \"11 Gifs de Violetta Disney In My World Being What I Am,\" which suggests that the collage is a collection of animated GIFs related to the Disney character Violetta. The text is stylized with a mix of fonts and colors, including purple and white, which complements the color scheme of the image.\n\nThe image has a dreamy, ethereal quality, with a soft-focus effect applied to the background and some of the figures, which gives it a slightly surreal appearance. The color palette is warm, with a dominance of pinks and purples, contributing to a cohesive visual theme.\n\nThe overall layout of the image is a balanced composition with the central figure prominently displayed and other elements arranged around her, creating a sense of movement and activity. The image is designed to evoke the theme and spirit of the Disney show \"Violetta,\" showcasing the main character and her world.<image_gen_start><trainable_start><img_start_baichuan>{\"path\": \"densefusion-1m/263eb49f9dc93bcd98434a26b944e44b\"}<img_end_baichuan><trainable_end><image_gen_end>", |
|
|
"The image is a collage of multiple figures, predominantly featuring a central female character with long, wavy hair, smiling and holding a microphone. She is wearing a sleeveless top and a skirt with a ruffled texture. Surrounding her are several other figures, each with distinct hairstyles and clothing, suggesting a variety of characters or poses. The background is a blurred image of a stage with spotlights, indicating a performance setting.\n\nOverlaying the image are graphic elements and text. The text reads \"11 Gifs de Violetta Disney In My World Being What I Am,\" which suggests that the collage is a collection of animated GIFs related to the Disney character Violetta. The text is stylized with a mix of fonts and colors, including purple and white, which complements the color scheme of the image.\n\nThe image has a dreamy, ethereal quality, with a soft-focus effect applied to the background and some of the figures, which gives it a slightly surreal appearance. The color palette is warm, with a dominance of pinks and purples, contributing to a cohesive visual theme.\n\nThe overall layout of the image is a balanced composition with the central figure prominently displayed and other elements arranged around her, creating a sense of movement and activity. The image is designed to evoke the theme and spirit of the Disney show \"Violetta,\" showcasing the main character and her world.<image_gen_start><trainable_start><img_start_baichuan>{\"path\": \"densefusion-1m/263eb49f9dc93bcd98434a26b944e44b\"}<img_end_baichuan><trainable_end><image_gen_end>", |
|
|
"The image is a collage of multiple figures, predominantly featuring a central female character with long, wavy hair, smiling and holding a microphone. She is wearing a sleeveless top and a skirt with a ruffled texture. Surrounding her are several other figures, each with distinct hairstyles and clothing, suggesting a variety of characters or poses. The background is a blurred image of a stage with spotlights, indicating a performance setting.\n\nOverlaying the image are graphic elements and text. The text reads \"11 Gifs de Violetta Disney In My World Being What I Am,\" which suggests that the collage is a collection of animated GIFs related to the Disney character Violetta. The text is stylized with a mix of fonts and colors, including purple and white, which complements the color scheme of the image.\n\nThe image has a dreamy, ethereal quality, with a soft-focus effect applied to the background and some of the figures, which gives it a slightly surreal appearance. The color palette is warm, with a dominance of pinks and purples, contributing to a cohesive visual theme.\n\nThe overall layout of the image is a balanced composition with the central figure prominently displayed and other elements arranged around her, creating a sense of movement and activity. The image is designed to evoke the theme and spirit of the Disney show \"Violetta,\" showcasing the main character and her world.<image_gen_start><trainable_start><img_start_baichuan>{\"path\": \"densefusion-1m/263eb49f9dc93bcd98434a26b944e44b\"}<img_end_baichuan><trainable_end><image_gen_end>" |
|
|
] |
|
|
ret = processor(examples) |
|
|
print(ret) |
|
|
print(torch.sum(ret.input_ids == 151659)) |
|
|
print(torch.sum(ret.input_ids == 151662)) |
|
|
try: |
|
|
print(ret.bridge_length) |
|
|
print(ret.patch_nums) |
|
|
except: |
|
|
pass |
|
|
print(torch.sum(ret.attention_mask, dim=1)) |
|
|
|
|
|
def test_grounding(): |
|
|
from transformers import AutoConfig, AutoTokenizer |
|
|
config = AutoConfig.from_pretrained("./", trust_remote_code=True) |
|
|
tokenizer = AutoTokenizer.from_pretrained("./", model_max_length=4096) |
|
|
processor = BaichuanMMProcessor(tokenizer, config, True, '/home/nfs_bc_alignment/sunhaoze/sounds') |
|
|
examples = ["<img_start_baichuan>{\"path\": \"grit\/663423bf2f0884c034bf75279bce9694\"}<img_end_baichuan>\nWhere is \"A woman\" ? Answer: <trainable_start>The bounding box is <box_start_baichuan>(0.58,0.8),(0.71,1.0)<box_end_baichuan><trainable_end>", |
|
|
"hello, baichuan 你好 百川智能。", |
|
|
"<img_start_baichuan>{\"path\": \"grit\/0e6e3952c584cbac7235940a22514656\"}<img_end_baichuan> Generate the caption with grounding: <trainable_start>Photo pour Portrait of <ref_start_baichuan>young Asian muslim woman wearing hijab<ref_end_baichuan><box_start_baichuan>(0.09,0.01),(0.77,1.0)<box_end_baichuan> shows regret gesture, hand on her forehead, forget something important, against red background - image libre de droit<trainable_end>", |
|
|
"Recognize the object in the outlined section <img_start_baichuan>{\"path\": \"grit\/045823cf6f819670f27aee20af7ae0e6\"}<img_end_baichuan> of the picture.<box_start_baichuan>(0.07,0.2),(0.91,0.96)<box_end_baichuan>\n<trainable_start>Inflatable water trampolines<trainable_end>" |
|
|
] |
|
|
ret = processor(examples) |
|
|
print(ret) |
|
|
for i, input_ids in enumerate(ret.input_ids): |
|
|
print("="*60) |
|
|
print(ret.labels[i]) |
|
|
|
|
|
def test_pack(): |
|
|
from transformers import AutoConfig, AutoTokenizer |
|
|
config = AutoConfig.from_pretrained("./", trust_remote_code=True) |
|
|
tokenizer = AutoTokenizer.from_pretrained("./", model_max_length=2048) |
|
|
processor = BaichuanMMProcessor(tokenizer, config, True, '/home/nfs_bc_alignment/sunhaoze/sounds') |
|
|
examples = open('/cpfs/29f69eb5e2e60f26/user/sunhaoze/pretrain-v6/sogou/part-00000').readlines()[:5] |
|
|
examples += open('/home/nfs_bc_alignment/sunhaoze/text/openaqa-as-stage2-v1/part-00000').readlines()[:5] |
|
|
random.shuffle(examples) |
|
|
batch_output = processor.pack_batch_pretrain(examples) |
|
|
for i, b in enumerate(batch_output): |
|
|
print('='*60) |
|
|
try: |
|
|
print(b.input_ids, len(b.input_ids)) |
|
|
print(b.labels, len(b.labels)) |
|
|
print(b.attention_mask, len(b.attention_mask)) |
|
|
print(b.position_ids, len(b.position_ids)) |
|
|
print(b.seqlens, len(b.seqlens)) |
|
|
print(b.audios) |
|
|
print(b.bridge_length) |
|
|
except: |
|
|
continue |
|
|
|
|
|
batch_for_model = processor.collect_batch_pretrain(batch_output) |
|
|
print(batch_for_model.input_ids.shape) |
|
|
print(batch_for_model.labels.shape) |
|
|
print(batch_for_model.audios.shape) |
|
|
print(batch_for_model["bridge_length"]) |
|
|
print(batch_for_model.images.shape) |
|
|
print(batch_for_model["patch_nums"]) |
|
|
print(batch_for_model["position_ids"]) |
|
|
print(batch_for_model["seqlens"]) |
|
|
|
|
|
def test_cos_audio(): |
|
|
cos_client = CosClient() |
|
|
audio_bytes = cos_client('panda/data/common_voice/cv-corpus-18.0-2024-06-14/zh-CN/clips/common_voice_zh-CN_19428637.mp3', 'audio-data-1317568651') |
|
|
wave, sr = torchaudio.load(audio_bytes, normalize=False) |
|
|
print(wave.shape, sr) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
fire.Fire() |