import torch import torch.nn as nn import torch.nn.functional as F class MPSAdaptiveAvgPool2d(nn.Module): """ A wrapper around AdaptiveAvgPool2d that falls back to CPU when running on MPS and the input/output size combination is not supported. """ def __init__(self, output_size): super().__init__() self.output_size = output_size self.pool = nn.AdaptiveAvgPool2d(output_size) def forward(self, x): if x.device.type == 'mps': # Check if the operation is supported on MPS h, w = x.shape[2], x.shape[3] if isinstance(self.output_size, tuple): out_h, out_w = self.output_size else: out_h = out_w = self.output_size # MPS requires input sizes to be divisible by output sizes if h % out_h != 0 or w % out_w != 0: # Fallback to CPU for this operation device = x.device x_cpu = x.cpu() output_cpu = self.pool(x_cpu) return output_cpu.to(device) # Use normal pooling for CUDA or when MPS is supported return self.pool(x)