SnowFlash383935 commited on
Commit
bae15c5
·
verified ·
1 Parent(s): bb736c2

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +10 -7
model.py CHANGED
@@ -42,11 +42,14 @@ class FleshkaTabularTransformer(PreTrainedModel):
42
  if module.bias is not None:
43
  nn.init.zeros_(module.bias)
44
 
45
- def forward(self, x: list):
46
  # x: [batch_size, input_dim]
47
- x = tensor(self._normalize(x), dtype=float32).unsqueeze(0)
48
- x = self.input_proj(x) # [batch_size, d_model]
49
- x = x.unsqueeze(1) # [batch_size, 1, d_model] (добавляем seq_len=1)
50
- x = self.transformer(x) # [batch_size, 1, d_model]
51
- x = x.squeeze(1) # [batch_size, d_model]
52
- return self.head(x) # [batch_size, 1]
 
 
 
 
42
  if module.bias is not None:
43
  nn.init.zeros_(module.bias)
44
 
45
+ def forward(self, lx: list):
46
  # x: [batch_size, input_dim]
47
+ out = list()
48
+ for nx in lx:
49
+ x = tensor(self._normalize(nx), dtype=float32).unsqueeze(0)
50
+ x = self.input_proj(x) # [batch_size, d_model]
51
+ x = x.unsqueeze(1) # [batch_size, 1, d_model] (добавляем seq_len=1)
52
+ x = self.transformer(x) # [batch_size, 1, d_model]
53
+ x = x.squeeze(1) # [batch_size, d_model]
54
+ out.append(self.head(x).item().item() > 0)
55
+ return out