Spaces:
Runtime error
Runtime error
| import math, torch | |
| from functools import partial | |
| from torch import nn, Tensor | |
| from torchvision.transforms.functional import normalize | |
| from transformers import AutoModel | |
| from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD | |
| from .configuration_live import LiveConfigMixin | |
| def _siglip_vision_encode(vision_model: nn.Module, frames: Tensor, frame_token_cls: bool, frame_token_pooled: tuple, | |
| mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], rescale_factor=0.00392156862745098, **kwargs): | |
| frames = normalize(frames * rescale_factor, mean=mean, std=std) | |
| with torch.cuda.amp.autocast(): | |
| vision_outputs = vision_model(frames) | |
| last_hidden_state = vision_outputs.last_hidden_state | |
| if frame_token_pooled: | |
| s = int(math.sqrt(last_hidden_state.shape[1])) | |
| spatial_tokens = torch.nn.functional.adaptive_avg_pool2d( | |
| last_hidden_state.reshape( | |
| last_hidden_state.shape[0], s, s, last_hidden_state.shape[-1] | |
| ).permute(0, 3, 1, 2), | |
| frame_token_pooled | |
| ).flatten(2, 3).permute(0, 2, 1) | |
| if not frame_token_cls: | |
| return spatial_tokens | |
| if frame_token_cls: | |
| cls_token = vision_outputs.pooler_output[:, None] | |
| if not frame_token_pooled: | |
| return cls_token | |
| return torch.cat([cls_token, spatial_tokens], dim=1) | |
| def _clip_vision_encode(vision_model: nn.Module, frames: Tensor, frame_token_cls: bool, frame_token_pooled: tuple, | |
| mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, rescale_factor=0.00392156862745098, **kwargs): | |
| frames = normalize(frames * rescale_factor, mean=mean, std=std) | |
| with torch.cuda.amp.autocast(): | |
| vision_outputs = vision_model(frames) | |
| last_hidden_state = vision_outputs.last_hidden_state | |
| if frame_token_pooled: | |
| s = int(math.sqrt(last_hidden_state.shape[1])) | |
| spatial_tokens = torch.nn.functional.adaptive_avg_pool2d( | |
| last_hidden_state[:,1:].reshape( | |
| last_hidden_state.shape[0], s, s, last_hidden_state.shape[-1] | |
| ).permute(0, 3, 1, 2), | |
| frame_token_pooled | |
| ).flatten(2, 3).permute(0, 2, 1) | |
| if not frame_token_cls: | |
| return spatial_tokens | |
| if frame_token_cls: | |
| cls_token = last_hidden_state[:,0] | |
| if not frame_token_pooled: | |
| return cls_token | |
| return torch.cat([cls_token, spatial_tokens], dim=1) | |
| def build_live_vision(config: LiveConfigMixin): | |
| model = AutoModel.from_pretrained(config.vision_pretrained).vision_model | |
| if 'google/siglip-large-patch16-384' == config.vision_pretrained: | |
| return model, partial(_siglip_vision_encode, frame_token_cls=config.frame_token_cls, frame_token_pooled=config.frame_token_pooled) | |
| elif 'laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90k' == config.vision_pretrained or 'openai/clip-vit-large-patch14-336' == config.vision_pretrained: | |
| return model, partial(_clip_vision_encode, config) | |
| else: | |
| raise ValueError(f'Unverified vision_pretrained: {config.vision_pretrained}') |