image2painting / model.py
Lasercatz
Upload 9 files
97bca33 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import ViTModel
import math
class FeatureExtractor(nn.Module):
def __init__(self,
vit_model_name: str,
input_size: int,
mean,
std,
):
super().__init__()
self.vit = ViTModel.from_pretrained(
vit_model_name,
output_hidden_states=False,
ignore_mismatched_sizes=True
)
self.register_buffer("mean", torch.tensor(mean).view(1, 3, 1, 1))
self.register_buffer("std", torch.tensor(std).view(1, 3, 1, 1))
self.input_size = input_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (B, 3, H, W), typically H=W=224 (or any multiple of patch_size=16).
Returns:
patch_feats: (B, N, hidden_size), where N = (H/16)*(W/16).
"""
B, C, H, W = x.shape
# 1) Resize if needed
if H != self.input_size or W != self.input_size:
raise ValueError(f"Input image size must be {self.input_size}x{self.input_size}, but got {H}x{W}.")
# 2) Normalize in-place (feature-extractor view only)
x = (x - self.mean) / self.std # type: ignore
outputs = self.vit(pixel_values=x)
seq = outputs.last_hidden_state
patch_feats = seq[:, 1:, :] # shape (B, N, hidden_size)
return patch_feats
def get_sinusoidal_pe(steps, d_model, device):
"""
Compute sinusoidal positional encodings for the given steps.
Args:
steps: Tensor of step indices (n_steps,)
d_model: Feature dimension (e.g., 768)
device: Device to place the tensor on
Returns:
pe: Positional encodings (n_steps, d_model)
"""
steps = steps.to(torch.float32)
div_term = torch.exp(torch.arange(
0, d_model, 2, device=device) * (-math.log(10000.0) / d_model))
pe = torch.zeros(len(steps), d_model, device=device)
pe[:, 0::2] = torch.sin(steps.unsqueeze(1) * div_term)
pe[:, 1::2] = torch.cos(steps.unsqueeze(1) * div_term[: d_model // 2])
return pe
class StrokeDecoderLayer(nn.Module):
def __init__(self, feature_dim, n_heads, ff_dim, dropout):
super().__init__()
self.cross_norm = nn.LayerNorm(feature_dim)
self.cross_attn = nn.MultiheadAttention(embed_dim=feature_dim,
num_heads=n_heads,
dropout=dropout,
batch_first=True)
self.self_norm = nn.LayerNorm(feature_dim)
self.self_attn = nn.MultiheadAttention(embed_dim=feature_dim,
num_heads=n_heads,
dropout=dropout,
batch_first=True)
self.ffn_norm = nn.LayerNorm(feature_dim)
self.ffn = nn.Sequential(
nn.Linear(feature_dim, ff_dim),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(ff_dim, feature_dim),
nn.Dropout(dropout)
)
def forward(self, Q, kv):
# Self-attention
Q = Q + self.self_attn(self.self_norm(Q),
self.self_norm(Q),
self.self_norm(Q),
need_weights=False)[0]
# Cross-attention
Q = Q + self.cross_attn(self.cross_norm(Q),
kv, kv,
need_weights=False)[0]
# Final FFN
Q = Q + self.ffn(self.ffn_norm(Q))
return Q
class StrokeTransformer(nn.Module):
def __init__(self,
feature_dim: int,
points_per_stroke: int,
action_dim: int,
n_heads: int,
dropout: float,
ff_dim: int,
n_layers: int = 5):
super().__init__()
self.feature_dim = feature_dim
self.points_per_stroke = points_per_stroke
self.action_dim = action_dim
self.max_steps = 300
# learnable base queries
self.base_queries = nn.Parameter(torch.randn(self.max_steps, feature_dim))
# stack of decoder layers
self.layers = nn.ModuleList([
StrokeDecoderLayer(feature_dim, n_heads, ff_dim, dropout)
for _ in range(n_layers)
])
# output head
self.param_head = nn.Sequential(
nn.LayerNorm(feature_dim),
nn.Linear(feature_dim, feature_dim),
nn.ReLU(inplace=True),
nn.Linear(feature_dim, action_dim)
)
def forward(self, patch_feats: torch.Tensor):
"""
Args:
patch_feats: (B, N, D)
n_steps: number of strokes
Returns:
strokes: (B, n_steps, action_dim)
"""
B, N, D = patch_feats.shape
device = patch_feats.device
# initialize queries (same for batch)
Q = self.base_queries.unsqueeze(0).expand(B, -1, -1) # (B, T, D)
abs_steps = torch.arange(self.max_steps).to(device)
pe_abs = get_sinusoidal_pe(abs_steps, D,device)
Q = Q + pe_abs.unsqueeze(0)
# run through decoder layers
for layer in self.layers:
Q = layer(Q, patch_feats)
# predict parameters
strokes = self.param_head(Q)
# split + squashing like before
pp = self.points_per_stroke
part1 = torch.sigmoid(strokes[..., 0:2*pp])
part2 = torch.sigmoid(strokes[..., 2*pp:2*pp+2])
part3 = torch.sigmoid(strokes[..., 2*pp+2:2*pp+4])
part4 = torch.sigmoid(strokes[..., 2*pp+4:])
strokes = torch.cat([part1, part2, part3, part4], dim=-1)
return strokes