| import os |
| import torch |
| from safetensors import safe_open |
|
|
| class ExLlamaV2ModuleWrapper: |
| @classmethod |
| def wrap(cls, model, load = True): |
| for idx, module in enumerate(model.modules): |
| if idx == 0 or idx >= (len(model.modules) - 2): |
| continue |
| model.modules[idx] = ExLlamaV2ModuleWrapper(model, module, idx) |
|
|
| if not load: |
| return |
|
|
| suppress_dir_file = os.path.join(model.config.model_dir, 'suppress_dir.safetensors') |
| if os.path.exists(suppress_dir_file): |
| print(f'Loading suppress direction file "{suppress_dir_file}"') |
| with safe_open(suppress_dir_file, framework='pt', device='cpu') as f: |
| model._suppress_dir = [] |
| for layer in range(len(f.keys())): |
| model._suppress_dir.append(f.get_tensor(f'_suppress_dir_{layer}')) |
| else: |
| print(f'No suppress direction file, not wrapping. Tried to load: "{suppress_dir_file}"') |
| return |
|
|
| def __init__(self, model, module, idx): |
| if not hasattr(model, '_suppress_dir'): |
| model._suppress_dir = None |
| if not hasattr(model, '_residual'): |
| model._residual = None |
| self.model = model |
| self.module = module |
| self.idx = idx |
|
|
| def __getattribute__(self, name): |
| if name == 'forward': |
| return object.__getattribute__(self, 'wrapped_forward') |
|
|
| try: |
| return getattr(object.__getattribute__(self, 'module'), name) |
| except AttributeError: |
| pass |
| return object.__getattribute__(self, name) |
|
|
| def suppress(self, x): |
| if self.model._suppress_dir is not None: |
| r = self.model._suppress_dir[self.idx - 2].clone().to(x.device) |
| r = r.view(-1, 1) |
| proj_scalar = torch.matmul(x, r) |
| proj = proj_scalar * r.transpose(0, 1) |
| x = x - proj |
| return x |
|
|
| def wrapped_forward(self, *args, **kwargs): |
| if self.model._residual is not None: |
| if len(self.model._residual) < self.idx and args[0].shape[1] == 1: |
| self.model._residual.append(args[0].clone().to('cpu')) |
| x = self.suppress(args[0]) |
| x = self.module.forward(*((x,) + args[1:]), **kwargs) |
| return self.suppress(x) |
|
|