File size: 10,999 Bytes
c5681ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
import torch
from nnsight import LanguageModel
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
from huggingface_hub import hf_hub_download


def load_saes(cfg, device):
    """Load steering vectors from SAEs and prepare steering components."""
    if not cfg['features'] or len(cfg['features']) == 0:
        print("No features specified, returning empty steering components.")
        return []

    steering_components = []
    cache_dir = "./downloads"
    features = cfg['features']
    reduced_strengths = cfg['reduced_strengths']

    for i, feature in enumerate(features):
        layer_idx, feature_idx = feature[0], feature[1]
        strength = feature[2] if len(feature) > 2 else 0.0

        # If the strengths in the config file were given in reduced form, scale them by layer index
        if reduced_strengths:
            strength *= layer_idx

        # Display strength (avoid division by zero)
        reduced_str = f"[{strength/layer_idx:.2f}]" if layer_idx > 0 else "[N/A]"
        print(f"Loading feature {layer_idx} {feature_idx} {strength:.2f} {reduced_str}")

        sae_filename = cfg['sae_filename_prefix'] + f"{layer_idx}" + cfg['sae_filename_suffix']
        file_path = hf_hub_download(repo_id=cfg['sae_path'], filename=sae_filename, cache_dir=cache_dir)
        sae = torch.load(file_path, map_location="cpu")
        vec = sae["decoder.weight"][:, feature_idx].to(device, non_blocking=True)

        steering_components.append({
            'layer': layer_idx,
            'feature': feature_idx,
            'strength': strength,
            'vector': vec
        })
        del sae

    return steering_components


def load_saes_from_file(file_path, cfg, device):
    """
    Load pre-extracted steering vectors from a local file.

    This is much faster than load_saes() since it doesn't download large SAE files.
    The file should be created using extract_steering_vectors.py script.

    Args:
        file_path: Path to the .pt file containing steering vectors
        cfg: Configuration dict with 'features' list
        device: Device to load tensors on ('cuda' or 'cpu')

    Returns:
        List of steering component dicts with keys: 'layer', 'feature', 'strength', 'vector'
    """
    import os

    if not os.path.exists(file_path):
        raise FileNotFoundError(
            f"Steering vectors file not found: {file_path}\n"
            f"Please run: python extract_steering_vectors.py"
        )

    print(f"Loading pre-extracted steering vectors from {file_path}...")

    # Load the dictionary of vectors
    steering_vectors_dict = torch.load(file_path, map_location="cpu")

    if not cfg['features'] or len(cfg['features']) == 0:
        print("No features specified in config.")
        return []

    steering_components = []
    features = cfg['features']
    reduced_strengths = cfg.get('reduced_strengths', False)

    for i, feature in enumerate(features):
        layer_idx, feature_idx = feature[0], feature[1]
        strength = feature[2] if len(feature) > 2 else 0.0

        if reduced_strengths:
            strength *= layer_idx

        # Look up the pre-extracted vector
        key = (layer_idx, feature_idx)
        if key not in steering_vectors_dict:
            raise KeyError(
                f"Vector for layer {layer_idx}, feature {feature_idx} not found in {file_path}.\n"
                f"Please re-run: python extract_steering_vectors.py"
            )

        vec = steering_vectors_dict[key].to(device, non_blocking=True)

        # Display
        reduced_str = f"[{strength/layer_idx:.2f}]" if layer_idx > 0 else "[N/A]"
        print(f"Loaded feature {layer_idx} {feature_idx} {strength:.2f} {reduced_str}")

        steering_components.append({
            'layer': layer_idx,
            'feature': feature_idx,
            'strength': strength,
            'vector': vec  # Already normalized in the file
        })

    print(f"Loaded {len(steering_components)} steering vector(s) from local file")
    return steering_components


def generate_steered_answer(model: LanguageModel,
                            chat,
                            steering_components,
                            max_new_tokens=128,
                            temperature=0.0,
                            repetition_penalty=1.0,
                            clamp_intensity=False):
    """
    Generates an answer from the model given a chat history, applying steering components.
    Expects steering_components to be a list of dicts with keys:
        'layer': int, layer index to apply steering
        'strength': float, steering intensity
        'vector': torch.Tensor, steering vector
    """
    input_ids = model.tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=True)
    with model.generate(max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty,
                        do_sample=temperature > 0.0, temperature=temperature,
                        pad_token_id=model.tokenizer.eos_token_id) as tracer:
        with tracer.invoke(input_ids):
            with tracer.all():
                for sc in steering_components:
                    layer, strength, vector = sc["layer"], sc["strength"], sc["vector"]

                    # Ensure vector matches model dtype and device
                    layer_output = model.model.layers[layer].output
                    vector = vector.to(dtype=layer_output.dtype, device=layer_output.device)

                    length = layer_output.shape[1]
                    amount = (strength * vector).unsqueeze(0).expand(length, -1).unsqueeze(0).clone()
                    if clamp_intensity:
                        projection = (layer_output @ vector).unsqueeze(-1)@(vector.unsqueeze(0))
                        amount -= projection

                    layer_output += amount
        with tracer.invoke():
            trace = model.generator.output.save()

    answer = model.tokenizer.decode(trace[0][len(input_ids):], skip_special_tokens=True)
    output = {'input_ids': input_ids, 'trace': trace, 'answer': answer}
    return output



