|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import List |
|
|
|
|
|
from .._common import default_net |
|
|
from ..functional import Tensor, lora_plugin |
|
|
from ..module import Module |
|
|
|
|
|
|
|
|
class LoraRuntimeParams(object): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
lora_ranks: List[Tensor] = None, |
|
|
lora_weights_pointers: List[Tensor] = None, |
|
|
host_request_types: Tensor = None, |
|
|
host_context_lengths: Tensor = None, |
|
|
max_context_length: Tensor = None, |
|
|
max_encoder_context_length: Tensor = None, |
|
|
host_encoder_input_lengths: Tensor = None, |
|
|
weight_index: int = 0, |
|
|
): |
|
|
|
|
|
self.lora_ranks = lora_ranks |
|
|
self.lora_weights_pointers = lora_weights_pointers |
|
|
self.host_request_types = host_request_types |
|
|
self.host_context_lengths = host_context_lengths |
|
|
self.max_context_length = max_context_length |
|
|
self.max_encoder_context_length = max_encoder_context_length |
|
|
self.host_encoder_input_lengths = host_encoder_input_lengths |
|
|
self.weight_index = weight_index |
|
|
|
|
|
|
|
|
class Lora(Module): |
|
|
|
|
|
def __init__(self, |
|
|
in_hidden_size: int = 0, |
|
|
out_hidden_sizes: List[int] = [0], |
|
|
max_low_rank: int = 0) -> None: |
|
|
super().__init__() |
|
|
|
|
|
self.in_hidden_size = in_hidden_size |
|
|
self.out_hidden_sizes = out_hidden_sizes |
|
|
self.max_low_rank = max_low_rank |
|
|
|
|
|
def forward(self, |
|
|
x, |
|
|
lora_runtime_params: LoraRuntimeParams = None, |
|
|
is_cross_attention: bool = False): |
|
|
if default_net().plugin_config.lora_plugin: |
|
|
result = lora_plugin( |
|
|
x, |
|
|
in_hidden_size=self.in_hidden_size, |
|
|
out_hidden_sizes=self.out_hidden_sizes, |
|
|
host_request_types=lora_runtime_params.host_request_types, |
|
|
transb=True, |
|
|
|
|
|
host_context_lengths=lora_runtime_params.host_context_lengths |
|
|
if not is_cross_attention else |
|
|
lora_runtime_params.host_encoder_input_lengths, |
|
|
|
|
|
max_context_length=lora_runtime_params.max_context_length |
|
|
if not is_cross_attention else |
|
|
lora_runtime_params.max_encoder_context_length, |
|
|
max_low_rank=self.max_low_rank, |
|
|
lora_ranks=lora_runtime_params.lora_ranks, |
|
|
lora_weights_pointers=lora_runtime_params.lora_weights_pointers, |
|
|
weight_index=lora_runtime_params.weight_index, |
|
|
) |
|
|
else: |
|
|
assert False, "Not support lora without plugin" |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
class LoraParams(object): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
lora_ranks=None, |
|
|
lora_weights_pointers=None, |
|
|
host_context_lengths: Tensor = None, |
|
|
max_context_length: Tensor = None, |
|
|
max_encoder_context_length: Tensor = None, |
|
|
host_request_types: Tensor = None, |
|
|
host_encoder_input_lengths: Tensor = None, |
|
|
weight_index: int = 0, |
|
|
): |
|
|
|
|
|
self.lora_ranks = lora_ranks |
|
|
self.lora_weights_pointers = lora_weights_pointers |
|
|
|
|
|
self.host_context_lengths = host_context_lengths |
|
|
self.max_context_length = max_context_length |
|
|
self.max_encoder_context_length = max_encoder_context_length |
|
|
self.host_request_types = host_request_types |
|
|
self.host_encoder_input_lengths = host_encoder_input_lengths |
|
|
self.weight_index = weight_index |
|
|
|
|
|
def get_layer_params(self, layer_idx: int): |
|
|
return LoraParams( |
|
|
lora_ranks=[self.lora_ranks[layer_idx]], |
|
|
lora_weights_pointers=[self.lora_weights_pointers[layer_idx]], |
|
|
host_context_lengths=self.host_context_lengths, |
|
|
max_context_length=self.max_context_length, |
|
|
max_encoder_context_length=self.max_encoder_context_length, |
|
|
host_request_types=self.host_request_types, |
|
|
host_encoder_input_lengths=self.host_encoder_input_lengths, |
|
|
weight_index=self.weight_index, |
|
|
) |
|
|
|
|
|
def get_runtime_params(self, layer_idx: int, lora_module: str): |
|
|
if f"{lora_module}_lora_ranks" in self.lora_ranks[layer_idx]: |
|
|
return LoraRuntimeParams( |
|
|
lora_ranks=[ |
|
|
self.lora_ranks[layer_idx][f"{lora_module}_lora_ranks"] |
|
|
], |
|
|
lora_weights_pointers=[ |
|
|
self.lora_weights_pointers[layer_idx] |
|
|
[f"{lora_module}_lora_weights_pointers"] |
|
|
], |
|
|
host_context_lengths=self.host_context_lengths, |
|
|
max_context_length=self.max_context_length, |
|
|
max_encoder_context_length=self.max_encoder_context_length, |
|
|
host_request_types=self.host_request_types, |
|
|
host_encoder_input_lengths=self.host_encoder_input_lengths, |
|
|
weight_index=self.weight_index, |
|
|
) |
|
|
else: |
|
|
return None |
|
|
|