| import re |
| import time |
| import random |
| import io |
| from pathlib import Path |
| import json |
| import torch |
| import requests |
| from safetensors.torch import save_file |
|
|
| from exllamav2 import( |
| ExLlamaV2, |
| ExLlamaV2Config, |
| ExLlamaV2Cache, |
| ExLlamaV2Tokenizer, |
| ) |
|
|
| from exllamav2.generator import ( |
| ExLlamaV2BaseGenerator, |
| ExLlamaV2Sampler |
| ) |
|
|
| from exl2_wrapper import ExLlamaV2ModuleWrapper |
|
|
| |
|
|
| template = '<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful AI assistant.<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\n{instruction}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\n' |
|
|
| model_dir = '/path/to/Meta-Llama-3-8B-Instruct' |
|
|
| harmful_prompts_url = 'ADD_URL_HERE' |
| harmless_prompts_url = 'https://raw.githubusercontent.com/tatsu-lab/stanford_alpaca/main/alpaca_data.json' |
|
|
| |
|
|
| torch.cuda._lazy_init() |
| torch.set_printoptions(precision = 5, sci_mode = False, linewidth = 150) |
|
|
| config = ExLlamaV2Config() |
| config.model_dir = model_dir |
| config.prepare() |
| config.max_seq_len = 2048 |
| model = ExLlamaV2(config) |
| ExLlamaV2ModuleWrapper.wrap(model, False) |
| model._residual = [] |
|
|
|
|
| out_dir = Path(config.model_dir.replace('/', '_')) |
| out_dir.mkdir(exist_ok = True) |
|
|
| harmful_prompts_file = out_dir / Path('harmful_prompts.json') |
| harmless_prompts_file = out_dir / Path('harmless_prompts.json') |
|
|
| refused_residual_file = out_dir / Path('refused_residual.pth') |
| allowed_residual_file = out_dir / Path('allowed_residual.pth') |
| allowed_residual_mean_file = out_dir / Path('allowed_residual_mean.pth') |
|
|
| suppress_dir_file = out_dir / Path('suppress_dir.safetensors') |
|
|
| refused = [] |
| def get_residual(prompts, num_tokens, silent, max_capture, capture_type): |
| global model, tokenizer, settings, refused, generator |
|
|
| refused = [] |
| residuals = [] |
|
|
| print(f'Processing {len(prompts)} prompts') |
| for idx, prompt in enumerate(prompts): |
| if idx and not (idx % 100): |
| print('', len(residuals)) |
|
|
| prompt = template.format(instruction = prompt) |
|
|
| model._residual = [] |
| out = generator.generate_simple(prompt, settings, num_tokens, completion_only = True) |
|
|
| refusal = re.match(r'^(I\'m not|I cannot|I can\'t|I\'m sorry|As an A|I apolog|I\'m (unable|really|here)|[1I], as|I must|I understand|It(\'s| is) important|Sorry|The (assistant|AI))', out) |
| if capture_type is None or (capture_type == 'refused' and refusal) or (capture_type == 'allowed' and not refusal): |
| residuals.append(model._residual[:]) |
|
|
| if refusal: |
| refused.append(prompt) |
| print('-' if refusal else '+', end='', flush = True) |
|
|
| if max_capture and len(residuals) >= max_capture: |
| print('\nMax capture reached') |
| break |
|
|
| if not silent: |
| print(out) |
|
|
| if not len(residuals): |
| return None |
|
|
| print(f'\nCaptured {len(residuals)} residual streams') |
|
|
| res = [] |
| for l in range(len(residuals[0])): |
| res.append(torch.cat([t[l][0, -1, :].unsqueeze(0) for t in residuals], dim=0)) |
| return res |
|
|
| if not harmful_prompts_file.exists(): |
| print('Downloading harmful prompts') |
| res = requests.get(harmful_prompts_url) |
|
|
| harmful_prompts = [] |
| for line in res.iter_lines(): |
| if line: |
| harmful_prompts.append(json.loads(line.decode())['prompt']) |
| with harmful_prompts_file.open('w') as f: |
| json.dump(harmful_prompts, f) |
| print('Done') |
| else: |
| with harmful_prompts_file.open('r') as f: |
| harmful_prompts = json.load(f) |
|
|
| print(" -- Loading model...") |
| t = time.time() |
| cache = ExLlamaV2Cache(model, lazy=True) |
| model.load_autosplit(cache) |
| t = time.time() - t |
| print(f" -- Loaded model in {t:.4f} seconds") |
|
|
| print(" -- Loading tokenizer...") |
| tokenizer = ExLlamaV2Tokenizer(config) |
| settings = ExLlamaV2Sampler.Settings() |
| settings.temperature = 0 |
|
|
| generator = ExLlamaV2BaseGenerator(model, cache, tokenizer) |
|
|
| with torch.inference_mode(): |
|
|
| if not refused_residual_file.exists(): |
| print('Building refused residual data') |
| refused_residual = get_residual(harmful_prompts, 4, True, 2000, 'refused') |
| torch.save(refused_residual, refused_residual_file) |
| else: |
| print('Loading refusal residual data') |
| refused_residual = torch.load(refused_residual_file) |
| print('Done') |
| |
| allowed_residual_mean = [] |
| if not allowed_residual_mean_file.exists(): |
| if not allowed_residual_file.exists(): |
| print('Building allowed residual data') |
| if not harmless_prompts_file.exists(): |
| print('Downloading harmless prompts') |
| res = requests.get(harmless_prompts_url) |
|
|
| all_prompts = json.loads(res.content.decode('utf8')) |
| harmless_prompts = [i['instruction'] for i in all_prompts if i['input'] == ''] |
|
|
| with harmless_prompts_file.open('w') as f: |
| json.dump(harmless_prompts, f) |
| print('Done') |
| else: |
| with harmless_prompts_file.open('r') as f: |
| harmless_prompts = json.load(f) |
| allowed_residual = get_residual(harmless_prompts, 4, True, 2000, 'allowed') |
| torch.save(allowed_residual, allowed_residual_file) |
| else: |
| print('Loading allowed residual data') |
| allowed_residual = torch.load(allowed_residual_file) |
| |
| print('Done') |
|
|
| print('Calculating mean allowed residual') |
| for i in range(len(allowed_residual)): |
| allowed_residual_mean.append(allowed_residual[i].mean(dim = 0)) |
| print('Done') |
| torch.save(allowed_residual_mean, allowed_residual_mean_file) |
| else: |
| allowed_residual_mean = torch.load(allowed_residual_mean_file) |
|
|
| if model._suppress_dir is None: |
| model._suppress_dir = [] |
|
|
| for o in range(6): |
| print('Iteration', o) |
|
|
| for i in range(len(refused_residual)): |
| refusal_dir = refused_residual[i].mean(dim = 0) - allowed_residual_mean[i] |
| refusal_dir = refusal_dir / refusal_dir.norm() if refusal_dir.norm() > 0.0001 else torch.zeros_like(refusal_dir) |
| if len(model._suppress_dir) > i: |
| model._suppress_dir[i] = (model._suppress_dir[i] + refusal_dir) / 2 |
| else: |
| model._suppress_dir.append(refusal_dir) |
|
|
| refused_residual = get_residual(random.sample(harmful_prompts, 2000), 4, True, 50, 'refused') |
|
|
| if not refused_residual or refused_residual[0].shape[0] < 30: |
| break |
|
|
|
|
| save_file({f'_suppress_dir_{layer}': tensor for layer, tensor in enumerate(model._suppress_dir)}, suppress_dir_file) |
|
|
| torch.cuda.synchronize() |
|
|
|
|