Update model.py
Browse files
model.py
CHANGED
|
@@ -45,7 +45,7 @@ class minGRU(nn.Module):
|
|
| 45 |
gate = gate.sigmoid()
|
| 46 |
out = torch.lerp(h_prev, hidden, gate) if h_prev is not None else (hidden * gate)
|
| 47 |
|
| 48 |
-
h_next = out[:, -1]
|
| 49 |
out = self.out_proj(out)
|
| 50 |
|
| 51 |
return out, h_next
|
|
@@ -70,7 +70,7 @@ class minGRU(nn.Module):
|
|
| 70 |
out = heinsen_associative_scan_log(log_coeffs, log_values)
|
| 71 |
out = out[:, -seq_len:]
|
| 72 |
|
| 73 |
-
h_next = out[:, -1]
|
| 74 |
out = self.out_proj(out)
|
| 75 |
|
| 76 |
return out, h_next
|
|
|
|
| 45 |
gate = gate.sigmoid()
|
| 46 |
out = torch.lerp(h_prev, hidden, gate) if h_prev is not None else (hidden * gate)
|
| 47 |
|
| 48 |
+
h_next = out[:, -1:]
|
| 49 |
out = self.out_proj(out)
|
| 50 |
|
| 51 |
return out, h_next
|
|
|
|
| 70 |
out = heinsen_associative_scan_log(log_coeffs, log_values)
|
| 71 |
out = out[:, -seq_len:]
|
| 72 |
|
| 73 |
+
h_next = out[:, -1:]
|
| 74 |
out = self.out_proj(out)
|
| 75 |
|
| 76 |
return out, h_next
|