File size: 395 Bytes
3d1c0e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright (c) 2025 FoundationVision
# SPDX-License-Identifier: MIT
import torch

def swish(x):
    if type(x) == list:
        for i in range(len(x)):
            x[i] = swish(x[i])
        return x
    try:
        return x*torch.sigmoid(x)
    except:
        for _i in range(x.shape[2]):
            x[:,:,_i:_i+1,:,:] = x[:,:,_i:_i+1,:,:]*torch.sigmoid(x[:,:,_i:_i+1,:,:])
        return x