Spaces:
Runtime error
Runtime error
| """ | |
| PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation | |
| Official implementation of the paper: | |
| "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" | |
| by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis | |
| Licensed under a modified MIT license | |
| """ | |
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import math | |
| from typing import Any, Optional, Tuple | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| # Rotary Positional Encoding, adapted from: | |
| # 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py | |
| # 2. https://github.com/naver-ai/rope-vit | |
| # 3. https://github.com/lucidrains/rotary-embedding-torch | |
| def init_t_xy(end_x: int, end_y: int): | |
| t = torch.arange(end_x * end_y, dtype=torch.float32) | |
| t_x = (t % end_x).float() | |
| t_y = torch.div(t, end_x, rounding_mode="floor").float() | |
| return t_x, t_y | |
| def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0): | |
| freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) | |
| freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) | |
| t_x, t_y = init_t_xy(end_x, end_y) | |
| freqs_x = torch.outer(t_x, freqs_x) | |
| freqs_y = torch.outer(t_y, freqs_y) | |
| freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) | |
| freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y) | |
| return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1) | |
| def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): | |
| ndim = x.ndim | |
| assert 0 <= 1 < ndim | |
| assert freqs_cis.shape == (x.shape[-2], x.shape[-1]) | |
| shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)] | |
| return freqs_cis.view(*shape) | |
| def apply_rotary_enc( | |
| xq: torch.Tensor, | |
| xk: torch.Tensor, | |
| freqs_cis: torch.Tensor, | |
| repeat_freqs_k: bool = False, | |
| ): | |
| xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | |
| xk_ = ( | |
| torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | |
| if xk.shape[-2] != 0 | |
| else None | |
| ) | |
| freqs_cis = reshape_for_broadcast(freqs_cis, xq_) | |
| xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | |
| if xk_ is None: | |
| # no keys to rotate, due to dropout | |
| return xq_out.type_as(xq).to(xq.device), xk | |
| # repeat freqs along seq_len dim to match k seq_len | |
| if repeat_freqs_k: | |
| r = xk_.shape[-2] // xq_.shape[-2] | |
| if freqs_cis.is_cuda: | |
| freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) | |
| else: | |
| # torch.repeat on complex numbers may not be supported on non-CUDA devices | |
| # (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten | |
| freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3) | |
| xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | |
| return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) | |