Junnyfirst34 commited on
Commit
31899f4
·
verified ·
1 Parent(s): b138e12

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +19 -0
train.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class SensorEncoder(nn.Module):
5
+ def __init__(self):
6
+ super().__init__()
7
+ self.fc = nn.Linear(10, 32)
8
+
9
+ def forward(self, x):
10
+ return self.fc(x)
11
+
12
+ def main():
13
+ model = SensorEncoder()
14
+ dummy_input = torch.randn(1, 10)
15
+ output = model(dummy_input)
16
+ print(output)
17
+
18
+ if __name__ == "__main__":
19
+ main()