HemanM commited on
Commit
bd6b890
·
verified ·
1 Parent(s): b54253f

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +4 -1
model.py CHANGED
@@ -1,11 +1,14 @@
 
1
  import torch.nn as nn
2
 
3
  class SimpleEvoModel(nn.Module):
4
- def __init__(self, input_dim=768, hidden_dim=256, output_dim=2):
5
  super().__init__()
6
  self.model = nn.Sequential(
 
7
  nn.Linear(input_dim, hidden_dim),
8
  nn.ReLU(),
 
9
  nn.Linear(hidden_dim, output_dim)
10
  )
11
 
 
1
+ import torch
2
  import torch.nn as nn
3
 
4
  class SimpleEvoModel(nn.Module):
5
+ def __init__(self, input_dim=384, hidden_dim=256, output_dim=2, dropout=0.1):
6
  super().__init__()
7
  self.model = nn.Sequential(
8
+ nn.LayerNorm(input_dim),
9
  nn.Linear(input_dim, hidden_dim),
10
  nn.ReLU(),
11
+ nn.Dropout(dropout),
12
  nn.Linear(hidden_dim, output_dim)
13
  )
14