| import types |
| from typing import Optional, List, Union, Callable |
| from collections import OrderedDict |
|
|
| import torch |
| from torch import nn, Tensor |
| from torch.nn import functional as F |
|
|
| from torchvision.models.mobilenetv2 import MobileNetV2 |
| from torchvision.models.resnet import ResNet |
| from torchvision.models.efficientnet import EfficientNet |
| from torchvision.models.vision_transformer import VisionTransformer |
| from torchvision.models.segmentation.fcn import FCN |
| from torchvision.models.segmentation.deeplabv3 import DeepLabV3 |
|
|
|
|
| def compute_policy_loss(loss_sequence, mask_sequence, rewards): |
| losses = sum(mask * padded_loss for mask, padded_loss in zip(mask_sequence, loss_sequence)) |
| returns = sum(padded_reward * mask for padded_reward, mask in zip(rewards, mask_sequence)) |
| loss = torch.mean(losses * returns) |
|
|
| return loss |
|
|
|
|
| class TPBlock(nn.Module): |
| def __init__(self, depths: int, in_planes: int, out_planes: int = None, rank=1, shape_dims=3, channel_first=True, dtype=torch.float32) -> None: |
| super().__init__() |
| out_planes = in_planes if out_planes is None else out_planes |
| self.layers = torch.nn.ModuleList([self._make_layer(in_planes, out_planes, rank, shape_dims, channel_first, dtype) for _ in range(depths)]) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| for layer in self.layers: |
| x = x + layer(x) |
| return x |
|
|
| def _make_layer(self, in_planes: int, out_planes: int = None, rank=1, shape_dims=3, channel_first=True, dtype=torch.float32) -> nn.Sequential: |
| |
| class Permute(nn.Module): |
| def __init__(self, *dims): |
| super().__init__() |
| self.dims = dims |
| def forward(self, x): |
| return x.permute(*self.dims) |
| |
| class RMSNorm(nn.Module): |
| __constants__ = ["eps"] |
| eps: float |
|
|
| def __init__(self, hidden_size, eps: float = 1e-6, device=None, dtype=None): |
| """ |
| LlamaRMSNorm is equivalent to T5LayerNorm. |
| """ |
| factory_kwargs = {"device": device, "dtype": dtype} |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(hidden_size, **factory_kwargs)) |
|
|
| def forward(self, hidden_states): |
| input_dtype = hidden_states.dtype |
| hidden_states = hidden_states.to(torch.float32) |
| variance = hidden_states.pow(2).mean(dim=1, keepdim=True) |
| hidden_states = hidden_states * torch.rsqrt(variance + self.eps) |
| weight = self.weight.view(1, -1, *[1] * (hidden_states.ndim - 2)) |
| return weight * hidden_states.to(input_dtype) |
|
|
| def extra_repr(self): |
| return f"{self.weight.shape[0]}, eps={self.eps}" |
|
|
| conv_map = { |
| 2: (nn.Conv1d, (0, 2, 1), (0, 2, 1)), |
| 3: (nn.Conv2d, (0, 3, 1, 2), (0, 2, 3, 1)), |
| 4: (nn.Conv3d, (0, 4, 1, 2, 3), (0, 2, 3, 4, 1)), |
| } |
| Conv, pre_dims, post_dims = conv_map[shape_dims] |
| kernel_size, dilation, padding = self.generate_hyperparameters(rank) |
| |
| pre_permute = nn.Identity() if channel_first else Permute(*pre_dims) |
| post_permute = nn.Identity() if channel_first else Permute(*post_dims) |
| conv1 = Conv(in_planes, out_planes, kernel_size, padding=padding, dilation=dilation, bias=False, dtype=dtype, device='cuda') |
| nn.init.zeros_(conv1.weight) |
| bn1 = RMSNorm(out_planes, dtype=dtype, device="cuda") |
| relu = nn.ReLU(inplace=True) |
| conv2 = Conv(out_planes, in_planes, kernel_size, padding=padding, dilation=dilation, bias=False, dtype=dtype, device='cuda') |
| nn.init.zeros_(conv2.weight) |
| bn2 = RMSNorm(in_planes, dtype=dtype, device="cuda") |
|
|
| return torch.nn.Sequential(pre_permute, conv1, bn1, relu, conv2, bn2, relu, post_permute) |
|
|
| @staticmethod |
| def generate_hyperparameters(rank: int): |
| """ |
| Generates kernel size and dilation rate pairs sorted by increasing padded kernel size. |
| |
| Args: |
| rank: Number of (kernel_size, dilation) pairs to generate. Must be positive. |
| |
| Returns: |
| Tuple[int, int]: A (kernel_size, dilation) tuple where: |
| - kernel_size: Always odd and >= 1 |
| - dilation: Computed to maintain consistent padded kernel size growth |
| |
| Note: |
| Padded kernel size is calculated as: |
| (kernel_size - 1) * dilation + 1 |
| Pairs are generated first in order of increasing padded kernel size, |
| then by increasing kernel size for equal padded kernel sizes. |
| """ |
| pairs = [(1, 1, 0)] |
| padded_kernel_size = 3 |
| |
| while len(pairs) < rank: |
| for kernel_size in range(3, padded_kernel_size + 1, 2): |
| if (padded_kernel_size - 1) % (kernel_size - 1) == 0: |
| dilation = (padded_kernel_size - 1) // (kernel_size - 1) |
| padding = dilation * (kernel_size - 1) // 2 |
| pairs.append((kernel_size, dilation, padding)) |
| if len(pairs) >= rank: |
| break |
| |
| |
| padded_kernel_size += 2 |
| |
| return pairs[-1] |
|
|
|
|
| |
| class ResNetConfig: |
| @staticmethod |
| def gen_shared_head(self): |
| def func(hidden_states): |
| """ |
| Args: |
| hidden_states (Tensor): Hidden States tensor of shape [B, C, H, W]. |
| |
| Returns: |
| logits (Tensor): Logits tensor of shape [B, C]. |
| """ |
| x = self.avgpool(hidden_states) |
| x = torch.flatten(x, 1) |
| logits = self.fc(x) |
| return logits |
| return func |
| |
| @staticmethod |
| def gen_logits(self, shared_head): |
| def func(hidden_states): |
| """ |
| Args: |
| hidden_states (Tensor): Hidden States tensor of shape [B, L, hidden_units]. |
| |
| Returns: |
| logits_seqence (List[Tensor]): List of Logits tensors. |
| """ |
| logits_sequence = [shared_head(hidden_states)] |
| for layer in self.trp_blocks: |
| logits_sequence.append(shared_head(layer(hidden_states))) |
| return logits_sequence |
| return func |
|
|
| @staticmethod |
| def gen_mask(label_smoothing=0.0, top_k=1): |
| def func(logits_sequence, labels): |
| """ |
| Args: |
| logits_sequence (List[Tensor]): List of Logits tensors. |
| labels (Tensor): Target labels of shape [B] or [B, C]. |
| |
| Returns: |
| mask_sequence (List[Tensor]): Boolean mask tensor of shape [B*(L-1)]. |
| """ |
| labels = torch.argmax(labels, dim=1) if label_smoothing > 0.0 else labels |
|
|
| mask_sequence = [torch.ones_like(labels, dtype=torch.float32, device=labels.device)] |
| for logits in logits_sequence: |
| with torch.no_grad(): |
| topk_values, topk_indices = torch.topk(logits, top_k, dim=-1) |
| mask = torch.eq(topk_indices, labels[:, None]).any(dim=-1).to(torch.float32) |
| mask_sequence.append(mask_sequence[-1] * mask) |
| return mask_sequence |
| return func |
| |
| @staticmethod |
| def gen_criterion(label_smoothing=0.0): |
| def func(logits_sequence, labels): |
| """ |
| Args: |
| logits_sequence (List[Tensor]): List of Logits tensor. |
| labels (Tensor): labels labels of shape [B] or [B, C]. |
| |
| Returns: |
| loss (Tensor): Scalar tensor representing the loss. |
| mask (Tensor): Boolean mask tensor of shape [B]. |
| """ |
| labels = torch.argmax(labels, dim=1) if label_smoothing > 0.0 else labels |
|
|
| loss_sequence = [] |
| for logits in logits_sequence: |
| loss_sequence.append(F.cross_entropy(logits, labels, reduction="none", label_smoothing=label_smoothing)) |
|
|
| return loss_sequence |
| return func |
|
|
| @staticmethod |
| def gen_forward(rewards, label_smoothing=0.0, top_k=1): |
| def func(self, x: Tensor, targets=None) -> Tensor: |
| x = self.conv1(x) |
| x = self.bn1(x) |
| x = self.relu(x) |
| x = self.maxpool(x) |
|
|
| x = self.layer1(x) |
| x = self.layer2(x) |
| x = self.layer3(x) |
| hidden_states = self.layer4(x) |
| x = self.avgpool(hidden_states) |
| x = torch.flatten(x, 1) |
| logits = self.fc(x) |
|
|
| if self.training: |
| shared_head = ResNetConfig.gen_shared_head(self) |
| compute_logits = ResNetConfig.gen_logits(self, shared_head) |
| compute_mask = ResNetConfig.gen_mask(label_smoothing, top_k) |
| compute_loss = ResNetConfig.gen_criterion(label_smoothing) |
| |
| logits_sequence = compute_logits(hidden_states) |
| mask_sequence = compute_mask(logits_sequence, targets) |
| loss_sequence = compute_loss(logits_sequence, targets) |
| loss = compute_policy_loss(loss_sequence, mask_sequence, rewards) |
| |
| return logits, loss |
|
|
| return logits |
| |
| return func |
|
|
|
|
| |
| class MobileNetV2Config(ResNetConfig): |
| @staticmethod |
| def gen_shared_head(self): |
| def func(hidden_states): |
| """ |
| Args: |
| hidden_states (Tensor): Hidden States tensor of shape [B, C, H, W]. |
| |
| Returns: |
| logits (Tensor): Logits tensor of shape [B, C]. |
| """ |
| x = nn.functional.adaptive_avg_pool2d(hidden_states, (1, 1)) |
| x = torch.flatten(x, 1) |
| logits = self.classifier(x) |
| return logits |
| return func |
| |
| @staticmethod |
| def gen_forward(rewards, label_smoothing=0.0, top_k=1): |
| def func(self, x: Tensor, targets=None) -> Tensor: |
| hidden_states = self.features(x) |
| |
| x = nn.functional.adaptive_avg_pool2d(hidden_states, (1, 1)) |
| x = torch.flatten(x, 1) |
| logits = self.classifier(x) |
|
|
| if self.training: |
| shared_head = MobileNetV2Config.gen_shared_head(self) |
| compute_logits = MobileNetV2Config.gen_logits(self, shared_head) |
| compute_mask = MobileNetV2Config.gen_mask(label_smoothing, top_k) |
| compute_loss = MobileNetV2Config.gen_criterion(label_smoothing) |
| |
| logits_sequence = compute_logits(hidden_states) |
| mask_sequence = compute_mask(logits_sequence, targets) |
| loss_sequence = compute_loss(logits_sequence, targets) |
| loss = compute_policy_loss(loss_sequence, mask_sequence, rewards) |
| |
| return logits, loss |
|
|
| return logits |
| |
| return func |
|
|
|
|
| |
| class EfficientNetConfig(ResNetConfig): |
| @staticmethod |
| def gen_shared_head(self): |
| def func(hidden_states): |
| """ |
| Args: |
| hidden_states (Tensor): Hidden States tensor of shape [B, C, H, W]. |
| |
| Returns: |
| logits (Tensor): Logits tensor of shape [B, C]. |
| """ |
| x = self.avgpool(hidden_states) |
| x = torch.flatten(x, 1) |
| logits = self.classifier(x) |
| return logits |
| return func |
| |
| @staticmethod |
| def gen_forward(rewards, label_smoothing=0.0, top_k=1): |
| def func(self, x: Tensor, targets=None) -> Tensor: |
| hidden_states = self.features(x) |
| x = self.avgpool(hidden_states) |
| x = torch.flatten(x, 1) |
| logits = self.classifier(x) |
|
|
| if self.training: |
| shared_head = EfficientNetConfig.gen_shared_head(self) |
| compute_logits = EfficientNetConfig.gen_logits(self, shared_head) |
| compute_mask = EfficientNetConfig.gen_mask(label_smoothing, top_k) |
| compute_loss = EfficientNetConfig.gen_criterion(label_smoothing) |
| |
| logits_sequence = compute_logits(hidden_states) |
| mask_sequence = compute_mask(logits_sequence, targets) |
| loss_sequence = compute_loss(logits_sequence, targets) |
| loss = compute_policy_loss(loss_sequence, mask_sequence, rewards) |
| |
| return logits, loss |
|
|
| return logits |
| |
| return func |
| |
|
|
| |
| class VisionTransformerConfig(ResNetConfig): |
| @staticmethod |
| def gen_shared_head(self): |
| def func(hidden_states): |
| """ |
| Args: |
| hidden_states (Tensor): Hidden States tensor of shape [B, C, H, W]. |
| |
| Returns: |
| logits (Tensor): Logits tensor of shape [B, C]. |
| """ |
| x = hidden_states[:, 0] |
| logits = self.heads(x) |
| return logits |
| return func |
|
|
| @staticmethod |
| def gen_forward(rewards, label_smoothing=0.0, top_k=1): |
| def func(self, images: Tensor, targets=None): |
| x = self._process_input(images) |
| n = x.shape[0] |
| batch_class_token = self.class_token.expand(n, -1, -1) |
| x = torch.cat([batch_class_token, x], dim=1) |
| hidden_states = self.encoder(x) |
| x = hidden_states[:, 0] |
|
|
| logits = self.heads(x) |
|
|
|
|
| if self.training: |
| shared_head = VisionTransformerConfig.gen_shared_head(self) |
| compute_logits = VisionTransformerConfig.gen_logits(self, shared_head) |
| compute_mask = VisionTransformerConfig.gen_mask(label_smoothing, top_k) |
| compute_loss = VisionTransformerConfig.gen_criterion(label_smoothing) |
| |
| logits_sequence = compute_logits(hidden_states) |
| mask_sequence = compute_mask(logits_sequence, targets) |
| loss_sequence = compute_loss(logits_sequence, targets) |
| loss = compute_policy_loss(loss_sequence, mask_sequence, rewards) |
| |
| return logits, loss |
| return logits |
| return func |
| |
|
|
| |
| class FCNConfig(ResNetConfig): |
| @staticmethod |
| def gen_out_shared_head(self, input_shape): |
| def func(features): |
| """ |
| Args: |
| features (Tensor): features tensor of shape [B, hidden_units, H, W]. |
| |
| Returns: |
| result (Tensors): result tensor of shape [B, C, H, W]. |
| """ |
| x = self.classifier(features) |
| result = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False) |
| return result |
| return func |
| |
| @staticmethod |
| def gen_aux_shared_head(self, input_shape): |
| def func(features): |
| """ |
| Args: |
| features (Tensor): features tensor of shape [B, hidden_units, H, W]. |
| |
| Returns: |
| result (Tensors): result tensor of shape [B, C, H, W]. |
| """ |
| x = self.aux_classifier(features) |
| result = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False) |
| return result |
| return func |
| |
| @staticmethod |
| def gen_out_logits(self, shared_head): |
| def func(hidden_states): |
| """ |
| Args: |
| hidden_states (Tensor): Hidden States tensor of shape [B, L, hidden_units]. |
| |
| Returns: |
| logits_seqence (List[Tensor]): List of Logits tensors. |
| """ |
| logits_sequence = [shared_head(hidden_states)] |
| for layer in self.out_trp_blocks: |
| logits_sequence.append(shared_head(layer(hidden_states))) |
| return logits_sequence |
| return func |
| |
| @staticmethod |
| def gen_aux_logits(self, shared_head): |
| def func(hidden_states): |
| """ |
| Args: |
| hidden_states (Tensor): Hidden States tensor of shape [B, L, hidden_units]. |
| |
| Returns: |
| logits_seqence (List[Tensor]): List of Logits tensors. |
| """ |
| logits_sequence = [shared_head(hidden_states)] |
| for layer in self.aux_trp_blocks: |
| logits_sequence.append(shared_head(layer(hidden_states))) |
| return logits_sequence |
| return func |
|
|
| @staticmethod |
| def gen_mask(label_smoothing=0.0, top_k=1): |
| def func(logits_sequence, labels): |
| """ |
| Args: |
| logits_sequence (List[Tensor]): List of Logits tensors with shape [B, C, H, W]. |
| labels (Tensor): Target labels of shape [B, H, W]. |
| |
| Returns: |
| mask_sequence (List[Tensor]): Boolean mask tensor of shape [B, H, W]. |
| """ |
| labels = torch.argmax(labels, dim=1) if label_smoothing > 0.0 else labels |
|
|
| mask_sequence = [torch.ones_like(labels, dtype=torch.float32, device=labels.device)] |
| for logits in logits_sequence: |
| with torch.no_grad(): |
| topk_values, topk_indices = torch.topk(logits, top_k, dim=1) |
| mask = torch.eq(topk_indices, labels[:, None, :, :]).any(dim=1).to(torch.float32) |
| mask_sequence.append(mask_sequence[-1] * mask) |
| return mask_sequence |
| return func |
| |
| @staticmethod |
| def gen_criterion(label_smoothing=0.0): |
| def func(logits_sequence, labels): |
| """ |
| Args: |
| logits_sequence (List[Tensor]): List of Logits tensor. |
| labels (Tensor): labels labels of shape [B] or [B, C]. |
| |
| Returns: |
| loss (Tensor): Scalar tensor representing the loss. |
| mask (Tensor): Boolean mask tensor of shape [B]. |
| """ |
| labels = torch.argmax(labels, dim=1) if label_smoothing > 0.0 else labels |
|
|
| loss_sequence = [] |
| for logits in logits_sequence: |
| loss_sequence.append(F.cross_entropy(logits, labels, ignore_index=255, reduction="none", label_smoothing=label_smoothing)) |
|
|
| return loss_sequence |
| return func |
| |
| @staticmethod |
| def gen_forward(rewards, label_smoothing=0.0, top_k=1): |
| def func(self, images: Tensor, targets=None): |
| input_shape = images.shape[-2:] |
| |
| features = self.backbone(images) |
|
|
| result = OrderedDict() |
| x = features["out"] |
| x = self.classifier(x) |
| x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False) |
| result["out"] = x |
|
|
| if self.aux_classifier is not None: |
| x = features["aux"] |
| x = self.aux_classifier(x) |
| x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False) |
| result["aux"] = x |
|
|
| if self.training: |
| torch._assert(targets is not None, "targets should not be none when in training mode") |
| out_shared_head = FCNConfig.gen_out_shared_head(self, input_shape) |
| aux_shared_head = FCNConfig.gen_aux_shared_head(self, input_shape) |
| compute_out_logits = FCNConfig.gen_out_logits(self, out_shared_head) |
| compute_aux_logits = FCNConfig.gen_aux_logits(self, aux_shared_head) |
| compute_mask = FCNConfig.gen_mask(label_smoothing, top_k) |
| compute_loss = FCNConfig.gen_criterion(label_smoothing) |
|
|
| out_logits_sequence = compute_out_logits(features["out"]) |
| out_mask_sequence = compute_mask(out_logits_sequence, targets) |
| out_loss_sequence = compute_loss(out_logits_sequence, targets) |
| out_loss = compute_policy_loss(out_loss_sequence, out_mask_sequence, rewards) |
|
|
| aux_logits_sequence = compute_aux_logits(features["aux"]) |
| aux_mask_sequence = compute_mask(aux_logits_sequence, targets) |
| aux_loss_sequence = compute_loss(aux_logits_sequence, targets) |
| aux_loss = compute_policy_loss(aux_loss_sequence, aux_mask_sequence, rewards) |
|
|
| loss = out_loss + 0.5 * aux_loss |
| return result, loss |
| return result |
| return func |
|
|
|
|
| |
| class DeepLabV3Config(FCNConfig): |
| pass |
|
|
|
|
| def apply_trp(model, depths: List[int], in_planes: int, out_planes: int, rewards, **kwargs): |
| if isinstance(model, ResNet): |
| print("✅ Applying TRP to ResNet for Image Classification...") |
| model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=in_planes, out_planes=out_planes, rank=k) for k, d in enumerate(depths)]) |
| model.forward = types.MethodType(ResNetConfig.gen_forward(rewards, label_smoothing=kwargs["label_smoothing"], top_k=1), model) |
| elif isinstance(model, MobileNetV2): |
| print("✅ Applying TRP to MobileNetV2 for Image Classification...") |
| model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=in_planes, out_planes=out_planes, rank=k) for k, d in enumerate(depths)]) |
| model.forward = types.MethodType(MobileNetV2Config.gen_forward(rewards, label_smoothing=kwargs["label_smoothing"], top_k=1), model) |
| elif isinstance(model, EfficientNet): |
| print("✅ Applying TRP to EfficientNet for Image Classification...") |
| model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=in_planes, out_planes=out_planes, rank=k) for k, d in enumerate(depths)]) |
| model.forward = types.MethodType(EfficientNetConfig.gen_forward(rewards, label_smoothing=kwargs["label_smoothing"], top_k=1), model) |
| elif isinstance(model, VisionTransformer): |
| print("✅ Applying TRP to VisionTransformer for Image Classification...") |
| model.trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=in_planes, out_planes=out_planes, rank=k, shape_dims=2, channel_first=False) for k, d in enumerate(depths)]) |
| model.forward = types.MethodType(VisionTransformerConfig.gen_forward(rewards, label_smoothing=kwargs["label_smoothing"], top_k=1), model) |
| elif isinstance(model, FCN): |
| print("✅ Applying TRP to FCN for Semantic Segmentation...") |
| model.out_trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=2048, out_planes=out_planes, rank=k) for k, d in enumerate(depths)]) |
| model.aux_trp_blocks = torch.nn.ModuleList([TPBlock(depths=d, in_planes=1024, out_planes=out_planes, rank=k) for k, d in enumerate(depths)]) |
| model.forward = types.MethodType(FCNConfig.gen_forward(rewards, label_smoothing=0.0, top_k=1), model) |
| elif isinstance(model, DeepLabV3): |
| print("✅ Applying TRP to DeepLabV3 for Semantic Segmentation...") |
| model.out_trp_blocks = torch.nn.ModuleList([TPBlock(depths, in_planes=2048, out_planes=out_planes, rank=k) for k, d in enumerate(depths)]) |
| model.aux_trp_blocks = torch.nn.ModuleList([TPBlock(depths, in_planes=1024, out_planes=out_planes, rank=k) for k, d in enumerate(depths)]) |
| model.forward = types.MethodType(DeepLabV3Config.gen_forward(rewards, label_smoothing=0.0, top_k=1), model) |
| return model |