File size: 11,226 Bytes
6b02bd5
 
 
 
df5b00c
 
cc0dc87
 
6b02bd5
 
df5b00c
 
 
 
 
 
 
 
 
 
 
 
309c7a2
 
6b02bd5
 
 
 
 
 
 
309c7a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc0dc87
df5b00c
af6ef19
6b02bd5
309c7a2
 
 
6b02bd5
309c7a2
 
 
6b02bd5
 
 
309c7a2
 
 
 
 
 
 
 
 
6b02bd5
 
 
309c7a2
 
 
 
6b02bd5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49296ad
6b02bd5
 
cc0dc87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b02bd5
 
 
 
 
309c7a2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
---
library_name: transformers
tags:
- generated_from_trainer
datasets:
- HuggingFaceFW/fineweb
metrics:
- accuracy
model-index:
- name: T5LA
  results:
  - task:
      name: Causal Language Modeling
      type: text-generation
    dataset:
      name: HuggingFaceFW/fineweb sample-10BT
      type: HuggingFaceFW/fineweb
      args: sample-10BT
    metrics:
    - name: Accuracy
      type: accuracy
      value: 0.032223235792499715
base_model:
- google-t5/t5-base
---

<!-- This model card has been generated automatically according to the information the Trainer had access to. You
should probably proofread and complete it, then remove this comment. -->

# T5LA

