jinmang2 commited on
Commit
eb21b46
·
1 Parent(s): 55949a6

Create modeling_dalle.py

Browse files
Files changed (1) hide show
  1. modeling_dalle.py +160 -0
modeling_dalle.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from collections import OrderedDict
9
+
10
+ from transformers import PreTrainedModel
11
+
12
+ from .configuration_dalle imoprt DallEConfig
13
+
14
+
15
+ class Conv2d(nn.Module):
16
+ def __init__(self, n_in, n_out, kw, config, use_float16=True):
17
+ super().__init__()
18
+
19
+ assert n_in >= 1
20
+ assert n_out >= 1
21
+ assert kw >= 1 and kw % 2 == 1
22
+
23
+ self.n_in = n_in
24
+ self.n_out = n_out
25
+ self.kw = kw
26
+ self.config = config
27
+ self.use_float16 = use_float16
28
+ w = torch.empty(
29
+ (n_out, n_in, kw, kw),
30
+ dtype=torch.float32,
31
+ device=config.device,
32
+ requires_grad=config.requires_grad,
33
+ )
34
+ w.normal_(std=1 / math.sqrt(n_in * kw ** 2))
35
+
36
+ b = torch.zeros(
37
+ (n_out,),
38
+ dtype=torch.float32,
39
+ device=config.device,
40
+ requires_grad=config.requires_grad,
41
+ )
42
+
43
+ self.w = nn.Parameter(w)
44
+ self.b = nn.Parameter(b)
45
+
46
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
47
+ if self.use_float16 and 'cuda' in self.w.device.type:
48
+ if x.dtype != torch.float16:
49
+ x = x.half()
50
+ w, b = self.w.half(), self.b.half()
51
+ else:
52
+ if x.dtype != torch.float32:
53
+ x = x.float()
54
+ w, b = self.w, self.b
55
+ return F.conv2d(x, w, b, padding=(self.kw - 1) // 2)
56
+
57
+ def extra_repr(self):
58
+ inner_repr = f"n_in={self.n_in}, n_out={self.n_out}, kw={self.kw}, "
59
+ inner_repr += f"use_float16={self.use_float16}, "
60
+ inner_repr += f"device={self.config.device}, "
61
+ inner_repr += f"requires_grad={self.config.requires_grad}"
62
+ return inner_repr
63
+
64
+
65
+ class EncoderBlock(nn.Module):
66
+ def __init__(self, n_in, n_out, n_layers, config):
67
+ super().__init__()
68
+
69
+ assert n_in >= 1
70
+ assert n_out >= 1 and n_out % 4 == 0
71
+ assert n_layers >= 1
72
+
73
+ self.n_in = n_in
74
+ self.n_out = n_out
75
+ self.n_hid = n_out // 4
76
+ self.post_gain = 1 / (n_layers ** 2)
77
+
78
+ if self.n_in != self.n_out:
79
+ self.id_path = Conv2d(self.n_in, self.n_out, 1, config)
80
+ else:
81
+ self.id_path = nn.Identity()
82
+
83
+ self.res_path = nn.Sequential(OrderedDict([
84
+ ('relu_1', nn.ReLU()),
85
+ ('conv_1', Conv2d(self.n_in, self.n_hid, 3, config)),
86
+ ('relu_2', nn.ReLU()),
87
+ ('conv_2', Conv2d(self.n_hid, self.n_hid, 3, config)),
88
+ ('relu_3', nn.ReLU()),
89
+ ('conv_3', Conv2d(self.n_hid, self.n_hid, 3, config)),
90
+ ('relu_4', nn.ReLU()),
91
+ ('conv_4', Conv2d(self.n_hid, self.n_out, 1, config)),
92
+ ]))
93
+
94
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
95
+ return self.id_path(x) + self.post_gain * self.res_path(x)
96
+
97
+
98
+ class DallEPreTrainedModel(PreTrainedModel):
99
+ config_class = DallEConfig
100
+ base_model_prefix="dalle"
101
+
102
+
103
+ class DallEEncoder(DallEPreTrainedModel):
104
+
105
+ def __init__(self, config):
106
+ super().__init__(config)
107
+ blk_range = range(config.n_blk_per_group)
108
+ n_layers = config.group_count * config.n_blk_per_group
109
+
110
+ in_channels = config.input_channels
111
+ n_hid = config.n_hid
112
+
113
+ self.blocks = nn.Sequential(OrderedDict([
114
+ ('input', Conv2d(in_channels, n_hid, 7, config)),
115
+ ('group_1', nn.Sequential(OrderedDict([
116
+ *[(f'block_{i + 1}',
117
+ EncoderBlock(n_hid, n_hid, n_layers, config))
118
+ for i in blk_range],
119
+ ('pool', nn.MaxPool2d(kernel_size=2)),
120
+ ]))),
121
+ ('group_2', nn.Sequential(OrderedDict([
122
+ *[(f'block_{i + 1}',
123
+ EncoderBlock(
124
+ n_hid if i == 0 else 2 * n_hid,
125
+ 2 * n_hid, n_layers, config))
126
+ for i in blk_range],
127
+ ('pool', nn.MaxPool2d(kernel_size=2)),
128
+ ]))),
129
+ ('group_3', nn.Sequential(OrderedDict([
130
+ *[(f'block_{i + 1}',
131
+ EncoderBlock(
132
+ 2 * n_hid if i == 0 else 4 * n_hid,
133
+ 4 * n_hid, n_layers, config))
134
+ for i in blk_range],
135
+ ('pool', nn.MaxPool2d(kernel_size=2)),
136
+ ]))),
137
+ ('group_4', nn.Sequential(OrderedDict([
138
+ *[(f'block_{i + 1}',
139
+ EncoderBlock(
140
+ 4 * n_hid if i == 0 else 8 * n_hid,
141
+ 8 * n_hid, n_layers, config))
142
+ for i in blk_range],
143
+ ]))),
144
+ ('output', nn.Sequential(OrderedDict([
145
+ ('relu', nn.ReLU()),
146
+ ('conv', Conv2d(
147
+ 8 * n_hid, config.vocab_size,
148
+ 1, config, use_float16=False)),
149
+ ]))),
150
+ ]))
151
+
152
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
153
+ if len(x.shape) != 4:
154
+ raise ValueError(f'input shape {x.shape} is not 4d')
155
+ if x.shape[1] != self.input_channels:
156
+ raise ValueError(f'input has {x.shape[1]} channels but model built for {self.input_channels}')
157
+ if x.dtype != torch.float32:
158
+ raise ValueError('input must have dtype torch.float32')
159
+
160
+ return self.blocks(x)