| import copy |
| import math |
| import random |
| from collections import OrderedDict |
| from dataclasses import asdict |
| from functools import partial |
| from logging import getLogger |
| from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, Literal |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from einops import rearrange |
| from timm.layers import DropPath |
| from torch import nn |
| from torch.nn import functional as F |
| from torch.nn.init import constant_, xavier_normal_, xavier_uniform_ |
| from torch.nn.parameter import Parameter |
| from torch.utils.checkpoint import checkpoint |
|
|
| from core.vision_encoder.rope import Rope2D |
| from core.vision_encoder.config import PEConfig, PETextConfig, PE_VISION_CONFIG, PE_TEXT_CONFIG, fetch_pe_checkpoint |
|
|
|
|
|
|
| logger = getLogger() |
|
|
|
|
|
|
| class LayerScale(nn.Module): |
| def __init__(self, dim, init_values=1e-5, inplace=False): |
| super().__init__() |
| self.inplace = inplace |
| self.dim = dim |
| self.init_values = init_values |
|
|
| def forward(self, x): |
| return x.mul_(self.gamma) if self.inplace else x * self.gamma |
|
|
| def init_tensors(self): |
| self.gamma = nn.Parameter(self.init_values * torch.ones(self.dim)) |
|
|
|
|
| class AttentionPooling(nn.Module): |
| def __init__( |
| self, |
| embed_dim: int, |
| num_heads: int, |
| num_probe: int = 1, |
| mlp_ratio: int = 4, |
| act_layer: Callable = nn.GELU, |
| norm_layer: Callable = nn.LayerNorm, |
| ): |
| super().__init__() |
|
|
| self.embed_dim = embed_dim |
| self.num_heads = num_heads |
|
|
| assert ( |
| self.embed_dim % num_heads == 0 |
| ), "embed_dim must be divisible by num_heads" |
|
|
| self.probe = nn.Parameter(torch.randn(1, num_probe, self.embed_dim)) |
| self.attn = nn.MultiheadAttention( |
| self.embed_dim, self.num_heads, batch_first=True |
| ) |
|
|
| self.layernorm = norm_layer(embed_dim) |
| self.mlp_width = int(embed_dim * mlp_ratio) |
| self.mlp = nn.Sequential( |
| OrderedDict( |
| [ |
| ("c_fc", nn.Linear(self.embed_dim, self.mlp_width)), |
| ("gelu", act_layer()), |
| ("c_proj", nn.Linear(self.mlp_width, self.embed_dim)), |
| ] |
| ) |
| ) |
|
|
| def forward(self, x: torch.Tensor): |
| batch, _, _ = x.shape |
|
|
| q = self.probe.repeat((batch, 1, 1)).to(x.dtype) |
| x = self.attn(q, x, x, need_weights=False)[0] |
| x = x + self.mlp(self.layernorm(x)) |
|
|
| return x |
|
|
|
|
| class SelfAttention(nn.Module): |
| r""" |
| Implements sequence packed attention and RoPe |
| """ |
|
|
| def __init__( |
| self, |
| embed_dim: int, |
| num_heads: int, |
| rope: Optional[nn.Module] = None, |
| ): |
| super(SelfAttention, self).__init__() |
| self.embed_dim = embed_dim |
|
|
| self.num_heads = num_heads |
| self.head_dim = embed_dim // num_heads |
| assert ( |
| self.head_dim * num_heads == self.embed_dim |
| ), "embed_dim must be divisible by num_heads" |
|
|
| |
| self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) |
| self.in_proj_bias = Parameter(torch.empty(3 * embed_dim)) |
| self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) |
|
|
| self.rope = rope |
| self.scale = self.head_dim ** (-0.5) |
|
|
| def init_tensors(self): |
| xavier_uniform_(self.in_proj_weight) |
| constant_(self.in_proj_bias, 0.0) |
| constant_(self.out_proj.bias, 0.0) |
|
|
| def forward(self, x, attn_mask=None): |
| batch, seq, embed_dim = x.shape |
| proj = F.linear(x, self.in_proj_weight, self.in_proj_bias) |
|
|
| |
| proj = ( |
| proj.unflatten(-1, (3, embed_dim)) |
| .unsqueeze(0) |
| .transpose(0, -2) |
| .squeeze(-2) |
| .contiguous() |
| ) |
| q, k, v = proj[0], proj[1], proj[2] |
|
|
| |
| q = rearrange(q, "b s (h d) -> b h s d", h=self.num_heads) |
| k = rearrange(k, "b s (h d) -> b h s d", h=self.num_heads) |
| v = rearrange(v, "b s (h d) -> b h s d", h=self.num_heads) |
|
|
| if self.rope: |
| q, k = self.rope(q, k) |
|
|
| attn = F.scaled_dot_product_attention( |
| q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale |
| ) |
| attn = rearrange(attn, "b h s d -> b s (h d)") |
|
|
| return F.linear(attn, self.out_proj.weight, self.out_proj.bias) |
|
|
|
|
| class ResidualAttentionBlock(nn.Module): |
| def __init__( |
| self, |
| d_model: int, |
| n_head: int, |
| mlp_ratio: float = 4.0, |
| ls_init_value: float = None, |
| act_layer: Callable = nn.GELU, |
| norm_layer: Callable = nn.LayerNorm, |
| drop_path: float = 0.0, |
| rope: Optional[nn.Module] = None, |
| ): |
| super().__init__() |
|
|
| if rope: |
| self.attn = SelfAttention(d_model, n_head, rope=rope) |
| else: |
| self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True) |
|
|
| self.ls_1 = ( |
| LayerScale(d_model, ls_init_value) |
| if ls_init_value is not None |
| else nn.Identity() |
| ) |
| self.ls_2 = ( |
| LayerScale(d_model, ls_init_value) |
| if ls_init_value is not None |
| else nn.Identity() |
| ) |
|
|
| self.ln_1 = norm_layer(d_model) |
| self.ln_2 = norm_layer(d_model) |
|
|
| self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
| self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
|
|
| mlp_width = int(d_model * mlp_ratio) |
| self.mlp = nn.Sequential( |
| OrderedDict( |
| [ |
| ("c_fc", nn.Linear(d_model, mlp_width)), |
| ("gelu", act_layer()), |
| ("c_proj", nn.Linear(mlp_width, d_model)), |
| ] |
| ) |
| ) |
|
|
| def _call_attn( |
| self, |
| q_x: torch.Tensor, |
| attn_mask: Optional[torch.Tensor] = None, |
| ): |
|
|
| if attn_mask is not None: |
| |
| if not attn_mask.dtype == torch.bool: |
| attn_mask = attn_mask.to(q_x.dtype) |
|
|
| if isinstance(self.attn, SelfAttention): |
| return self.attn(q_x, attn_mask=attn_mask) |
| else: |
| return self.attn(q_x, q_x, q_x, attn_mask=attn_mask, need_weights=False)[0] |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| attn_mask: Optional[torch.Tensor] = None, |
| ): |
| x = x + self.drop_path1( |
| self.ls_1(self._call_attn(self.ln_1(x), attn_mask=attn_mask)) |
| ) |
| x = x + self.drop_path2(self.ls_2(self.mlp(self.ln_2(x)))) |
| return x |
|
|
|
|
| class Transformer(nn.Module): |
| def __init__( |
| self, |
| width: int, |
| layers: int, |
| heads: int, |
| mlp_ratio: float = 4.0, |
| ls_init_value: float = None, |
| act_layer: Callable = nn.GELU, |
| norm_layer: Callable = nn.LayerNorm, |
| drop_path: float = 0.0, |
| rope: Optional[nn.Module] = None, |
| ): |
| super().__init__() |
| self.width = width |
| self.layers = layers |
| self.grad_checkpointing = False |
|
|
| self.resblocks = nn.ModuleList( |
| [ |
| ResidualAttentionBlock( |
| width, |
| heads, |
| mlp_ratio, |
| ls_init_value=ls_init_value, |
| act_layer=act_layer, |
| norm_layer=norm_layer, |
| drop_path=drop_path, |
| rope=rope, |
| ) |
| for _ in range(layers) |
| ] |
| ) |
|
|
| @torch.jit.ignore |
| def set_grad_checkpointing(self, enable=True): |
| self.grad_checkpointing = enable |
|
|
| @torch.jit.ignore |
| def truncate(self, layer_idx: int): |
| """ Delete layers so the last layer is the given layer index. """ |
| self.layers = ((self.layers + layer_idx) % self.layers) + 1 |
| self.resblocks = nn.ModuleList(self.resblocks[:self.layers]) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| attn_mask: Optional[torch.Tensor] = None, |
| layer_idx: int = -1, |
| ): |
| stop_idx = (self.layers + layer_idx) % self.layers |
|
|
| for i, r in enumerate(self.resblocks): |
| if self.grad_checkpointing and not torch.jit.is_scripting(): |
| |
| x = checkpoint(r, x, None, None, attn_mask) |
| else: |
| x = r(x, attn_mask=attn_mask) |
| |
| if i == stop_idx: |
| break |
|
|
| return x |
|
|
|
|
| class VisionTransformer(nn.Module): |
| def __init__( |
| self, |
| patch_size: int, |
| width: int, |
| layers: int, |
| heads: int, |
| mlp_ratio: float, |
| act_layer: Callable = nn.GELU, |
| norm_layer: Callable = partial(nn.LayerNorm, eps=1e-5), |
| use_ln_pre: bool = True, |
| use_ln_post: bool = True, |
| ls_init_value: float = None, |
| drop_path: float = 0.0, |
| image_size: int = 448, |
| use_abs_posemb: bool = True, |
| use_rope2d: bool = True, |
| use_cls_token: bool = False, |
| output_dim: Optional[int] = 1280, |
| attn_pooler_heads: int = 8, |
| pool_type: Literal["attn", "tok", "avg", "none"] = "attn", |
| ): |
| super().__init__() |
| assert pool_type in ("attn", "tok", "avg", "none") |
| self.pool_type = pool_type |
| self.patch_size = patch_size |
|
|
| self.output_dim = output_dim or width |
| self.proj_dim = output_dim |
| self.heads = heads |
| self.width = width |
| self.layers = layers |
|
|
| self.use_abs_posemb = use_abs_posemb |
| self.use_cls_token = use_cls_token |
| self.use_rope2d = use_rope2d |
| self.image_size = image_size |
|
|
| self.conv1 = nn.Conv2d( |
| in_channels=3, |
| out_channels=width, |
| kernel_size=patch_size, |
| stride=patch_size, |
| bias=False, |
| ) |
| self.rope = ( |
| Rope2D( |
| dim=width // heads, |
| use_cls_token=self.use_cls_token, |
| ) |
| if self.use_rope2d |
| else None |
| ) |
|
|
| self.ln_pre = norm_layer(width) if use_ln_pre else nn.Identity() |
| self.ln_post = norm_layer(self.width) if use_ln_post else nn.Identity() |
|
|
| self.transformer = Transformer( |
| width, |
| layers, |
| heads, |
| mlp_ratio, |
| ls_init_value=ls_init_value, |
| act_layer=act_layer, |
| norm_layer=norm_layer, |
| drop_path=drop_path, |
| rope=self.rope, |
| ) |
|
|
| if pool_type == "attn": |
| self.attn_pool = AttentionPooling( |
| embed_dim=width, |
| num_heads=attn_pooler_heads, |
| act_layer=act_layer, |
| norm_layer=norm_layer, |
| ) |
| else: |
| self.attn_pool = None |
|
|
| self.init_tensors() |
|
|
|
|
| def init_tensors(self): |
| def init_submodule_tensors(module): |
| for name, child in module.named_children(): |
| if hasattr(child, "init_tensors"): |
| logger.debug(f"Initializing tensors for submodule: {name}") |
| child.init_tensors() |
| init_submodule_tensors(child) |
|
|
| init_submodule_tensors(self) |
| self.rope.init_tensors() |
|
|
| |
| init_scale = self.width**-0.5 |
|
|
| if self.use_cls_token: |
| self.class_embedding = nn.Parameter(init_scale * torch.randn(self.width)) |
|
|
| if self.use_abs_posemb: |
| self.posemb_grid_size = self.image_size // self.patch_size |
| self.positional_embedding = nn.Parameter( |
| init_scale |
| * torch.randn( |
| int(self.use_cls_token) + self.posemb_grid_size**2, self.width |
| ) |
| ) |
|
|
| if self.proj_dim is not None: |
| self.proj = nn.Parameter( |
| init_scale * torch.randn(self.width, self.proj_dim) |
| ) |
|
|
|
|
| def load_ckpt(self, ckpt_path: str, verbose: bool = True): |
| _sd = torch.load(ckpt_path, weights_only=True) |
| if "state_dict" in _sd: |
| _sd = _sd["state_dict"] |
| elif "weights" in _sd: |
| _sd = _sd["weights"] |
|
|
| |
| _sd = {k.replace("module.", ""): v for k, v in _sd.items()} |
| if any(k.startswith("visual.") for k in _sd): |
| _sd = {k.replace("visual.", ""): v for k, v in _sd.items() if "visual" in k} |
|
|
| m, u = self.load_state_dict(_sd, strict=False) |
|
|
| if verbose or (m or u): |
| logger.info(f"Missing keys for loading vision encoder: {m}") |
| logger.info(f"Unexpected keys for loading vision encoder: {u}") |
| print(f"Missing keys for loading vision encoder: {m}") |
| print(f"Unexpected keys for loading vision encoder: {u}") |
|
|
|
|
| def truncate(self, layer_idx: int): |
| """ Delete layers so the last layer is the given layer index. """ |
| self.transformer.truncate(layer_idx) |
| self.layers = self.transformer.layers |
|
|
|
|
| @classmethod |
| def from_config( |
| cls, |
| name: str, |
| pretrained: bool = False, |
| checkpoint_path: Optional[str] = None, |
| **kwdargs |
| ): |
| if name not in PE_VISION_CONFIG: |
| raise RuntimeError(f"{name} not found in configs.") |
| |
| args = asdict(PE_VISION_CONFIG[name]) |
| args.update(kwdargs) |
| |
| model = cls(**args) |
| if pretrained: |
| model.load_ckpt(fetch_pe_checkpoint(name, checkpoint_path)) |
| |
| return model |
| |
| @classmethod |
| def available_configs(cls): |
| return list(PE_VISION_CONFIG.keys()) |
|
|
|
|
| @torch.jit.ignore |
| def set_grad_checkpointing(self, enable=True): |
| self.transformer.set_grad_checkpointing(enable=enable) |
|
|
| def _sample_abs_posemb(self, grid_h: int, grid_w: int): |
| """Interpolates the absolute position embedding if necessary.""" |
| if self.posemb_grid_size == grid_h and self.posemb_grid_size == grid_w: |
| return self.positional_embedding[None, ...] |
|
|
| pos_embed = self.positional_embedding |
| if self.use_cls_token: |
| cls_token_embed, pos_embed = pos_embed[:1], pos_embed[1:] |
|
|
| pos_embed = ( |
| pos_embed.reshape(1, self.posemb_grid_size, self.posemb_grid_size, -1) |
| .permute(0, 3, 1, 2) |
| .contiguous() |
| ) |
| pos_embed = F.interpolate( |
| pos_embed, size=(grid_h, grid_w), mode="bilinear", align_corners=False |
| ) |
| pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(-1, self.width).contiguous() |
|
|
| if self.use_cls_token: |
| pos_embed = torch.cat([cls_token_embed, pos_embed], dim=0) |
|
|
| return pos_embed[None, ...] |
|
|
| def _pool(self, x: torch.Tensor): |
| if self.pool_type == "tok": |
| return x[:, 0] |
| elif self.pool_type == "avg": |
| return x.mean(dim=1) |
| elif self.pool_type == "attn": |
| return self.attn_pool(x).squeeze(1) |
| elif self.pool_type == "none": |
| return x |
| else: |
| raise NotImplementedError |
|
|
| def forward_features( |
| self, |
| x: torch.Tensor, |
| norm: bool = False, |
| layer_idx: int = -1, |
| strip_cls_token: bool = False |
| ): |
| batch, _, h, w = x.shape |
| grid_h, grid_w = h // self.patch_size, w // self.patch_size |
|
|
| x = self.conv1(x) |
| x = x.permute(0, 2, 3, 1).reshape(batch, -1, self.width) |
|
|
| if self.use_cls_token: |
| x = torch.cat( |
| [self.class_embedding.view(1, 1, -1).expand(batch, -1, -1), x], |
| dim=1, |
| ) |
|
|
| if self.use_abs_posemb: |
| x = x + self._sample_abs_posemb(grid_h, grid_w) |
|
|
| if self.use_rope2d: |
| self.rope.update_grid(x.device, grid_h, grid_w) |
|
|
| x = self.ln_pre(x) |
| x = self.transformer(x, layer_idx=layer_idx) |
|
|
| if norm: |
| x = self.ln_post(x) |
|
|
| if strip_cls_token and self.use_cls_token: |
| x = x[:, 1:, :] |
|
|
| return x |
|
|
| def forward(self, x: torch.Tensor, **kwargs): |
| x = self.forward_features(x, norm=True, **kwargs) |
| x = self._pool(x) |
|
|
| if self.proj_dim is not None: |
| x = x @ self.proj |
|
|
| return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| class TextTransformer(nn.Module): |
| def __init__( |
| self, |
| context_length: int = 72, |
| vocab_size: int = 49408, |
| width: int = 512, |
| heads: int = 8, |
| layers: int = 12, |
| mlp_ratio: float = 4.0, |
| ls_init_value: float = None, |
| output_dim: int = 1280, |
| no_causal_mask: bool = False, |
| pad_id: int = 0, |
| pool_type: str = "argmax", |
| proj_bias: bool = False, |
| act_layer: Callable = nn.GELU, |
| norm_layer: Callable = partial(nn.LayerNorm, eps=1e-5), |
| output_tokens: bool = False, |
| use_ln_post: bool = True, |
| ): |
| super().__init__() |
| assert pool_type in ("first", "last", "argmax", "none") |
| self.pool_type = pool_type |
| self.output_tokens = output_tokens |
| self.num_pos = self.context_length = context_length |
| self.vocab_size = vocab_size |
| self.width = width |
| self.output_dim = output_dim |
| self.heads = heads |
| self.pad_id = pad_id |
| self.layers = layers |
|
|
| self.token_embedding = nn.Embedding(vocab_size, width) |
| self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) |
|
|
| self.transformer = Transformer( |
| width=width, |
| layers=layers, |
| heads=heads, |
| mlp_ratio=mlp_ratio, |
| ls_init_value=ls_init_value, |
| act_layer=act_layer, |
| norm_layer=norm_layer, |
| ) |
|
|
| self.ln_final = norm_layer(width) if use_ln_post else nn.Identity() |
|
|
| if no_causal_mask: |
| self.attn_mask = None |
| else: |
| self.register_buffer( |
| "attn_mask", self.build_causal_mask(), persistent=False |
| ) |
|
|
| if pool_type == "attn" or pool_type == "attn_eos": |
| self.attn_pool = AttentionPooling( |
| embed_dim=width, |
| num_heads=heads, |
| act_layer=act_layer, |
| norm_layer=norm_layer, |
| ) |
| else: |
| self.attn_pool = None |
|
|
| if proj_bias: |
| self.text_projection = nn.Linear(width, output_dim) |
| else: |
| self.text_projection = nn.Parameter(torch.empty(width, output_dim)) |
|
|
| def build_causal_mask(self): |
| |
| |
| mask = torch.empty(self.num_pos, self.num_pos) |
| mask.fill_(float("-inf")) |
| mask.triu_(1) |
| return mask |
|
|
| def load_ckpt(self, ckpt_path: str, verbose: bool = True): |
| _sd = torch.load(ckpt_path, weights_only=True) |
| if "state_dict" in _sd: |
| _sd = _sd["state_dict"] |
| elif "weights" in _sd: |
| _sd = _sd["weights"] |
|
|
| _sd = {k.replace("module.", ""): v for k, v in _sd.items()} |
|
|
| m, u = self.load_state_dict(_sd, strict=False) |
| |
| if verbose or (m or u): |
| logger.info(f"Missing keys for loading model: {m}") |
| logger.info(f"Unexpected keys for loading model: {u}") |
| print(f"Missing keys for loading model: {m}") |
| print(f"Unexpected keys for loading model: {u}") |
|
|
| def build_cls_mask(self, text): |
| cls_mask = (text != self.pad_id).unsqueeze(1) |
| cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True) |
| additive_mask = torch.empty(cls_mask.shape, device=cls_mask.device) |
| additive_mask.fill_(0) |
| additive_mask.masked_fill_(~cls_mask, float("-inf")) |
| additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0) |
| return additive_mask |
|
|
| def text_global_pool( |
| self, x, text: Optional[torch.Tensor] = None, pool_type: str = "argmax" |
| ): |
| if pool_type == "first": |
| pooled, tokens = x[:, 0], x[:, 1:] |
| elif pool_type == "last": |
| pooled, tokens = x[:, -1], x[:, :-1] |
| elif pool_type == "argmax": |
| |
| assert text is not None |
| pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x |
| else: |
| pooled = tokens = x |
|
|
| return pooled, tokens |
|
|
| def forward(self, text): |
| seq_len = text.shape[1] |
| x = self.token_embedding( |
| text |
| ) |
| attn_mask = self.attn_mask |
| if attn_mask is not None: |
| attn_mask = attn_mask[:seq_len, :seq_len] |
|
|
| x = x + self.positional_embedding[:seq_len] |
| x = self.transformer(x, attn_mask=attn_mask) |
|
|
| x = self.ln_final(x) |
| pooled, tokens = self.text_global_pool(x, text, pool_type=self.pool_type) |
|
|
| if self.text_projection is not None: |
| if isinstance(self.text_projection, nn.Linear): |
| pooled = self.text_projection(pooled) |
| else: |
| pooled = pooled @ self.text_projection |
|
|
| if self.output_tokens: |
| return pooled, tokens |
|
|
| return pooled |
|
|
|
|
|
|
|
|
| class CLIP(TextTransformer): |
| def __init__( |
| self, |
| vision_cfg: PEConfig, |
| text_cfg: PETextConfig, |
| init_logit_scale: float = np.log(1 / 0.07) |
| ): |
| super(CLIP, self).__init__(**asdict(text_cfg)) |
| self.visual = VisionTransformer(**asdict(vision_cfg)) |
| self.image_size = self.visual.image_size |
| self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale) |
|
|
|
|
| def encode_image(self, image, normalize: bool = False): |
| x = self.visual(image) |
| return F.normalize(x, dim=-1) if normalize else x |
|
|
| def encode_video(self, video, normalize: bool = False): |
| b, n, c, h, w = video.shape |
| frms = video.reshape(b * n, c, h, w) |
| frm_feats = self.encode_image(frms, normalize=normalize) |
| video_feats = frm_feats.reshape(b, n, -1) |
| video_feats = video_feats.mean(dim=1) |
| return video_feats |
|
|
| def encode_text(self, text, normalize: bool = False): |
| x = super().forward(text) |
| return F.normalize(x, dim=-1) if normalize else x |
|
|
| def forward( |
| self, |
| image: Optional[torch.Tensor] = None, |
| text: Optional[torch.Tensor] = None, |
| ): |
| image_features = ( |
| self.encode_image(image, normalize=True) if image is not None else None |
| ) |
| text_features = ( |
| self.encode_text(text, normalize=True) if text is not None else None |
| ) |
| return image_features, text_features, self.logit_scale.exp() |
| |
|
|
| @classmethod |
| def from_config( |
| cls, |
| name: str, |
| pretrained: bool = False, |
| checkpoint_path: Optional[str] = None |
| ): |
| if name not in PE_VISION_CONFIG or name not in PE_TEXT_CONFIG: |
| raise RuntimeError(f"{name} not found in configs.") |
| |
| model = cls(PE_VISION_CONFIG[name], PE_TEXT_CONFIG[name]) |
| if pretrained: |
| model.load_ckpt(fetch_pe_checkpoint(name, checkpoint_path)) |
| |
| return model |
|
|
| @classmethod |
| def available_configs(cls): |
| return [k for k in PE_VISION_CONFIG if k in PE_TEXT_CONFIG] |