Update model.py
Browse files
model.py
CHANGED
|
@@ -47,3 +47,8 @@ class GroupedAutoEncoder(nn.Module):
|
|
| 47 |
# Concatenate groups back together
|
| 48 |
# reconstructed = torch.cat(decoded_groups, dim=1)
|
| 49 |
return reconstructed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
# Concatenate groups back together
|
| 48 |
# reconstructed = torch.cat(decoded_groups, dim=1)
|
| 49 |
return reconstructed
|
| 50 |
+
|
| 51 |
+
input_dim = 5120
|
| 52 |
+
hidden_dim = 320
|
| 53 |
+
num_groups = 40
|
| 54 |
+
model = GroupedAutoEncoder(input_dim=input_dim, hidden_dim=hidden_dim, num_groups=num_groups).cuda()
|