UniversalAlgorithmic's picture
Upload 13 files
97d8aaa verified
raw
history blame
16.1 kB
import types
from typing import Optional, List, Union, Callable
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from torchvision.models.mobilenetv2 import MobileNetV2
from torchvision.models.resnet import ResNet
from torchvision.models.efficientnet import EfficientNet
from torchvision.models.vision_transformer import VisionTransformer
def compute_policy_loss(loss_sequence, mask_sequence, rewards):
losses = sum(mask * padded_loss for mask, padded_loss in zip(mask_sequence, loss_sequence))
returns = sum(padded_reward * mask for padded_reward, mask in zip(rewards, mask_sequence))
loss = torch.mean(losses * returns)
return loss
class TPBlock(nn.Module):
def __init__(self, depths: int, in_planes: int, out_planes: int = None, rank=1, shape_dims=3, channel_first=True, dtype=torch.float32) -> None:
super().__init__()
out_planes = in_planes if out_planes is None else out_planes
self.layers = torch.nn.ModuleList([self._make_layer(in_planes, out_planes, rank, shape_dims, channel_first, dtype) for _ in range(depths)])
def forward(self, x: Tensor) -> Tensor:
for layer in self.layers:
x = x + layer(x)
return x
def _make_layer(self, in_planes: int, out_planes: int = None, rank=1, shape_dims=3, channel_first=True, dtype=torch.float32) -> nn.Sequential:
class Permute(nn.Module):
def __init__(self, *dims):
super().__init__()
self.dims = dims
def forward(self, x):
return x.permute(*self.dims)
class RMSNorm(nn.Module):
__constants__ = ["eps"]
eps: float
def __init__(self, hidden_size, eps: float = 1e-6, device=None, dtype=None):
"""
LlamaRMSNorm is equivalent to T5LayerNorm.
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(hidden_size, **factory_kwargs))
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(dim=1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
weight = self.weight.view(1, -1, *[1] * (hidden_states.ndim - 2))
return weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{self.weight.shape[0]}, eps={self.eps}"
conv_map = {
2: (nn.Conv1d, (0, 2, 1), (0, 2, 1)),
3: (nn.Conv2d, (0, 3, 1, 2), (0, 2, 3, 1)),
4: (nn.Conv3d, (0, 4, 1, 2, 3), (0, 2, 3, 4, 1)),
}
Conv, pre_dims, post_dims = conv_map[shape_dims]
kernel_size, dilation, padding = self.generate_hyperparameters(rank)
pre_permute = nn.Identity() if channel_first else Permute(*pre_dims)
post_permute = nn.Identity() if channel_first else Permute(*post_dims)
conv1 = Conv(in_planes, out_planes, kernel_size, padding=padding, dilation=dilation, bias=False, dtype=dtype, device='cuda')
nn.init.zeros_(conv1.weight)
bn1 = RMSNorm(out_planes, dtype=dtype, device="cuda")
relu = nn.ReLU(inplace=True)
conv2 = Conv(out_planes, in_planes, kernel_size, padding=padding, dilation=dilation, bias=False, dtype=dtype, device='cuda')
nn.init.zeros_(conv2.weight)
bn2 = RMSNorm(in_planes, dtype=dtype, device="cuda")
return torch.nn.Sequential(pre_permute, conv1, bn1, relu, conv2, bn2, relu, post_permute)
@staticmethod
def generate_hyperparameters(rank: int):
"""
Generates kernel size and dilation rate pairs sorted by increasing padded kernel size.
Args:
rank: Number of (kernel_size, dilation) pairs to generate. Must be positive.
Returns:
Tuple[int, int]: A (kernel_size, dilation) tuple where:
- kernel_size: Always odd and >= 1
- dilation: Computed to maintain consistent padded kernel size growth
Note:
Padded kernel size is calculated as:
(kernel_size - 1) * dilation + 1
Pairs are generated first in order of increasing padded kernel size,
then by increasing kernel size for equal padded kernel sizes.
"""
pairs = [(1, 1, 0)] # Start with smallest possible
padded_kernel_size = 3
while len(pairs) < rank:
for kernel_size in range(3, padded_kernel_size + 1, 2):
if (padded_kernel_size - 1) % (kernel_size - 1) == 0:
dilation = (padded_kernel_size - 1) // (kernel_size - 1)
padding = dilation * (kernel_size - 1) // 2
pairs.append((kernel_size, dilation, padding))
if len(pairs) >= rank:
break
# Move to next odd padded kernel size
padded_kernel_size += 2
return pairs[-1]
class ResNetConfig:
@staticmethod
def gen_shared_head(self):
def func(hidden_states):
"""
Args:
hidden_states (Tensor): Hidden States tensor of shape [B, C, H, W].
Returns:
logits (Tensor): Logits tensor of shape [B, C].
"""
x = self.avgpool(hidden_states)
x = torch.flatten(x, 1)
logits = self.fc(x)
return logits
return func
@staticmethod
def gen_logits(self, shared_head):
def func(hidden_states):
"""
Args:
hidden_states (Tensor): Hidden States tensor of shape [B, L, hidden_units].
Returns:
logits_seqence (List[Tensor]): List of Logits tensors.
"""
logits_sequence = [shared_head(hidden_states)]
for layer in self.trp_blocks:
logits_sequence.append(shared_head(layer(hidden_states)))
return logits_sequence
return func
@staticmethod
def gen_mask(label_smoothing=0.0, top_k=1):
def func(logits_sequence, labels):
"""
Args:
logits_sequence (List[Tensor]): List of Logits tensors.
labels (Tensor): Target labels of shape [B] or [B, C].
Returns:
mask_sequence (List[Tensor]): List of Mask tensor.
returns (Tensor): Boolean mask tensor of shape [B*(L-1)].
"""
labels = torch.argmax(labels, dim=1) if label_smoothing > 0.0 else labels
mask_sequence = [torch.ones_like(labels, dtype=torch.float32, device=labels.device)]
for logits in logits_sequence:
with torch.no_grad():
topk_values, topk_indices = torch.topk(logits, top_k, dim=-1)
mask = torch.eq(topk_indices, labels[:, None]).any(dim=-1).to(torch.float32)
mask_sequence.append(mask_sequence[-1] * mask)
return mask_sequence
return func
@staticmethod
def gen_criterion(label_smoothing=0.0):
def func(logits_sequence, labels):
"""
Args:
logits_sequence (List[Tensor]): List of Logits tensor.
labels (Tensor): labels labels of shape [B] or [B, C].
Returns:
loss (Tensor): Scalar tensor representing the loss.
mask (Tensor): Boolean mask tensor of shape [B].
"""
labels = torch.argmax(labels, dim=1) if label_smoothing > 0.0 else labels
loss_sequence = []
for logits in logits_sequence:
loss_sequence.append(F.cross_entropy(logits, labels, reduction="none", label_smoothing=label_smoothing))
return loss_sequence
return func
@staticmethod
def gen_forward(rewards, label_smoothing=0.0, top_k=1):
def func(self, x: Tensor, targets=None) -> Tensor:
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
hidden_states = self.layer4(x)
x = self.avgpool(hidden_states)
x = torch.flatten(x, 1)
logits = self.fc(x)
if self.training:
shared_head = ResNetConfig.gen_shared_head(self)
compute_logits = ResNetConfig.gen_logits(self, shared_head)
compute_mask = ResNetConfig.gen_mask(label_smoothing, top_k)
compute_loss = ResNetConfig.gen_criterion(label_smoothing)
logits_sequence = compute_logits(hidden_states)
mask_sequence = compute_mask(logits_sequence, targets)
loss_sequence = compute_loss(logits_sequence, targets)
loss = compute_policy_loss(loss_sequence, mask_sequence, rewards)
return logits, loss
return logits
return func
class MobileNetV2Config(ResNetConfig):
@staticmethod
def gen_shared_head(self):
def func(hidden_states):
"""
Args:
hidden_states (Tensor): Hidden States tensor of shape [B, C, H, W].
Returns:
logits (Tensor): Logits tensor of shape [B, C].
"""
x = nn.functional.adaptive_avg_pool2d(hidden_states, (1, 1))
x = torch.flatten(x, 1)
logits = self.classifier(x)
return logits
return func
@staticmethod
def gen_forward(rewards, label_smoothing=0.0, top_k=1):
def func(self, x: Tensor, targets=None) -> Tensor:
hidden_states = self.features(x)
# Cannot use "squeeze" as batch-size can be 1
x = nn.functional.adaptive_avg_pool2d(hidden_states, (1, 1))
x = torch.flatten(x, 1)
logits = self.classifier(x)
if self.training:
shared_head = MobileNetV2Config.gen_shared_head(self)
compute_logits = MobileNetV2Config.gen_logits(self, shared_head)
compute_mask = MobileNetV2Config.gen_mask(label_smoothing, top_k)
compute_loss = MobileNetV2Config.gen_criterion(label_smoothing)
logits_sequence = compute_logits(hidden_states)
mask_sequence = compute_mask(logits_sequence, targets)
loss_sequence = compute_loss(logits_sequence, targets)
loss = compute_policy_loss(loss_sequence, mask_sequence, rewards)
return logits, loss
return logits
return func
class EfficientNetConfig(ResNetConfig):
@staticmethod
def gen_shared_head(self):
def func(hidden_states):
"""
Args:
hidden_states (Tensor): Hidden States tensor of shape [B, C, H, W].
Returns:
logits (Tensor): Logits tensor of shape [B, C].
"""
x = self.avgpool(hidden_states)
x = torch.flatten(x, 1)
logits = self.classifier(x)
return logits
return func
@staticmethod
def gen_forward(rewards, label_smoothing=0.0, top_k=1):
def func(self, x: Tensor, targets=None) -> Tensor:
hidden_states = self.features(x)
x = self.avgpool(hidden_states)
x = torch.flatten(x, 1)
logits = self.classifier(x)
if self.training:
shared_head = EfficientNetConfig.gen_shared_head(self)
compute_logits = EfficientNetConfig.gen_logits(self, shared_head)
compute_mask = EfficientNetConfig.gen_mask(label_smoothing, top_k)
compute_loss = EfficientNetConfig.gen_criterion(label_smoothing)
logits_sequence = compute_logits(hidden_states)
mask_sequence = compute_mask(logits_sequence, targets)
loss_sequence = compute_loss(logits_sequence, targets)
loss = compute_policy_loss(loss_sequence, mask_sequence, rewards)
return logits, loss
return logits
return func
class VisionTransformerConfig(ResNetConfig):
@staticmethod
def gen_shared_head(self):
def func(hidden_states):
"""
Args:
hidden_states (Tensor): Hidden States tensor of shape [B, C, H, W].
Returns:
logits (Tensor): Logits tensor of shape [B, C].
"""
x = hidden_states[:, 0]
logits = self.heads(x)
return logits
return func
@staticmethod
def gen_forward(rewards, label_smoothing=0.0, top_k=1):
def func(self, images: Tensor, targets=None):
x = self._process_input(images)
n = x.shape[0]
batch_class_token = self.class_token.expand(n, -1, -1)
x = torch.cat([batch_class_token, x], dim=1)
hidden_states = self.encoder(x)
x = hidden_states[:, 0]
logits = self.heads(x)
if self.training:
shared_head = VisionTransformerConfig.gen_shared_head(self)
compute_logits = VisionTransformerConfig.gen_logits(self, shared_head)
compute_mask = VisionTransformerConfig.gen_mask(label_smoothing, top_k)
compute_loss = VisionTransformerConfig.gen_criterion(label_smoothing)
logits_sequence = compute_logits(hidden_states)
mask_sequence = compute_mask(logits_sequence, targets)
loss_sequence = compute_loss(logits_sequence, targets)
loss = compute_policy_loss(loss_sequence, mask_sequence, rewards)
return logits, loss
return logits
return func
def apply_trp(model, depths: List[int], in_planes: int, out_planes: int, rewards, **kwargs):
if isinstance(model, ResNet):
print("✅ Applying TRP to ResNet for Image Classification...")
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=in_planes, out_planes=out_planes, rank=k) for k, d in enumerate(depths)])
model.forward = types.MethodType(ResNetConfig.gen_forward(rewards, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
elif isinstance(model, MobileNetV2):
print("✅ Applying TRP to MobileNetV2 for Image Classification...")
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=in_planes, out_planes=out_planes, rank=k) for k, d in enumerate(depths)])
model.forward = types.MethodType(MobileNetV2Config.gen_forward(rewards, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
elif isinstance(model, EfficientNet):
print("✅ Applying TRP to EfficientNet for Image Classification...")
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=in_planes, out_planes=out_planes, rank=k) for k, d in enumerate(depths)])
model.forward = types.MethodType(EfficientNetConfig.gen_forward(rewards, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
elif isinstance(model, VisionTransformer):
print("✅ Applying TRP to VisionTransformer for Image Classification...")
model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=in_planes, out_planes=out_planes, rank=k, shape_dims=2, channel_first=False) for k, d in enumerate(depths)])
model.forward = types.MethodType(VisionTransformerConfig.gen_forward(rewards, label_smoothing=kwargs["label_smoothing"], top_k=1), model)
return model