File size: 5,519 Bytes
7feac49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
from contextlib import contextmanager
from types import MethodType
from typing import Any, List, Optional
import torch
from peft.tuners import lora
from peft.tuners.lora import LoraLayer
def round_robin(num_reqs, num_workers):
"""Distribute requests evenly across workers using round-robin algorithm.
Args:
num_reqs (int): Total number of requests to distribute
num_workers (int): Number of available workers
Returns:
list: A list of lists where each sublist contains the request indices
assigned to that particular node
"""
distribution = [[] for _ in range(num_workers)]
for idx in range(num_reqs):
worker_id = idx % num_workers
distribution[worker_id].append(idx)
return distribution
@contextmanager
def patch_lora_merge(model, parameter_group=None):
"""Patch LoraLayer's merge and get_delta_weight methods for controlled merging.
Args:
model: The PEFT model to patch
parameter_group: Optional list of parameter names to restrict merging
Yields:
The patched model (context manager ensures cleanup)
"""
from peft.tuners.tuners_utils import check_adapters_to_merge
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
if parameter_group and all(self.name not in pg for pg in parameter_group):
return # Skip if not in target parameter group
adapter_names = check_adapters_to_merge(self, adapter_names)
if not adapter_names:
return
for active_adapter in adapter_names:
if active_adapter in self.lora_A.keys():
base_layer = self.get_base_layer()
if self.use_dora.get(active_adapter, False):
self.lora_magnitude_vector[active_adapter].weight.data = \
self.lora_magnitude_vector[active_adapter].weight.data.to(base_layer.weight.device)
return self.merge_origin(safe_merge, adapter_names)
def get_delta_weight(self, adapter) -> torch.Tensor:
# Ensure tensors are on correct device
if isinstance(self, lora.Embedding):
self.lora_embedding_A[adapter].data = self.lora_embedding_A[adapter].data.to(self.base_layer.weight.device)
self.lora_embedding_B[adapter].data = self.lora_embedding_B[adapter].data.to(self.base_layer.weight.device)
else:
self.lora_A[adapter].weight.data = self.lora_A[adapter].weight.data.to(self.base_layer.weight.device)
self.lora_B[adapter].weight.data = self.lora_B[adapter].weight.data.to(self.base_layer.weight.device)
return self.get_delta_weight_origin(adapter).to(self.base_layer.weight.device)
def _cache_pop(self, key: str) -> Any:
value = self._caches.pop(key).to(self.base_layer.weight.device)
return value
# Patch all LoraLayer instances
for name, module in model.named_modules():
if isinstance(module, LoraLayer):
module.name = name
if not hasattr(module, 'merge_origin') and hasattr(module, 'base_layer'):
module.merge_origin = module.merge
module.merge = MethodType(merge, module)
module.get_delta_weight_origin = module.get_delta_weight
module.get_delta_weight = MethodType(get_delta_weight, module)
module._cache_pop_origin = module._cache_pop
module._cache_pop = MethodType(_cache_pop, module)
try:
yield model
finally:
# Cleanup: restore original methods
for module in model.modules():
if isinstance(module, LoraLayer):
if hasattr(module, 'merge_origin'):
module.merge = module.merge_origin
del module.merge_origin
module.get_delta_weight = module.get_delta_weight_origin
del module.get_delta_weight_origin
module._cache_pop = module._cache_pop_origin
del module._cache_pop_origin
@contextmanager
def patch_lora_unmerge(model):
"""Patch the unmerge method to ensure proper device handling."""
def _cache_pop_patched(self, key: str) -> Any:
value = self._caches.pop(key).to(self.base_layer.weight.device)
return value
def unmerge_patched(self):
if not self.merged:
return
# Move magnitude vectors to correct device first
for adapter in list(self.merged_adapters):
if self.use_dora.get(adapter, False):
self.lora_magnitude_vector[adapter].weight.data = \
self.lora_magnitude_vector[adapter].weight.data.to(self.base_layer.weight.device)
return self.unmerge_origin()
for module in model.modules():
if isinstance(module, LoraLayer) and not hasattr(module, 'unmerge_origin'):
module.unmerge_origin = module.unmerge
module.unmerge = MethodType(unmerge_patched, module)
module._cache_pop_origin = module._cache_pop
module._cache_pop = MethodType(_cache_pop_patched, module)
try:
yield model
finally:
for module in model.modules():
if isinstance(module, LoraLayer) and hasattr(module, 'unmerge_origin'):
module.unmerge = module.unmerge_origin
del module.unmerge_origin
module._cache_pop = module._cache_pop_origin
del module._cache_pop_origin
|