This model is part of the work published in the paper [Interactive Text Games: Lookahead Is All You Need!](https://openreview.net/pdf?id=D38rTnrkal)

Four models are introduced in the above paper:
- [nanoGPTLA](https://huggingface.co/hrezaei/nanoGPTLookAhead)
- [nanoGPTLAA](https://huggingface.co/hrezaei/nanoGPTLookAheadA)
- [nanoGPTLAA2](https://huggingface.co/hrezaei/nanoGPTLookAheadA2)
- [nanoGPTLAE](https://huggingface.co/hrezaei/nanoGPTLookAheadAE)

These models are implemented in [this repository](https://github.com/HRezaei/nanoGPT) which is a customized version of [nanoGPT](https://github.com/karpathy/nanoGPT).

The same variations are also implemented in [this fork](https://github.com/HRezaei/transformers/tree/feature/lookahead_models) of Transformers library, on top of [Google-t5/T5](https://github.com/huggingface/transformers/tree/128387757105c7c0b57b519ac2aaff217a20e3f0/src/transformers/models/t5) implementation.
These models are also trained and published as follows:
- [T5LA](https://huggingface.co/hrezaei/T5LA)
- [T5LAA](https://huggingface.co/hrezaei/T5LAA)
- [T5LAA2](https://huggingface.co/hrezaei/T5LAA2)
- [T5LAE](https://huggingface.co/hrezaei/T5LAE)

All the above models are on the scale of GPT2 (~100M parameters). The work is in progress to train them on larger scales.

## Model description

This model is not fine-tuned on any instruction or human feedback datasets. It is just pre-trained on the HuggingFaceFW/fineweb sample-10BT dataset.
It achieves the following results on the evaluation set:
- Loss: 5.5467
- Accuracy: 0.0322

Since the above fork is not merged into the main Transformers library yet, if you need to load it with AutoModel.from_pretrained(), 
you need to first install Transformers from [this branch](https://github.com/HRezaei/transformers/tree/feature/lookahead_models),
which contains the code for T5LA models. This can be done by:

```shell
pip install git+https://github.com/HRezaei/transformers.git@feature/lookahead_models
```

## Intended uses & limitations

The model is designed to predict not only the next immediate token after the prompt (which normal LLMs do), but also to predict
the second, third, ..., up to K next tokens, conditioned on the prompt. These future predictions can be useful for approximated ranking,
where a set of potential responses are needed to be ranked based on the approximated probability of their tokens conditioned on the prompt, 
rather than conditioned on their previous tokens. 

The main limitation is that future predictions are generaly not suitable for generating text, as they don't consider token interdependencies,
i.e. the future tokens are not conditioned on the previous tokens. Thus, for generation, one should rely only on the next immediate token. 
However, the quality of next immediate token prediction is also degraded, because during training, the loss function has more terms to 
minimize (one term for next immediate token like original LLMs, and one extra term per each future tokens).

## Training and evaluation data

This model is not fine-tuned on any instruction or human feedback datasets. It is just pre-trained on the HuggingFaceFW/fineweb sample-10BT dataset.
It achieves the following results on the evaluation set:
- Loss: 5.5467
- Accuracy: 0.0322

## Training procedure

### Training hyperparameters

The following hyperparameters were used during training:
- learning_rate: 5e-05
- train_batch_size: 8
- eval_batch_size: 8
- seed: 42
- distributed_type: multi-GPU
- num_devices: 2
- total_train_batch_size: 16
- total_eval_batch_size: 16
- optimizer: Use adamw_torch with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
- lr_scheduler_type: linear
- training_steps: 200000
- mixed_precision_training: Native AMP

### Training results

| Training Loss | Epoch  | Step   | Accuracy | Validation Loss |
|:-------------:|:------:|:------:|:--------:|:---------------:|
| 9.4056        | 0.01   | 1000   | 0.0435   | 9.1215          |
| 8.4062        | 0.02   | 2000   | 0.0443   | 8.1939          |
| 7.7307        | 0.03   | 3000   | 0.0444   | 7.6024          |
| 7.39          | 0.04   | 4000   | 0.0444   | 7.3338          |
| 7.2546        | 0.05   | 5000   | 0.0441   | 7.2452          |
| 7.1985        | 0.06   | 6000   | 0.0369   | 7.1682          |
| 7.1009        | 0.07   | 7000   | 0.0346   | 7.0718          |
| 7.004         | 0.08   | 8000   | 0.0332   | 6.9778          |
| 6.9159        | 0.09   | 9000   | 0.0325   | 6.8964          |
| 6.8548        | 0.1    | 10000  | 0.0325   | 6.8307          |
| 6.7833        | 0.11   | 11000  | 0.0326   | 6.7702          |
| 6.7376        | 0.12   | 12000  | 0.0337   | 6.7163          |
| 6.6821        | 0.13   | 13000  | 0.0346   | 6.6615          |
| 6.6373        | 0.14   | 14000  | 0.0349   | 6.6086          |
| 6.5895        | 0.15   | 15000  | 0.0344   | 6.5569          |
| 6.5421        | 0.16   | 16000  | 0.0354   | 6.5119          |
| 6.5051        | 0.17   | 17000  | 0.0355   | 6.4678          |
| 6.4391        | 0.18   | 18000  | 0.0360   | 6.4324          |
| 6.4242        | 0.19   | 19000  | 0.0355   | 6.4015          |
| 6.3889        | 0.2    | 20000  | 0.0373   | 6.3553          |
| 6.3631        | 0.21   | 21000  | 0.0367   | 6.3285          |
| 6.3296        | 0.22   | 22000  | 0.0369   | 6.3015          |
| 6.3081        | 0.23   | 23000  | 0.0364   | 6.2699          |
| 6.2784        | 0.24   | 24000  | 0.0370   | 6.2454          |
| 6.2589        | 0.25   | 25000  | 0.0374   | 6.2167          |
| 6.2371        | 0.26   | 26000  | 0.0370   | 6.1890          |
| 6.1978        | 0.27   | 27000  | 0.0376   | 6.1660          |
| 6.1895        | 0.28   | 28000  | 0.0375   | 6.1378          |
| 6.1636        | 0.29   | 29000  | 0.0366   | 6.1213          |
| 6.1262        | 0.3    | 30000  | 0.0370   | 6.0967          |
| 6.1345        | 0.31   | 31000  | 0.0361   | 6.0745          |
| 6.1096        | 0.32   | 32000  | 0.0360   | 6.0556          |
| 6.0794        | 0.33   | 33000  | 0.0357   | 6.0413          |
| 6.0643        | 0.34   | 34000  | 0.0363   | 6.0136          |
| 6.057         | 0.35   | 35000  | 0.0362   | 5.9965          |
| 6.0337        | 0.36   | 36000  | 0.0354   | 5.9806          |
| 6.0217        | 0.37   | 37000  | 0.0363   | 5.9584          |
| 6.0045        | 0.38   | 38000  | 0.0359   | 5.9526          |
| 5.9896        | 0.39   | 39000  | 0.0355   | 5.9288          |
| 5.9711        | 0.4    | 40000  | 0.0352   | 5.9152          |
| 5.9629        | 0.41   | 41000  | 0.0349   | 5.8962          |
| 5.9465        | 0.42   | 42000  | 0.0359   | 5.8821          |
| 5.9463        | 0.43   | 43000  | 0.0345   | 5.8692          |
| 5.9317        | 0.44   | 44000  | 0.0343   | 5.8699          |
| 5.9097        | 1.0034 | 45000  | 0.0346   | 5.8483          |
| 5.9107        | 1.0134 | 46000  | 0.0348   | 5.8352          |
| 5.8838        | 1.0234 | 47000  | 0.0343   | 5.8188          |
| 5.887         | 1.0334 | 48000  | 0.0340   | 5.8086          |
| 5.8563        | 1.0434 | 49000  | 0.0338   | 5.7971          |
| 5.8576        | 1.0534 | 50000  | 0.0339   | 5.7968          |
| 5.8567        | 1.0635 | 51000  | 0.0343   | 5.7797          |
| 5.841         | 1.0735 | 52000  | 0.0337   | 5.7677          |
| 5.8192        | 1.0835 | 53000  | 0.0332   | 5.7613          |
| 5.8214        | 1.0935 | 54000  | 0.0338   | 5.7486          |
| 5.8166        | 1.1035 | 55000  | 0.0338   | 5.7409          |
| 5.806         | 1.1135 | 56000  | 0.0333   | 5.7342          |
| 5.7961        | 1.1235 | 57000  | 0.0335   | 5.7236          |
| 5.7847        | 1.1335 | 58000  | 0.0333   | 5.7164          |
| 5.787         | 1.1435 | 59000  | 0.0330   | 5.7096          |
| 5.7711        | 1.1535 | 60000  | 0.0328   | 5.7035          |
| 5.7699        | 1.1635 | 61000  | 0.0331   | 5.6888          |
| 5.763         | 1.1734 | 62000  | 0.0334   | 5.6875          |
| 5.7434        | 1.1835 | 63000  | 0.0330   | 5.6809          |
| 5.7477        | 1.1934 | 64000  | 0.0329   | 5.6686          |
| 5.7409        | 1.2034 | 65000  | 0.0330   | 5.6624          |
| 5.737         | 1.2134 | 66000  | 0.0339   | 5.6758          |
| 5.729         | 1.2234 | 67000  | 0.0326   | 5.6546          |
| 5.7232        | 1.2334 | 68000  | 0.0329   | 5.6467          |
| 5.7127        | 1.2434 | 69000  | 0.0329   | 5.6449          |
| 5.7187        | 1.2534 | 70000  | 0.0329   | 5.6352          |
| 5.717         | 1.2634 | 71000  | 0.0326   | 5.6264          |
| 5.714         | 1.2734 | 72000  | 0.0330   | 5.6219          |
| 5.7079        | 1.2834 | 73000  | 0.0330   | 5.6169          |
| 5.7034        | 1.2934 | 74000  | 0.0326   | 5.6131          |
| 5.6768        | 1.3034 | 75000  | 0.0325   | 5.6125          |
| 5.6955        | 1.3135 | 76000  | 0.0328   | 5.6075          |
| 5.6947        | 1.3235 | 77000  | 0.0325   | 5.6017          |
| 5.7056        | 1.3335 | 78000  | 0.0323   | 5.5956          |
| 5.6636        | 1.3435 | 79000  | 0.0326   | 5.5921          |
| 5.6723        | 1.3535 | 80000  | 0.0326   | 5.5881          |
| 5.659         | 1.3635 | 81000  | 0.0324   | 5.5823          |
| 5.6729        | 1.3735 | 82000  | 0.0326   | 5.5795          |
| 5.6595        | 1.3835 | 83000  | 0.0322   | 5.5794          |
| 5.6565        | 1.3935 | 84000  | 0.0328   | 5.5758          |
| 5.6649        | 1.4034 | 85000  | 0.0325   | 5.5716          |
| 5.6561        | 1.4135 | 86000  | 0.0321   | 5.5695          |
| 5.6405        | 1.4234 | 87000  | 0.0323   | 5.5654          |
| 5.6482        | 1.4335 | 88000  | 0.0321   | 5.5628          |
| 5.6425        | 1.4434 | 89000  | 0.0323   | 5.5622          |
| 5.6379        | 2.0069 | 90000  | 0.0323   | 5.5582          |
| 5.6357        | 2.0169 | 91000  | 0.0322   | 5.5573          |
| 5.6381        | 2.0269 | 92000  | 0.0320   | 5.5568          |
| 5.6427        | 2.0369 | 93000  | 0.0324   | 5.5526          |
| 5.6364        | 2.0469 | 94000  | 0.0323   | 5.5526          |
| 5.626         | 2.0569 | 95000  | 0.0321   | 5.5501          |
| 5.636         | 2.0669 | 96000  | 0.0324   | 5.5492          |
| 5.632         | 2.0769 | 97000  | 0.0323   | 5.5489          |
| 5.6133        | 2.0869 | 98000  | 0.0323   | 5.5479          |
| 5.6291        | 2.0969 | 99000  | 0.0323   | 5.5477          |
| 5.6271        | 2.1069 | 100000 | 0.0322   | 5.5470          |


### Framework versions

- Transformers 4.49.0.dev0
- Pytorch 2.5.1+cu121
- Datasets 3.2.0
- Tokenizers 0.21.0