| import torch | |
| from torch import nn | |
| from sglang.srt.distributed import ( | |
| get_tensor_model_parallel_rank, | |
| split_tensor_along_last_dim, | |
| tensor_model_parallel_all_gather, | |
| tensor_model_parallel_all_reduce, | |
| ) | |
| from sglang.srt.layers.linear import ( | |
| ColumnParallelLinear, | |
| MergedColumnParallelLinear, | |
| QKVParallelLinear, | |
| RowParallelLinear, | |
| ) | |
| from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding | |
| from sglang.srt.lora.backend.base_backend import BaseLoRABackend | |
| class BaseLayerWithLoRA(nn.Module): | |
| def __init__( | |
| self, | |
| base_layer: nn.Module, | |
| lora_backend: BaseLoRABackend, | |
| ): | |
| super().__init__() | |
| self.base_layer: nn.Module = base_layer | |
| self.set_lora: bool = False | |
| self.lora_backend: BaseLoRABackend = lora_backend | |
| def forward(self, x: torch.Tensor): | |
| return self.base_layer.forward(x) | |
| def set_lora_info(self, *args): | |
| pass | |
| def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): | |
| pass | |
| def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): | |
| pass | |
| class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): | |
| """ | |
| Vocab parallel embedding layer with support for LoRA (Low-Rank Adaptation). | |
| Note: The current version does not yet implement the LoRA functionality. | |
| This class behaves exactly the same as the base VocabParallelEmbedding. | |
| Future versions will integrate LoRA functionality to support efficient parameter fine-tuning. | |
| """ | |
| def __init__( | |
| self, | |
| base_layer: VocabParallelEmbedding, | |
| lora_backend: BaseLoRABackend, | |
| ) -> None: | |
| super().__init__(base_layer, lora_backend) | |
| self.weight = base_layer.weight | |
| class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): | |
| def __init__( | |
| self, | |
| base_layer: ColumnParallelLinear, | |
| lora_backend: BaseLoRABackend, | |
| ) -> None: | |
| super().__init__(base_layer, lora_backend) | |
| shard_size = self.base_layer.output_partition_sizes[0] | |
| self.output_offset = torch.tensor( | |
| [ | |
| 0, | |
| shard_size, | |
| ], | |
| dtype=torch.int32, | |
| device=next(self.base_layer.parameters()).device, | |
| ) | |
| def set_lora_info( | |
| self, | |
| A_buffer: torch.Tensor, | |
| B_buffer: torch.Tensor, | |
| ): | |
| self.set_lora = True | |
| self.A_buffer = A_buffer | |
| self.B_buffer = B_buffer | |
| def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: | |
| lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer) | |
| lora_output = self.lora_backend.run_lora_b_sgemm( | |
| x=lora_a_output, | |
| weights=self.B_buffer, | |
| output_offset=self.output_offset, | |
| base_output=base_output, | |
| ) | |
| return lora_output | |
| def forward(self, input_: torch.Tensor): | |
| # duplicate the logic in ColumnParallelLinear | |
| bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None | |
| output_parallel = self.base_layer.quant_method.apply( | |
| self.base_layer, input_, bias | |
| ) | |
| if self.set_lora: | |
| output_parallel = self.apply_lora(output_parallel, input_) | |
| if self.base_layer.gather_output: | |
| output = tensor_model_parallel_all_gather(output_parallel) | |
| else: | |
| output = output_parallel | |
| output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None | |
| return output, output_bias | |
| def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): | |
| return A | |
| def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): | |
| shard_size = self.base_layer.output_partition_sizes[0] | |
| start_idx = tp_rank * shard_size | |
| end_idx = (tp_rank + 1) * shard_size | |
| B = B[start_idx:end_idx, :] | |
| return B | |
| class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): | |
| def __init__( | |
| self, | |
| base_layer: MergedColumnParallelLinear, | |
| lora_backend: BaseLoRABackend, | |
| ) -> None: | |
| super().__init__(base_layer, lora_backend) | |
| def set_lora_info( | |
| self, | |
| A_buffer: torch.Tensor, | |
| B_buffer: torch.Tensor, | |
| ): | |
| self.set_lora = True | |
| self.A_buffer_gate_up = A_buffer | |
| self.B_buffer_gate_up = B_buffer | |
| shard_size = self.base_layer.output_partition_sizes[0] | |
| self.output_offset = torch.tensor( | |
| [ | |
| 0, | |
| shard_size, | |
| 2 * shard_size, | |
| ], | |
| dtype=torch.int32, | |
| device=next(self.base_layer.parameters()).device, | |
| ) | |
| def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: | |
| lora_output = self.lora_backend.run_gate_up_lora( | |
| x=x, | |
| gate_up_lora_a=self.A_buffer_gate_up, | |
| gate_up_lora_b=self.B_buffer_gate_up, | |
| output_offset=self.output_offset, | |
| base_output=base_output, | |
| ) | |
| return lora_output | |
| def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): | |
| return A | |
| def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): | |
| # Since the outputs for both gate and up are identical, we use a random one. | |
| shard_size = self.base_layer.output_partition_sizes[0] | |
| gate_size = self.base_layer.output_sizes[0] | |
| start_idx = tp_rank * shard_size | |
| end_idx = (tp_rank + 1) * shard_size | |
| return torch.concat( | |
| ( | |
| B[start_idx:end_idx, :], | |
| B[gate_size + start_idx : gate_size + end_idx], | |
| ), | |
| dim=0, | |
| ) | |
| class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): | |
| def __init__( | |
| self, | |
| base_layer: QKVParallelLinear, | |
| lora_backend: BaseLoRABackend, | |
| ) -> None: | |
| super().__init__(base_layer, lora_backend) | |
| q_proj_shard_size = self.base_layer.q_proj_shard_size | |
| kv_proj_shard_size = self.base_layer.kv_proj_shard_size | |
| self.output_offset = torch.tensor( | |
| [ | |
| 0, | |
| q_proj_shard_size, | |
| q_proj_shard_size + kv_proj_shard_size, | |
| q_proj_shard_size + 2 * kv_proj_shard_size, | |
| ], | |
| dtype=torch.int32, | |
| device=next(self.base_layer.parameters()).device, | |
| ) | |
| # For computing number of launched blocks | |
| self.max_qkv_out_dim = max(q_proj_shard_size, kv_proj_shard_size) | |
| def set_lora_info( | |
| self, | |
| A_buffer_qkv: torch.Tensor, | |
| B_buffer_qkv: torch.Tensor, | |
| ): | |
| self.set_lora = True | |
| self.A_buffer_qkv = A_buffer_qkv | |
| self.B_buffer_qkv = B_buffer_qkv | |
| def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: | |
| lora_output = self.lora_backend.run_qkv_lora( | |
| x=x, | |
| qkv_lora_a=self.A_buffer_qkv, | |
| qkv_lora_b=self.B_buffer_qkv, | |
| base_output=base_output, | |
| output_offset=self.output_offset, | |
| max_qkv_out_dim=self.max_qkv_out_dim, | |
| ) | |
| return lora_output | |
| def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): | |
| return A | |
| def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int) -> torch.Tensor: | |
| base_layer = self.base_layer | |
| q_proj_shard_size = base_layer.q_proj_shard_size | |
| kv_proj_shard_size = base_layer.kv_proj_shard_size | |
| num_kv_head_replicas = base_layer.num_kv_head_replicas | |
| q_start_idx = q_proj_shard_size * tp_rank | |
| q_end_idx = q_start_idx + q_proj_shard_size | |
| kv_shard_id = tp_rank // num_kv_head_replicas | |
| kv_start_idx = kv_proj_shard_size * kv_shard_id | |
| kv_end_idx = kv_start_idx + kv_proj_shard_size | |
| q_size, k_size, _ = base_layer.output_sizes | |
| B_q_shard = B[q_start_idx:q_end_idx, :] | |
| B_k_shard = B[q_size + kv_start_idx : q_size + kv_end_idx, :] | |
| B_v_shard = B[q_size + k_size + kv_start_idx : q_size + k_size + kv_end_idx, :] | |
| return torch.concat( | |
| ( | |
| B_q_shard, | |
| B_k_shard, | |
| B_v_shard, | |
| ), | |
| dim=0, | |
| ) | |
| class RowParallelLinearWithLoRA(BaseLayerWithLoRA): | |
| def __init__( | |
| self, | |
| base_layer: RowParallelLinear, | |
| lora_backend: BaseLoRABackend, | |
| ) -> None: | |
| super().__init__(base_layer, lora_backend) | |
| def set_lora_info(self, A_buffer: torch.Tensor, B_buffer: torch.Tensor): | |
| self.set_lora = True | |
| self.A_buffer = A_buffer | |
| self.B_buffer = B_buffer | |
| output_size = self.base_layer.output_size | |
| self.output_offset = torch.tensor( | |
| [ | |
| 0, | |
| output_size, | |
| ], | |
| dtype=torch.int32, | |
| device=next(self.base_layer.parameters()).device, | |
| ) | |
| def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: | |
| lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer) | |
| lora_output = self.lora_backend.run_lora_b_sgemm( | |
| x=lora_a_output, | |
| weights=self.B_buffer, | |
| output_offset=self.output_offset, | |
| base_output=base_output, | |
| ) | |
| return lora_output | |
| def forward(self, input_: torch.Tensor, skip_all_reduce=False): | |
| # duplicate the logic in RowParallelLinear | |
| if self.base_layer.input_is_parallel: | |
| input_parallel = input_ | |
| else: | |
| tp_rank = get_tensor_model_parallel_rank() | |
| splitted_input = split_tensor_along_last_dim( | |
| input_, num_partitions=self.base_layer.tp_size | |
| ) | |
| input_parallel = splitted_input[tp_rank].contiguous() | |
| output_parallel = self.base_layer.quant_method.apply( | |
| self.base_layer, input_parallel | |
| ) | |
| if self.set_lora: | |
| output_parallel = self.apply_lora(output_parallel, input_parallel) | |
| if ( | |
| self.base_layer.reduce_results | |
| and self.base_layer.tp_size > 1 | |
| and not skip_all_reduce | |
| ): | |
| output_ = tensor_model_parallel_all_reduce(output_parallel) | |
| else: | |
| output_ = output_parallel | |
| if not self.base_layer.skip_bias_add: | |
| output = ( | |
| output_ + self.base_layer.bias | |
| if self.base_layer.bias is not None | |
| else output_ | |
| ) | |
| output_bias = None | |
| else: | |
| output = output_ | |
| output_bias = self.base_layer.bias | |
| return output, output_bias | |
| def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): | |
| shard_size = self.base_layer.input_size_per_partition | |
| start_idx = tp_rank * shard_size | |
| end_idx = (tp_rank + 1) * shard_size | |
| A = A[:, start_idx:end_idx].contiguous() | |
| return A | |
| def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): | |
| return B | |
| def get_lora_layer( | |
| layer: nn.Module, lora_backend: BaseLoRABackend | |
| ) -> BaseLayerWithLoRA: | |
| supported_layer_types = { | |
| # the order matters | |
| VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA, | |
| QKVParallelLinear: QKVParallelLinearWithLoRA, | |
| MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA, | |
| ColumnParallelLinear: ColumnParallelLinearWithLoRA, | |
| RowParallelLinear: RowParallelLinearWithLoRA, | |
| } | |
| for src_layer_type, lora_layer_type in supported_layer_types.items(): | |
| if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck | |
| ret = lora_layer_type(layer, lora_backend) | |
| return ret | |
| raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.") | |
Xet Storage Details
- Size:
- 11.9 kB
- Xet hash:
- 5ae781f1673b1a3432f8e8943d9c3fd9d4b2dab0911774b6f27f3621fb92808b
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.