shftan commited on
Commit
ec01901
·
1 Parent(s): e911dac

Make nested functions

Browse files
Files changed (3) hide show
  1. __pycache__/utils.cpython-312.pyc +0 -0
  2. app.py +98 -94
  3. requirements.txt +1 -0
__pycache__/utils.cpython-312.pyc DELETED
Binary file (4.32 kB)
 
app.py CHANGED
@@ -14,101 +14,102 @@ from utils import get_tokens, select_concepts, get_concepts_dictionary, get_resp
14
  # print(zero.device) # <-- 'cuda:0'
15
  # return f"Hello {zero + n} Tensor"
16
 
17
- # Set model, interpreter, dictionary choices
18
- model_name = "google/gemma-2-2b-it"
19
- interpreter_name = "pyvene/gemma-reft-r1-2b-it-res"
20
- interpreter_path = "l20/weight.pt"
21
- interpreter_component = "model.layers[20].output"
22
- dictionary_url = "https://huggingface.co/pyvene/gemma-reft-r1-2b-it-res/raw/main/l20/metadata.jsonl"
23
-
24
- # Interpreter class
25
- class Encoder(pv.CollectIntervention):
26
- def __init__(self, **kwargs):
27
- super().__init__(**kwargs, keep_last_dim=True)
28
- self.proj = torch.nn.Linear(
29
- self.embed_dim, kwargs["latent_dim"], bias=False)
30
- def forward(self, base, source=None, subspaces=None):
31
- return torch.relu(self.proj(base))
32
-
33
- # Load tokenizer and model
34
- tokenizer = AutoTokenizer.from_pretrained(model_name)
35
- model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto').to('cuda')
36
-
37
- # Load fast model inference pipeline
38
- pipe = pipeline(
39
- task="text-generation",
40
- model=model_name,
41
- use_fast=True
42
- )
43
-
44
- path_to_params = hf_hub_download(
45
- repo_id=interpreter_name,
46
- filename=interpreter_path,
47
- force_download=False,
48
- )
49
- params = torch.load(path_to_params)
50
- encoder = Encoder(embed_dim=params.shape[0], latent_dim=params.shape[1]).cuda()
51
- encoder.proj.weight.data = params.float()
52
- pv_model = pv.IntervenableModel({
53
- "component": interpreter_component,
54
- "intervention": encoder}, model=model).cuda()
55
-
56
- # Load dictionary
57
- all_concepts = get_concepts_dictionary(dictionary_url)
58
-
59
-
60
  @spaces.GPU
