Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import ViTModel | |
| import math | |
| class FeatureExtractor(nn.Module): | |
| def __init__(self, | |
| vit_model_name: str, | |
| input_size: int, | |
| mean, | |
| std, | |
| ): | |
| super().__init__() | |
| self.vit = ViTModel.from_pretrained( | |
| vit_model_name, | |
| output_hidden_states=False, | |
| ignore_mismatched_sizes=True | |
| ) | |
| self.register_buffer("mean", torch.tensor(mean).view(1, 3, 1, 1)) | |
| self.register_buffer("std", torch.tensor(std).view(1, 3, 1, 1)) | |
| self.input_size = input_size | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| x: (B, 3, H, W), typically H=W=224 (or any multiple of patch_size=16). | |
| Returns: | |
| patch_feats: (B, N, hidden_size), where N = (H/16)*(W/16). | |
| """ | |
| B, C, H, W = x.shape | |
| # 1) Resize if needed | |
| if H != self.input_size or W != self.input_size: | |
| raise ValueError(f"Input image size must be {self.input_size}x{self.input_size}, but got {H}x{W}.") | |
| # 2) Normalize in-place (feature-extractor view only) | |
| x = (x - self.mean) / self.std # type: ignore | |
| outputs = self.vit(pixel_values=x) | |
| seq = outputs.last_hidden_state | |
| patch_feats = seq[:, 1:, :] # shape (B, N, hidden_size) | |
| return patch_feats | |
| def get_sinusoidal_pe(steps, d_model, device): | |
| """ | |
| Compute sinusoidal positional encodings for the given steps. | |
| Args: | |
| steps: Tensor of step indices (n_steps,) | |
| d_model: Feature dimension (e.g., 768) | |
| device: Device to place the tensor on | |
| Returns: | |
| pe: Positional encodings (n_steps, d_model) | |
| """ | |
| steps = steps.to(torch.float32) | |
| div_term = torch.exp(torch.arange( | |
| 0, d_model, 2, device=device) * (-math.log(10000.0) / d_model)) | |
| pe = torch.zeros(len(steps), d_model, device=device) | |
| pe[:, 0::2] = torch.sin(steps.unsqueeze(1) * div_term) | |
| pe[:, 1::2] = torch.cos(steps.unsqueeze(1) * div_term[: d_model // 2]) | |
| return pe | |
| class StrokeDecoderLayer(nn.Module): | |
| def __init__(self, feature_dim, n_heads, ff_dim, dropout): | |
| super().__init__() | |
| self.cross_norm = nn.LayerNorm(feature_dim) | |
| self.cross_attn = nn.MultiheadAttention(embed_dim=feature_dim, | |
| num_heads=n_heads, | |
| dropout=dropout, | |
| batch_first=True) | |
| self.self_norm = nn.LayerNorm(feature_dim) | |
| self.self_attn = nn.MultiheadAttention(embed_dim=feature_dim, | |
| num_heads=n_heads, | |
| dropout=dropout, | |
| batch_first=True) | |
| self.ffn_norm = nn.LayerNorm(feature_dim) | |
| self.ffn = nn.Sequential( | |
| nn.Linear(feature_dim, ff_dim), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(dropout), | |
| nn.Linear(ff_dim, feature_dim), | |
| nn.Dropout(dropout) | |
| ) | |
| def forward(self, Q, kv): | |
| # Self-attention | |
| Q = Q + self.self_attn(self.self_norm(Q), | |
| self.self_norm(Q), | |
| self.self_norm(Q), | |
| need_weights=False)[0] | |
| # Cross-attention | |
| Q = Q + self.cross_attn(self.cross_norm(Q), | |
| kv, kv, | |
| need_weights=False)[0] | |
| # Final FFN | |
| Q = Q + self.ffn(self.ffn_norm(Q)) | |
| return Q | |
| class StrokeTransformer(nn.Module): | |
| def __init__(self, | |
| feature_dim: int, | |
| points_per_stroke: int, | |
| action_dim: int, | |
| n_heads: int, | |
| dropout: float, | |
| ff_dim: int, | |
| n_layers: int = 5): | |
| super().__init__() | |
| self.feature_dim = feature_dim | |
| self.points_per_stroke = points_per_stroke | |
| self.action_dim = action_dim | |
| self.max_steps = 300 | |
| # learnable base queries | |
| self.base_queries = nn.Parameter(torch.randn(self.max_steps, feature_dim)) | |
| # stack of decoder layers | |
| self.layers = nn.ModuleList([ | |
| StrokeDecoderLayer(feature_dim, n_heads, ff_dim, dropout) | |
| for _ in range(n_layers) | |
| ]) | |
| # output head | |
| self.param_head = nn.Sequential( | |
| nn.LayerNorm(feature_dim), | |
| nn.Linear(feature_dim, feature_dim), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(feature_dim, action_dim) | |
| ) | |
| def forward(self, patch_feats: torch.Tensor): | |
| """ | |
| Args: | |
| patch_feats: (B, N, D) | |
| n_steps: number of strokes | |
| Returns: | |
| strokes: (B, n_steps, action_dim) | |
| """ | |
| B, N, D = patch_feats.shape | |
| device = patch_feats.device | |
| # initialize queries (same for batch) | |
| Q = self.base_queries.unsqueeze(0).expand(B, -1, -1) # (B, T, D) | |
| abs_steps = torch.arange(self.max_steps).to(device) | |
| pe_abs = get_sinusoidal_pe(abs_steps, D,device) | |
| Q = Q + pe_abs.unsqueeze(0) | |
| # run through decoder layers | |
| for layer in self.layers: | |
| Q = layer(Q, patch_feats) | |
| # predict parameters | |
| strokes = self.param_head(Q) | |
| # split + squashing like before | |
| pp = self.points_per_stroke | |
| part1 = torch.sigmoid(strokes[..., 0:2*pp]) | |
| part2 = torch.sigmoid(strokes[..., 2*pp:2*pp+2]) | |
| part3 = torch.sigmoid(strokes[..., 2*pp+2:2*pp+4]) | |
| part4 = torch.sigmoid(strokes[..., 2*pp+4:]) | |
| strokes = torch.cat([part1, part2, part3, part4], dim=-1) | |
| return strokes | |