Force curope, TODO: fix pytorch rope implementation for bfloat16
Browse files- models/pos_embed.py +7 -4
models/pos_embed.py
CHANGED
|
@@ -104,8 +104,9 @@ try:
|
|
| 104 |
from extensions.curope import cuRoPE2D
|
| 105 |
RoPE2D = cuRoPE2D
|
| 106 |
except ImportError:
|
| 107 |
-
|
| 108 |
-
|
|
|
|
| 109 |
class RoPE2D(torch.nn.Module):
|
| 110 |
|
| 111 |
def __init__(self, freq=100.0, F0=1.0):
|
|
@@ -135,7 +136,7 @@ except ImportError:
|
|
| 135 |
cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
|
| 136 |
sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
|
| 137 |
return (tokens * cos) + (self.rotate_half(tokens) * sin)
|
| 138 |
-
|
| 139 |
def forward(self, tokens, positions):
|
| 140 |
"""
|
| 141 |
input:
|
|
@@ -144,6 +145,8 @@ except ImportError:
|
|
| 144 |
output:
|
| 145 |
* tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
|
| 146 |
"""
|
|
|
|
|
|
|
| 147 |
assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two"
|
| 148 |
D = tokens.size(3) // 2
|
| 149 |
assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2
|
|
@@ -153,4 +156,4 @@ except ImportError:
|
|
| 153 |
y = self.apply_rope1d(y, positions[:,:,0], cos, sin)
|
| 154 |
x = self.apply_rope1d(x, positions[:,:,1], cos, sin)
|
| 155 |
tokens = torch.cat((y, x), dim=-1)
|
| 156 |
-
return tokens
|
|
|
|
| 104 |
from extensions.curope import cuRoPE2D
|
| 105 |
RoPE2D = cuRoPE2D
|
| 106 |
except ImportError:
|
| 107 |
+
# critical error, we need to use the slow pytorch version
|
| 108 |
+
raise ImportError("CUDA-compiled version of RoPE2D is required but could not be found. Please compile the CUDA extension before running.")
|
| 109 |
+
|
| 110 |
class RoPE2D(torch.nn.Module):
|
| 111 |
|
| 112 |
def __init__(self, freq=100.0, F0=1.0):
|
|
|
|
| 136 |
cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
|
| 137 |
sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
|
| 138 |
return (tokens * cos) + (self.rotate_half(tokens) * sin)
|
| 139 |
+
|
| 140 |
def forward(self, tokens, positions):
|
| 141 |
"""
|
| 142 |
input:
|
|
|
|
| 145 |
output:
|
| 146 |
* tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
|
| 147 |
"""
|
| 148 |
+
tokens = tokens.to(torch.float32)
|
| 149 |
+
#positions = positions.to(torch.float32)
|
| 150 |
assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two"
|
| 151 |
D = tokens.size(3) // 2
|
| 152 |
assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2
|
|
|
|
| 156 |
y = self.apply_rope1d(y, positions[:,:,0], cos, sin)
|
| 157 |
x = self.apply_rope1d(x, positions[:,:,1], cos, sin)
|
| 158 |
tokens = torch.cat((y, x), dim=-1)
|
| 159 |
+
return tokens
|