Slicky325 commited on
Commit
4bafda2
·
verified ·
1 Parent(s): d9e4845

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +2 -77
README.md CHANGED
@@ -1,85 +1,10 @@
1
- ---
2
- tags:
3
- - token-importance
4
- - attention-classifier
5
- - llama
6
- ---
7
-
8
  # Token Importance Classifier
9
 
10
- This model is a single-layer attention-based classifier trained to predict token importance in sequences.
11
-
12
- ## Model Details
13
-
14
- - **Architecture**: Single-layer self-attention network with RoPE positional embeddings
15
- - **Base Model**: meta-llama/Llama-3.1-8B
16
- - **Hidden Dimension**: 4096
17
- - **Number of Heads**: 32
18
- - **Max Sequence Length**: 131072
19
-
20
- ## Training Configuration
21
-
22
- ```yaml
23
- data:
24
- max_seq_len: 131072
25
- path: /root/workspace/data_generation/data/sample_output.jsonl
26
- tokenizer_path: meta-llama/Llama-3.1-8B
27
- valid_split: 0.1
28
- final_metrics:
29
- accuracy: 0.8365938756296772
30
- f1: 0.9094284550391643
31
- precision: 0.8365938756296772
32
- recall: 1.0
33
- huggingface:
34
- private: false
35
- push_to_hub: true
36
- repo_id: Slicky325/token-selector-model
37
- model:
38
- base_model_dir: meta-llama/Llama-3.1-8B
39
- dropout: 0.1
40
- hidden_dim: 4096
41
- max_seq_len: 131072
42
- num_heads: 32
43
- rope_theta: 500000
44
- save_embeddings: false
45
- save_path: models/selector.pt
46
- train_embeddings: false
47
- use_positional: true
48
- system:
49
- device: cuda
50
- num_workers: 2
51
- training:
52
- batch_size: 4
53
- epochs: 1
54
- grad_clip: 1.0
55
- learning_rate: 0.001
56
- seed: 42
57
- weight_decay: 0.0
58
-
59
- ```
60
-
61
- ## Validation Metrics
62
-
63
- - **Accuracy**: 0.8365938756296772
64
- - **Precision**: 0.8365938756296772
65
- - **Recall**: 1.0
66
- - **F1 Score**: 0.9094284550391643
67
 
68
  ## Usage
69
-
70
  ```python
71
  import torch
72
- from pathlib import Path
73
-
74
- # Load the checkpoint
75
  checkpoint = torch.load('selector.pt')
76
- model_state = checkpoint['model_state_dict']
77
- config = checkpoint['config']
78
-
79
- # Initialize your model architecture and load the weights
80
- # model.load_state_dict(model_state)
81
  ```
82
-
83
- ## Citation
84
-
85
- If you use this model in your research, please cite appropriately.
 
 
 
 
 
 
 
 
1
  # Token Importance Classifier
2
 
3
+ Trained with F1: 0.9094, Accuracy: 0.8366
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  ## Usage
 
6
  ```python
7
  import torch
 
 
 
8
  checkpoint = torch.load('selector.pt')
9
+ model.load_state_dict(checkpoint['model_state_dict'])
 
 
 
 
10
  ```