File size: 1,202 Bytes
8285881
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
import torch.nn as nn
import torch.nn.functional as F


class MPSAdaptiveAvgPool2d(nn.Module):
    """
    A wrapper around AdaptiveAvgPool2d that falls back to CPU when running on MPS
    and the input/output size combination is not supported.
    """
    
    def __init__(self, output_size):
        super().__init__()
        self.output_size = output_size
        self.pool = nn.AdaptiveAvgPool2d(output_size)
    
    def forward(self, x):
        if x.device.type == 'mps':
            # Check if the operation is supported on MPS
            h, w = x.shape[2], x.shape[3]
            if isinstance(self.output_size, tuple):
                out_h, out_w = self.output_size
            else:
                out_h = out_w = self.output_size
            
            # MPS requires input sizes to be divisible by output sizes
            if h % out_h != 0 or w % out_w != 0:
                # Fallback to CPU for this operation
                device = x.device
                x_cpu = x.cpu()
                output_cpu = self.pool(x_cpu)
                return output_cpu.to(device)
        
        # Use normal pooling for CUDA or when MPS is supported
        return self.pool(x)