File size: 1,732 Bytes
e9f9fd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from ...vision import *

__all__ = ['xception']

def sep_conv(ni,nf,pad=None,pool=False,act=True):
    layers =  [nn.ReLU()] if act else []
    layers += [
        nn.Conv2d(ni,ni,3,1,1,groups=ni,bias=False),
        nn.Conv2d(ni,nf,1,bias=False),
        nn.BatchNorm2d(nf)
    ]
    if pool: layers.append(nn.MaxPool2d(2))
    return nn.Sequential(*layers)

def conv(ni,nf,ks=1,stride=1, pad=None, act=True):
    if pad is None: pad=ks//2
    layers = [
        nn.Conv2d(ni,nf,ks,stride,pad,bias=False),
        nn.BatchNorm2d(nf),
    ]
    if act: layers.append(nn.ReLU())
    return nn.Sequential(*layers)

class ConvSkip(Module):
    def __init__(self,ni,nf=None,act=True):
        self.nf,self.ni = nf,ni
        if self.nf is None: self.nf = ni
        self.conv = conv(ni,nf,stride=2, act=False)
        self.m = nn.Sequential(
            sep_conv(ni,ni,act=act),
            sep_conv(ni,nf,pool=True)
        )

    def forward(self,x): return self.conv(x) + self.m(x)

def middle_flow(nf):
    layers = [sep_conv(nf,nf) for i in range(3)]
    return SequentialEx(*layers, MergeLayer())

def xception(c, k=8, n_middle=8):
    "Preview version of Xception network. Not tested yet - use at own risk. No pretrained model yet."
    layers = [
        conv(3, k*4, 3, 2),
        conv(k*4, k*8, 3),
        ConvSkip(k*8, k*16, act=False),
        ConvSkip(k*16, k*32),
        ConvSkip(k*32, k*91),
    ]
    for i in range(n_middle): layers.append(middle_flow(k*91))
    layers += [
        ConvSkip(k*91,k*128),
        sep_conv(k*128,k*192,act=False),
        sep_conv(k*192,k*256),
        nn.ReLU(),
        nn.AdaptiveAvgPool2d(1),
        Flatten(),
        nn.Linear(k*256,c)
    ]
    return nn.Sequential(*layers)