shftan commited on
Commit
af4a860
Β·
1 Parent(s): 685c726

Clean up code

Browse files
Files changed (4) hide show
  1. .gitignore +1 -0
  2. __pycache__/utils.cpython-312.pyc +0 -0
  3. app.py +127 -7
  4. utils.py +80 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *.pyc
__pycache__/utils.cpython-312.pyc ADDED
Binary file (4.32 kB). View file
 
app.py CHANGED
@@ -1,14 +1,134 @@
1
  import gradio as gr
2
  import spaces
 
3
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- zero = torch.Tensor([0]).cuda()
6
- print(zero.device) # <-- 'cpu' πŸ€”
7
 
8
  @spaces.GPU
9
- def greet(n):
10
- print(zero.device) # <-- 'cuda:0' πŸ€—
11
- return f"Hello {zero + n} Tensor"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- demo = gr.Interface(fn=greet, inputs=gr.Number(), outputs=gr.Text())
14
- demo.launch()
 
 
 
 
 
 
1
  import gradio as gr
2
  import spaces
3
+ from huggingface_hub import hf_hub_download
4
  import torch
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
6
+ import pyvene as pv
7
+ from utils import get_tokens, select_concepts, get_concepts_dictionary, get_response, plot_tokens_with_highlights
8
+
9
+ #zero = torch.Tensor([0]).cuda()
10
+ #print(zero.device) # <-- 'cpu'
11
+
12
+ #@spaces.GPU
13
+ #def greet(n):
14
+ # print(zero.device) # <-- 'cuda:0'
15
+ # return f"Hello {zero + n} Tensor"
16
+
17
+ # Set model, interpreter, dictionary choices
18
+ model_name = "google/gemma-2-2b-it"
19
+ interpreter_name = "pyvene/gemma-reft-r1-2b-it-res"
20
+ interpreter_path = "l20/weight.pt"
21
+ interpreter_component = "model.layers[20].output"
22
+ dictionary_url = "https://huggingface.co/pyvene/gemma-reft-r1-2b-it-res/raw/main/l20/metadata.jsonl"
23
+
24
+ # Interpreter class
25
+ class Encoder(pv.CollectIntervention):
26
+ def __init__(self, **kwargs):
27
+ super().__init__(**kwargs, keep_last_dim=True)
28
+ self.proj = torch.nn.Linear(
29
+ self.embed_dim, kwargs["latent_dim"], bias=False)
30
+ def forward(self, base, source=None, subspaces=None):
31
+ return torch.relu(self.proj(base))
32
+
33
+ # Load tokenizer and model
34
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
35
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto').to('cuda')
36
+
37
+ # Load fast model inference pipeline
38
+ pipe = pipeline(
39
+ task="text-generation",
40
+ model=model_name,
41
+ use_fast=True
42
+ )
43
+
44
+ path_to_params = hf_hub_download(
45
+ repo_id=interpreter_name,
46
+ filename=interpreter_path,
47
+ force_download=False,
48
+ )
49
+ params = torch.load(path_to_params)
50
+ encoder = Encoder(embed_dim=params.shape[0], latent_dim=params.shape[1]).cuda()
51
+ encoder.proj.weight.data = params.float()
52
+ pv_model = pv.IntervenableModel({
53
+ "component": interpreter_component,
54
+ "intervention": encoder}, model=model).cuda()
55
+
56
+ # Load dictionary
57
+ all_concepts = get_concepts_dictionary(dictionary_url)
58
 
 
 
59
 
60
  @spaces.GPU
61
+ # Function to process user input to the app
62
+ def process_user_input(prompt, concept):
63
+ # Check if prompt or concept are empty
64
+ if not prompt or not concept:
65
+ return f"<h3>Please provide both a prompt and a concept</h3>"
66
+
67
+ # Convert prompt to tokens
68
+ tokens, token_ids = get_tokens(tokenizer, prompt)
69
+
70
+ # Get concept IDs and names
71
+ concept_ids, concept_df = select_concepts(all_concepts, concept)
72
+ if len(concept_ids) == 0:
73
+ concepts_html = f"<h3>No relevant concepts found for '{concept}' in LLM thoughts dictionary. Try another concept.</h3>"
74
+ else:
75
+ concepts_html = f"<h3>using the following in the LLM thoughts dictionary relevant to '{concept}' ({len(concept_ids)} out of {len(all_concepts)} concepts):</h3>"
76
+ styled_table = concept_df.style.hide(axis="index").set_properties(**{'background-color': '#f0f0f0', 'color': 'black', 'border-color': 'white'}).to_html()
77
+ concepts_html += f'<div style="height: 200px; overflow-y: scroll;">{styled_table}</div>'
78
+
79
+ # Get activations
80
+ if len(concept_ids) > 0:
81
+ acts = pv_model.forward({"input_ids": token_ids}, return_dict=True).collected_activations[0]
82
+ vals = acts[0, :, concept_ids].sum(-1).cpu()
83
+
84
+ # Get highlighted tokens
85
+ highlighted_tokens_html = plot_tokens_with_highlights(tokens, vals, concept)
86
+ else:
87
+ highlighted_tokens_html = ""
88
+
89
+ # Get LLM response
90
+ response = get_response(pipe, prompt)
91
+ response_html = f"""<h3>LLM response to your prompt:</h3>
92
+ {response}
93
+ """
94
+
95
+ # Write documentation
96
+ documentation_html = f"""<h3>How does this work?</h3>
97
+ <ul>
98
+ <li>The LLM model is an instruction-tuned model, <a href="https://huggingface.co/google/gemma-2-2b-it">Google gemma-2-2b-it</a>.
99
+ <li>The LLM interpreter, <a href="https://huggingface.co/pyvene/gemma-reft-r1-2b-it-res">gemma-reft-r1-2b-it-res</a> (not from Google) is trained on the LLM model's layer 20's residual stream. The choices of layer 20 and the residual stream are arbitrary.
100
+ <li>The LLM interpreter decomposes the layer 20 residual stream activations into a <a href="https://huggingface.co/pyvene/gemma-reft-r1-2b-it-res/raw/main/l20/metadata.jsonl">dictionary</a> of {len(all_concepts)} human-understandable concepts. This dictionary is not comprehensive; it is possible for a concept you input to not be present in this dictionary.
101
+ <li>Each token is highlighted in terms of how much information about a given concept is carried in each token.
102
+ <li>Do you find the results surprising? Any feedback? Any ideas on how I can make this app more useful? Please let me know! Contact: Sarah Tan.
103
+ </ul>
104
+ """
105
+
106
+ # Combine HTMLs
107
+ output_html = highlighted_tokens_html + concepts_html + "<p>&nbsp;</p>" + response_html + "<p>&nbsp;</p>" + documentation_html
108
+
109
+ return output_html
110
+
111
+ if __name__ == "__main__":
112
+ description_text = """
113
+ ## Does an LLM Think Like You?
114
+ Input a prompt and a concept that you think is most relevant for your prompt. See how much (if at all) the LLM uses that concept when processing your prompt.
115
+ Examples:
116
+ - **Prompt**: What is 2+2? **Concept**: math
117
+ - **Prompt**: I really like anchovies on pizza but I know a lot of people don't. **Concept**: food
118
+ """
119
+
120
+ with gr.Blocks() as demo:
121
+ gr.Markdown(description_text)
122
+ with gr.Row():
123
+ prompt_input = gr.Textbox(label="Enter a prompt", value="I really like anchovies on pizza but I know a lot of people don't.")
124
+ concept_input = gr.Textbox(label="Enter a concept that you think is most relevant for your prompt", value="food")
125
+ process_button = gr.Button("See if an LLM thinks like you!")
126
+ output_html = gr.HTML()
127
 
