Added generator code
Browse files- Meta-Llama-3-70B-Instruct-8bpw/suppress_dir.safetensors +3 -0
- Meta-Llama-3-8B-Instruct/suppress_dir.safetensors +3 -0
- Phi-3-mini-128k-instruct/suppress_dir.safetensors +3 -0
- README.md +86 -0
- exl2_wrapper.py +61 -0
- gen.py +198 -0
- test_inference.py +542 -0
Meta-Llama-3-70B-Instruct-8bpw/suppress_dir.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:faec2cc2c48d1a925a58d08a5396e3255f50d269ccc66b6610defd5ce6074cfe
|
| 3 |
+
size 2634640
|
Meta-Llama-3-8B-Instruct/suppress_dir.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:de22c6df410a1bb839b3ae66a1d3b7aadcc1254d81a3c7fae17b8d509ed1f801
|
| 3 |
+
size 529440
|
Phi-3-mini-128k-instruct/suppress_dir.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b0d2417d1c9684e73f44b5338f024975fcefd8d777a124633e27f6e9cc13e56a
|
| 3 |
+
size 398360
|
README.md
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
pipeline_tag: text-generation
|
| 4 |
+
---
|
| 5 |
+
|
| 6 |
+
ZoRA: Zero Rank Adaption
|
| 7 |
+
=
|
| 8 |
+
Inspired by [*Refusal in LLMs is mediated by a single direction*](https://www.alignmentforum.org/posts/jGuXSZgv6qfdhMCuJ/refusal-in-llms-is-mediated-by-a-single-direction), ZoRA is a refinement of the original approach that allows for adapting large language models to suppress refusals. The key features of ZoRA include:
|
| 9 |
+
* **Layer-wise ablation**: Measure and ablate a separate set of vectors for each layer
|
| 10 |
+
* **Multi-pass refinement**: Re-measure multiple times to refine the vectors
|
| 11 |
+
* **Single-token generation**: Measure refusal at the beginning of the response
|
| 12 |
+
* **Inference engine injection**: Load a small set of vectors to suppress refusals directly into a high-performance inference engine
|
| 13 |
+
|
| 14 |
+
This approach enables the use of original model weights while loading a small set of suppression vectors. See below for vector generation details.
|
| 15 |
+
|
| 16 |
+
ZoRA currently supports Exllamav2 only and is intended for research purposes. Seeking feedback on the viability of these models with suppression applied.
|
| 17 |
+
|
| 18 |
+
Usage
|
| 19 |
+
=
|
| 20 |
+
Put the `supress_dir.safetensors` into the model directory and wrap your ExLlamaV2 model object in the code:
|
| 21 |
+
```
|
| 22 |
+
from exl2_wrapper import ExLlamaV2ModuleWrapper
|
| 23 |
+
ExLlamaV2ModuleWrapper.wrap(model)
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
Example
|
| 27 |
+
=
|
| 28 |
+
There's a modified `test_inference.py` from [exllamav2](https://github.com/turboderp/exllamav2) for testing. For example:
|
| 29 |
+
```
|
| 30 |
+
python test_inference.py -m Meta-Llama-3-70B-Instruct-8bpw -p '<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful AI assistant.<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\nYour prompt.<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\n' -gs auto
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
Generator
|
| 34 |
+
=
|
| 35 |
+
The code to generate the ablation vectors has been added. To run the code, you need to add the URL for the harmful prompts.
|
| 36 |
+
|
| 37 |
+
Here is a sample output for the Llama3-8b model:
|
| 38 |
+
|
| 39 |
+
```
|
| 40 |
+
Downloading harmful prompts
|
| 41 |
+
Done
|
| 42 |
+
-- Loading model...
|
| 43 |
+
-- Loaded model in 2.7671 seconds
|
| 44 |
+
-- Loading tokenizer...
|
| 45 |
+
Building refused residual data
|
| 46 |
+
Processing 5000 prompts
|
| 47 |
+
---------------------------------------------------------------------------------------------------- 100
|
| 48 |
+
---------------------------------------------------------------------------------------------------- 200
|
| 49 |
+
[...]
|
| 50 |
+
---------------------------------------------+------------------------------------------------------ 1898
|
| 51 |
+
---------------------------------------------------------------------------------------------------- 1998
|
| 52 |
+
--
|
| 53 |
+
Max capture reached
|
| 54 |
+
Captured 2000 residual streams
|
| 55 |
+
Done
|
| 56 |
+
Building allowed residual data
|
| 57 |
+
Downloading harmless prompts
|
| 58 |
+
Done
|
| 59 |
+
Processing 31323 prompts
|
| 60 |
+
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 100
|
| 61 |
+
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 200
|
| 62 |
+
[...]
|
| 63 |
+
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1898
|
| 64 |
+
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1998
|
| 65 |
+
++
|
| 66 |
+
Max capture reached
|
| 67 |
+
Captured 2000 residual streams
|
| 68 |
+
Done
|
| 69 |
+
Calculating mean allowed residual
|
| 70 |
+
Done
|
| 71 |
+
Iteration 0
|
| 72 |
+
Processing 2000 prompts
|
| 73 |
+
---+++++++++++++++++++++++++-+-+++++++++-++++++++++++++-+++-++-++++++++++++++-++++---++++++++-++++-+ 15
|
| 74 |
+
+++++++-++++++++++++++-+-++++++++++++++++++++++++++++-+++++++++--+++++++++++-++++++++++++++++++++++- 23
|
| 75 |
+
+++++++++++++++++++++++-++-++++++++++++++++-++++++++++-++-++++++++++++++++++++-++++++++--+++++++++++ 31
|
| 76 |
+
--+-+++++++++++++-++++++-+++++-+++-+++++-++++-++++++++++-++++-++++++++-++++++++++++++++++-++++++++++ 44
|
| 77 |
+
-++++++++-+++++++++-++++++++--++++-
|
| 78 |
+
Max capture reached
|
| 79 |
+
Captured 50 residual streams
|
| 80 |
+
Iteration 1
|
| 81 |
+
Processing 2000 prompts
|
| 82 |
+
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 0
|
| 83 |
+
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 0
|
| 84 |
+
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 0
|
| 85 |
+
[...]
|
| 86 |
+
```
|
exl2_wrapper.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from safetensors import safe_open
|
| 4 |
+
|
| 5 |
+
class ExLlamaV2ModuleWrapper:
|
| 6 |
+
@classmethod
|
| 7 |
+
def wrap(cls, model, load = True):
|
| 8 |
+
for idx, module in enumerate(model.modules):
|
| 9 |
+
if idx == 0 or idx >= (len(model.modules) - 2):
|
| 10 |
+
continue
|
| 11 |
+
model.modules[idx] = ExLlamaV2ModuleWrapper(model, module, idx)
|
| 12 |
+
|
| 13 |
+
if not load:
|
| 14 |
+
return
|
| 15 |
+
|
| 16 |
+
suppress_dir_file = os.path.join(model.config.model_dir, 'suppress_dir.safetensors')
|
| 17 |
+
if os.path.exists(suppress_dir_file):
|
| 18 |
+
print(f'Loading suppress direction file "{suppress_dir_file}"')
|
| 19 |
+
with safe_open(suppress_dir_file, framework='pt', device='cpu') as f:
|
| 20 |
+
model._suppress_dir = []
|
| 21 |
+
for layer in range(len(f.keys())):
|
| 22 |
+
model._suppress_dir.append(f.get_tensor(f'_suppress_dir_{layer}'))
|
| 23 |
+
else:
|
| 24 |
+
print(f'No suppress direction file, not wrapping. Tried to load: "{suppress_dir_file}"')
|
| 25 |
+
return
|
| 26 |
+
|
| 27 |
+
def __init__(self, model, module, idx):
|
| 28 |
+
if not hasattr(model, '_suppress_dir'):
|
| 29 |
+
model._suppress_dir = None
|
| 30 |
+
if not hasattr(model, '_residual'):
|
| 31 |
+
model._residual = None
|
| 32 |
+
self.model = model
|
| 33 |
+
self.module = module
|
| 34 |
+
self.idx = idx
|
| 35 |
+
|
| 36 |
+
def __getattribute__(self, name):
|
| 37 |
+
if name == 'forward':
|
| 38 |
+
return object.__getattribute__(self, 'wrapped_forward')
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
return getattr(object.__getattribute__(self, 'module'), name)
|
| 42 |
+
except AttributeError:
|
| 43 |
+
pass
|
| 44 |
+
return object.__getattribute__(self, name)
|
| 45 |
+
|
| 46 |
+
def suppress(self, x):
|
| 47 |
+
if self.model._suppress_dir is not None:
|
| 48 |
+
r = self.model._suppress_dir[self.idx - 2].clone().to(x.device)
|
| 49 |
+
r = r.view(-1, 1)
|
| 50 |
+
proj_scalar = torch.matmul(x, r)
|
| 51 |
+
proj = proj_scalar * r.transpose(0, 1)
|
| 52 |
+
x = x - proj
|
| 53 |
+
return x
|
| 54 |
+
|
| 55 |
+
def wrapped_forward(self, *args, **kwargs):
|
| 56 |
+
if self.model._residual is not None:
|
| 57 |
+
if len(self.model._residual) < self.idx and args[0].shape[1] == 1:
|
| 58 |
+
self.model._residual.append(args[0].clone().to('cpu'))
|
| 59 |
+
x = self.suppress(args[0])
|
| 60 |
+
x = self.module.forward(*((x,) + args[1:]), **kwargs)
|
| 61 |
+
return self.suppress(x)
|
gen.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import time
|
| 3 |
+
import random
|
| 4 |
+
import io
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import json
|
| 7 |
+
import torch
|
| 8 |
+
import requests
|
| 9 |
+
from safetensors.torch import save_file
|
| 10 |
+
|
| 11 |
+
from exllamav2 import(
|
| 12 |
+
ExLlamaV2,
|
| 13 |
+
ExLlamaV2Config,
|
| 14 |
+
ExLlamaV2Cache,
|
| 15 |
+
ExLlamaV2Tokenizer,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
from exllamav2.generator import (
|
| 19 |
+
ExLlamaV2BaseGenerator,
|
| 20 |
+
ExLlamaV2Sampler
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
from exl2_wrapper import ExLlamaV2ModuleWrapper
|
| 24 |
+
|
| 25 |
+
### START Settings
|
| 26 |
+
|
| 27 |
+
template = '<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful AI assistant.<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\n{instruction}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\n'
|
| 28 |
+
|
| 29 |
+
model_dir = '/path/to/Meta-Llama-3-8B-Instruct'
|
| 30 |
+
|
| 31 |
+
harmful_prompts_url = 'ADD_URL_HERE'
|
| 32 |
+
harmless_prompts_url = 'https://raw.githubusercontent.com/tatsu-lab/stanford_alpaca/main/alpaca_data.json'
|
| 33 |
+
|
| 34 |
+
### END Settings
|
| 35 |
+
|
| 36 |
+
torch.cuda._lazy_init()
|
| 37 |
+
torch.set_printoptions(precision = 5, sci_mode = False, linewidth = 150)
|
| 38 |
+
|
| 39 |
+
config = ExLlamaV2Config()
|
| 40 |
+
config.model_dir = model_dir
|
| 41 |
+
config.prepare()
|
| 42 |
+
config.max_seq_len = 2048
|
| 43 |
+
model = ExLlamaV2(config)
|
| 44 |
+
ExLlamaV2ModuleWrapper.wrap(model, False)
|
| 45 |
+
model._residual = [] # Enable residual capture
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
out_dir = Path(config.model_dir.replace('/', '_'))
|
| 49 |
+
out_dir.mkdir(exist_ok = True)
|
| 50 |
+
|
| 51 |
+
harmful_prompts_file = out_dir / Path('harmful_prompts.json')
|
| 52 |
+
harmless_prompts_file = out_dir / Path('harmless_prompts.json')
|
| 53 |
+
|
| 54 |
+
refused_residual_file = out_dir / Path('refused_residual.pth')
|
| 55 |
+
allowed_residual_file = out_dir / Path('allowed_residual.pth')
|
| 56 |
+
allowed_residual_mean_file = out_dir / Path('allowed_residual_mean.pth')
|
| 57 |
+
|
| 58 |
+
suppress_dir_file = out_dir / Path('suppress_dir.safetensors')
|
| 59 |
+
|
| 60 |
+
refused = []
|
| 61 |
+
def get_residual(prompts, num_tokens, silent, max_capture, capture_type):
|
| 62 |
+
global model, tokenizer, settings, refused, generator
|
| 63 |
+
|
| 64 |
+
refused = []
|
| 65 |
+
residuals = []
|
| 66 |
+
|
| 67 |
+
print(f'Processing {len(prompts)} prompts')
|
| 68 |
+
for idx, prompt in enumerate(prompts):
|
| 69 |
+
if idx and not (idx % 100):
|
| 70 |
+
print('', len(residuals))
|
| 71 |
+
|
| 72 |
+
prompt = template.format(instruction = prompt)
|
| 73 |
+
|
| 74 |
+
model._residual = []
|
| 75 |
+
out = generator.generate_simple(prompt, settings, num_tokens, completion_only = True)
|
| 76 |
+
|
| 77 |
+
refusal = re.match(r'^(I\'m not|I cannot|I can\'t|I\'m sorry|As an A|I apolog|I\'m (unable|really|here)|[1I], as|I must|I understand|It(\'s| is) important|Sorry|The (assistant|AI))', out)
|
| 78 |
+
if capture_type is None or (capture_type == 'refused' and refusal) or (capture_type == 'allowed' and not refusal):
|
| 79 |
+
residuals.append(model._residual[:])
|
| 80 |
+
|
| 81 |
+
if refusal:
|
| 82 |
+
refused.append(prompt)
|
| 83 |
+
print('-' if refusal else '+', end='', flush = True)
|
| 84 |
+
|
| 85 |
+
if max_capture and len(residuals) >= max_capture:
|
| 86 |
+
print('\nMax capture reached')
|
| 87 |
+
break
|
| 88 |
+
|
| 89 |
+
if not silent:
|
| 90 |
+
print(out)
|
| 91 |
+
|
| 92 |
+
if not len(residuals):
|
| 93 |
+
return None
|
| 94 |
+
|
| 95 |
+
print(f'\nCaptured {len(residuals)} residual streams')
|
| 96 |
+
|
| 97 |
+
res = []
|
| 98 |
+
for l in range(len(residuals[0])):
|
| 99 |
+
res.append(torch.cat([t[l][0, -1, :].unsqueeze(0) for t in residuals], dim=0))
|
| 100 |
+
return res
|
| 101 |
+
|
| 102 |
+
if not harmful_prompts_file.exists():
|
| 103 |
+
print('Downloading harmful prompts')
|
| 104 |
+
res = requests.get(harmful_prompts_url)
|
| 105 |
+
|
| 106 |
+
harmful_prompts = []
|
| 107 |
+
for line in res.iter_lines():
|
| 108 |
+
if line:
|
| 109 |
+
harmful_prompts.append(json.loads(line.decode())['prompt'])
|
| 110 |
+
with harmful_prompts_file.open('w') as f:
|
| 111 |
+
json.dump(harmful_prompts, f)
|
| 112 |
+
print('Done')
|
| 113 |
+
else:
|
| 114 |
+
with harmful_prompts_file.open('r') as f:
|
| 115 |
+
harmful_prompts = json.load(f)
|
| 116 |
+
|
| 117 |
+
print(" -- Loading model...")
|
| 118 |
+
t = time.time()
|
| 119 |
+
cache = ExLlamaV2Cache(model, lazy=True)
|
| 120 |
+
model.load_autosplit(cache)
|
| 121 |
+
t = time.time() - t
|
| 122 |
+
print(f" -- Loaded model in {t:.4f} seconds")
|
| 123 |
+
|
| 124 |
+
print(" -- Loading tokenizer...")
|
| 125 |
+
tokenizer = ExLlamaV2Tokenizer(config)
|
| 126 |
+
settings = ExLlamaV2Sampler.Settings()
|
| 127 |
+
settings.temperature = 0
|
| 128 |
+
|
| 129 |
+
generator = ExLlamaV2BaseGenerator(model, cache, tokenizer)
|
| 130 |
+
|
| 131 |
+
with torch.inference_mode():
|
| 132 |
+
|
| 133 |
+
if not refused_residual_file.exists():
|
| 134 |
+
print('Building refused residual data')
|
| 135 |
+
refused_residual = get_residual(harmful_prompts, 4, True, 2000, 'refused')
|
| 136 |
+
torch.save(refused_residual, refused_residual_file)
|
| 137 |
+
else:
|
| 138 |
+
print('Loading refusal residual data')
|
| 139 |
+
refused_residual = torch.load(refused_residual_file)
|
| 140 |
+
print('Done')
|
| 141 |
+
|
| 142 |
+
allowed_residual_mean = []
|
| 143 |
+
if not allowed_residual_mean_file.exists():
|
| 144 |
+
if not allowed_residual_file.exists():
|
| 145 |
+
print('Building allowed residual data')
|
| 146 |
+
if not harmless_prompts_file.exists():
|
| 147 |
+
print('Downloading harmless prompts')
|
| 148 |
+
res = requests.get(harmless_prompts_url)
|
| 149 |
+
|
| 150 |
+
all_prompts = json.loads(res.content.decode('utf8'))
|
| 151 |
+
harmless_prompts = [i['instruction'] for i in all_prompts if i['input'] == '']
|
| 152 |
+
|
| 153 |
+
with harmless_prompts_file.open('w') as f:
|
| 154 |
+
json.dump(harmless_prompts, f)
|
| 155 |
+
print('Done')
|
| 156 |
+
else:
|
| 157 |
+
with harmless_prompts_file.open('r') as f:
|
| 158 |
+
harmless_prompts = json.load(f)
|
| 159 |
+
allowed_residual = get_residual(harmless_prompts, 4, True, 2000, 'allowed')
|
| 160 |
+
torch.save(allowed_residual, allowed_residual_file)
|
| 161 |
+
else:
|
| 162 |
+
print('Loading allowed residual data')
|
| 163 |
+
allowed_residual = torch.load(allowed_residual_file)
|
| 164 |
+
|
| 165 |
+
print('Done')
|
| 166 |
+
|
| 167 |
+
print('Calculating mean allowed residual')
|
| 168 |
+
for i in range(len(allowed_residual)):
|
| 169 |
+
allowed_residual_mean.append(allowed_residual[i].mean(dim = 0))
|
| 170 |
+
print('Done')
|
| 171 |
+
torch.save(allowed_residual_mean, allowed_residual_mean_file)
|
| 172 |
+
else:
|
| 173 |
+
allowed_residual_mean = torch.load(allowed_residual_mean_file)
|
| 174 |
+
|
| 175 |
+
if model._suppress_dir is None:
|
| 176 |
+
model._suppress_dir = []
|
| 177 |
+
|
| 178 |
+
for o in range(6):
|
| 179 |
+
print('Iteration', o)
|
| 180 |
+
|
| 181 |
+
for i in range(len(refused_residual)):
|
| 182 |
+
refusal_dir = refused_residual[i].mean(dim = 0) - allowed_residual_mean[i]
|
| 183 |
+
refusal_dir = refusal_dir / refusal_dir.norm() if refusal_dir.norm() > 0.0001 else torch.zeros_like(refusal_dir)
|
| 184 |
+
if len(model._suppress_dir) > i:
|
| 185 |
+
model._suppress_dir[i] = (model._suppress_dir[i] + refusal_dir) / 2
|
| 186 |
+
else:
|
| 187 |
+
model._suppress_dir.append(refusal_dir)
|
| 188 |
+
|
| 189 |
+
refused_residual = get_residual(random.sample(harmful_prompts, 2000), 4, True, 50, 'refused')
|
| 190 |
+
|
| 191 |
+
if not refused_residual or refused_residual[0].shape[0] < 30:
|
| 192 |
+
break
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
save_file({f'_suppress_dir_{layer}': tensor for layer, tensor in enumerate(model._suppress_dir)}, suppress_dir_file)
|
| 196 |
+
|
| 197 |
+
torch.cuda.synchronize()
|
| 198 |
+
|
test_inference.py
ADDED
|
@@ -0,0 +1,542 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from exllamav2 import(
|
| 3 |
+
ExLlamaV2,
|
| 4 |
+
ExLlamaV2Config,
|
| 5 |
+
ExLlamaV2Cache,
|
| 6 |
+
ExLlamaV2Cache_8bit,
|
| 7 |
+
ExLlamaV2Cache_Q4,
|
| 8 |
+
ExLlamaV2Tokenizer,
|
| 9 |
+
model_init,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
from exllamav2.generator import (
|
| 13 |
+
ExLlamaV2BaseGenerator,
|
| 14 |
+
ExLlamaV2Sampler
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
from exllamav2.attn import ExLlamaV2Attention
|
| 18 |
+
from exllamav2.mlp import ExLlamaV2MLP
|
| 19 |
+
from exllamav2.moe_mlp import ExLlamaV2MoEMLP
|
| 20 |
+
from exllamav2.parallel_decoder import ExLlamaV2ParallelDecoder
|
| 21 |
+
|
| 22 |
+
import argparse, os, math, time
|
| 23 |
+
import torch
|
| 24 |
+
import torch.nn.functional as F
|
| 25 |
+
from conversion.tokenize import get_tokens
|
| 26 |
+
from conversion.quantize import list_live_tensors
|
| 27 |
+
import gc
|
| 28 |
+
|
| 29 |
+
# from exllamav2.mlp import set_catch
|
| 30 |
+
|
| 31 |
+
import sys
|
| 32 |
+
import json
|
| 33 |
+
|
| 34 |
+
torch.cuda._lazy_init()
|
| 35 |
+
torch.set_printoptions(precision = 5, sci_mode = False, linewidth = 150)
|
| 36 |
+
|
| 37 |
+
# torch.backends.cuda.matmul.allow_tf32 = True
|
| 38 |
+
# torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
|
| 39 |
+
# torch.set_float32_matmul_precision("medium")
|
| 40 |
+
|
| 41 |
+
# (!!!) NOTE: These go on top of the engine arguments that can be found in `model_init.py` (!!!)
|
| 42 |
+
parser = argparse.ArgumentParser(description = "Test inference on ExLlamaV2 model")
|
| 43 |
+
parser.add_argument("-ed", "--eval_dataset", type = str, help = "Perplexity evaluation dataset (.parquet file)")
|
| 44 |
+
parser.add_argument("-er", "--eval_rows", type = int, default = 128, help = "Number of rows to apply from dataset")
|
| 45 |
+
parser.add_argument("-el", "--eval_length", type = int, default = 2048, help = "Max no. tokens per sample")
|
| 46 |
+
parser.add_argument("-et", "--eval_token", action = "store_true", help = "Evaluate perplexity on token-by-token inference using cache")
|
| 47 |
+
parser.add_argument("-e8", "--eval_token_8bit", action = "store_true", help = "Evaluate perplexity on token-by-token inference using 8-bit (FP8) cache")
|
| 48 |
+
parser.add_argument("-eq4", "--eval_token_q4", action = "store_true", help = "Evaluate perplexity on token-by-token inference using Q4 cache")
|
| 49 |
+
# parser.add_argument("-eb", "--eval_bos", action = "store_true", help = "Add BOS token to every row in perplexity test (required by Gemma and maybe other models.)")
|
| 50 |
+
parser.add_argument("-p", "--prompt", type = str, help = "Generate from prompt (basic sampling settings)")
|
| 51 |
+
parser.add_argument("-pnb", "--prompt_no_bos", action = "store_true", help = "Don't add BOS token to prompt")
|
| 52 |
+
parser.add_argument("-t", "--tokens", type = int, default = 128, help = "Max no. tokens")
|
| 53 |
+
parser.add_argument("-ps", "--prompt_speed", action = "store_true", help = "Test prompt processing (batch) speed over context length")
|
| 54 |
+
parser.add_argument("-s", "--speed", action = "store_true", help = "Test raw generation speed over context length")
|
| 55 |
+
parser.add_argument("-mix", "--mix_layers", type = str, help = "Load replacement layers from secondary model. Example: --mix_layers 1,6-7:/mnt/models/other_model")
|
| 56 |
+
parser.add_argument("-nwu", "--no_warmup", action = "store_true", help = "Skip warmup before testing model")
|
| 57 |
+
parser.add_argument("-sl", "--stream_layers", action = "store_true", help = "Load model layer by layer (perplexity evaluation only)")
|
| 58 |
+
parser.add_argument("-sp", "--standard_perplexity", choices = ["wiki2"], help = "Run standard (HF) perplexity test, stride 512 (experimental)")
|
| 59 |
+
parser.add_argument("-rr", "--rank_reduce", type = str, help = "Rank-reduction for MLP layers of model, in reverse order (for experimentation)")
|
| 60 |
+
parser.add_argument("-mol", "--max_output_len", type = int, help = "Set max output chunk size (incompatible with ppl tests)")
|
| 61 |
+
|
| 62 |
+
# Initialize model and tokenizer
|
| 63 |
+
|
| 64 |
+
model_init.add_args(parser)
|
| 65 |
+
args = parser.parse_args()
|
| 66 |
+
|
| 67 |
+
# Check conflicting settings
|
| 68 |
+
|
| 69 |
+
if args.stream_layers:
|
| 70 |
+
if args.eval_token or args.eval_token_8bit or args.eval_token_q4:
|
| 71 |
+
print(" ## Can't test token ppl while streaming layers")
|
| 72 |
+
sys.exit()
|
| 73 |
+
if args.prompt:
|
| 74 |
+
print(" ## Can't generate while streaming layers")
|
| 75 |
+
sys.exit()
|
| 76 |
+
if args.speed or args.prompt_speed:
|
| 77 |
+
print(" ## Can't test speed while streaming layers")
|
| 78 |
+
sys.exit()
|
| 79 |
+
if args.gpu_split:
|
| 80 |
+
print(" ## Can only use one GPU when streaming layers")
|
| 81 |
+
sys.exit()
|
| 82 |
+
if args.eval_dataset:
|
| 83 |
+
if args.length and args.eval_length != args.length:
|
| 84 |
+
print(" !! Overriding model context length to match eval row length")
|
| 85 |
+
args.length = args.eval_length
|
| 86 |
+
|
| 87 |
+
# Init
|
| 88 |
+
|
| 89 |
+
model_init.check_args(args)
|
| 90 |
+
model_init.print_options(args)
|
| 91 |
+
model, tokenizer = model_init.init(args,
|
| 92 |
+
allow_auto_split = True,
|
| 93 |
+
skip_load = args.stream_layers,
|
| 94 |
+
benchmark = True,
|
| 95 |
+
max_output_len = args.max_output_len)
|
| 96 |
+
cache = None
|
| 97 |
+
|
| 98 |
+
from exl2_wrapper import ExLlamaV2ModuleWrapper
|
| 99 |
+
ExLlamaV2ModuleWrapper.wrap(model)
|
| 100 |
+
|
| 101 |
+
# Auto split
|
| 102 |
+
|
| 103 |
+
if not model.loaded and not args.stream_layers:
|
| 104 |
+
|
| 105 |
+
if args.mix_layers:
|
| 106 |
+
print(" !! Warning, auto split does not account for VRAM requirement of replacement layers")
|
| 107 |
+
|
| 108 |
+
print(" -- Loading model...")
|
| 109 |
+
cache = ExLlamaV2Cache(model, lazy = True)
|
| 110 |
+
t = time.time()
|
| 111 |
+
model.load_autosplit(cache)
|
| 112 |
+
t = time.time() - t
|
| 113 |
+
print(f" -- Loaded model in {t:.4f} seconds")
|
| 114 |
+
|
| 115 |
+
if args.stream_layers:
|
| 116 |
+
|
| 117 |
+
stream_batch_size = 2
|
| 118 |
+
model.config.max_batch_size = stream_batch_size
|
| 119 |
+
model.load(lazy = True)
|
| 120 |
+
|
| 121 |
+
# Rank reduction
|
| 122 |
+
|
| 123 |
+
if args.rank_reduce:
|
| 124 |
+
|
| 125 |
+
if args.stream_layers:
|
| 126 |
+
print(" ## --rank_reduce can not be combined with --stream_layers")
|
| 127 |
+
sys.exit()
|
| 128 |
+
|
| 129 |
+
rr = args.rank_reduce.split(",")
|
| 130 |
+
idx = len(model.modules) - 1
|
| 131 |
+
for r in rr:
|
| 132 |
+
k = float(r)
|
| 133 |
+
|
| 134 |
+
while True:
|
| 135 |
+
idx -= 1
|
| 136 |
+
module = model.modules[idx]
|
| 137 |
+
if isinstance(module, ExLlamaV2ParallelDecoder): break
|
| 138 |
+
if isinstance(module, ExLlamaV2MLP): break
|
| 139 |
+
if isinstance(module, ExLlamaV2MoEMLP): break
|
| 140 |
+
if idx < 0:
|
| 141 |
+
print(" ## Not enough layers")
|
| 142 |
+
sys.exit()
|
| 143 |
+
|
| 144 |
+
print(f" -- Reducing {module.key} ({module.name}) to {k * 100:.2f}%")
|
| 145 |
+
module.rank_reduce(k)
|
| 146 |
+
|
| 147 |
+
# Replacement
|
| 148 |
+
|
| 149 |
+
if args.mix_layers:
|
| 150 |
+
intervals_, extra_dir = args.mix_layers.split(":")
|
| 151 |
+
|
| 152 |
+
print(f" -- Loading replacement layers from: {extra_dir}")
|
| 153 |
+
|
| 154 |
+
extra_config = ExLlamaV2Config()
|
| 155 |
+
extra_config.model_dir = extra_dir
|
| 156 |
+
extra_config.prepare()
|
| 157 |
+
intervals = intervals_.split(",")
|
| 158 |
+
for interval in intervals:
|
| 159 |
+
ab = interval.split("-")
|
| 160 |
+
a, b = int(ab[0]), int(ab[-1])
|
| 161 |
+
for idx in range(a, b + 1):
|
| 162 |
+
print(f" -- Layer {idx}...")
|
| 163 |
+
layerkey = "model.layers." + str(idx) + "."
|
| 164 |
+
remove = [k for k in model.config.tensor_file_map.keys() if k.startswith(layerkey)]
|
| 165 |
+
replace = [k for k in extra_config.tensor_file_map.keys() if k.startswith(layerkey)]
|
| 166 |
+
# reload = [k for k in model.modules_dict.keys() if k.startswith(layerkey)]
|
| 167 |
+
for k in remove: del model.config.tensor_file_map[k]
|
| 168 |
+
for k in replace: model.config.tensor_file_map[k] = extra_config.tensor_file_map[k]
|
| 169 |
+
# for k in reload:
|
| 170 |
+
# model.modules_dict[k].unload()
|
| 171 |
+
# model.modules_dict[k].load()
|
| 172 |
+
if not args.stream_layers:
|
| 173 |
+
model.modules[idx * 2 + 1].reload()
|
| 174 |
+
model.modules[idx * 2 + 2].reload()
|
| 175 |
+
|
| 176 |
+
# Test generation
|
| 177 |
+
|
| 178 |
+
if args.prompt:
|
| 179 |
+
|
| 180 |
+
with torch.inference_mode():
|
| 181 |
+
|
| 182 |
+
if cache is None:
|
| 183 |
+
cache = ExLlamaV2Cache(model)
|
| 184 |
+
|
| 185 |
+
ids = tokenizer.encode(args.prompt)
|
| 186 |
+
tokens_prompt = ids.shape[-1]
|
| 187 |
+
|
| 188 |
+
print(f" -- Warmup...")
|
| 189 |
+
|
| 190 |
+
generator = ExLlamaV2BaseGenerator(model, cache, tokenizer)
|
| 191 |
+
if not args.no_warmup: generator.warmup()
|
| 192 |
+
|
| 193 |
+
print(f" -- Generating...")
|
| 194 |
+
print()
|
| 195 |
+
|
| 196 |
+
settings = ExLlamaV2Sampler.Settings()
|
| 197 |
+
settings.temperature = 0.75
|
| 198 |
+
settings.top_k = 100
|
| 199 |
+
settings.top_p = 0.75
|
| 200 |
+
settings.token_repetition_penalty = 1.05
|
| 201 |
+
settings.disallow_tokens(tokenizer, [tokenizer.eos_token_id])
|
| 202 |
+
|
| 203 |
+
time_begin = time.time()
|
| 204 |
+
|
| 205 |
+
output = generator.generate_simple(args.prompt, settings, args.tokens, token_healing = True, add_bos = not args.prompt_no_bos)
|
| 206 |
+
|
| 207 |
+
torch.cuda.synchronize()
|
| 208 |
+
time_prompt = time.time()
|
| 209 |
+
|
| 210 |
+
time_end = time.time()
|
| 211 |
+
|
| 212 |
+
print(output)
|
| 213 |
+
print()
|
| 214 |
+
|
| 215 |
+
total_gen = time_end - time_begin
|
| 216 |
+
print(f" -- Response generated in {total_gen:.2f} seconds, {args.tokens} tokens, {args.tokens / total_gen:.2f} tokens/second (includes prompt eval.)")
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
# Test perplexity
|
| 220 |
+
|
| 221 |
+
if args.eval_dataset or args.standard_perplexity:
|
| 222 |
+
|
| 223 |
+
with torch.inference_mode():
|
| 224 |
+
|
| 225 |
+
print(f" -- Running perplexity test")
|
| 226 |
+
|
| 227 |
+
if args.standard_perplexity:
|
| 228 |
+
|
| 229 |
+
eval_length = args.eval_length
|
| 230 |
+
if args.eval_dataset:
|
| 231 |
+
print(f" !! Note, overriding specified --eval_dataset with {args.standard_perplexity}")
|
| 232 |
+
|
| 233 |
+
from datasets import load_dataset
|
| 234 |
+
|
| 235 |
+
if args.standard_perplexity == "wiki2":
|
| 236 |
+
ds = "wikitext"
|
| 237 |
+
part = "wikitext-2-raw-v1"
|
| 238 |
+
split = "test"
|
| 239 |
+
# if args.standard_perplexity == "c4":
|
| 240 |
+
# ds = "allenai/c4"
|
| 241 |
+
# part = "allenai--c4"
|
| 242 |
+
# split = "train"
|
| 243 |
+
|
| 244 |
+
print(f" -- Loading dataset {ds}, {part}, {split}...")
|
| 245 |
+
test = load_dataset(ds, part, split = split)
|
| 246 |
+
|
| 247 |
+
print(f" -- Tokenizing samples...")
|
| 248 |
+
text = "\n\n".join(test["text"])
|
| 249 |
+
eval_tokens = tokenizer.encode(text)
|
| 250 |
+
|
| 251 |
+
stride = 512
|
| 252 |
+
seqs = []
|
| 253 |
+
eval_len = []
|
| 254 |
+
a = 0
|
| 255 |
+
while True:
|
| 256 |
+
b = a + model.config.max_seq_len
|
| 257 |
+
if b > eval_tokens.shape[-1]: break
|
| 258 |
+
seqs.append(eval_tokens[:, a:b])
|
| 259 |
+
eval_len.append(b if a == 0 else stride)
|
| 260 |
+
a += stride
|
| 261 |
+
|
| 262 |
+
eval_tokens = torch.cat(seqs, dim = 0)
|
| 263 |
+
|
| 264 |
+
else:
|
| 265 |
+
|
| 266 |
+
eval_dataset = args.eval_dataset
|
| 267 |
+
eval_rows = args.eval_rows
|
| 268 |
+
eval_length = args.eval_length
|
| 269 |
+
|
| 270 |
+
print(f" -- Dataset: {eval_dataset}")
|
| 271 |
+
print(f" -- Tokenizing eval data, {eval_rows} rows x {eval_length} tokens...")
|
| 272 |
+
|
| 273 |
+
eval_tokens = get_tokens(eval_rows, eval_length, eval_dataset, tokenizer)
|
| 274 |
+
eval_len = [eval_tokens.shape[1]] * eval_tokens.shape[0]
|
| 275 |
+
|
| 276 |
+
# if args.eval_bos:
|
| 277 |
+
if model.config.arch.requires_bos:
|
| 278 |
+
boss = torch.full((eval_tokens.shape[0], 1), tokenizer.bos_token_id, dtype = torch.long)
|
| 279 |
+
eval_tokens = torch.cat((boss, eval_tokens[:, :-1]), dim = 1)
|
| 280 |
+
|
| 281 |
+
logprob_sum = 0.0
|
| 282 |
+
logprob_count = 0
|
| 283 |
+
|
| 284 |
+
def ppl(input_ids__, logits__, lengths__):
|
| 285 |
+
|
| 286 |
+
logprob_sum_ = 0.0
|
| 287 |
+
logprob_count_ = 0
|
| 288 |
+
|
| 289 |
+
assert logits__.shape[0] == input_ids__.shape[0]
|
| 290 |
+
ll = logits__.shape[1]
|
| 291 |
+
|
| 292 |
+
for bi in range(logits__.shape[0]):
|
| 293 |
+
cl = max(ll - lengths__[bi], 0)
|
| 294 |
+
logits_ = logits__[bi:bi+1, cl:, :]
|
| 295 |
+
input_ids_ = input_ids__[bi:bi+1, cl:]
|
| 296 |
+
|
| 297 |
+
chunksize = logits_.shape[1] * 4000 // logits_.shape[2] + 1
|
| 298 |
+
b_ = 0
|
| 299 |
+
while b_ < logits_.shape[1]:
|
| 300 |
+
a_ = b_
|
| 301 |
+
b_ = min(b_ + chunksize, logits_.shape[1])
|
| 302 |
+
|
| 303 |
+
logits_f = logits_[:, a_:b_, :].float() + 1e-10
|
| 304 |
+
target_ids = input_ids_[:, a_ + 1:b_ + 1].to(logits_.device)
|
| 305 |
+
|
| 306 |
+
log_probs = F.log_softmax(logits_f, dim=-1)
|
| 307 |
+
token_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
|
| 308 |
+
logprob_sum_ += token_log_probs.sum().item()
|
| 309 |
+
logprob_count_ += target_ids.numel()
|
| 310 |
+
|
| 311 |
+
return logprob_sum_, logprob_count_
|
| 312 |
+
|
| 313 |
+
if args.stream_layers:
|
| 314 |
+
|
| 315 |
+
print(f" -- Inference (streamed)", end = "")
|
| 316 |
+
sys.stdout.flush()
|
| 317 |
+
|
| 318 |
+
batch_size, seq_len = eval_tokens.shape
|
| 319 |
+
attn_params = ExLlamaV2Attention.Params(stream_batch_size, seq_len, 0, None, None)
|
| 320 |
+
# attn_mask = model.build_attn_mask(stream_batch_size, seq_len, 0, None, "cuda:0")
|
| 321 |
+
|
| 322 |
+
for idx, module in enumerate(model.modules):
|
| 323 |
+
module.set_device_idx(-1 if idx == 0 else 0)
|
| 324 |
+
|
| 325 |
+
model.modules[0].load()
|
| 326 |
+
hidden_state = model.modules[0].forward(eval_tokens)
|
| 327 |
+
model.modules[0].unload()
|
| 328 |
+
|
| 329 |
+
for idx, module in enumerate(model.modules):
|
| 330 |
+
if idx == 0: continue
|
| 331 |
+
|
| 332 |
+
print(".", end = "")
|
| 333 |
+
sys.stdout.flush()
|
| 334 |
+
module.load()
|
| 335 |
+
|
| 336 |
+
b = 0
|
| 337 |
+
while b < eval_tokens.shape[0]:
|
| 338 |
+
a = b
|
| 339 |
+
b = min(b + stream_batch_size, eval_tokens.shape[0])
|
| 340 |
+
x = hidden_state[a:b, :, :].to("cuda:0")
|
| 341 |
+
x = module.forward(x, cache = None, attn_params = attn_params, past_len = 0, loras = None)
|
| 342 |
+
|
| 343 |
+
if idx < len(model.modules) - 1:
|
| 344 |
+
hidden_state[a:b, :, :] = x.to("cpu")
|
| 345 |
+
|
| 346 |
+
else:
|
| 347 |
+
input_ids = eval_tokens[a:b, :]
|
| 348 |
+
logits = x[:, :-1, :]
|
| 349 |
+
|
| 350 |
+
# if model.config.logit_scale != 1:
|
| 351 |
+
# logits.mul_(model.config.logit_scale)
|
| 352 |
+
|
| 353 |
+
logprob_sum__, logprob_count__ = ppl(input_ids, logits, eval_len[a:b])
|
| 354 |
+
logprob_sum += logprob_sum__
|
| 355 |
+
logprob_count += logprob_count__
|
| 356 |
+
|
| 357 |
+
module.unload()
|
| 358 |
+
|
| 359 |
+
print()
|
| 360 |
+
|
| 361 |
+
else:
|
| 362 |
+
|
| 363 |
+
print(f" -- Inference", end = "")
|
| 364 |
+
sys.stdout.flush()
|
| 365 |
+
|
| 366 |
+
if cache is None:
|
| 367 |
+
cache = ExLlamaV2Cache(model, max_seq_len = eval_length) if eval_length > model.config.max_input_len else None
|
| 368 |
+
|
| 369 |
+
for i in range(eval_tokens.shape[0]):
|
| 370 |
+
|
| 371 |
+
if i % 10 == 0: print(".", end = "")
|
| 372 |
+
sys.stdout.flush()
|
| 373 |
+
|
| 374 |
+
input_ids = eval_tokens[i:i+1, :]
|
| 375 |
+
|
| 376 |
+
input_ids = input_ids[:, :]
|
| 377 |
+
if cache is not None: cache.current_seq_len = 0
|
| 378 |
+
logits = model.forward(input_ids, cache)
|
| 379 |
+
logits = logits[:, :-1, :]
|
| 380 |
+
|
| 381 |
+
logprob_sum__, logprob_count__ = ppl(input_ids, logits, eval_len[i:i+1])
|
| 382 |
+
logprob_sum += logprob_sum__
|
| 383 |
+
logprob_count += logprob_count__
|
| 384 |
+
|
| 385 |
+
print()
|
| 386 |
+
|
| 387 |
+
mean_log_prob = logprob_sum / logprob_count
|
| 388 |
+
perplexity = math.exp(-mean_log_prob)
|
| 389 |
+
print(f" -- Evaluation perplexity: {perplexity:.4f}")
|
| 390 |
+
|
| 391 |
+
def test_ppl_token():
|
| 392 |
+
global logprob_sum, logprob_count, i, input_ids
|
| 393 |
+
global logits, target_ids, log_probs, token_log_probs
|
| 394 |
+
global mean_log_prob, perplexity
|
| 395 |
+
|
| 396 |
+
# set_catch("model.layers.3")
|
| 397 |
+
|
| 398 |
+
logprob_sum = 0
|
| 399 |
+
logprob_count = 0
|
| 400 |
+
|
| 401 |
+
for i in range(eval_tokens.shape[0]):
|
| 402 |
+
|
| 403 |
+
cache.current_seq_len = 0
|
| 404 |
+
|
| 405 |
+
for j in range(eval_tokens.shape[1] - 1):
|
| 406 |
+
if j % 256 == 0: print(".", end = "")
|
| 407 |
+
sys.stdout.flush()
|
| 408 |
+
|
| 409 |
+
input_ids = eval_tokens[i:i + 1, j:j + 1]
|
| 410 |
+
logits = model.forward(input_ids, cache)
|
| 411 |
+
logits = logits.float() + 1e-10
|
| 412 |
+
|
| 413 |
+
log_probs = F.log_softmax(logits, dim = -1)
|
| 414 |
+
logprob_sum += log_probs[0, 0, eval_tokens[i, j+1]]
|
| 415 |
+
logprob_count += 1
|
| 416 |
+
|
| 417 |
+
# mean_log_prob = logprob_sum / logprob_count
|
| 418 |
+
# perplexity = math.exp(-mean_log_prob)
|
| 419 |
+
# print(f" -- Token {j}: {perplexity:.4f}")
|
| 420 |
+
|
| 421 |
+
print()
|
| 422 |
+
|
| 423 |
+
mean_log_prob = logprob_sum / logprob_count
|
| 424 |
+
perplexity = math.exp(-mean_log_prob)
|
| 425 |
+
print(f" -- Evaluation perplexity: {perplexity:.4f}")
|
| 426 |
+
|
| 427 |
+
if args.eval_token:
|
| 428 |
+
if args.standard_perplexity:
|
| 429 |
+
print(f" !! Note, can't evalutate token perplexity on standard test")
|
| 430 |
+
else:
|
| 431 |
+
print(f" -- Inference (token)", end = "")
|
| 432 |
+
sys.stdout.flush()
|
| 433 |
+
cache = ExLlamaV2Cache(model, max_seq_len = eval_length)
|
| 434 |
+
test_ppl_token()
|
| 435 |
+
|
| 436 |
+
if args.eval_token_8bit:
|
| 437 |
+
if args.standard_perplexity:
|
| 438 |
+
print(f" !! Note, can't evalutate token perplexity on standard test")
|
| 439 |
+
else:
|
| 440 |
+
print(f" -- Inference (token, 8-bit cache)", end = "")
|
| 441 |
+
sys.stdout.flush()
|
| 442 |
+
cache = ExLlamaV2Cache_8bit(model, max_seq_len = eval_length)
|
| 443 |
+
test_ppl_token()
|
| 444 |
+
|
| 445 |
+
if args.eval_token_q4:
|
| 446 |
+
if args.standard_perplexity:
|
| 447 |
+
print(f" !! Note, can't evalutate token perplexity on standard test")
|
| 448 |
+
else:
|
| 449 |
+
print(f" -- Inference (token, Q4 cache)", end = "")
|
| 450 |
+
sys.stdout.flush()
|
| 451 |
+
cache = ExLlamaV2Cache_Q4(model, max_seq_len = eval_length)
|
| 452 |
+
test_ppl_token()
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
# Test prompt speed
|
| 456 |
+
|
| 457 |
+
if args.prompt_speed:
|
| 458 |
+
|
| 459 |
+
with torch.inference_mode():
|
| 460 |
+
|
| 461 |
+
if cache is None:
|
| 462 |
+
cache = ExLlamaV2Cache(model)
|
| 463 |
+
|
| 464 |
+
ids = torch.randint(0, model.config.vocab_size - 1, (1, model.config.max_seq_len))
|
| 465 |
+
|
| 466 |
+
print(f" -- Warmup...")
|
| 467 |
+
|
| 468 |
+
if not args.no_warmup:
|
| 469 |
+
model.forward(ids[:, -1:])
|
| 470 |
+
|
| 471 |
+
print(f" -- Measuring prompt speed...")
|
| 472 |
+
|
| 473 |
+
torch.cuda.synchronize()
|
| 474 |
+
|
| 475 |
+
current_len = 128
|
| 476 |
+
step = 128
|
| 477 |
+
prompt_iters = 3
|
| 478 |
+
while True:
|
| 479 |
+
|
| 480 |
+
total_time = 0
|
| 481 |
+
for i in range(prompt_iters):
|
| 482 |
+
|
| 483 |
+
torch.cuda.synchronize()
|
| 484 |
+
time_begin = time.time()
|
| 485 |
+
|
| 486 |
+
cache.current_seq_len = 0
|
| 487 |
+
model.forward(ids[:, :current_len], cache, preprocess_only = True)
|
| 488 |
+
|
| 489 |
+
torch.cuda.synchronize()
|
| 490 |
+
time_end = time.time()
|
| 491 |
+
total_time += time_end - time_begin
|
| 492 |
+
|
| 493 |
+
tps = current_len / (total_time / prompt_iters)
|
| 494 |
+
|
| 495 |
+
print(f" ** Length {current_len:>5} tokens: {tps:>11.4f} t/s")
|
| 496 |
+
|
| 497 |
+
if current_len >= 1024: step = 1024
|
| 498 |
+
if current_len >= 4096: step = 4096
|
| 499 |
+
if current_len >= 16384: step = 8192
|
| 500 |
+
|
| 501 |
+
current_len_ = current_len
|
| 502 |
+
current_len = min(current_len + step, model.config.max_seq_len)
|
| 503 |
+
if current_len == current_len_: break
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
# Test token speed
|
| 507 |
+
|
| 508 |
+
if args.speed:
|
| 509 |
+
|
| 510 |
+
with torch.inference_mode():
|
| 511 |
+
|
| 512 |
+
if cache is None:
|
| 513 |
+
cache = ExLlamaV2Cache(model)
|
| 514 |
+
cache.current_seq_len = 0
|
| 515 |
+
|
| 516 |
+
print(f" -- Measuring token speed...")
|
| 517 |
+
ids = tokenizer.encode("X")
|
| 518 |
+
model.forward(ids[:, :])
|
| 519 |
+
|
| 520 |
+
current_idx = ids.shape[-1]
|
| 521 |
+
next_stop = 128
|
| 522 |
+
|
| 523 |
+
while True:
|
| 524 |
+
|
| 525 |
+
time_begin = time.time()
|
| 526 |
+
|
| 527 |
+
tokens = next_stop - current_idx
|
| 528 |
+
for i in range(tokens):
|
| 529 |
+
|
| 530 |
+
logits = model.forward(ids[:, -1:], cache)
|
| 531 |
+
sample = torch.argmax(logits[0, -1]).cpu().unsqueeze(0).unsqueeze(0)
|
| 532 |
+
ids = torch.cat((ids, sample), dim=-1)
|
| 533 |
+
|
| 534 |
+
time_end = time.time()
|
| 535 |
+
tps = tokens / (time_end - time_begin)
|
| 536 |
+
|
| 537 |
+
print(f" ** Position {current_idx:>5} + {tokens:>3} tokens: {tps:>9.4f} t/s")
|
| 538 |
+
|
| 539 |
+
current_idx = next_stop
|
| 540 |
+
next_stop = min(next_stop + 128, model.config.max_seq_len)
|
| 541 |
+
if next_stop == current_idx: break
|
| 542 |
+
|