han-xudong commited on
Commit
5650402
·
verified ·
1 Parent(s): 14b1ce8

Upload modeling.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling.py +5 -4
modeling.py CHANGED
@@ -1,5 +1,6 @@
1
  import torch
2
  import torch.nn as nn
 
3
  from transformers import PreTrainedModel, PretrainedConfig
4
 
5
  class BallNetConfig(PretrainedConfig):
@@ -24,16 +25,16 @@ class BallNetConfig(PretrainedConfig):
24
 
25
 
26
  class Normalizer(nn.Module):
27
- def __init__(self, mean: Tensor, std: Tensor, eps: float = 1e-8):
28
  super().__init__()
29
  self.register_buffer("mean", mean)
30
  self.register_buffer("std", std)
31
  self.eps = eps
32
 
33
- def normalize(self, x: Tensor) -> Tensor:
34
  return (x - self.mean) / (self.std + self.eps)
35
 
36
- def denormalize(self, x: Tensor) -> Tensor:
37
  return x * (self.std + self.eps) + self.mean
38
 
39
 
@@ -108,7 +109,7 @@ class BallNetModel(PreTrainedModel):
108
 
109
  self.post_init()
110
 
111
- def forward(self, x: Tensor, **kwargs):
112
  """
113
  x: (B, 6)
114
  """
 
1
  import torch
2
  import torch.nn as nn
3
+ from typing import List
4
  from transformers import PreTrainedModel, PretrainedConfig
5
 
6
  class BallNetConfig(PretrainedConfig):
 
25
 
26
 
27
  class Normalizer(nn.Module):
28
+ def __init__(self, mean, std, eps: float = 1e-8):
29
  super().__init__()
30
  self.register_buffer("mean", mean)
31
  self.register_buffer("std", std)
32
  self.eps = eps
33
 
34
+ def normalize(self, x):
35
  return (x - self.mean) / (self.std + self.eps)
36
 
37
+ def denormalize(self, x):
38
  return x * (self.std + self.eps) + self.mean
39
 
40
 
 
109
 
110
  self.post_init()
111
 
112
+ def forward(self, x, **kwargs):
113
  """
114
  x: (B, 6)
115
  """