Update model.py
Browse files
model.py
CHANGED
|
@@ -42,11 +42,14 @@ class FleshkaTabularTransformer(PreTrainedModel):
|
|
| 42 |
if module.bias is not None:
|
| 43 |
nn.init.zeros_(module.bias)
|
| 44 |
|
| 45 |
-
def forward(self,
|
| 46 |
# x: [batch_size, input_dim]
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
if module.bias is not None:
|
| 43 |
nn.init.zeros_(module.bias)
|
| 44 |
|
| 45 |
+
def forward(self, lx: list):
|
| 46 |
# x: [batch_size, input_dim]
|
| 47 |
+
out = list()
|
| 48 |
+
for nx in lx:
|
| 49 |
+
x = tensor(self._normalize(nx), dtype=float32).unsqueeze(0)
|
| 50 |
+
x = self.input_proj(x) # [batch_size, d_model]
|
| 51 |
+
x = x.unsqueeze(1) # [batch_size, 1, d_model] (добавляем seq_len=1)
|
| 52 |
+
x = self.transformer(x) # [batch_size, 1, d_model]
|
| 53 |
+
x = x.squeeze(1) # [batch_size, d_model]
|
| 54 |
+
out.append(self.head(x).item().item() > 0)
|
| 55 |
+
return out
|