PrismAudio / data_utils /v2a_utils /feature_utils_288.py
liuhuadai's picture
Update data_utils/v2a_utils/feature_utils_288.py
27f9239 verified
from typing import Literal, Optional
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torchvision.transforms import Normalize
from PrismAudio.models.factory import create_model_from_config
from PrismAudio.models.utils import load_ckpt_state_dict
import einshape
import sys
import os
from transformers import AutoTokenizer,AutoModelForSeq2SeqLM,AutoModel,T5EncoderModel
import logging
import os
import numpy as np
log = logging.getLogger()
import jax
import jax.numpy as jnp
from videoprism import models as vp
from data_utils.ext.synchformer import Synchformer
def copy_state_dict(model, state_dict):
"""Load state_dict to model, but only for keys that match exactly.
Args:
model (nn.Module): model to load state_dict.
state_dict (OrderedDict): state_dict to load.
"""
model_state_dict = model.state_dict()
missing_keys = []
unexpected_keys = []
for key in state_dict:
if key not in model_state_dict:
unexpected_keys.append(key)
elif state_dict[key].shape != model_state_dict[key].shape:
unexpected_keys.append(key)
for key in model_state_dict:
if key not in state_dict:
missing_keys.append(key)
print("Missing keys in state_dict:", missing_keys)
print("Unexpected keys in state_dict:", unexpected_keys)
for key in state_dict:
if key in model_state_dict and state_dict[key].shape == model_state_dict[key].shape:
if isinstance(state_dict[key], torch.nn.Parameter):
# backwards compatibility for serialized parameters
state_dict[key] = state_dict[key].data
model_state_dict[key] = state_dict[key]
model.load_state_dict(model_state_dict, strict=False)
class FeaturesUtils(nn.Module):
def __init__(
self,
*,
vae_ckpt: Optional[str] = None,
vae_config: Optional[str] = None,
synchformer_ckpt: Optional[str] = None,
enable_conditions: bool = True,
need_vae_encoder: bool = True,
):
super().__init__()
if enable_conditions:
self.t5 = AutoModelForSeq2SeqLM.from_pretrained("google/t5gemma-l-l-ul2-it").get_encoder()
self.t5tokenizer = AutoTokenizer.from_pretrained("google/t5gemma-l-l-ul2-it")
self.synchformer = Synchformer()
self.synchformer.load_state_dict(
torch.load(synchformer_ckpt, weights_only=True, map_location='cpu'))
else:
self.synchformer = None
self.tokenizer = None
if vae_ckpt is not None:
with open(vae_config) as f:
vae_config = json.load(f)
self.vae = create_model_from_config(vae_config)
print(f"Loading model checkpoint from {vae_ckpt}")
# Load checkpoint
copy_state_dict(self.vae, load_ckpt_state_dict(vae_ckpt,prefix='autoencoder.'))#,prefix='autoencoder.'
def _init_jax(self):
if hasattr(self, "flax_model") and hasattr(self, "text_tokenizer"):
return # already init
backend = jax.default_backend()
if backend != 'gpu':
log.warning(
f"JAX is running on {backend.upper()} instead of GPU! "
f"Performance will be significantly degraded."
)
self.jax_dev = jax.devices()[0] # CPU只有一个设备
else:
local_rank = int(os.environ.get("LOCAL_RANK", 0))
devices = jax.devices()
device_idx = local_rank % len(devices)
self.jax_dev = devices[device_idx]
model_name = 'videoprism_lvt_public_v1_large'
self.flax_model = vp.get_model(model_name)
state = vp.load_pretrained_weights(model_name)
self.loaded_state = jax.device_put(state, device=self.jax_dev)
self.text_tokenizer = vp.load_text_tokenizer('c4_en')
self.apply_jit = jax.jit(lambda x, y, z: self.flax_model.apply(
self.loaded_state, x, y, z, train=False, return_intermediate=True
), device=self.jax_dev)
# def train(self, mode: bool) -> None:
# return super().train(False)
def encode_video_and_text_with_videoprism(self, x: torch.Tensor, cot: str, batch_size: int = -1) -> torch.Tensor:
self._init_jax()
b, t, h, w, c = x.shape
assert c == 3 and h == 288 and w == 288
text_ids, text_paddings = vp.tokenize_texts(self.text_tokenizer, cot)
x = jax.device_put(x.cpu().numpy(), device=self.jax_dev)
text_ids = jax.device_put(text_ids, device=self.jax_dev)
text_paddings = jax.device_put(text_paddings, device=self.jax_dev)
video_embeddings, text_embeddings, outputs = self.apply_jit(
x, text_ids, text_paddings
)
frame_embed = outputs['frame_embeddings']
spatialtemporal_embed = einshape.jax_einshape(
'b(ts)d->btsd', outputs['spatiotemporal_features'], t=frame_embed.shape[0]
)
return video_embeddings[0],frame_embed[0],spatialtemporal_embed[0][0],text_embeddings
@torch.inference_mode()
def encode_video_with_sync(self, x: torch.Tensor, batch_size: int = -1) -> torch.Tensor:
assert self.synchformer is not None, 'Synchformer is not loaded'
b, t, c, h, w = x.shape
assert c == 3 and h == 224 and w == 224
segment_size = 16
step_size = 8
num_segments = (t - segment_size) // step_size + 1
segments = []
for i in range(num_segments):
segments.append(x[:, i * step_size:i * step_size + segment_size])
x = torch.stack(segments, dim=1) # (B, S, T, C, H, W)
outputs = []
if batch_size < 0:
batch_size = b
x = rearrange(x, 'b s t c h w -> (b s) 1 t c h w')
for i in range(0, b * num_segments, batch_size):
outputs.append(self.synchformer(x[i:i + batch_size]))
x = torch.cat(outputs, dim=0)
x = rearrange(x, '(b s) 1 t d -> b (s t) d', b=b)
return x
@torch.inference_mode()
def encode_t5_text(self, text: list[str]) -> torch.Tensor:
assert self.t5 is not None, 'T5 model is not loaded'
assert self.t5tokenizer is not None, 'T5 Tokenizer is not loaded'
# x: (B, L)
inputs = self.t5tokenizer(text,
padding=True,
truncation=False,
return_tensors="pt").to(self.device)
text_features = self.t5(**inputs).last_hidden_state
return text_features
@torch.inference_mode()
def encode_audio(self, x) -> torch.Tensor:
x = self.vae.encode(x)
return x
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype