atticusg commited on
Commit
bd0fb44
·
verified ·
1 Parent(s): 9564343

Mock submission

Browse files
Files changed (15) hide show
  1. .gitattributes +6 -0
  2. mock_submission/4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/DAS_Gemma2ForCausalLM_answer_pointer/4_answer_MCQA_Gemma2ForCausalLM_submission_answer_pointer__results.json +100 -0
  3. mock_submission/4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/DAS_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-correct_symbol)_featurizer +3 -0
  4. mock_submission/4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/DAS_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-correct_symbol)_indices +1 -0
  5. mock_submission/4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/DAS_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-correct_symbol)_inverse_featurizer +3 -0
  6. mock_submission/4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/DAS_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-correct_symbol_period)_featurizer +3 -0
  7. mock_submission/4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/DAS_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-correct_symbol_period)_indices +1 -0
  8. mock_submission/4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/DAS_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-correct_symbol_period)_inverse_featurizer +3 -0
  9. mock_submission/4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/DAS_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-last_token)_featurizer +3 -0
  10. mock_submission/4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/DAS_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-last_token)_indices +1 -0
  11. mock_submission/4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/DAS_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-last_token)_inverse_featurizer +3 -0
  12. mock_submission/__pycache__/featurizer.cpython-312.pyc +0 -0
  13. mock_submission/__pycache__/token_position.cpython-312.pyc +0 -0
  14. mock_submission/featurizer.py +52 -0
  15. mock_submission/token_position.py +65 -0
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ mock_submission/4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/DAS_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-correct_symbol_period)_featurizer filter=lfs diff=lfs merge=lfs -text
37
+ mock_submission/4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/DAS_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-correct_symbol_period)_inverse_featurizer filter=lfs diff=lfs merge=lfs -text
38
+ mock_submission/4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/DAS_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-correct_symbol)_featurizer filter=lfs diff=lfs merge=lfs -text
39
+ mock_submission/4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/DAS_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-correct_symbol)_inverse_featurizer filter=lfs diff=lfs merge=lfs -text
40
+ mock_submission/4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/DAS_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-last_token)_featurizer filter=lfs diff=lfs merge=lfs -text
41
+ mock_submission/4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/DAS_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-last_token)_inverse_featurizer filter=lfs diff=lfs merge=lfs -text
mock_submission/4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/DAS_Gemma2ForCausalLM_answer_pointer/4_answer_MCQA_Gemma2ForCausalLM_submission_answer_pointer__results.json ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "method_name": "submission",
3
+ "model_name": "Gemma2ForCausalLM",
4
+ "task_name": "4_answer_MCQA",
5
+ "dataset": {
6
+ "answerPosition_testprivate": {
7
+ "model_unit": {
8
+ "[[AtomicModelUnit(id='ResidualStream(Layer-0,Token-correct_symbol)')]]": {
9
+ "metadata": {
10
+ "layer": 0,
11
+ "position": "correct_symbol"
12
+ },
13
+ "answer_pointer": {
14
+ "average_score": 0.68
15
+ }
16
+ },
17
+ "[[AtomicModelUnit(id='ResidualStream(Layer-0,Token-correct_symbol_period)')]]": {
18
+ "metadata": {
19
+ "layer": 0,
20
+ "position": "correct_symbol_period"
21
+ },
22
+ "answer_pointer": {
23
+ "average_score": 0.0
24
+ }
25
+ },
26
+ "[[AtomicModelUnit(id='ResidualStream(Layer-0,Token-last_token)')]]": {
27
+ "metadata": {
28
+ "layer": 0,
29
+ "position": "last_token"
30
+ },
31
+ "answer_pointer": {
32
+ "average_score": 0.0
33
+ }
34
+ }
35
+ }
36
+ },
37
+ "randomLetter_testprivate": {
38
+ "model_unit": {
39
+ "[[AtomicModelUnit(id='ResidualStream(Layer-0,Token-correct_symbol)')]]": {
40
+ "metadata": {
41
+ "layer": 0,
42
+ "position": "correct_symbol"
43
+ },
44
+ "answer_pointer": {
45
+ "average_score": 0.8235294117647058
46
+ }
47
+ },
48
+ "[[AtomicModelUnit(id='ResidualStream(Layer-0,Token-correct_symbol_period)')]]": {
49
+ "metadata": {
50
+ "layer": 0,
51
+ "position": "correct_symbol_period"
52
+ },
53
+ "answer_pointer": {
54
+ "average_score": 1.0
55
+ }
56
+ },
57
+ "[[AtomicModelUnit(id='ResidualStream(Layer-0,Token-last_token)')]]": {
58
+ "metadata": {
59
+ "layer": 0,
60
+ "position": "last_token"
61
+ },
62
+ "answer_pointer": {
63
+ "average_score": 1.0
64
+ }
65
+ }
66
+ }
67
+ },
68
+ "answerPosition_randomLetter_testprivate": {
69
+ "model_unit": {
70
+ "[[AtomicModelUnit(id='ResidualStream(Layer-0,Token-correct_symbol)')]]": {
71
+ "metadata": {
72
+ "layer": 0,
73
+ "position": "correct_symbol"
74
+ },
75
+ "answer_pointer": {
76
+ "average_score": 0.0625
77
+ }
78
+ },
79
+ "[[AtomicModelUnit(id='ResidualStream(Layer-0,Token-correct_symbol_period)')]]": {
80
+ "metadata": {
81
+ "layer": 0,
82
+ "position": "correct_symbol_period"
83
+ },
84
+ "answer_pointer": {
85
+ "average_score": 0.0
86
+ }
87
+ },
88
+ "[[AtomicModelUnit(id='ResidualStream(Layer-0,Token-last_token)')]]": {
89
+ "metadata": {
90
+ "layer": 0,
91
+ "position": "last_token"
92
+ },
93
+ "answer_pointer": {
94
+ "average_score": 0.0
95
+ }
96
+ }
97
+ }
98
+ }
99
+ }
100
+ }
mock_submission/4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/DAS_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-correct_symbol)_featurizer ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dab9bccd2ea775eb56ad98fa9bac02c8d6a170d41d3a58b4b300c7e97eb80af8
3
+ size 21531300
mock_submission/4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/DAS_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-correct_symbol)_indices ADDED
@@ -0,0 +1 @@
 
 
1
+ null
mock_submission/4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/DAS_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-correct_symbol)_inverse_featurizer ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a5f69e8af1b271494715d6d2cf3936a9f1897065b5cd7a1e35417c0eb19a665
3
+ size 21531356
mock_submission/4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/DAS_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-correct_symbol_period)_featurizer ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d3ca3b99e9badc80119a4d711f60f35caf610ae7a8bcf08689385b490a197c0
3
+ size 21531349
mock_submission/4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/DAS_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-correct_symbol_period)_indices ADDED
@@ -0,0 +1 @@
 
 
1
+ null
mock_submission/4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/DAS_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-correct_symbol_period)_inverse_featurizer ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:630c183b185a53d826f9e17f6932dfa7e7d1011d8fb8435bc29b42fb1ac45189
3
+ size 21531533
mock_submission/4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/DAS_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-last_token)_featurizer ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8483d8ff87d3a188f542bcf17e545d63bb2039a644249982d82ef8c45a65964e
3
+ size 21531208
mock_submission/4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/DAS_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-last_token)_indices ADDED
@@ -0,0 +1 @@
 
 
1
+ null
mock_submission/4_answer_MCQA_Gemma2ForCausalLM_answer_pointer/DAS_Gemma2ForCausalLM_answer_pointer/ResidualStream(Layer-0,Token-last_token)_inverse_featurizer ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f02db1835842cbffc192c82973dc8d08dcc9b2f5667f57ac8f11c7af32c684b8
3
+ size 21531328
mock_submission/__pycache__/featurizer.cpython-312.pyc ADDED
Binary file (4.04 kB). View file
 
