Update model.py
Browse files
model.py
CHANGED
|
@@ -46,7 +46,7 @@ class FleshkaTabularTransformer(PreTrainedModel):
|
|
| 46 |
# x: [batch_size, input_dim]
|
| 47 |
out = list()
|
| 48 |
for nx in lx:
|
| 49 |
-
x = tensor(self._normalize(nx), dtype=float32).unsqueeze(0)
|
| 50 |
x = self.input_proj(x) # [batch_size, d_model]
|
| 51 |
x = x.unsqueeze(1) # [batch_size, 1, d_model] (добавляем seq_len=1)
|
| 52 |
x = self.transformer(x) # [batch_size, 1, d_model]
|
|
|
|
| 46 |
# x: [batch_size, input_dim]
|
| 47 |
out = list()
|
| 48 |
for nx in lx:
|
| 49 |
+
x = tensor(self._normalize(nx), dtype=float32).unsqueeze(0).to(self.device)
|
| 50 |
x = self.input_proj(x) # [batch_size, d_model]
|
| 51 |
x = x.unsqueeze(1) # [batch_size, 1, d_model] (добавляем seq_len=1)
|
| 52 |
x = self.transformer(x) # [batch_size, 1, d_model]
|