Changes for Core ML conversion
Browse files- modelling_RW.py +3 -3
modelling_RW.py
CHANGED
|
@@ -29,7 +29,7 @@ logger = logging.get_logger(__name__)
|
|
| 29 |
# In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
|
| 30 |
class Linear(nn.Linear):
|
| 31 |
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 32 |
-
ret = input @ self.weight.T
|
| 33 |
if self.bias is None:
|
| 34 |
return ret
|
| 35 |
else:
|
|
@@ -68,7 +68,7 @@ class RotaryEmbedding(torch.nn.Module):
|
|
| 68 |
self,
|
| 69 |
seq_len: int,
|
| 70 |
device="cuda",
|
| 71 |
-
dtype=torch.
|
| 72 |
) -> torch.Tensor:
|
| 73 |
if seq_len != self.seq_len_cached:
|
| 74 |
self.seq_len_cached = seq_len
|
|
@@ -89,7 +89,7 @@ class RotaryEmbedding(torch.nn.Module):
|
|
| 89 |
|
| 90 |
def forward(self, q, k):
|
| 91 |
batch, seq_len, head_dim = q.shape
|
| 92 |
-
cos, sin = self.cos_sin(seq_len, q.device
|
| 93 |
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
| 94 |
|
| 95 |
|
|
|
|
| 29 |
# In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
|
| 30 |
class Linear(nn.Linear):
|
| 31 |
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 32 |
+
ret = input @ self.weight.permute(1, 0) #transpose(0, 1) #.T
|
| 33 |
if self.bias is None:
|
| 34 |
return ret
|
| 35 |
else:
|
|
|
|
| 68 |
self,
|
| 69 |
seq_len: int,
|
| 70 |
device="cuda",
|
| 71 |
+
dtype=torch.float16,
|
| 72 |
) -> torch.Tensor:
|
| 73 |
if seq_len != self.seq_len_cached:
|
| 74 |
self.seq_len_cached = seq_len
|
|
|
|
| 89 |
|
| 90 |
def forward(self, q, k):
|
| 91 |
batch, seq_len, head_dim = q.shape
|
| 92 |
+
cos, sin = self.cos_sin(seq_len, q.device)
|
| 93 |
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
| 94 |
|
| 95 |
|