|
|
|
|
|
"""
|
|
|
UniMoE Audio Utilities Module
|
|
|
Author: UniMoE Audio Team
|
|
|
"""
|
|
|
|
|
|
import copy
|
|
|
import glob
|
|
|
import json
|
|
|
import math
|
|
|
import os
|
|
|
import re
|
|
|
import shutil
|
|
|
import sys
|
|
|
import time
|
|
|
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union, TYPE_CHECKING, Callable
|
|
|
|
|
|
import dac
|
|
|
import datasets
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
import torchaudio
|
|
|
import transformers
|
|
|
from audiotools import AudioSignal
|
|
|
from safetensors import safe_open
|
|
|
from tqdm import tqdm
|
|
|
from transformers import AutoProcessor, AutoTokenizer, LogitsProcessor, LogitsProcessorList
|
|
|
from moviepy.video.io.VideoFileClip import VideoFileClip
|
|
|
from PIL import Image
|
|
|
from torchvision import io, transforms
|
|
|
from torchvision.transforms import InterpolationMode
|
|
|
import torchvision
|
|
|
|
|
|
from qwen_vl_utils import smart_resize, process_vision_info
|
|
|
|
|
|
import deepspeed
|
|
|
from deepspeed import comm as dist
|
|
|
from deepspeed.moe.sharded_moe import _capacity, _one_hot_to_float, einsum, gumbel_rsample
|
|
|
from torch import Tensor
|
|
|
|
|
|
try:
|
|
|
import torch_npu
|
|
|
IS_CUDA = False
|
|
|
except:
|
|
|
IS_CUDA = True
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
|
from tutel import moe as tutel_moe
|
|
|
TUTEL_INSTALLED = True
|
|
|
except:
|
|
|
|
|
|
TUTEL_INSTALLED = False
|
|
|
pass
|
|
|
|
|
|
|
|
|
SYSTEM_MESSAGE = """<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"""
|
|
|
INPUT_FORMAT = """<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"""
|
|
|
AUDIO_START = "<|AUDIO_START|>"
|
|
|
|
|
|
DEFAULT_VIDEO_PROMPT = "<|vision_start|><|video_pad|><|vision_end|>{}"
|
|
|
IMAGE_FACTOR = 28
|
|
|
MIN_PIXELS = 4 * 28 * 28
|
|
|
MAX_PIXELS = 16384 * 28 * 28
|
|
|
MAX_RATIO = 200
|
|
|
VIDEO_TOTAL_PIXELS = 16 * 28 * 28
|
|
|
VIDEO_MIN_PIXELS = 16 * 28 * 28
|
|
|
VIDEO_MAX_PIXELS = 64 * 28 * 28
|
|
|
FRAME_FACTOR = 2
|
|
|
|
|
|
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
|
|
IMAGENET_STD = (0.229, 0.224, 0.225)
|
|
|
|
|
|
IMG_START_TOKEN='<img>'
|
|
|
IMG_END_TOKEN='</img>'
|
|
|
IMG_CONTEXT_TOKEN='<IMG_CONTEXT>'
|
|
|
IMG_PREFIX_FORMAT = "<|IMAGE_PLACE_HOLDER|>"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Dac:
|
|
|
def __init__(self):
|
|
|
base_dir = os.path.dirname(__file__)
|
|
|
dac_model_dir = os.path.join(base_dir, "dac_model")
|
|
|
model_path = os.path.join(dac_model_dir, "weights_16khz.pth")
|
|
|
|
|
|
if not os.path.isfile(model_path):
|
|
|
print(f"DAC model not found at {model_path}, downloading...")
|
|
|
os.makedirs(dac_model_dir, exist_ok=True)
|
|
|
downloaded_path = dac.utils.download(model_type="16khz")
|
|
|
shutil.move(downloaded_path, model_path)
|
|
|
print(f"DAC model downloaded and saved to {model_path}")
|
|
|
|
|
|
env_path = os.environ.get("DAC_WEIGHTS")
|
|
|
candidates = []
|
|
|
if env_path:
|
|
|
candidates.append(env_path)
|
|
|
|
|
|
candidates.extend([
|
|
|
model_path,
|
|
|
os.path.join(base_dir, "weights_16khz.pth"),
|
|
|
os.path.join(os.getcwd(), "utils", "dac_model", "weights_16khz.pth"),
|
|
|
os.path.join(os.getcwd(), "dac_model", "weights_16khz.pth"),
|
|
|
])
|
|
|
|
|
|
final_model_path = next((p for p in candidates if p and os.path.isfile(p)), None)
|
|
|
if not final_model_path:
|
|
|
searched = "\n - " + "\n - ".join(candidates)
|
|
|
raise FileNotFoundError(
|
|
|
"DAC weights not found. Please place weights_16khz.pth in one of the following locations or set DAC_WEIGHTS to an absolute path:" + searched
|
|
|
)
|
|
|
|
|
|
self.model = dac.DAC.load(final_model_path)
|
|
|
self.resampler = dict()
|
|
|
if IS_CUDA:
|
|
|
self.model = self.model.to("cuda")
|
|
|
else:
|
|
|
self.model = self.model.to("npu")
|
|
|
|
|
|
def encode(self, audio_path):
|
|
|
signal = AudioSignal(audio_path)
|
|
|
if signal.audio_data.shape[1] == 2:
|
|
|
signal.audio_data = 0.5 * (signal.audio_data[:, :1, :] + signal.audio_data[:, 1:, :])
|
|
|
signal.to(self.model.device)
|
|
|
|
|
|
if signal.sample_rate != 16000:
|
|
|
if not str(signal.sample_rate) in self.resampler:
|
|
|
self.resampler[str(signal.sample_rate)] = torchaudio.transforms.Resample(signal.sample_rate, 16000)
|
|
|
if IS_CUDA:
|
|
|
self.resampler[str(signal.sample_rate)] = self.resampler[str(signal.sample_rate)].cuda()
|
|
|
else:
|
|
|
self.resampler[str(signal.sample_rate)] = self.resampler[str(signal.sample_rate)].npu()
|
|
|
|
|
|
signal.audio_data = self.resampler[str(signal.sample_rate)](signal.audio_data)
|
|
|
signal.sample_rate = 16000
|
|
|
|
|
|
x = self.model.preprocess(signal.audio_data.to(self.model.device), signal.sample_rate)
|
|
|
z, codes, latents, _, _ = self.model.encode(x)
|
|
|
|
|
|
codes = codes[0].clone().detach().transpose(0, 1)
|
|
|
assert codes.shape[1] == 12 and len(codes.shape) == 2
|
|
|
codes = codes.tolist()
|
|
|
|
|
|
return codes
|
|
|
|
|
|
def decode(self, codes, save_path, min_duration=None):
|
|
|
assert codes.shape[0] == 1 and codes.shape[1] == 12
|
|
|
z, _, _ = self.model.quantizer.from_codes(codes.to(self.model.device))
|
|
|
audio_out = self.model.decode(z)[0].detach().cpu()
|
|
|
|
|
|
sample_rate = 16000
|
|
|
duration = audio_out.size(1) / sample_rate
|
|
|
if min_duration is not None and duration < min_duration:
|
|
|
padding_duration = min_duration - duration
|
|
|
padding_samples = int(padding_duration * sample_rate)
|
|
|
padding = torch.zeros((audio_out.size(0), padding_samples), dtype=audio_out.dtype, device=audio_out.device)
|
|
|
audio_out = torch.cat((audio_out, padding), dim=1)
|
|
|
|
|
|
torchaudio.save(save_path, audio_out.detach().cpu(), sample_rate=16000, encoding="PCM_S", bits_per_sample=16)
|
|
|
|
|
|
|
|
|
def build_delay_indices(B: int, T: int, C: int, delay_pattern: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
delay_arr = torch.tensor(delay_pattern, dtype=torch.int32)
|
|
|
|
|
|
t_idx_BxT = torch.broadcast_to(
|
|
|
torch.arange(T, dtype=torch.int32)[None, :],
|
|
|
[B, T],
|
|
|
)
|
|
|
t_idx_BxTx1 = t_idx_BxT[..., None]
|
|
|
t_idx_BxTxC = t_idx_BxTx1 - delay_arr.view(1, 1, C)
|
|
|
|
|
|
b_idx_BxTxC = torch.broadcast_to(
|
|
|
torch.arange(B, dtype=torch.int32).view(B, 1, 1),
|
|
|
[B, T, C],
|
|
|
)
|
|
|
c_idx_BxTxC = torch.broadcast_to(
|
|
|
torch.arange(C, dtype=torch.int32).view(1, 1, C),
|
|
|
[B, T, C],
|
|
|
)
|
|
|
t_clamped_BxTxC = torch.clamp(t_idx_BxTxC, 0, T - 1)
|
|
|
indices_BTCx3 = torch.stack(
|
|
|
[
|
|
|
b_idx_BxTxC.reshape(-1),
|
|
|
t_clamped_BxTxC.reshape(-1),
|
|
|
c_idx_BxTxC.reshape(-1),
|
|
|
],
|
|
|
dim=1,
|
|
|
).long()
|
|
|
|
|
|
return t_idx_BxTxC, indices_BTCx3
|
|
|
|
|
|
|
|
|
def apply_audio_delay(audio_BxTxC: torch.Tensor, pad_value: int, bos_value: int, precomp: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
|
|
|
device = audio_BxTxC.device
|
|
|
t_idx_BxTxC, indices_BTCx3 = precomp
|
|
|
t_idx_BxTxC = t_idx_BxTxC.to(device)
|
|
|
indices_BTCx3 = indices_BTCx3.to(device)
|
|
|
gathered_flat = audio_BxTxC[indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]]
|
|
|
gathered_BxTxC = gathered_flat.view(audio_BxTxC.shape)
|
|
|
mask_bos = t_idx_BxTxC < 0
|
|
|
mask_pad = t_idx_BxTxC >= audio_BxTxC.shape[1]
|
|
|
|
|
|
bos_tensor = torch.tensor(bos_value, dtype=audio_BxTxC.dtype, device=device)
|
|
|
pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
|
|
|
|
|
|
result_BxTxC = torch.where(mask_bos, bos_tensor, torch.where(mask_pad, pad_tensor, gathered_BxTxC))
|
|
|
|
|
|
return result_BxTxC
|
|
|
|
|
|
|
|
|
def build_revert_indices(B: int, T: int, C: int, delay_pattern: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
device = None
|
|
|
delay_arr = torch.tensor(delay_pattern, dtype=torch.int32, device=device)
|
|
|
t_idx_BT1 = torch.broadcast_to(torch.arange(T, device=device).unsqueeze(0), [B, T])
|
|
|
t_idx_BT1 = t_idx_BT1.unsqueeze(-1)
|
|
|
t_idx_BxTxC = torch.minimum(
|
|
|
t_idx_BT1 + delay_arr.view(1, 1, C),
|
|
|
torch.tensor(T - 1, device=device),
|
|
|
)
|
|
|
b_idx_BxTxC = torch.broadcast_to(torch.arange(B, device=device).view(B, 1, 1), [B, T, C])
|
|
|
c_idx_BxTxC = torch.broadcast_to(torch.arange(C, device=device).view(1, 1, C), [B, T, C])
|
|
|
indices_BTCx3 = torch.stack(
|
|
|
[
|
|
|
b_idx_BxTxC.reshape(-1),
|
|
|
t_idx_BxTxC.reshape(-1),
|
|
|
c_idx_BxTxC.reshape(-1),
|
|
|
],
|
|
|
axis=1,
|
|
|
).long()
|
|
|
|
|
|
return t_idx_BxTxC, indices_BTCx3
|
|
|
|
|
|
|
|
|
def revert_audio_delay(
|
|
|
audio_BxTxC: torch.Tensor,
|
|
|
pad_value: int,
|
|
|
precomp: Tuple[torch.Tensor, torch.Tensor],
|
|
|
T: int,
|
|
|
) -> torch.Tensor:
|
|
|
t_idx_BxTxC, indices_BTCx3 = precomp
|
|
|
device = audio_BxTxC.device
|
|
|
t_idx_BxTxC = t_idx_BxTxC.to(device)
|
|
|
indices_BTCx3 = indices_BTCx3.to(device)
|
|
|
gathered_flat = audio_BxTxC[indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]]
|
|
|
gathered_BxTxC = gathered_flat.view(audio_BxTxC.size())
|
|
|
|
|
|
pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
|
|
|
T_tensor = torch.tensor(T, device=device)
|
|
|
|
|
|
result_BxTxC = torch.where(t_idx_BxTxC >= T_tensor, pad_tensor, gathered_BxTxC)
|
|
|
|
|
|
return result_BxTxC
|
|
|
|
|
|
|
|
|
def prepare_audio_prompt(model, audio_prompts: list[torch.Tensor]):
|
|
|
num_channels = model.config.codec_channels
|
|
|
audio_bos_value = model.config.codec_bos_value
|
|
|
delay_pattern = model.config.codec_delay_pattern
|
|
|
max_delay_pattern = max(delay_pattern)
|
|
|
batch_size = len(audio_prompts)
|
|
|
max_len = max(p.shape[0] if p is not None else 0 for p in audio_prompts) + max_delay_pattern + 1
|
|
|
prefill_steps = []
|
|
|
prefill = torch.full(
|
|
|
(batch_size, max_len, num_channels),
|
|
|
fill_value=-1,
|
|
|
dtype=torch.int,
|
|
|
device=model.device,
|
|
|
)
|
|
|
prefill[:, 0, :] = audio_bos_value
|
|
|
for i in range(batch_size):
|
|
|
prompt = audio_prompts[i]
|
|
|
if prompt is not None:
|
|
|
prompt = prompt.to(device=model.device, dtype=torch.int)
|
|
|
prefill[i, 1 : prompt.shape[0] + 1, :] = prompt
|
|
|
prefill_steps.append(prompt.shape[0] + 1)
|
|
|
else:
|
|
|
prefill_steps.append(1)
|
|
|
|
|
|
delay_precomp = build_delay_indices(
|
|
|
B=batch_size,
|
|
|
T=max_len,
|
|
|
C=num_channels,
|
|
|
delay_pattern=delay_pattern,
|
|
|
)
|
|
|
|
|
|
delayed_batch = apply_audio_delay(
|
|
|
audio_BxTxC=prefill,
|
|
|
pad_value=-1,
|
|
|
bos_value=audio_bos_value,
|
|
|
precomp=delay_precomp,
|
|
|
)
|
|
|
|
|
|
return delayed_batch, prefill_steps
|
|
|
|
|
|
|
|
|
class DecoderOutput:
|
|
|
def __init__(self, prefill, prefill_steps, device: torch.device, labels_prefill=None):
|
|
|
self.generated_tokens = prefill
|
|
|
self.prefill_steps = prefill_steps
|
|
|
self.labels_prefill = labels_prefill
|
|
|
self.device = device
|
|
|
|
|
|
def get_tokens_at(self, step_from: int, step_to: int = None) -> torch.Tensor:
|
|
|
if step_to is None:
|
|
|
step_to = step_from + 1
|
|
|
return self.generated_tokens[:, step_from:step_to, :].to(self.device)
|
|
|
|
|
|
def get_labels_at(self, step_from: int, step_to: int = None) -> torch.Tensor:
|
|
|
if step_to is None:
|
|
|
step_to = step_from + 1
|
|
|
if self.labels_prefill is None:
|
|
|
return None
|
|
|
return self.labels_prefill[:, step_from:step_to, :].to(self.device)
|
|
|
|
|
|
def update_one(self, dec_out: torch.Tensor, step: int, apply_mask: bool = False):
|
|
|
dec_out = dec_out.to(self.generated_tokens.dtype).to(self.generated_tokens.device)
|
|
|
if apply_mask:
|
|
|
assert step < self.generated_tokens.shape[1]
|
|
|
mask = self.generated_tokens[:, step, :] == -1
|
|
|
self.generated_tokens[:, step, :] = torch.where(mask, dec_out, self.generated_tokens[:, step, :])
|
|
|
else:
|
|
|
assert step == self.generated_tokens.shape[1]
|
|
|
self.generated_tokens = torch.cat((self.generated_tokens, dec_out[:, None, :]), dim=1)
|
|
|
|
|
|
|
|
|
def generate_output(model, generated_codes: torch.Tensor, lengths_Bx: torch.Tensor) -> list[np.ndarray]:
|
|
|
num_channels = model.config.codec_channels
|
|
|
batch_size = generated_codes.shape[0]
|
|
|
seq_length = generated_codes.shape[1]
|
|
|
delay_pattern = model.config.codec_delay_pattern
|
|
|
audio_pad_value = model.config.codec_pad_value
|
|
|
max_delay_pattern = max(delay_pattern)
|
|
|
revert_precomp = build_revert_indices(
|
|
|
B=batch_size,
|
|
|
T=seq_length,
|
|
|
C=num_channels,
|
|
|
delay_pattern=delay_pattern,
|
|
|
)
|
|
|
codebook = revert_audio_delay(
|
|
|
audio_BxTxC=generated_codes,
|
|
|
pad_value=audio_pad_value,
|
|
|
precomp=revert_precomp,
|
|
|
T=seq_length,
|
|
|
)[:, :-max_delay_pattern, :]
|
|
|
|
|
|
audios = []
|
|
|
for i in range(batch_size):
|
|
|
audios.append(codebook[i, : lengths_Bx[i], :].cpu())
|
|
|
|
|
|
return audios
|
|
|
|
|
|
def frame_process(images, **ele):
|
|
|
images = [torchvision.transforms.functional.pil_to_tensor(img) for img in images]
|
|
|
video = torch.stack(images, dim=0)
|
|
|
|
|
|
|
|
|
nframes, _, height, width = video.shape
|
|
|
min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
|
|
|
total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
|
|
|
max_pixels = max(min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), int(min_pixels * 1.05))
|
|
|
max_pixels_supposed = ele.get("max_pixels", max_pixels)
|
|
|
if max_pixels_supposed > max_pixels:
|
|
|
print(f"The given max_pixels[{max_pixels_supposed}] exceeds limit[{max_pixels}].")
|
|
|
max_pixels = min(max_pixels_supposed, max_pixels)
|
|
|
if "resized_height" in ele and "resized_width" in ele:
|
|
|
resized_height, resized_width = smart_resize(
|
|
|
ele["resized_height"],
|
|
|
ele["resized_width"],
|
|
|
factor=IMAGE_FACTOR,
|
|
|
)
|
|
|
else:
|
|
|
resized_height, resized_width = smart_resize(
|
|
|
height,
|
|
|
width,
|
|
|
factor=IMAGE_FACTOR,
|
|
|
min_pixels=min_pixels,
|
|
|
max_pixels=max_pixels,
|
|
|
)
|
|
|
video = transforms.functional.resize(
|
|
|
video,
|
|
|
[resized_height, resized_width],
|
|
|
interpolation=InterpolationMode.BICUBIC,
|
|
|
antialias=True,
|
|
|
).float()
|
|
|
return video
|
|
|
|
|
|
def preprocess_codec(model, codec):
|
|
|
"""Preprocess codec tokens"""
|
|
|
codec_token = torch.tensor(codec, dtype=torch.long)
|
|
|
codec_token_len = codec_token.shape[0]
|
|
|
max_delay_pattern = max(model.config.codec_delay_pattern)
|
|
|
codec_input_ids = torch.zeros((codec_token_len + max_delay_pattern + 1, model.num_channels), dtype=torch.long)
|
|
|
|
|
|
for c in range(model.num_channels):
|
|
|
start = model.config.codec_delay_pattern[c] + 1
|
|
|
codec_input_ids[:start, c] = model.config.codec_bos_value
|
|
|
codec_input_ids[start : start + codec_token_len, c] = codec_token[:, c]
|
|
|
codec_input_ids[start + codec_token_len :, c] = model.config.codec_pad_value
|
|
|
if start + codec_token_len < codec_input_ids.shape[0]:
|
|
|
codec_input_ids[start + codec_token_len, c] = model.config.codec_eos_value
|
|
|
|
|
|
return codec_input_ids
|
|
|
|
|
|
|
|
|
def tts_preprocess(batch_caption, prompt_codec, prompt_text, device):
|
|
|
|
|
|
text_input = []
|
|
|
codec_input_ids = []
|
|
|
for caption in batch_caption:
|
|
|
prompt_caption = "<|SPEECH_PROMPT_START|>" + prompt_text + "<|SPEECH_PROMPT_END|>"
|
|
|
prompt_caption += "<|VOICE_PROMPT_START|>" + "<|AUDIO_PLACEHOLDER|>" * prompt_codec.shape[0] + "<|VOICE_PROMPT_END|>"
|
|
|
prompt_caption_fn = lambda x: prompt_caption + "<|SPEECH_START|>" + x + "<|SPEECH_END|>"
|
|
|
|
|
|
text_input.append(SYSTEM_MESSAGE + INPUT_FORMAT.format(f"<|SPEECH_PROMPT_START|>{prompt_text}<|SPEECH_PROMPT_END|><|VOICE_PROMPT_START|><|VOICE_PROMPT_END|><|SPEECH_START|>{caption}<|SPEECH_END|>") + AUDIO_START)
|
|
|
text_input.append(SYSTEM_MESSAGE + INPUT_FORMAT.format(prompt_caption_fn("")) + AUDIO_START)
|
|
|
text_input.append(SYSTEM_MESSAGE + INPUT_FORMAT.format(prompt_caption_fn(caption)) + AUDIO_START)
|
|
|
codec_input_ids.append(prompt_codec.clone())
|
|
|
codec_input_ids.append(prompt_codec.clone())
|
|
|
|
|
|
codec_input_ids = torch.cat(codec_input_ids, dim=0).to(device)
|
|
|
|
|
|
tts_generation_kwargs = {
|
|
|
"codec_input_ids": codec_input_ids,
|
|
|
"cfg_scale": [2, 3],
|
|
|
"neg_input_size": 3,
|
|
|
}
|
|
|
|
|
|
return text_input, tts_generation_kwargs
|
|
|
|
|
|
def t2m_preprocess(batch_caption):
|
|
|
|
|
|
text_input = []
|
|
|
for caption in batch_caption:
|
|
|
text_input.append(SYSTEM_MESSAGE + INPUT_FORMAT.format("<|MUSIC_START|>" + "Low quality." + "<|MUSIC_END|>") + AUDIO_START)
|
|
|
text_input.append(SYSTEM_MESSAGE + INPUT_FORMAT.format("<|MUSIC_START|>" + caption + "<|MUSIC_END|>") + AUDIO_START)
|
|
|
|
|
|
t2m_generation_kwargs = {
|
|
|
"cfg_scale": 10,
|
|
|
"neg_input_size": 2,
|
|
|
}
|
|
|
|
|
|
return text_input, t2m_generation_kwargs
|
|
|
|
|
|
def v2m_preprocess(batch_caption, batch_video, fps=1):
|
|
|
|
|
|
def extract_images_from_video(video_path, fps=1, max_frames=1):
|
|
|
video = VideoFileClip(video_path)
|
|
|
duration = video.duration
|
|
|
|
|
|
|
|
|
images = []
|
|
|
for i, t in enumerate(range(0, math.ceil(duration * fps))):
|
|
|
time_in_video = t / fps
|
|
|
frame = video.get_frame(time_in_video)
|
|
|
img = Image.fromarray(frame)
|
|
|
images.append(img)
|
|
|
|
|
|
if max_frames is not None and i >= max_frames - 1:
|
|
|
break
|
|
|
|
|
|
return images
|
|
|
|
|
|
text_input = []
|
|
|
video_inputs = []
|
|
|
fps_inputs = []
|
|
|
|
|
|
for caption, video in zip(batch_caption, batch_video):
|
|
|
text_input.append(SYSTEM_MESSAGE + INPUT_FORMAT.format("<|MUSIC_START|>" + "Low quality." + "<|MUSIC_END|>") + AUDIO_START)
|
|
|
text_input.append(SYSTEM_MESSAGE + INPUT_FORMAT.format("<|MUSIC_START|>" + caption + "<|MUSIC_END|>") + AUDIO_START)
|
|
|
|
|
|
video_input = frame_process(
|
|
|
extract_images_from_video(video, fps),
|
|
|
fps = fps,
|
|
|
)
|
|
|
|
|
|
video_inputs.append(video_input)
|
|
|
video_inputs.append(video_input)
|
|
|
|
|
|
fps_inputs.append(fps)
|
|
|
fps_inputs.append(fps)
|
|
|
|
|
|
t2m_generation_kwargs = {
|
|
|
"cfg_scale": 10,
|
|
|
"neg_input_size": 2,
|
|
|
}
|
|
|
|
|
|
return text_input, video_inputs, fps_inputs, t2m_generation_kwargs |