mansaripo commited on
Commit
6f7dfbe
·
verified ·
1 Parent(s): d11440f

Update vllm_plugin/cloverlm_vllm.py

Browse files
Files changed (1) hide show
  1. vllm_plugin/cloverlm_vllm.py +39 -15
vllm_plugin/cloverlm_vllm.py CHANGED
@@ -12,6 +12,7 @@ from vllm.model_executor.layers.attention import Attention
12
  from vllm.model_executor.layers.layernorm import RMSNorm
13
  from vllm.model_executor.layers.linear import (
14
  ColumnParallelLinear,
 
15
  RowParallelLinear,
16
  )
17
  from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -22,6 +23,10 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
22
  )
23
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
24
  from vllm.model_executor.models.utils import AutoWeightsLoader, WeightsMapper
 
 
 
 
25
 
26
 
27
  def _build_rope_cos_sin(
@@ -62,45 +67,64 @@ class CloverLMAttention(nn.Module):
62
  prefix: str = "",
63
  ):
64
  super().__init__()
65
- self.num_heads = num_heads
66
- self.num_kv_heads = num_kv_heads
 
 
67
  self.head_dim = head_dim
68
- self.q_size = num_heads * head_dim
69
- self.kv_size = num_kv_heads * head_dim
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  self.lq = ColumnParallelLinear(
72
- d, self.q_size, bias=False,
73
  quant_config=quant_config,
74
  prefix=f"{prefix}.lq",
75
  )
76
- self.lk = ColumnParallelLinear(
77
- d, self.kv_size, bias=False,
78
  quant_config=quant_config,
79
  prefix=f"{prefix}.lk",
80
  )
81
- self.lv = ColumnParallelLinear(
82
- d, self.kv_size, bias=False,
83
  quant_config=quant_config,
84
  prefix=f"{prefix}.lv",
85
  )
86
  self.lo = RowParallelLinear(
87
- self.q_size, d, bias=False,
88
  quant_config=quant_config,
89
  prefix=f"{prefix}.lo",
90
  )
91
 
92
- # Per-head learnable scale: stored as (1, heads, 1, 1) in checkpoint,
93
- # reshaped to (heads,) for efficient multiply after sphere norm.
94
  self.scale = nn.Parameter(
95
- torch.empty(1, num_heads, 1, 1),
96
  requires_grad=False,
97
  )
 
 
 
 
 
 
 
 
98
 
99
  self.attn = Attention(
100
- num_heads=num_heads,
101
  head_size=head_dim,
102
  scale=1.0,
103
- num_kv_heads=num_kv_heads,
104
  cache_config=cache_config,
105
  quant_config=quant_config,
106
  prefix=f"{prefix}.attn",
 
12
  from vllm.model_executor.layers.layernorm import RMSNorm
13
  from vllm.model_executor.layers.linear import (
14
  ColumnParallelLinear,
15
+ ReplicatedLinear,
16
  RowParallelLinear,
17
  )
18
  from vllm.model_executor.layers.logits_processor import LogitsProcessor
 
23
  )
24
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
25
  from vllm.model_executor.models.utils import AutoWeightsLoader, WeightsMapper
26
+ from vllm.distributed import (
27
+ get_tensor_model_parallel_rank,
28
+ get_tensor_model_parallel_world_size,
29
+ )
30
 
31
 
32
  def _build_rope_cos_sin(
 
67
  prefix: str = "",
68
  ):
69
  super().__init__()
70
+ tp_size = get_tensor_model_parallel_world_size()
71
+ tp_rank = get_tensor_model_parallel_rank()
72
+
73
+ self.num_heads = num_heads // tp_size
74
  self.head_dim = head_dim
75
+ self.q_size = self.num_heads * head_dim
76
+
77
+ total_q_size = num_heads * head_dim
78
+ total_kv_size = num_kv_heads * head_dim
79
+
80
+ if num_kv_heads % tp_size == 0:
81
+ self.num_kv_heads = num_kv_heads // tp_size
82
+ kv_linear_cls = ColumnParallelLinear
83
+ else:
84
+ self.num_kv_heads = num_kv_heads
85
+ kv_linear_cls = ReplicatedLinear
86
+
87
+ self.kv_size = self.num_kv_heads * head_dim
88
 
89
  self.lq = ColumnParallelLinear(
90
+ d, total_q_size, bias=False,
91
  quant_config=quant_config,
92
  prefix=f"{prefix}.lq",
93
  )
94
+ self.lk = kv_linear_cls(
95
+ d, total_kv_size, bias=False,
96
  quant_config=quant_config,
97
  prefix=f"{prefix}.lk",
98
  )
99
+ self.lv = kv_linear_cls(
100
+ d, total_kv_size, bias=False,
101
  quant_config=quant_config,
102
  prefix=f"{prefix}.lv",
103
  )
104
  self.lo = RowParallelLinear(
105
+ total_q_size, d, bias=False,
106
  quant_config=quant_config,
107
  prefix=f"{prefix}.lo",
108
  )
109
 
 
 
110
  self.scale = nn.Parameter(
111
+ torch.empty(1, self.num_heads, 1, 1),
112
  requires_grad=False,
113
  )
114
+ heads_per_tp = self.num_heads
115
+
116
+ def _scale_weight_loader(param, loaded_weight):
117
+ start = tp_rank * heads_per_tp
118
+ end = start + heads_per_tp
119
+ param.data.copy_(loaded_weight[:, start:end, :, :])
120
+
121
+ self.scale.weight_loader = _scale_weight_loader
122
 
123
  self.attn = Attention(
124
+ num_heads=self.num_heads,
125
  head_size=head_dim,
126
  scale=1.0,
127
+ num_kv_heads=self.num_kv_heads,
128
  cache_config=cache_config,
129
  quant_config=quant_config,
130
  prefix=f"{prefix}.attn",