xjlulu commited on
Commit
a716d4e
Β·
1 Parent(s): a5b6307
Files changed (1) hide show
  1. model.py +26 -19
model.py CHANGED
@@ -1,5 +1,4 @@
1
  from typing import Dict
2
-
3
  import torch
4
  import torch.nn as nn
5
 
@@ -16,14 +15,11 @@ class SeqClassifier(nn.Module):
16
  num_class: int,
17
  ) -> None:
18
  super(SeqClassifier, self).__init__()
 
 
19
  self.embed = nn.Embedding.from_pretrained(embeddings, freeze=False)
20
- self.hidden_size=hidden_size
21
- self.num_layers=num_layers
22
- self.dropout=dropout
23
- self.bidirectional=bidirectional
24
- self.num_class=num_class
25
 
26
- # model architecture
27
  self.rnn = nn.LSTM(
28
  input_size=embeddings.size(1),
29
  hidden_size=hidden_size,
@@ -32,38 +28,49 @@ class SeqClassifier(nn.Module):
32
  bidirectional=bidirectional,
33
  batch_first=True
34
  )
35
- self.dropout_layer = nn.Dropout(p=self.dropout)
 
 
 
 
36
  self.fc = nn.Linear(self.encoder_output_size, num_class)
37
 
 
 
 
 
 
 
 
38
  @property
39
  def encoder_output_size(self) -> int:
40
- # calculate the output dimension of rnn
41
  if self.bidirectional:
42
  return self.hidden_size * 2
43
  else:
44
  return self.hidden_size
45
 
46
  def forward(self, batch) -> torch.Tensor:
47
- # ε°‡θΌΈε…₯塌ε…₯到詞塌ε…₯η©Ίι–“οΌŒε°±ζ˜―ζŠŠθ©žη΄’εΌ•ζ›ζˆθ©žε‘ι‡
48
  embedded = self.embed(batch)
49
 
50
- # 過 LSTM 局
51
  rnn_output, _ = self.rnn(embedded)
52
  rnn_output = self.dropout_layer(rnn_output)
53
 
54
  if not self.training:
55
- last_hidden_state_forward = rnn_output[ -1, :self.hidden_size] # ζ­£ε‘ζ–Ήε‘ηš„ιšθ—ηŠΆζ€
56
- last_hidden_state_backward = rnn_output[ 0, self.hidden_size:] # εε‘ζ–Ήε‘ηš„ιšθ—ηŠΆζ€
57
  combined_hidden_state = torch.cat((last_hidden_state_forward, last_hidden_state_backward), dim=0)
58
 
59
- # ι€šιŽε…¨ι€£ζŽ₯ε±€
60
  logits = self.fc(combined_hidden_state)
61
- return logits # θΏ”ε›žι ζΈ¬η΅ζžœ
62
 
63
- last_hidden_state_forward = rnn_output[:, -1, :self.hidden_size] # ζ­£ε‘ζ–Ήε‘ηš„ιšθ—ηŠΆζ€
64
- last_hidden_state_backward = rnn_output[:, 0, self.hidden_size:] # εε‘ζ–Ήε‘ηš„ιšθ—ηŠΆζ€
65
  combined_hidden_state = torch.cat((last_hidden_state_forward, last_hidden_state_backward), dim=1)
66
 
67
- # ι€šιŽε…¨ι€£ζŽ₯ε±€
68
  logits = self.fc(combined_hidden_state)
69
- return logits # θΏ”ε›žι ζΈ¬η΅ζžœ
 
1
  from typing import Dict
 
2
  import torch
3
  import torch.nn as nn
4
 
 
15
  num_class: int,
16
  ) -> None:
17
  super(SeqClassifier, self).__init__()
18
+
19
+ # Word embeddings layer
20
  self.embed = nn.Embedding.from_pretrained(embeddings, freeze=False)
 
 
 
 
 
21
 
22
+ # LSTM layer
23
  self.rnn = nn.LSTM(
24
  input_size=embeddings.size(1),
25
  hidden_size=hidden_size,
 
28
  bidirectional=bidirectional,
29
  batch_first=True
30
  )
31
+
32
+ # Dropout layer
33
+ self.dropout_layer = nn.Dropout(p=dropout)
34
+
35
+ # Fully connected layer for classification
36
  self.fc = nn.Linear(self.encoder_output_size, num_class)
37
 
38
+ # Model parameters
39
+ self.hidden_size = hidden_size
40
+ self.num_layers = num_layers
41
+ self.dropout = dropout
42
+ self.bidirectional = bidirectional
43
+ self.num_class = num_class
44
+
45
  @property
46
  def encoder_output_size(self) -> int:
47
+ # Calculate the output dimension of the RNN
48
  if self.bidirectional:
49
  return self.hidden_size * 2
50
  else:
51
  return self.hidden_size
52
 
53
  def forward(self, batch) -> torch.Tensor:
54
+ # Embed the input into the word embedding space
55
  embedded = self.embed(batch)
56
 
57
+ # Pass through the LSTM layer
58
  rnn_output, _ = self.rnn(embedded)
59
  rnn_output = self.dropout_layer(rnn_output)
60
 
61
  if not self.training:
62
+ last_hidden_state_forward = rnn_output[-1, :self.hidden_size] # Forward hidden state
63
+ last_hidden_state_backward = rnn_output[0, self.hidden_size:] # Backward hidden state
64
  combined_hidden_state = torch.cat((last_hidden_state_forward, last_hidden_state_backward), dim=0)
65
 
66
+ # Pass through the fully connected layer
67
  logits = self.fc(combined_hidden_state)
68
+ return logits # Return predictions
69
 
70
+ last_hidden_state_forward = rnn_output[:, -1, :self.hidden_size] # Forward hidden state
71
+ last_hidden_state_backward = rnn_output[:, 0, self.hidden_size:] # Backward hidden state
72
  combined_hidden_state = torch.cat((last_hidden_state_forward, last_hidden_state_backward), dim=1)
73
 
74
+ # Pass through the fully connected layer
75
  logits = self.fc(combined_hidden_state)
76
+ return logits # Return predictions