Spaces:
Sleeping
Sleeping
| # ignore_header_test | |
| # ruff: noqa: E402 | |
| """""" | |
| """ | |
| Transolver model. This code was modified from, https://github.com/thuml/Transolver | |
| The following license is provided from their source, | |
| MIT License | |
| Copyright (c) 2024 THUML @ Tsinghua University | |
| Permission is hereby granted, free of charge, to any person obtaining a copy | |
| of this software and associated documentation files (the "Software"), to deal | |
| in the Software without restriction, including without limitation the rights | |
| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
| copies of the Software, and to permit persons to whom the Software is | |
| furnished to do so, subject to the following conditions: | |
| The above copyright notice and this permission notice shall be included in all | |
| copies or substantial portions of the Software. | |
| THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
| AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
| OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
| SOFTWARE. | |
| """ | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| from einops import rearrange | |
| class RotaryEmbedding(nn.Module): | |
| "ROPE: Rotary Position Embedding" | |
| def __init__(self, dim, min_freq=1 / 2, scale=1.0): | |
| super().__init__() | |
| inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) | |
| self.min_freq = min_freq | |
| self.scale = scale | |
| self.register_buffer("inv_freq", inv_freq) | |
| def forward(self, coordinates, device): | |
| # coordinates [b, n] | |
| t = coordinates.to(device).type_as(self.inv_freq) | |
| t = t * (self.scale / self.min_freq) | |
| freqs = torch.einsum("... i , j -> ... i j", t, self.inv_freq) # [b, n, d//2] | |
| return torch.cat((freqs, freqs), dim=-1) # [b, n, d] | |
| def rotate_half(x): | |
| x = rearrange(x, "... (j d) -> ... j d", j=2) | |
| x1, x2 = x.unbind(dim=-2) | |
| return torch.cat((-x2, x1), dim=-1) | |
| def apply_rotary_pos_emb(t, freqs): | |
| return (t * freqs.cos()) + (rotate_half(t) * freqs.sin()) | |
| def apply_2d_rotary_pos_emb(t, freqs_x, freqs_y): | |
| # split t into first half and second half | |
| # t: [b, h, n, d] | |
| # freq_x/y: [b, n, d] | |
| d = t.shape[-1] | |
| t_x, t_y = t[..., : d // 2], t[..., d // 2 :] | |
| return torch.cat( | |
| (apply_rotary_pos_emb(t_x, freqs_x), apply_rotary_pos_emb(t_y, freqs_y)), dim=-1 | |
| ) | |
| class PositionalEncoding(nn.Module): | |
| "Implement the PE function." | |
| def __init__(self, d_model, dropout, max_len=421 * 421): | |
| super(PositionalEncoding, self).__init__() | |
| self.dropout = nn.Dropout(p=dropout) | |
| # Compute the positional encodings once in log space. | |
| pe = torch.zeros(max_len, d_model) | |
| position = torch.arange(0, max_len).unsqueeze(1) | |
| div_term = torch.exp( | |
| torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model) | |
| ) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| pe = pe.unsqueeze(0) | |
| self.register_buffer("pe", pe) | |
| def forward(self, x): | |
| x = x + self.pe[:, : x.size(1)].requires_grad_(False) | |
| return self.dropout(x) | |
| def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): | |
| """ | |
| Create sinusoidal timestep embeddings. | |
| :param timesteps: a 1-D Tensor of N indices, one per batch element. | |
| These may be fractional. | |
| :param dim: the dimension of the output. | |
| :param max_period: controls the minimum frequency of the embeddings. | |
| :return: an [N x dim] Tensor of positional embeddings. | |
| """ | |
| half = dim // 2 | |
| freqs = torch.exp( | |
| -math.log(max_period) | |
| * torch.arange(start=0, end=half, dtype=torch.float32) | |
| / half | |
| ).to(device=timesteps.device) | |
| args = timesteps[:, None].float() * freqs[None] | |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) | |
| if dim % 2: | |
| embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) | |
| return embedding | |