AliHamza852 commited on
Commit
c4ccaf8
·
verified ·
1 Parent(s): 878b290

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -0
app.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+ SEQUENCE_LENGTH = 10
6
+ INPUT_SIZE = 1
7
+ OUTPUT_SIZE = 1
8
+ HIDDEN_UNITS = 128
9
+ device = torch.device('cpu')
10
+
11
+ class Seq2Seq(nn.Module):
12
+ def __init__(self, input_size, hidden_size, output_size, seq_len):
13
+ super(Seq2Seq, self).__init__()
14
+ self.seq_len = seq_len
15
+ self.hidden_size = hidden_size
16
+
17
+ self.encoder_lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
18
+ self.decoder_lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True)
19
+ self.decoder_linear = nn.Linear(hidden_size, output_size)
20
+
21
+ def forward(self, x):
22
+ _, (hidden, cell) = self.encoder_lstm(x)
23
+
24
+ context_vector = hidden.permute(1, 0, 2)
25
+ decoder_input = context_vector.repeat(1, self.seq_len, 1)
26
+
27
+ decoder_output, _ = self.decoder_lstm(decoder_input, (hidden, cell))
28
+
29
+ prediction = self.decoder_linear(decoder_output)
30
+
31
+ return prediction
32
+
33
+ model_path = 'seq2seq_model_weights.pth'
34
+ model = Seq2Seq(INPUT_SIZE, HIDDEN_UNITS, OUTPUT_SIZE, SEQUENCE_LENGTH).to(device)
35
+ model.load_state_dict(torch.load(model_path, map_location=device))
36
+ model.eval()