| |
| import torch |
| import numbers |
| from peft import LoraConfig |
|
|
|
|
| def get_loraconfig(transformer, rank=128, alpha=128, init_lora_weights="gaussian"): |
| target_modules = [] |
| for name, module in transformer.named_modules(): |
| if "blocks" in name and "face" not in name and "modulation" not in name and isinstance(module, torch.nn.Linear): |
| target_modules.append(name) |
|
|
| transformer_lora_config = LoraConfig( |
| r=rank, |
| lora_alpha=alpha, |
| init_lora_weights=init_lora_weights, |
| target_modules=target_modules, |
| ) |
| return transformer_lora_config |
|
|
|
|
|
|
| class TensorList(object): |
|
|
| def __init__(self, tensors): |
| """ |
| tensors: a list of torch.Tensor objects. No need to have uniform shape. |
| """ |
| assert isinstance(tensors, (list, tuple)) |
| assert all(isinstance(u, torch.Tensor) for u in tensors) |
| assert len(set([u.ndim for u in tensors])) == 1 |
| assert len(set([u.dtype for u in tensors])) == 1 |
| assert len(set([u.device for u in tensors])) == 1 |
| self.tensors = tensors |
| |
| def to(self, *args, **kwargs): |
| return TensorList([u.to(*args, **kwargs) for u in self.tensors]) |
| |
| def size(self, dim): |
| assert dim == 0, 'only support get the 0th size' |
| return len(self.tensors) |
| |
| def pow(self, *args, **kwargs): |
| return TensorList([u.pow(*args, **kwargs) for u in self.tensors]) |
| |
| def squeeze(self, dim): |
| assert dim != 0 |
| if dim > 0: |
| dim -= 1 |
| return TensorList([u.squeeze(dim) for u in self.tensors]) |
| |
| def type(self, *args, **kwargs): |
| return TensorList([u.type(*args, **kwargs) for u in self.tensors]) |
| |
| def type_as(self, other): |
| assert isinstance(other, (torch.Tensor, TensorList)) |
| if isinstance(other, torch.Tensor): |
| return TensorList([u.type_as(other) for u in self.tensors]) |
| else: |
| return TensorList([u.type(other.dtype) for u in self.tensors]) |
| |
| @property |
| def dtype(self): |
| return self.tensors[0].dtype |
| |
| @property |
| def device(self): |
| return self.tensors[0].device |
| |
| @property |
| def ndim(self): |
| return 1 + self.tensors[0].ndim |
| |
| def __getitem__(self, index): |
| return self.tensors[index] |
| |
| def __len__(self): |
| return len(self.tensors) |
| |
| def __add__(self, other): |
| return self._apply(other, lambda u, v: u + v) |
| |
| def __radd__(self, other): |
| return self._apply(other, lambda u, v: v + u) |
| |
| def __sub__(self, other): |
| return self._apply(other, lambda u, v: u - v) |
| |
| def __rsub__(self, other): |
| return self._apply(other, lambda u, v: v - u) |
| |
| def __mul__(self, other): |
| return self._apply(other, lambda u, v: u * v) |
| |
| def __rmul__(self, other): |
| return self._apply(other, lambda u, v: v * u) |
| |
| def __floordiv__(self, other): |
| return self._apply(other, lambda u, v: u // v) |
| |
| def __truediv__(self, other): |
| return self._apply(other, lambda u, v: u / v) |
| |
| def __rfloordiv__(self, other): |
| return self._apply(other, lambda u, v: v // u) |
| |
| def __rtruediv__(self, other): |
| return self._apply(other, lambda u, v: v / u) |
| |
| def __pow__(self, other): |
| return self._apply(other, lambda u, v: u ** v) |
| |
| def __rpow__(self, other): |
| return self._apply(other, lambda u, v: v ** u) |
| |
| def __neg__(self): |
| return TensorList([-u for u in self.tensors]) |
| |
| def __iter__(self): |
| for tensor in self.tensors: |
| yield tensor |
| |
| def __repr__(self): |
| return 'TensorList: \n' + repr(self.tensors) |
|
|
| def _apply(self, other, op): |
| if isinstance(other, (list, tuple, TensorList)) or ( |
| isinstance(other, torch.Tensor) and ( |
| other.numel() > 1 or other.ndim > 1 |
| ) |
| ): |
| assert len(other) == len(self.tensors) |
| return TensorList([op(u, v) for u, v in zip(self.tensors, other)]) |
| elif isinstance(other, numbers.Number) or ( |
| isinstance(other, torch.Tensor) and ( |
| other.numel() == 1 and other.ndim <= 1 |
| ) |
| ): |
| return TensorList([op(u, other) for u in self.tensors]) |
| else: |
| raise TypeError( |
| f'unsupported operand for *: "TensorList" and "{type(other)}"' |
| ) |