flpelerin commited on
Commit
41c262e
·
verified ·
1 Parent(s): 148ab0d

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +2 -2
model.py CHANGED
@@ -45,7 +45,7 @@ class minGRU(nn.Module):
45
  gate = gate.sigmoid()
46
  out = torch.lerp(h_prev, hidden, gate) if h_prev is not None else (hidden * gate)
47
 
48
- h_next = out[:, -1]
49
  out = self.out_proj(out)
50
 
51
  return out, h_next
@@ -70,7 +70,7 @@ class minGRU(nn.Module):
70
  out = heinsen_associative_scan_log(log_coeffs, log_values)
71
  out = out[:, -seq_len:]
72
 
73
- h_next = out[:, -1]
74
  out = self.out_proj(out)
75
 
76
  return out, h_next
 
45
  gate = gate.sigmoid()
46
  out = torch.lerp(h_prev, hidden, gate) if h_prev is not None else (hidden * gate)
47
 
48
+ h_next = out[:, -1:]
49
  out = self.out_proj(out)
50
 
51
  return out, h_next
 
70
  out = heinsen_associative_scan_log(log_coeffs, log_values)
71
  out = out[:, -seq_len:]
72
 
73
+ h_next = out[:, -1:]
74
  out = self.out_proj(out)
75
 
76
  return out, h_next