61
- # Function to process user input to the app
62
- def process_user_input(prompt, concept):
63
- # Check if prompt or concept are empty
64
- if not prompt or not concept:
65
- return f"<h3>Please provide both a prompt and a concept</h3>"
66
-
67
- # Convert prompt to tokens
68
- tokens, token_ids = get_tokens(tokenizer, prompt)
69
-
70
- # Get concept IDs and names
71
- concept_ids, concept_df = select_concepts(all_concepts, concept)
72
- if len(concept_ids) == 0:
73
- concepts_html = f"<h3>No relevant concepts found for '{concept}' in LLM thoughts dictionary. Try another concept.</h3>"
74
- else:
75
- concepts_html = f"<h3>using the following in the LLM thoughts dictionary relevant to '{concept}' ({len(concept_ids)} out of {len(all_concepts)} concepts):</h3>"
76
- styled_table = concept_df.style.hide(axis="index").set_properties(**{'background-color': '#f0f0f0', 'color': 'black', 'border-color': 'white'}).to_html()
77
- concepts_html += f'<div style="height: 200px; overflow-y: scroll;">{styled_table}</div>'
78
-
79
- # Get activations
80
- if len(concept_ids) > 0:
81
- acts = pv_model.forward({"input_ids": token_ids}, return_dict=True).collected_activations[0]
82
- vals = acts[0, :, concept_ids].sum(-1).cpu()
83
-
84
- # Get highlighted tokens
85
- highlighted_tokens_html = plot_tokens_with_highlights(tokens, vals, concept)
86
- else:
87
- highlighted_tokens_html = ""
88
-
89
- # Get LLM response
90
- response = get_response(pipe, prompt)
91
- response_html = f"""<h3>LLM response to your prompt:</h3>
92
- {response}
93
- """
94
-
95
- # Write documentation
96
- documentation_html = f"""<h3>How does this work?</h3>
97
- <ul>
98
- <li>The LLM model is an instruction-tuned model, <a href="https://huggingface.co/google/gemma-2-2b-it">Google gemma-2-2b-it</a>.
99
- <li>The LLM interpreter, <a href="https://huggingface.co/pyvene/gemma-reft-r1-2b-it-res">gemma-reft-r1-2b-it-res</a> (not from Google) is trained on the LLM model's layer 20's residual stream. The choices of layer 20 and the residual stream are arbitrary.
100
- <li>The LLM interpreter decomposes the layer 20 residual stream activations into a <a href="https://huggingface.co/pyvene/gemma-reft-r1-2b-it-res/raw/main/l20/metadata.jsonl">dictionary</a> of {len(all_concepts)} human-understandable concepts. This dictionary is not comprehensive; it is possible for a concept you input to not be present in this dictionary.
101
- <li>Each token is highlighted in terms of how much information about a given concept is carried in each token.
102
- <li>Do you find the results surprising? Any feedback? Any ideas on how I can make this app more useful? Please let me know! Contact: Sarah Tan.
103
- </ul>
104
- """
105
-
106
- # Combine HTMLs
107
- output_html = highlighted_tokens_html + concepts_html + "<p>&nbsp;</p>" + response_html + "<p>&nbsp;</p>" + documentation_html
108
-
109
- return output_html
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
- if __name__ == "__main__":
112
  description_text = """
113
  ## Does an LLM Think Like You?
114
  Input a prompt and a concept that you think is most relevant for your prompt. See how much (if at all) the LLM uses that concept when processing your prompt.
@@ -131,4 +132,7 @@ if __name__ == "__main__":
131
  outputs=output_html
132
  )
133
 
134
- demo.launch()
 
 
 
 
14
  # print(zero.device) # <-- 'cuda:0'
15
  # return f"Hello {zero + n} Tensor"
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  @spaces.GPU
18
+ def launch_app():
19
+
20
+ @spaces.GPU
21
+ # Function to process user input to the app
22
+ def process_user_input(prompt, concept):
23
+ # Check if prompt or concept are empty
24
+ if not prompt or not concept:
25
+ return f"<h3>Please provide both a prompt and a concept</h3>"
26
+
27
+ # Convert prompt to tokens
28
+ tokens, token_ids = get_tokens(tokenizer, prompt)
29
+
30
+ # Get concept IDs and names
31
+ concept_ids, concept_df = select_concepts(all_concepts, concept)
32
+ if len(concept_ids) == 0:
33
+ concepts_html = f"<h3>No relevant concepts found for '{concept}' in LLM thoughts dictionary. Try another concept.</h3>"
34
+ else:
35
+ concepts_html = f"<h3>using the following in the LLM thoughts dictionary relevant to '{concept}' ({len(concept_ids)} out of {len(all_concepts)} concepts):</h3>"
36
+ styled_table = concept_df.style.hide(axis="index").set_properties(**{'background-color': '#f0f0f0', 'color': 'black', 'border-color': 'white'}).to_html()
37
+ concepts_html += f'<div style="height: 200px; overflow-y: scroll;">{styled_table}</div>'
38
+
39
+ # Get activations
40
+ if len(concept_ids) > 0:
41
+ acts = pv_model.forward({"input_ids": token_ids}, return_dict=True).collected_activations[0]
42
+ vals = acts[0, :, concept_ids].sum(-1).cpu()
43
+
44
+ # Get highlighted tokens
45
+ highlighted_tokens_html = plot_tokens_with_highlights(tokens, vals, concept)
46
+ else:
47
+ highlighted_tokens_html = ""
48
+
49
+ # Get LLM response
50
+ response = get_response(pipe, prompt)
51
+ response_html = f"""<h3>LLM response to your prompt:</h3>
52
+ {response}
53
+ """
54
+
55
+ # Write documentation
56
+ documentation_html = f"""<h3>How does this work?</h3>
57
+ <ul>
58
+ <li>The LLM model is an instruction-tuned model, <a href="https://huggingface.co/google/gemma-2-2b-it">Google gemma-2-2b-it</a>.
59
+ <li>The LLM interpreter, <a href="https://huggingface.co/pyvene/gemma-reft-r1-2b-it-res">gemma-reft-r1-2b-it-res</a> (not from Google) is trained on the LLM model's layer 20's residual stream. The choices of layer 20 and the residual stream are arbitrary.
60
+ <li>The LLM interpreter decomposes the layer 20 residual stream activations into a <a href="https://huggingface.co/pyvene/gemma-reft-r1-2b-it-res/raw/main/l20/metadata.jsonl">dictionary</a> of {len(all_concepts)} human-understandable concepts. This dictionary is not comprehensive; it is possible for a concept you input to not be present in this dictionary.
61
+ <li>Each token is highlighted in terms of how much information about a given concept is carried in each token.
62
+ <li>Do you find the results surprising? Any feedback? Any ideas on how I can make this app more useful? Please let me know! Contact: Sarah Tan.
63
+ </ul>
64
+ """
65
+
66
+ # Combine HTMLs
67
+ output_html = highlighted_tokens_html + concepts_html + "<p>&nbsp;</p>" + response_html + "<p>&nbsp;</p>" + documentation_html
68
+
69
+ return output_html
70
+
71
+ # Set model, interpreter, dictionary choices
72
+ model_name = "google/gemma-2-2b-it"
73
+ interpreter_name = "pyvene/gemma-reft-r1-2b-it-res"
74
+ interpreter_path = "l20/weight.pt"
75
+ interpreter_component = "model.layers[20].output"
76
+ dictionary_url = "https://huggingface.co/pyvene/gemma-reft-r1-2b-it-res/raw/main/l20/metadata.jsonl"
77
+
78
+ # Interpreter class
79
+ class Encoder(pv.CollectIntervention):
80
+ def __init__(self, **kwargs):
81
+ super().__init__(**kwargs, keep_last_dim=True)
82
+ self.proj = torch.nn.Linear(
83
+ self.embed_dim, kwargs["latent_dim"], bias=False)
84
+ def forward(self, base, source=None, subspaces=None):
85
+ return torch.relu(self.proj(base))
86
+
87
+ # Load tokenizer and model
88
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
89
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto').to('cuda')
90
+
91
+ # Load fast model inference pipeline
92
+ pipe = pipeline(
93
+ task="text-generation",
94
+ model=model_name,
95
+ use_fast=True
96
+ )
97
+
98
+ path_to_params = hf_hub_download(
99
+ repo_id=interpreter_name,
100
+ filename=interpreter_path,
101
+ force_download=False,
102
+ )
103
+ params = torch.load(path_to_params)
104
+ encoder = Encoder(embed_dim=params.shape[0], latent_dim=params.shape[1]).cuda()
105
+ encoder.proj.weight.data = params.float()
106
+ pv_model = pv.IntervenableModel({
107
+ "component": interpreter_component,
108
+ "intervention": encoder}, model=model).cuda()
109
+
110
+ # Load dictionary
111
+ all_concepts = get_concepts_dictionary(dictionary_url)
112
 
 
113
  description_text = """
114
  ## Does an LLM Think Like You?
115
  Input a prompt and a concept that you think is most relevant for your prompt. See how much (if at all) the LLM uses that concept when processing your prompt.
 
132
  outputs=output_html
133
  )
134
 
135
+ demo.launch()
136
+
137
+ if __name__ == "__main__":
138
+ launch_app()
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ transformers