mock_submission/__pycache__/token_position.cpython-312.pyc ADDED
Binary file (3.17 kB). View file
 
mock_submission/featurizer.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copy of the existing SubspaceFeaturizer implementation for submission.
3
+ This file provides the same SubspaceFeaturizer functionality in a self-contained format.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import pyvene as pv
9
+ from CausalAbstraction.neural.featurizers import Featurizer
10
+
11
+
12
+ class SubspaceFeaturizerModuleCopy(torch.nn.Module):
13
+ def __init__(self, rotate_layer):
14
+ super().__init__()
15
+ self.rotate = rotate_layer
16
+
17
+ def forward(self, x):
18
+ r = self.rotate.weight.T
19
+ f = x.to(r.dtype) @ r.T
20
+ error = x - (f @ r).to(x.dtype)
21
+ return f, error
22
+
23
+
24
+ class SubspaceInverseFeaturizerModuleCopy(torch.nn.Module):
25
+ def __init__(self, rotate_layer):
26
+ super().__init__()
27
+ self.rotate = rotate_layer
28
+
29
+ def forward(self, f, error):
30
+ r = self.rotate.weight.T
31
+ return (f.to(r.dtype) @ r).to(f.dtype) + error.to(f.dtype)
32
+
33
+
34
+ class SubspaceFeaturizerCopy(Featurizer):
35
+ def __init__(self, shape=None, rotation_subspace=None, trainable=True, id="subspace"):
36
+ assert shape is not None or rotation_subspace is not None, "Either shape or rotation_subspace must be provided."
37
+ if shape is not None:
38
+ self.rotate = pv.models.layers.LowRankRotateLayer(*shape, init_orth=True)
39
+ elif rotation_subspace is not None:
40
+ shape = rotation_subspace.shape
41
+ self.rotate = pv.models.layers.LowRankRotateLayer(*shape, init_orth=False)
42
+ self.rotate.weight.data.copy_(rotation_subspace)
43
+ self.rotate = torch.nn.utils.parametrizations.orthogonal(self.rotate)
44
+
45
+ if not trainable:
46
+ self.rotate.requires_grad_(False)
47
+
48
+ # Create module-based featurizer and inverse_featurizer
49
+ featurizer = SubspaceFeaturizerModuleCopy(self.rotate)
50
+ inverse_featurizer = SubspaceInverseFeaturizerModuleCopy(self.rotate)
51
+
52
+ super().__init__(featurizer, inverse_featurizer, n_features=self.rotate.weight.shape[1], id=id)
mock_submission/token_position.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Token position definitions for MCQA task submission.
3
+ This file provides token position functions that identify key tokens in MCQA prompts.
4
+ """
5
+
6
+ import re
7
+ from CausalAbstraction.neural.LM_units import TokenPosition, get_last_token_index
8
+
9
+
10
+ def get_token_positions(pipeline, causal_model):
11
+ """
12
+ Get token positions for the simple MCQA task.
13
+
14
+ Args:
15
+ pipeline: The language model pipeline with tokenizer
16
+ causal_model: The causal model for the task
17
+
18
+ Returns:
19
+ list[TokenPosition]: List of TokenPosition objects for intervention experiments
20
+ """
21
+ def get_correct_symbol_index(input, pipeline, causal_model):
22
+ """
23
+ Find the index of the correct answer symbol in the prompt.
24
+
25
+ Args:
26
+ input (Dict): The input dictionary to a causal model
27
+ pipeline: The tokenizer pipeline
28
+ causal_model: The causal model
29
+
30
+ Returns:
31
+ list[int]: List containing the index of the correct answer symbol token
32
+ """
33
+ # Run the model to get the answer position
34
+ output = causal_model.run_forward(input)
35
+ pointer = output["answer_pointer"]
36
+ correct_symbol = output[f"symbol{pointer}"]
37
+ prompt = input["raw_input"]
38
+
39
+ # Find all single uppercase letters in the prompt
40
+ matches = list(re.finditer(r"\b[A-Z]\b", prompt))
41
+
42
+ # Find the match corresponding to our correct symbol
43
+ symbol_match = None
44
+ for match in matches:
45
+ if prompt[match.start():match.end()] == correct_symbol:
46
+ symbol_match = match
47
+ break
48
+
49
+ if not symbol_match:
50
+ raise ValueError(f"Could not find correct symbol {correct_symbol} in prompt: {prompt}")
51
+
52
+ # Get the substring up to the symbol match end
53
+ substring = prompt[:symbol_match.end()]
54
+ tokenized_substring = list(pipeline.load(substring)["input_ids"][0])
55
+
56
+ # The symbol token will be at the end of the substring
57
+ return [len(tokenized_substring) - 1]
58
+
59
+ # Create TokenPosition objects
60
+ token_positions = [
61
+ TokenPosition(lambda x: get_correct_symbol_index(x, pipeline, causal_model), pipeline, id="correct_symbol"),
62
+ TokenPosition(lambda x: [get_correct_symbol_index(x, pipeline, causal_model)[0]+1], pipeline, id="correct_symbol_period"),
63
+ TokenPosition(lambda x: get_last_token_index(x, pipeline), pipeline, id="last_token")
64
+ ]
65
+ return token_positions