m-ric commited on
Commit
a6e310c
·
1 Parent(s): f0fc4a6

go to lxt

Browse files
Files changed (2) hide show
  1. app.py +50 -13
  2. requirements.txt +2 -2
app.py CHANGED
@@ -1,13 +1,50 @@
1
- from bertviz.transformers_neuron_view import BertModel, BertTokenizer
2
- from bertviz.neuron_view import show
3
-
4
- model_type = 'bert'
5
- model_version = 'bert-base-uncased'
6
- do_lower_case = True
7
- model = BertModel.from_pretrained(model_version)
8
- tokenizer = BertTokenizer.from_pretrained(model_version, do_lower_case=do_lower_case)
9
- sentence_a = "The cat sat on the mat"
10
- sentence_b = "The cat lay on the rug"
11
- html = show(model, model_type, tokenizer, sentence_a, sentence_b, display_mode='dark', layer=2, head=0, html_action='return')
12
-
13
- gradio.HTML(html)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer
3
+ from lxt.models.llama import LlamaForCausalLM, attnlrp
4
+ from lxt.utils import clean_tokens
5
+ import gradio as gr
6
+
7
+ # Load model and tokenizer
8
+ model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.bfloat16, device_map="cuda")
9
+ tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
10
+
11
+ # Apply AttnLRP rules
12
+ attnlrp.register(model)
13
+
14
+ def generate_and_visualize(prompt):
15
+ input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).input_ids.to(model.device)
16
+ input_embeds = model.get_input_embeddings()(input_ids)
17
+
18
+ output_logits = model(inputs_embeds=input_embeds.requires_grad_(), use_cache=False).logits
19
+ max_logits, max_indices = torch.max(output_logits[0, -1, :], dim=-1)
20
+
21
+ max_logits.backward(max_logits)
22
+ relevance = input_embeds.grad.float().sum(-1).cpu()[0]
23
+
24
+ # Normalize relevance between [0, 1] for highlighting
25
+ relevance = (relevance - relevance.min()) / (relevance.max() - relevance.min())
26
+
27
+ # Remove '_' characters from token strings
28
+ tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
29
+ tokens = clean_tokens(tokens)
30
+
31
+ # Create list of (token, score) tuples for HighlightedText
32
+ highlighted_tokens = [(token, float(score)) for token, score in zip(tokens, relevance)]
33
+
34
+ return highlighted_tokens
35
+
36
+ # Define Gradio interface
37
+ with gr.Blocks() as demo:
38
+ gr.Markdown("# LLaMA Attention Visualization Demo")
39
+
40
+ with gr.Row():
41
+ input_text = gr.Textbox(label="Input Prompt", lines=5)
42
+ generate_button = gr.Button("Generate and Visualize")
43
+
44
+ output = gr.HighlightedText(label="Attention Visualization")
45
+
46
+ generate_button.click(generate_and_visualize, inputs=input_text, outputs=output)
47
+
48
+ # Launch the demo
49
+ if __name__ == "__main__":
50
+ demo.launch()
requirements.txt CHANGED
@@ -1,2 +1,2 @@
1
- bertviz==1.4.0
2
- ipython==8.18.1
 
1
+ accelerate
2
+ lxt