import torch import torch.nn as nn import torch.nn.functional as F from transformers import ViTModel import math class FeatureExtractor(nn.Module): def __init__(self, vit_model_name: str, input_size: int, mean, std, ): super().__init__() self.vit = ViTModel.from_pretrained( vit_model_name, output_hidden_states=False, ignore_mismatched_sizes=True ) self.register_buffer("mean", torch.tensor(mean).view(1, 3, 1, 1)) self.register_buffer("std", torch.tensor(std).view(1, 3, 1, 1)) self.input_size = input_size def forward(self, x: torch.Tensor) -> torch.Tensor: """ x: (B, 3, H, W), typically H=W=224 (or any multiple of patch_size=16). Returns: patch_feats: (B, N, hidden_size), where N = (H/16)*(W/16). """ B, C, H, W = x.shape # 1) Resize if needed if H != self.input_size or W != self.input_size: raise ValueError(f"Input image size must be {self.input_size}x{self.input_size}, but got {H}x{W}.") # 2) Normalize in-place (feature-extractor view only) x = (x - self.mean) / self.std # type: ignore outputs = self.vit(pixel_values=x) seq = outputs.last_hidden_state patch_feats = seq[:, 1:, :] # shape (B, N, hidden_size) return patch_feats def get_sinusoidal_pe(steps, d_model, device): """ Compute sinusoidal positional encodings for the given steps. Args: steps: Tensor of step indices (n_steps,) d_model: Feature dimension (e.g., 768) device: Device to place the tensor on Returns: pe: Positional encodings (n_steps, d_model) """ steps = steps.to(torch.float32) div_term = torch.exp(torch.arange( 0, d_model, 2, device=device) * (-math.log(10000.0) / d_model)) pe = torch.zeros(len(steps), d_model, device=device) pe[:, 0::2] = torch.sin(steps.unsqueeze(1) * div_term) pe[:, 1::2] = torch.cos(steps.unsqueeze(1) * div_term[: d_model // 2]) return pe class StrokeDecoderLayer(nn.Module): def __init__(self, feature_dim, n_heads, ff_dim, dropout): super().__init__() self.cross_norm = nn.LayerNorm(feature_dim) self.cross_attn = nn.MultiheadAttention(embed_dim=feature_dim, num_heads=n_heads, dropout=dropout, batch_first=True) self.self_norm = nn.LayerNorm(feature_dim) self.self_attn = nn.MultiheadAttention(embed_dim=feature_dim, num_heads=n_heads, dropout=dropout, batch_first=True) self.ffn_norm = nn.LayerNorm(feature_dim) self.ffn = nn.Sequential( nn.Linear(feature_dim, ff_dim), nn.ReLU(inplace=True), nn.Dropout(dropout), nn.Linear(ff_dim, feature_dim), nn.Dropout(dropout) ) def forward(self, Q, kv): # Self-attention Q = Q + self.self_attn(self.self_norm(Q), self.self_norm(Q), self.self_norm(Q), need_weights=False)[0] # Cross-attention Q = Q + self.cross_attn(self.cross_norm(Q), kv, kv, need_weights=False)[0] # Final FFN Q = Q + self.ffn(self.ffn_norm(Q)) return Q class StrokeTransformer(nn.Module): def __init__(self, feature_dim: int, points_per_stroke: int, action_dim: int, n_heads: int, dropout: float, ff_dim: int, n_layers: int = 5): super().__init__() self.feature_dim = feature_dim self.points_per_stroke = points_per_stroke self.action_dim = action_dim self.max_steps = 300 # learnable base queries self.base_queries = nn.Parameter(torch.randn(self.max_steps, feature_dim)) # stack of decoder layers self.layers = nn.ModuleList([ StrokeDecoderLayer(feature_dim, n_heads, ff_dim, dropout) for _ in range(n_layers) ]) # output head self.param_head = nn.Sequential( nn.LayerNorm(feature_dim), nn.Linear(feature_dim, feature_dim), nn.ReLU(inplace=True), nn.Linear(feature_dim, action_dim) ) def forward(self, patch_feats: torch.Tensor): """ Args: patch_feats: (B, N, D) n_steps: number of strokes Returns: strokes: (B, n_steps, action_dim) """ B, N, D = patch_feats.shape device = patch_feats.device # initialize queries (same for batch) Q = self.base_queries.unsqueeze(0).expand(B, -1, -1) # (B, T, D) abs_steps = torch.arange(self.max_steps).to(device) pe_abs = get_sinusoidal_pe(abs_steps, D,device) Q = Q + pe_abs.unsqueeze(0) # run through decoder layers for layer in self.layers: Q = layer(Q, patch_feats) # predict parameters strokes = self.param_head(Q) # split + squashing like before pp = self.points_per_stroke part1 = torch.sigmoid(strokes[..., 0:2*pp]) part2 = torch.sigmoid(strokes[..., 2*pp:2*pp+2]) part3 = torch.sigmoid(strokes[..., 2*pp+2:2*pp+4]) part4 = torch.sigmoid(strokes[..., 2*pp+4:]) strokes = torch.cat([part1, part2, part3, part4], dim=-1) return strokes