Slicky325 commited on
Commit
34e4e96
·
verified ·
1 Parent(s): e9ab5c4

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +85 -0
README.md ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - token-importance
4
+ - attention-classifier
5
+ - llama
6
+ ---
7
+
8
+ # Token Importance Classifier
9
+
10
+ This model is a single-layer attention-based classifier trained to predict token importance in sequences.
11
+
12
+ ## Model Details
13
+
14
+ - **Architecture**: Single-layer self-attention network with RoPE positional embeddings
15
+ - **Base Model**: meta-llama/Llama-3.1-8B
16
+ - **Hidden Dimension**: 4096
17
+ - **Number of Heads**: 32
18
+ - **Max Sequence Length**: 131072
19
+
20
+ ## Training Configuration
21
+
22
+ ```yaml
23
+ data:
24
+ max_seq_len: 131072
25
+ path: /root/workspace/data_generation/data/sample_output.jsonl
26
+ tokenizer_path: meta-llama/Llama-3.1-8B
27
+ valid_split: 0.1
28
+ final_metrics:
29
+ accuracy: 0.8365938756296772
30
+ f1: 0.9094284550391643
31
+ precision: 0.8365938756296772
32
+ recall: 1.0
33
+ huggingface:
34
+ private: false
35
+ push_to_hub: true
36
+ repo_id: Slicky325/token-selector-model
37
+ model:
38
+ base_model_dir: meta-llama/Llama-3.1-8B
39
+ dropout: 0.1
40
+ hidden_dim: 4096
41
+ max_seq_len: 131072
42
+ num_heads: 32
43
+ rope_theta: 500000
44
+ save_embeddings: false
45
+ save_path: models/selector.pt
46
+ train_embeddings: false
47
+ use_positional: true
48
+ system:
49
+ device: cuda
50
+ num_workers: 2
51
+ training:
52
+ batch_size: 4
53
+ epochs: 1
54
+ grad_clip: 1.0
55
+ learning_rate: 0.001
56
+ seed: 42
57
+ weight_decay: 0.0
58
+
59
+ ```
60
+
61
+ ## Validation Metrics
62
+
63
+ - **Accuracy**: 0.8365938756296772
64
+ - **Precision**: 0.8365938756296772
65
+ - **Recall**: 1.0
66
+ - **F1 Score**: 0.9094284550391643
67
+
68
+ ## Usage
69
+
70
+ ```python
71
+ import torch
72
+ from pathlib import Path
73
+
74
+ # Load the checkpoint
75
+ checkpoint = torch.load('selector.pt')
76
+ model_state = checkpoint['model_state_dict']
77
+ config = checkpoint['config']
78
+
79
+ # Initialize your model architecture and load the weights
80
+ # model.load_state_dict(model_state)
81
+ ```
82
+
83
+ ## Citation
84
+
85
+ If you use this model in your research, please cite appropriately.