Update model.py
Browse files
model.py
CHANGED
|
@@ -98,7 +98,7 @@ ALL_LAYERNORM_LAYERS.append(BharataiRMSNorm)
|
|
| 98 |
|
| 99 |
|
| 100 |
class BharataiRotaryEmbedding(nn.Module):
|
| 101 |
-
def __init__(self, dim, max_position_embeddings=
|
| 102 |
super().__init__()
|
| 103 |
|
| 104 |
self.dim = dim
|
|
@@ -136,7 +136,7 @@ class BharataiRotaryEmbedding(nn.Module):
|
|
| 136 |
class BharataiLinearScalingRotaryEmbedding(BharataiRotaryEmbedding):
|
| 137 |
"""BharataiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
| 138 |
|
| 139 |
-
def __init__(self, dim, max_position_embeddings=
|
| 140 |
self.scaling_factor = scaling_factor
|
| 141 |
super().__init__(dim, max_position_embeddings, base, device)
|
| 142 |
|
|
@@ -155,7 +155,7 @@ class BharataiLinearScalingRotaryEmbedding(BharataiRotaryEmbedding):
|
|
| 155 |
class BharataiDynamicNTKScalingRotaryEmbedding(BharataiRotaryEmbedding):
|
| 156 |
"""BharataiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
| 157 |
|
| 158 |
-
def __init__(self, dim, max_position_embeddings=
|
| 159 |
self.scaling_factor = scaling_factor
|
| 160 |
super().__init__(dim, max_position_embeddings, base, device)
|
| 161 |
|
|
|
|
| 98 |
|
| 99 |
|
| 100 |
class BharataiRotaryEmbedding(nn.Module):
|
| 101 |
+
def __init__(self, dim, max_position_embeddings=16384, base=10000, device=None):
|
| 102 |
super().__init__()
|
| 103 |
|
| 104 |
self.dim = dim
|
|
|
|
| 136 |
class BharataiLinearScalingRotaryEmbedding(BharataiRotaryEmbedding):
|
| 137 |
"""BharataiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
| 138 |
|
| 139 |
+
def __init__(self, dim, max_position_embeddings=16384, base=10000, device=None, scaling_factor=1.0):
|
| 140 |
self.scaling_factor = scaling_factor
|
| 141 |
super().__init__(dim, max_position_embeddings, base, device)
|
| 142 |
|
|
|
|
| 155 |
class BharataiDynamicNTKScalingRotaryEmbedding(BharataiRotaryEmbedding):
|
| 156 |
"""BharataiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
| 157 |
|
| 158 |
+
def __init__(self, dim, max_position_embeddings=16384, base=10000, device=None, scaling_factor=1.0):
|
| 159 |
self.scaling_factor = scaling_factor
|
| 160 |
super().__init__(dim, max_position_embeddings, base, device)
|
| 161 |
|