Spaces:
Running on Zero
Running on Zero
File size: 6,967 Bytes
ddb382a 8031e67 ddb382a 27f9239 ddb382a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 | from typing import Literal, Optional
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torchvision.transforms import Normalize
from PrismAudio.models.factory import create_model_from_config
from PrismAudio.models.utils import load_ckpt_state_dict
import einshape
import sys
import os
from transformers import AutoTokenizer,AutoModelForSeq2SeqLM,AutoModel,T5EncoderModel
import logging
import os
import numpy as np
log = logging.getLogger()
import jax
import jax.numpy as jnp
from videoprism import models as vp
from data_utils.ext.synchformer import Synchformer
def copy_state_dict(model, state_dict):
"""Load state_dict to model, but only for keys that match exactly.
Args:
model (nn.Module): model to load state_dict.
state_dict (OrderedDict): state_dict to load.
"""
model_state_dict = model.state_dict()
missing_keys = []
unexpected_keys = []
for key in state_dict:
if key not in model_state_dict:
unexpected_keys.append(key)
elif state_dict[key].shape != model_state_dict[key].shape:
unexpected_keys.append(key)
for key in model_state_dict:
if key not in state_dict:
missing_keys.append(key)
print("Missing keys in state_dict:", missing_keys)
print("Unexpected keys in state_dict:", unexpected_keys)
for key in state_dict:
if key in model_state_dict and state_dict[key].shape == model_state_dict[key].shape:
if isinstance(state_dict[key], torch.nn.Parameter):
# backwards compatibility for serialized parameters
state_dict[key] = state_dict[key].data
model_state_dict[key] = state_dict[key]
model.load_state_dict(model_state_dict, strict=False)
class FeaturesUtils(nn.Module):
def __init__(
self,
*,
vae_ckpt: Optional[str] = None,
vae_config: Optional[str] = None,
synchformer_ckpt: Optional[str] = None,
enable_conditions: bool = True,
need_vae_encoder: bool = True,
):
super().__init__()
if enable_conditions:
self.t5 = AutoModelForSeq2SeqLM.from_pretrained("google/t5gemma-l-l-ul2-it").get_encoder()
self.t5tokenizer = AutoTokenizer.from_pretrained("google/t5gemma-l-l-ul2-it")
self.synchformer = Synchformer()
self.synchformer.load_state_dict(
torch.load(synchformer_ckpt, weights_only=True, map_location='cpu'))
else:
self.synchformer = None
self.tokenizer = None
if vae_ckpt is not None:
with open(vae_config) as f:
vae_config = json.load(f)
self.vae = create_model_from_config(vae_config)
print(f"Loading model checkpoint from {vae_ckpt}")
# Load checkpoint
copy_state_dict(self.vae, load_ckpt_state_dict(vae_ckpt,prefix='autoencoder.'))#,prefix='autoencoder.'
def _init_jax(self):
if hasattr(self, "flax_model") and hasattr(self, "text_tokenizer"):
return # already init
backend = jax.default_backend()
if backend != 'gpu':
log.warning(
f"JAX is running on {backend.upper()} instead of GPU! "
f"Performance will be significantly degraded."
)
self.jax_dev = jax.devices()[0] # CPU只有一个设备
else:
local_rank = int(os.environ.get("LOCAL_RANK", 0))
devices = jax.devices()
device_idx = local_rank % len(devices)
self.jax_dev = devices[device_idx]
model_name = 'videoprism_lvt_public_v1_large'
self.flax_model = vp.get_model(model_name)
state = vp.load_pretrained_weights(model_name)
self.loaded_state = jax.device_put(state, device=self.jax_dev)
self.text_tokenizer = vp.load_text_tokenizer('c4_en')
self.apply_jit = jax.jit(lambda x, y, z: self.flax_model.apply(
self.loaded_state, x, y, z, train=False, return_intermediate=True
), device=self.jax_dev)
# def train(self, mode: bool) -> None:
# return super().train(False)
def encode_video_and_text_with_videoprism(self, x: torch.Tensor, cot: str, batch_size: int = -1) -> torch.Tensor:
self._init_jax()
b, t, h, w, c = x.shape
assert c == 3 and h == 288 and w == 288
text_ids, text_paddings = vp.tokenize_texts(self.text_tokenizer, cot)
x = jax.device_put(x.cpu().numpy(), device=self.jax_dev)
text_ids = jax.device_put(text_ids, device=self.jax_dev)
text_paddings = jax.device_put(text_paddings, device=self.jax_dev)
video_embeddings, text_embeddings, outputs = self.apply_jit(
x, text_ids, text_paddings
)
frame_embed = outputs['frame_embeddings']
spatialtemporal_embed = einshape.jax_einshape(
'b(ts)d->btsd', outputs['spatiotemporal_features'], t=frame_embed.shape[0]
)
return video_embeddings[0],frame_embed[0],spatialtemporal_embed[0][0],text_embeddings
@torch.inference_mode()
def encode_video_with_sync(self, x: torch.Tensor, batch_size: int = -1) -> torch.Tensor:
assert self.synchformer is not None, 'Synchformer is not loaded'
b, t, c, h, w = x.shape
assert c == 3 and h == 224 and w == 224
segment_size = 16
step_size = 8
num_segments = (t - segment_size) // step_size + 1
segments = []
for i in range(num_segments):
segments.append(x[:, i * step_size:i * step_size + segment_size])
x = torch.stack(segments, dim=1) # (B, S, T, C, H, W)
outputs = []
if batch_size < 0:
batch_size = b
x = rearrange(x, 'b s t c h w -> (b s) 1 t c h w')
for i in range(0, b * num_segments, batch_size):
outputs.append(self.synchformer(x[i:i + batch_size]))
x = torch.cat(outputs, dim=0)
x = rearrange(x, '(b s) 1 t d -> b (s t) d', b=b)
return x
@torch.inference_mode()
def encode_t5_text(self, text: list[str]) -> torch.Tensor:
assert self.t5 is not None, 'T5 model is not loaded'
assert self.t5tokenizer is not None, 'T5 Tokenizer is not loaded'
# x: (B, L)
inputs = self.t5tokenizer(text,
padding=True,
truncation=False,
return_tensors="pt").to(self.device)
text_features = self.t5(**inputs).last_hidden_state
return text_features
@torch.inference_mode()
def encode_audio(self, x) -> torch.Tensor:
x = self.vae.encode(x)
return x
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
|