def create_steering_hook(layer_idx, steering_components, clamp_intensity=False):
    """
    Create a forward hook for a specific layer that applies steering.

    Args:
        layer_idx: Which layer this hook is for
        steering_components: List of steering components (all layers)
        clamp_intensity: Whether to clamp steering intensity

    Returns:
        Forward hook function
    """
    layer_components = [sc for sc in steering_components if sc['layer'] == layer_idx]

    if not layer_components:
        return None

    def hook(module, input, output):
        """Forward hook that modifies the output hidden states."""
        # Handle different output formats (tuple vs tensor)
        if isinstance(output, tuple):
            hidden_states = output[0]
            rest_of_output = output[1:]
        else:
            hidden_states = output
            rest_of_output = None

        # Handle different shapes during generation
        original_shape = hidden_states.shape
        if len(original_shape) == 2:
            # During generation: [batch, hidden_dim] -> add seq_len dimension
            hidden_states = hidden_states.unsqueeze(1)  # [batch, 1, hidden_dim]

        for sc in layer_components:
            strength = sc['strength']
            vector = sc['vector']  # Already normalized

            # Ensure vector matches hidden_states dtype and device
            vector = vector.to(dtype=hidden_states.dtype, device=hidden_states.device)

            # Match nnsight's expansion pattern exactly
            seq_len = hidden_states.shape[1]
            amount = (strength * vector).unsqueeze(0).expand(seq_len, -1).unsqueeze(0)  # [1, seq_len, hidden_dim]

            if clamp_intensity:
                # Remove existing projection (prevents over-steering)
                projection_scalars = torch.einsum('bsh,h->bs', hidden_states, vector).unsqueeze(-1)
                projection_vectors = projection_scalars * vector.view(1, 1, -1)
                amount = amount - projection_vectors

            hidden_states = hidden_states + amount

        # Restore original shape if we added a dimension
        if len(original_shape) == 2:
            hidden_states = hidden_states.squeeze(1)  # [batch, hidden_dim]

        # Return in the same format as input
        if rest_of_output is not None:
            return (hidden_states,) + rest_of_output
        else:
            return hidden_states

    return hook


def stream_steered_answer_hf(model: AutoModelForCausalLM,
                                tokenizer: AutoTokenizer,
                                chat,
                                steering_components,
                                max_new_tokens=128,
                                temperature=0.0,
                                repetition_penalty=1.0,
                                clamp_intensity=False,
                                stream=True):
    """
    Generate steered answer using pure HuggingFace Transformers with streaming.

    Args:
        model: HuggingFace transformers model
        tokenizer: Tokenizer instance
        chat: Chat history in OpenAI format
        steering_components: List of dicts with 'layer', 'strength', 'vector'
        max_new_tokens: Maximum tokens to generate
        temperature: Sampling temperature (0 = greedy)
        repetition_penalty: Repetition penalty
        clamp_intensity: Whether to clamp steering intensity

    Yields:
        Partial text as tokens are generated

    """

    input_ids_list = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=True)
    input_ids = torch.tensor([input_ids_list]).to(model.device)

    # Register steering hooks
    hook_handles = []
    layers_to_steer = set(sc['layer'] for sc in steering_components)

    for layer_idx in layers_to_steer:
        hook_fn = create_steering_hook(layer_idx, steering_components, clamp_intensity)
        if hook_fn:
            layer_module = model.model.layers[layer_idx]
            handle = layer_module.register_forward_hook(hook_fn)
            hook_handles.append(handle)

    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    generation_kwargs = {
        "input_ids": input_ids,
        "max_new_tokens": max_new_tokens,
        "temperature": temperature if temperature > 0 else 1.0,
        "do_sample": temperature > 0,
        "repetition_penalty": repetition_penalty,
        "streamer": streamer,
        "pad_token_id": tokenizer.eos_token_id,
    }

    thread = Thread(target=lambda: model.generate(**generation_kwargs))
    thread.start()

    generated_text = ""
    for token_text in streamer:
        generated_text += token_text
        yield generated_text

    thread.join()

    for handle in hook_handles:
        handle.remove()