SnowFlash383935 commited on
Commit
50af525
·
verified ·
1 Parent(s): 24271f4

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +1 -1
model.py CHANGED
@@ -45,7 +45,7 @@ class FleshkaTabularTransformer(PreTrainedModel):
45
  out = []
46
  for nx in lx:
47
  # Убедимся, что входные данные в float16
48
- x = tensor(self._normalize(nx), dtype=self.input_proj.dtype).unsqueeze(0).to(self.device)
49
  x = self.input_proj(x)
50
  x = x.unsqueeze(1)
51
  x = self.transformer(x)
 
45
  out = []
46
  for nx in lx:
47
  # Убедимся, что входные данные в float16
48
+ x = tensor(self._normalize(nx), dtype=self.input_proj.weight.dtype).unsqueeze(0).to(self.device)
49
  x = self.input_proj(x)
50
  x = x.unsqueeze(1)
51
  x = self.transformer(x)