Update vllm_plugin/cloverlm_vllm.py
Browse files- vllm_plugin/cloverlm_vllm.py +39 -15
vllm_plugin/cloverlm_vllm.py
CHANGED
|
@@ -12,6 +12,7 @@ from vllm.model_executor.layers.attention import Attention
|
|
| 12 |
from vllm.model_executor.layers.layernorm import RMSNorm
|
| 13 |
from vllm.model_executor.layers.linear import (
|
| 14 |
ColumnParallelLinear,
|
|
|
|
| 15 |
RowParallelLinear,
|
| 16 |
)
|
| 17 |
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
@@ -22,6 +23,10 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
| 22 |
)
|
| 23 |
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
| 24 |
from vllm.model_executor.models.utils import AutoWeightsLoader, WeightsMapper
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
def _build_rope_cos_sin(
|
|
@@ -62,45 +67,64 @@ class CloverLMAttention(nn.Module):
|
|
| 62 |
prefix: str = "",
|
| 63 |
):
|
| 64 |
super().__init__()
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
| 67 |
self.head_dim = head_dim
|
| 68 |
-
self.q_size = num_heads * head_dim
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
self.lq = ColumnParallelLinear(
|
| 72 |
-
d,
|
| 73 |
quant_config=quant_config,
|
| 74 |
prefix=f"{prefix}.lq",
|
| 75 |
)
|
| 76 |
-
self.lk =
|
| 77 |
-
d,
|
| 78 |
quant_config=quant_config,
|
| 79 |
prefix=f"{prefix}.lk",
|
| 80 |
)
|
| 81 |
-
self.lv =
|
| 82 |
-
d,
|
| 83 |
quant_config=quant_config,
|
| 84 |
prefix=f"{prefix}.lv",
|
| 85 |
)
|
| 86 |
self.lo = RowParallelLinear(
|
| 87 |
-
|
| 88 |
quant_config=quant_config,
|
| 89 |
prefix=f"{prefix}.lo",
|
| 90 |
)
|
| 91 |
|
| 92 |
-
# Per-head learnable scale: stored as (1, heads, 1, 1) in checkpoint,
|
| 93 |
-
# reshaped to (heads,) for efficient multiply after sphere norm.
|
| 94 |
self.scale = nn.Parameter(
|
| 95 |
-
torch.empty(1, num_heads, 1, 1),
|
| 96 |
requires_grad=False,
|
| 97 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
self.attn = Attention(
|
| 100 |
-
num_heads=num_heads,
|
| 101 |
head_size=head_dim,
|
| 102 |
scale=1.0,
|
| 103 |
-
num_kv_heads=num_kv_heads,
|
| 104 |
cache_config=cache_config,
|
| 105 |
quant_config=quant_config,
|
| 106 |
prefix=f"{prefix}.attn",
|
|
|
|
| 12 |
from vllm.model_executor.layers.layernorm import RMSNorm
|
| 13 |
from vllm.model_executor.layers.linear import (
|
| 14 |
ColumnParallelLinear,
|
| 15 |
+
ReplicatedLinear,
|
| 16 |
RowParallelLinear,
|
| 17 |
)
|
| 18 |
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
|
|
| 23 |
)
|
| 24 |
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
| 25 |
from vllm.model_executor.models.utils import AutoWeightsLoader, WeightsMapper
|
| 26 |
+
from vllm.distributed import (
|
| 27 |
+
get_tensor_model_parallel_rank,
|
| 28 |
+
get_tensor_model_parallel_world_size,
|
| 29 |
+
)
|
| 30 |
|
| 31 |
|
| 32 |
def _build_rope_cos_sin(
|
|
|
|
| 67 |
prefix: str = "",
|
| 68 |
):
|
| 69 |
super().__init__()
|
| 70 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 71 |
+
tp_rank = get_tensor_model_parallel_rank()
|
| 72 |
+
|
| 73 |
+
self.num_heads = num_heads // tp_size
|
| 74 |
self.head_dim = head_dim
|
| 75 |
+
self.q_size = self.num_heads * head_dim
|
| 76 |
+
|
| 77 |
+
total_q_size = num_heads * head_dim
|
| 78 |
+
total_kv_size = num_kv_heads * head_dim
|
| 79 |
+
|
| 80 |
+
if num_kv_heads % tp_size == 0:
|
| 81 |
+
self.num_kv_heads = num_kv_heads // tp_size
|
| 82 |
+
kv_linear_cls = ColumnParallelLinear
|
| 83 |
+
else:
|
| 84 |
+
self.num_kv_heads = num_kv_heads
|
| 85 |
+
kv_linear_cls = ReplicatedLinear
|
| 86 |
+
|
| 87 |
+
self.kv_size = self.num_kv_heads * head_dim
|
| 88 |
|
| 89 |
self.lq = ColumnParallelLinear(
|
| 90 |
+
d, total_q_size, bias=False,
|
| 91 |
quant_config=quant_config,
|
| 92 |
prefix=f"{prefix}.lq",
|
| 93 |
)
|
| 94 |
+
self.lk = kv_linear_cls(
|
| 95 |
+
d, total_kv_size, bias=False,
|
| 96 |
quant_config=quant_config,
|
| 97 |
prefix=f"{prefix}.lk",
|
| 98 |
)
|
| 99 |
+
self.lv = kv_linear_cls(
|
| 100 |
+
d, total_kv_size, bias=False,
|
| 101 |
quant_config=quant_config,
|
| 102 |
prefix=f"{prefix}.lv",
|
| 103 |
)
|
| 104 |
self.lo = RowParallelLinear(
|
| 105 |
+
total_q_size, d, bias=False,
|
| 106 |
quant_config=quant_config,
|
| 107 |
prefix=f"{prefix}.lo",
|
| 108 |
)
|
| 109 |
|
|
|
|
|
|
|
| 110 |
self.scale = nn.Parameter(
|
| 111 |
+
torch.empty(1, self.num_heads, 1, 1),
|
| 112 |
requires_grad=False,
|
| 113 |
)
|
| 114 |
+
heads_per_tp = self.num_heads
|
| 115 |
+
|
| 116 |
+
def _scale_weight_loader(param, loaded_weight):
|
| 117 |
+
start = tp_rank * heads_per_tp
|
| 118 |
+
end = start + heads_per_tp
|
| 119 |
+
param.data.copy_(loaded_weight[:, start:end, :, :])
|
| 120 |
+
|
| 121 |
+
self.scale.weight_loader = _scale_weight_loader
|
| 122 |
|
| 123 |
self.attn = Attention(
|
| 124 |
+
num_heads=self.num_heads,
|
| 125 |
head_size=head_dim,
|
| 126 |
scale=1.0,
|
| 127 |
+
num_kv_heads=self.num_kv_heads,
|
| 128 |
cache_config=cache_config,
|
| 129 |
quant_config=quant_config,
|
| 130 |
prefix=f"{prefix}.attn",
|