AliHamza852 commited on
Commit
73066f5
·
verified ·
1 Parent(s): 8391596

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -7
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import torch
2
  import torch.nn as nn
3
  import numpy as np
 
4
 
5
  SEQUENCE_LENGTH = 10
6
  INPUT_SIZE = 1
@@ -13,24 +14,55 @@ class Seq2Seq(nn.Module):
13
  super(Seq2Seq, self).__init__()
14
  self.seq_len = seq_len
15
  self.hidden_size = hidden_size
16
-
17
  self.encoder_lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
18
  self.decoder_lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True)
19
  self.decoder_linear = nn.Linear(hidden_size, output_size)
20
 
21
  def forward(self, x):
22
  _, (hidden, cell) = self.encoder_lstm(x)
23
-
24
  context_vector = hidden.permute(1, 0, 2)
25
  decoder_input = context_vector.repeat(1, self.seq_len, 1)
26
-
27
  decoder_output, _ = self.decoder_lstm(decoder_input, (hidden, cell))
28
-
29
  prediction = self.decoder_linear(decoder_output)
30
-
31
  return prediction
32
 
33
- model_path = 'seq2seq_model_weights.pth'
34
  model = Seq2Seq(INPUT_SIZE, HIDDEN_UNITS, OUTPUT_SIZE, SEQUENCE_LENGTH).to(device)
35
  model.load_state_dict(torch.load(model_path, map_location=device))
36
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  import numpy as np
4
+ import gradio as gr
5
 
6
  SEQUENCE_LENGTH = 10
7
  INPUT_SIZE = 1
 
14
  super(Seq2Seq, self).__init__()
15
  self.seq_len = seq_len
16
  self.hidden_size = hidden_size
 
17
  self.encoder_lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
18
  self.decoder_lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True)
19
  self.decoder_linear = nn.Linear(hidden_size, output_size)
20
 
21
  def forward(self, x):
22
  _, (hidden, cell) = self.encoder_lstm(x)
 
23
  context_vector = hidden.permute(1, 0, 2)
24
  decoder_input = context_vector.repeat(1, self.seq_len, 1)
 
25
  decoder_output, _ = self.decoder_lstm(decoder_input, (hidden, cell))
 
26
  prediction = self.decoder_linear(decoder_output)
 
27
  return prediction
28
 
29
+ model_path = 'seq2seq_model_weights.pth'
30
  model = Seq2Seq(INPUT_SIZE, HIDDEN_UNITS, OUTPUT_SIZE, SEQUENCE_LENGTH).to(device)
31
  model.load_state_dict(torch.load(model_path, map_location=device))
32
+ model.eval()
33
+
34
+ def predict_sequence(input_text):
35
+ try:
36
+ numbers = [float(n.strip()) for n in input_text.split(',')]
37
+
38
+ if len(numbers) != SEQUENCE_LENGTH:
39
+ return f"Error: Please enter exactly {SEQUENCE_LENGTH} numbers, separated by commas."
40
+
41
+ input_array = np.array(numbers).reshape(1, SEQUENCE_LENGTH, 1)
42
+ input_tensor = torch.from_numpy(input_array).float().to(device)
43
+
44
+ with torch.no_grad():
45
+ prediction_tensor = model(input_tensor)
46
+
47
+ output_array = prediction_tensor.cpu().numpy().flatten()
48
+
49
+ output_text = ", ".join([f"{n:.1f}" for n in output_array])
50
+
51
+ return output_text
52
+
53
+ except Exception as e:
54
+ return f"An error occurred: {str(e)}"
55
+
56
+ demo = gr.Interface(
57
+ fn=predict_sequence,
58
+ inputs=gr.Textbox(
59
+ label="Input Sequence",
60
+ placeholder=f"Enter {SEQUENCE_LENGTH} numbers, e.g., 1, 2, 3, 4, 5, 6, 7, 8, 9, 10"
61
+ ),
62
+ outputs=gr.Textbox(label="Predicted Sequence"),
63
+ title="Q11: Seq2Seq Model (n -> n+1)",
64
+ description="Enter a sequence of 10 numbers to predict the next sequence.",
65
+ allow_flagging="never"
66
+ )
67
+
68
+ demo.launch()