128
+ process_button.click(
129
+ process_user_input,
130
+ inputs=[prompt_input, concept_input],
131
+ outputs=output_html
132
+ )
133
+
134
+ demo.launch()
utils.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import matplotlib.cm as cm
3
+ import matplotlib.colors as clrs
4
+ import requests
5
+ import json
6
+ import pandas as pd
7
+ import torch
8
+
9
+
10
+ # Function to get tokens given text
11
+ def get_tokens(tokenizer, text):
12
+ token_ids = tokenizer.encode(text, return_tensors="pt", add_special_tokens=False).to("cuda")
13
+ tokens = tokenizer.convert_ids_to_tokens(token_ids[0])
14
+
15
+ return tokens, token_ids
16
+
17
+ # Function to apply chat template to prompt
18
+ def decorate_prompt(tokenizer, prompt):
19
+ chat = [
20
+ {"role": "user", "content": prompt},
21
+ {"role": "assistant", "content": ""},
22
+ ]
23
+ text = tokenizer.apply_chat_template(chat, tokenize=False)
24
+ token_ids = tokenizer.encode(text, return_tensors="pt", add_special_tokens=False).to("cuda")
25
+
26
+ return token_ids
27
+
28
+ # Function to get response to prompt
29
+ def get_response(model_pipe, prompt):
30
+ response = model_pipe(prompt)[0]['generated_text']
31
+ return response
32
+
33
+ # Function to highlight tokens based on given values
34
+ def plot_tokens_with_highlights(tokens, values, concept, cmap_name='Oranges', vmin=None, vmax=None):
35
+ if len(tokens) != len(values):
36
+ raise ValueError("The number of tokens and values must be the same.")
37
+
38
+ # Set color map
39
+ cmap = cm.get_cmap(cmap_name)
40
+
41
+ norm = clrs.Normalize(vmin=vmin if vmin is not None else values.detach().min(),
42
+ vmax=vmax if vmax is not None else values.detach().max())
43
+
44
+ html_output = f"<h3>How much information about the concept '{concept}' is carried in each token:</h3>"
45
+
46
+ for token, value in zip(tokens, values.detach().numpy()):
47
+ rgba_color = cmap(norm(value))
48
+ hex_color = '#%02x%02x%02x' % (int(rgba_color[0]*255), int(rgba_color[1]*255), int(rgba_color[2]*255))
49
+ html_output += f'<span style="background-color: {hex_color};" title="{value:.4f}">{token}</span> '
50
+
51
+ return html_output
52
+
53
+ # Function to get concepts dictionary
54
+ def get_concepts_dictionary(dictionary_url):
55
+ response = requests.get(dictionary_url, stream=True)
56
+ response.raise_for_status()
57
+ data_dict = {}
58
+ for line in response.iter_lines(decode_unicode=True):
59
+ if line:
60
+ obj = json.loads(line)
61
+ concept_id = obj.get("concept_id")
62
+ concept = obj.get("concept")
63
+ if concept_id and concept:
64
+ data_dict[concept_id] = concept.capitalize()
65
+ return data_dict
66
+
67
+ # Function to get matching concepts
68
+ def select_concepts(all_concepts, desired_concept):
69
+ concept_ids = []
70
+ for k, v in all_concepts.items():
71
+ if desired_concept.lower() in v.lower():
72
+ concept_ids.append(k)
73
+
74
+ concept_data = []
75
+ for concept_id in concept_ids:
76
+ concept_name = all_concepts.get(concept_id, "Unknown Concept")
77
+ concept_data.append({"Concept ID": concept_id, "Concept Name": concept_name})
78
+ concept_df = pd.DataFrame(concept_data)
79
+
80
+ return torch.tensor(concept_ids), concept_df