njeffrie commited on
Commit
2aef7bb
·
verified ·
1 Parent(s): 367847b

Upload 3 files

Browse files
Files changed (3) hide show
  1. config.json +24 -0
  2. model.safetensors +3 -0
  3. modeling_gluformer.py +353 -0
config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activ": "relu",
3
+ "architectures": [
4
+ "GluformerForTimeSeries"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "modeling_gluformer.GluformerConfig",
8
+ "AutoModel": "modeling_gluformer.GluformerForTimeSeries"
9
+ },
10
+ "d_fcn": 2048,
11
+ "d_model": 512,
12
+ "distil": true,
13
+ "len_pred": 12,
14
+ "len_seq": 180,
15
+ "len_label": 60,
16
+ "model_type": "gluformer",
17
+ "n_heads": 12,
18
+ "num_dec_layers": 1,
19
+ "num_enc_layers": 2,
20
+ "num_features": 5,
21
+ "r_drop": 0.1,
22
+ "torch_dtype": "float32",
23
+ "transformers_version": "4.53.3"
24
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8ad7e01c7fc7a97d2291dd885d39dc794b141605c1d422c9942cab8ddc74999
3
+ size 65480616
modeling_gluformer.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from transformers import PreTrainedModel, PretrainedConfig
4
+ #from gluformer.model import Gluformer
5
+
6
+ # coding: utf-8
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import math
11
+ import numpy as np
12
+ from math import sqrt
13
+ from datetime import timedelta
14
+
15
+ # === Embedding Modules ===
16
+ class PositionalEmbedding(nn.Module):
17
+ def __init__(self, d_model, max_len=5000):
18
+ super(PositionalEmbedding, self).__init__()
19
+ pos_emb = torch.zeros(max_len, d_model).float()
20
+ pos_emb.require_grad = False
21
+ position = torch.arange(0, max_len).float().unsqueeze(1)
22
+ div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
23
+ pos_emb[:, 0::2] = torch.sin(position * div_term)
24
+ pos_emb[:, 1::2] = torch.cos(position * div_term)
25
+ pos_emb = pos_emb.unsqueeze(0)
26
+ self.register_buffer('pos_emb', pos_emb)
27
+ def forward(self, x):
28
+ return self.pos_emb[:, :x.size(1)]
29
+
30
+ class TokenEmbedding(nn.Module):
31
+ def __init__(self, d_model):
32
+ super(TokenEmbedding, self).__init__()
33
+ D_INP = 1
34
+ self.conv = nn.Conv1d(in_channels=D_INP, out_channels=d_model, kernel_size=3, padding=1, padding_mode='circular')
35
+ def forward(self, x):
36
+ x = self.conv(x.transpose(-1, 1)).transpose(-1, 1)
37
+ return x
38
+
39
+ class TemporalEmbedding(nn.Module):
40
+ def __init__(self, d_model, num_features):
41
+ super(TemporalEmbedding, self).__init__()
42
+ self.embed = nn.Linear(num_features, d_model)
43
+ def forward(self, x):
44
+ x = x.float()
45
+ return self.embed(x)
46
+
47
+ class SubjectEmbedding(nn.Module):
48
+ def __init__(self, d_model):
49
+ super(SubjectEmbedding, self).__init__()
50
+ self.id_embedding = nn.Linear(1, d_model)
51
+ def forward(self, x):
52
+ x = x.float().unsqueeze(1)
53
+ embed_x = self.id_embedding(x)
54
+ return embed_x
55
+
56
+ class DataEmbedding(nn.Module):
57
+ def __init__(self, d_model, r_drop, num_features):
58
+ super(DataEmbedding, self).__init__()
59
+ self.value_embedding = TokenEmbedding(d_model)
60
+ self.time_embedding = TemporalEmbedding(d_model, num_features)
61
+ self.positional_embedding = PositionalEmbedding(d_model)
62
+ self.subject_embedding = SubjectEmbedding(d_model)
63
+ self.dropout = nn.Dropout(r_drop)
64
+ def forward(self, x_id, x, x_mark):
65
+ x = self.value_embedding(x) + self.positional_embedding(x) + self.time_embedding(x_mark)
66
+ x = torch.cat((self.subject_embedding(x_id).unsqueeze(1), x), dim=1)
67
+ return self.dropout(x)
68
+
69
+ # === Attention Modules ===
70
+ class CausalConv1d(torch.nn.Conv1d):
71
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
72
+ self.__padding = (kernel_size - 1) * dilation
73
+ super(CausalConv1d, self).__init__(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=self.__padding, dilation=dilation, groups=groups, bias=bias)
74
+ def forward(self, input):
75
+ result = super(CausalConv1d, self).forward(input)
76
+ if self.__padding != 0:
77
+ return result[:, :, :-self.__padding]
78
+ return result
79
+
80
+ class TriangularCausalMask():
81
+ def __init__(self, b, n, device="cpu"):
82
+ mask_shape = [b, 1, n, n]
83
+ with torch.no_grad():
84
+ self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device)
85
+ @property
86
+ def mask(self):
87
+ return self._mask
88
+
89
+ class MultiheadAttention(nn.Module):
90
+ def __init__(self, d_model, n_heads, d_keys, mask_flag, r_att_drop=0.1):
91
+ super(MultiheadAttention, self).__init__()
92
+ self.h, self.d, self.mask_flag = n_heads, d_keys, mask_flag
93
+ self.proj_q = nn.Linear(d_model, self.h * self.d)
94
+ self.proj_k = nn.Linear(d_model, self.h * self.d)
95
+ self.proj_v = nn.Linear(d_model, self.h * self.d)
96
+ self.proj_out = nn.Linear(self.h * self.d, d_model)
97
+ self.dropout = nn.Dropout(r_att_drop)
98
+ def forward(self, q, k, v):
99
+ b, n_q, n_k, h, d = q.size(0), q.size(1), k.size(1), self.h, self.d
100
+ q, k, v = self.proj_q(q), self.proj_k(k), self.proj_v(v)
101
+ q, k, v = map(lambda x: x.reshape(b, -1, h, d), [q, k, v])
102
+ scores = torch.einsum('bnhd,bmhd->bhnm', (q, k))
103
+ if self.mask_flag:
104
+ att_mask = TriangularCausalMask(b, n_q, device=q.device)
105
+ scores.masked_fill_(att_mask.mask, -np.inf)
106
+ att = F.softmax(scores / (self.d ** .5), dim=-1)
107
+ att = self.dropout(att)
108
+ att_out = torch.einsum('bhnm,bmhd->bnhd', (att, v))
109
+ att_out = att_out.reshape(b, -1, h * d)
110
+ out = self.proj_out(att_out)
111
+ return out
112
+
113
+ # === Encoder Modules ===
114
+ class ConvLayer(nn.Module):
115
+ def __init__(self, d_model):
116
+ super(ConvLayer, self).__init__()
117
+ self.downConv = nn.Conv1d(in_channels=d_model, out_channels=d_model, kernel_size=3, padding=1, padding_mode='circular')
118
+ self.norm = nn.BatchNorm1d(d_model)
119
+ self.activ = nn.ELU()
120
+ self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
121
+ def forward(self, x):
122
+ x = self.downConv(x.transpose(-1, 1))
123
+ x = self.norm(x)
124
+ x = self.activ(x)
125
+ x = self.maxPool(x)
126
+ x = x.transpose(-1, 1)
127
+ return x
128
+
129
+ class EncoderLayer(nn.Module):
130
+ def __init__(self, att, d_model, d_fcn, r_drop, activ="relu"):
131
+ super(EncoderLayer, self).__init__()
132
+ self.att = att
133
+ self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_fcn, kernel_size=1)
134
+ self.conv2 = nn.Conv1d(in_channels=d_fcn, out_channels=d_model, kernel_size=1)
135
+ self.norm1 = nn.LayerNorm(d_model)
136
+ self.norm2 = nn.LayerNorm(d_model)
137
+ self.dropout = nn.Dropout(r_drop)
138
+ self.activ = F.relu if activ == "relu" else F.gelu
139
+ def forward(self, x):
140
+ new_x = self.att(x, x, x)
141
+ x = x + self.dropout(new_x)
142
+ res = x = self.norm1(x)
143
+ res = self.dropout(self.activ(self.conv1(res.transpose(-1, 1))))
144
+ res = self.dropout(self.conv2(res).transpose(-1, 1))
145
+ return self.norm2(x + res)
146
+
147
+ class Encoder(nn.Module):
148
+ def __init__(self, enc_layers, conv_layers=None, norm_layer=None):
149
+ super(Encoder, self).__init__()
150
+ self.enc_layers = nn.ModuleList(enc_layers)
151
+ self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
152
+ self.norm = norm_layer
153
+ def forward(self, x):
154
+ if self.conv_layers is not None:
155
+ for enc_layer, conv_layer in zip(self.enc_layers, self.conv_layers):
156
+ x = enc_layer(x)
157
+ x = conv_layer(x)
158
+ x = self.enc_layers[-1](x)
159
+ else:
160
+ for enc_layer in self.enc_layers:
161
+ x = enc_layer(x)
162
+ if self.norm is not None:
163
+ x = self.norm(x)
164
+ return x
165
+
166
+ # === Decoder Modules ===
167
+ class DecoderLayer(nn.Module):
168
+ def __init__(self, self_att, cross_att, d_model, d_fcn, r_drop, activ="relu"):
169
+ super(DecoderLayer, self).__init__()
170
+ self.self_att = self_att
171
+ self.cross_att = cross_att
172
+ self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_fcn, kernel_size=1)
173
+ self.conv2 = nn.Conv1d(in_channels=d_fcn, out_channels=d_model, kernel_size=1)
174
+ self.norm1 = nn.LayerNorm(d_model)
175
+ self.norm2 = nn.LayerNorm(d_model)
176
+ self.norm3 = nn.LayerNorm(d_model)
177
+ self.dropout = nn.Dropout(r_drop)
178
+ self.activ = F.relu if activ == "relu" else F.gelu
179
+ def forward(self, x_dec, x_enc):
180
+ x_dec = x_dec + self.self_att(x_dec, x_dec, x_dec)
181
+ x_dec = self.norm1(x_dec)
182
+ x_dec = x_dec + self.cross_att(x_dec, x_enc, x_enc)
183
+ res = x_dec = self.norm2(x_dec)
184
+ res = self.dropout(self.activ(self.conv1(res.transpose(-1, 1))))
185
+ res = self.dropout(self.conv2(res).transpose(-1, 1))
186
+ return self.norm3(x_dec + res)
187
+
188
+ class Decoder(nn.Module):
189
+ def __init__(self, layers, norm_layer=None):
190
+ super(Decoder, self).__init__()
191
+ self.layers = nn.ModuleList(layers)
192
+ self.norm = norm_layer
193
+ def forward(self, x_dec, x_enc):
194
+ for layer in self.layers:
195
+ x_dec = layer(x_dec, x_enc)
196
+ if self.norm is not None:
197
+ x_dec = self.norm(x_dec)
198
+ return x_dec
199
+
200
+ # === Variance Module ===
201
+ class Variance(nn.Module):
202
+ def __init__(self, d_model, r_drop, len_seq):
203
+ super(Variance, self).__init__()
204
+ self.proj1 = nn.Linear(d_model, 1)
205
+ self.dropout = nn.Dropout(r_drop)
206
+ self.activ1 = nn.ReLU()
207
+ self.proj2 = nn.Linear(len_seq + 1, 1)
208
+ self.activ2 = nn.Tanh()
209
+ def forward(self, x):
210
+ x = self.proj1(x)
211
+ x = self.activ1(x)
212
+ x = self.dropout(x)
213
+ x = x.transpose(-1, 1)
214
+ x = self.proj2(x)
215
+ x = 10 * self.activ2(x)
216
+ return x
217
+
218
+ # === Gluformer Model ===
219
+ class Gluformer(nn.Module):
220
+ def __init__(self, d_model, n_heads, d_fcn, r_drop, activ, num_enc_layers, num_dec_layers, distil, len_seq, len_pred, num_features=5):
221
+ super(Gluformer, self).__init__()
222
+ self.len_pred = len_pred
223
+ self.enc_embedding = DataEmbedding(d_model, r_drop, num_features)
224
+ self.dec_embedding = DataEmbedding(d_model, r_drop, num_features)
225
+ self.encoder = Encoder(
226
+ [
227
+ EncoderLayer(
228
+ att=MultiheadAttention(d_model=d_model, n_heads=n_heads, d_keys=d_model // n_heads, mask_flag=False, r_att_drop=r_drop),
229
+ d_model=d_model,
230
+ d_fcn=d_fcn,
231
+ r_drop=r_drop,
232
+ activ=activ) for l in range(num_enc_layers)
233
+ ],
234
+ [
235
+ ConvLayer(d_model) for l in range(num_enc_layers - 1)
236
+ ] if distil else None,
237
+ norm_layer=torch.nn.LayerNorm(d_model)
238
+ )
239
+ self.decoder = Decoder(
240
+ [
241
+ DecoderLayer(
242
+ self_att=MultiheadAttention(d_model=d_model, n_heads=n_heads, d_keys=d_model // n_heads, mask_flag=True, r_att_drop=r_drop),
243
+ cross_att=MultiheadAttention(d_model=d_model, n_heads=n_heads, d_keys=d_model // n_heads, mask_flag=False, r_att_drop=r_drop),
244
+ d_model=d_model,
245
+ d_fcn=d_fcn,
246
+ r_drop=r_drop,
247
+ activ=activ) for l in range(num_dec_layers)
248
+ ],
249
+ norm_layer=torch.nn.LayerNorm(d_model)
250
+ )
251
+ D_OUT = 1
252
+ self.projection = nn.Linear(d_model, D_OUT, bias=True)
253
+ self.var = Variance(d_model, r_drop, len_seq)
254
+
255
+ def forward(self, x_id, x_enc, x_mark_enc, x_dec, x_mark_dec):
256
+ enc_out = self.enc_embedding(x_id, x_enc, x_mark_enc)
257
+ var_out = self.var(enc_out)
258
+ enc_out = self.encoder(enc_out)
259
+ dec_out = self.dec_embedding(x_id, x_dec, x_mark_dec)
260
+ dec_out = self.decoder(dec_out, enc_out)
261
+ dec_out = self.projection(dec_out)
262
+ return dec_out[:, -self.len_pred:, :], var_out
263
+
264
+ class GluformerConfig(PretrainedConfig):
265
+ model_type = "gluformer"
266
+ def __init__(self, d_model=64, n_heads=4, d_fcn=128, r_drop=0.1, activ="relu", num_enc_layers=2, num_dec_layers=2, distil=False, len_seq=48, len_pred=12, num_features=5, **kwargs):
267
+ super().__init__(**kwargs)
268
+ self.d_model = d_model
269
+ self.n_heads = n_heads
270
+ self.d_fcn = d_fcn
271
+ self.r_drop = r_drop
272
+ self.activ = activ
273
+ self.num_enc_layers = num_enc_layers
274
+ self.num_dec_layers = num_dec_layers
275
+ self.distil = distil
276
+ self.len_seq = len_seq
277
+ self.len_pred = len_pred
278
+ self.num_features = num_features
279
+
280
+ # Preprocessor for Gluformer model.
281
+ #
282
+ # - Normalizes input glucose
283
+ # - Converts timestamps to normalized floats
284
+ # - Slices input glucose and timestamps to provide to decoder.
285
+ class Preprocessor:
286
+ UPPER = 402
287
+ LOWER = 38
288
+ SCALE_1 = 5
289
+ SCALE_2 = 2
290
+ def __init__(self, len_seq, len_pred, len_label):
291
+ self.len_seq = len_seq
292
+ self.len_pred = len_pred
293
+ self.len_label = len_label
294
+
295
+ def normalize_glucose(self, glucose):
296
+ return (glucose - self.LOWER) / (self.UPPER - self.LOWER) * (self.SCALE_1 * self.SCALE_2) - self.SCALE_1
297
+
298
+ def unnormalize_glucose(self, glucose):
299
+ return (glucose + self.SCALE_1) / (self.SCALE_1 * self.SCALE_2) * (self.UPPER - self.LOWER) + self.LOWER
300
+
301
+ def normalize_datetime(self, date):
302
+ DAYS_YEAR = 182.5
303
+ DAYS_MONTH = 15.5
304
+ DAYS_WEEK = 3.5
305
+ HOURS_DAY = 12.0
306
+ MINUTES_HOUR = 30.0
307
+ OFFSET = 1
308
+ return np.array([date.timetuple().tm_yday / DAYS_YEAR - OFFSET,
309
+ date.day / DAYS_MONTH - OFFSET,
310
+ date.weekday() / DAYS_WEEK - OFFSET,
311
+ date.hour / HOURS_DAY - OFFSET,
312
+ date.minute / MINUTES_HOUR - OFFSET], dtype = float)
313
+
314
+ def __call__(self, subject_id, timestamps, glucose_values):
315
+ subject_id = torch.tensor([subject_id]).float()
316
+ glucose_values = torch.tensor(glucose_values).reshape(1, self.len_seq, 1).float()
317
+ glucose_values = self.normalize_glucose(glucose_values)
318
+
319
+ # Model takes any number of inputs to encoder.
320
+ # Decoder takes exactly 60 (5 hours of history) previous values with 12 (1 hour) of zeros.
321
+ # Timestamps for y are the corresponding timestamp for the 60 values passed into the decoder with 12 future values separated by 5 minutes.
322
+ y_timestamps = timestamps[-self.len_label:] + [timestamps[-1] + timedelta(minutes=5 * i) for i in range(self.len_pred)]
323
+ decoder_input = torch.cat([glucose_values[:,-self.len_label:,:], torch.zeros(1, self.len_pred, 1).float()], dim=1)
324
+
325
+ x_ts = torch.tensor(np.vstack([self.normalize_datetime(date) for date in timestamps])).float().unsqueeze(0)
326
+ y_ts = torch.tensor(np.vstack([self.normalize_datetime(date) for date in y_timestamps])).float().unsqueeze(0)
327
+ return subject_id, glucose_values, decoder_input, x_ts, y_ts
328
+
329
+ class GluformerForTimeSeries(PreTrainedModel):
330
+ config_class = GluformerConfig
331
+ base_model_prefix = "gluformer"
332
+
333
+ def __init__(self, config: GluformerConfig):
334
+ super().__init__(config)
335
+ self.model = Gluformer(
336
+ d_model=config.d_model,
337
+ n_heads=config.n_heads,
338
+ d_fcn=config.d_fcn,
339
+ r_drop=config.r_drop,
340
+ activ=config.activ,
341
+ num_enc_layers=config.num_enc_layers,
342
+ num_dec_layers=config.num_dec_layers,
343
+ distil=config.distil,
344
+ len_seq=config.len_seq,
345
+ len_pred=config.len_pred,
346
+ num_features=config.num_features
347
+ )
348
+ self.preprocessor = Preprocessor(config.len_seq, config.len_pred, config.len_label)
349
+
350
+ def forward(self, subject_id, timestamps, glucose_values):
351
+ x_id, x_enc, x_dec, x_mark_enc, y_mark_dec = self.preprocessor(subject_id, timestamps, glucose_values)
352
+ output, log_var = self.model(x_id, x_enc, x_mark_enc, x_dec, y_mark_dec)
353
+ return self.preprocessor.unnormalize_glucose(output), log_var