Add files using upload-large-folder tool
Browse files- __pycache__/train_aramix_h100_full.cpython-313.pyc +0 -0
- aramix_h100/config.json +10 -0
- aramix_h100/model.pt +3 -0
- aramix_h100/model_best.pt +3 -0
- aramix_h100/qa_test_report_simple.json +134 -0
- aramix_h100/qa_test_report_simple.txt +80 -0
- aramix_h100/tokenizer_32k/tokenizer.json +0 -0
- aramix_h100/tokenizer_32k/tokenizer_config.json +9 -0
- aramix_h100/train_log.jsonl +44 -0
- aramix_h100/train_state.pt +3 -0
- donner +0 -0
- nlp_1b_h100_maxvram/config.json +10 -0
- nlp_1b_h100_maxvram/tokenizer_32k/tokenizer.json +0 -0
- nlp_1b_h100_maxvram/tokenizer_32k/tokenizer_config.json +9 -0
- nlp_1b_h100_opt/config.json +10 -0
- nlp_1b_h100_opt/model.pt +3 -0
- nlp_1b_h100_opt/model_best.pt +3 -0
- nlp_1b_h100_opt/tokenizer_32k/tokenizer.json +0 -0
- nlp_1b_h100_opt/tokenizer_32k/tokenizer_config.json +9 -0
- nlp_1b_h100_opt/train_state.pt +3 -0
- nlp_1b_wiki_en_fr_ar/config.json +10 -0
- nlp_1b_wiki_en_fr_ar/model_best.pt +3 -0
- nlp_1b_wiki_en_fr_ar/model_epoch_02.pt +3 -0
- nlp_1b_wiki_en_fr_ar/tokenizer_32k/tokenizer.json +0 -0
- nlp_1b_wiki_en_fr_ar/tokenizer_32k/tokenizer_config.json +9 -0
- simple_qa_test_aramix.py +504 -0
- simple_qa_test_aramix_v2.py +472 -0
- simple_qa_test_aramix_v3.py +583 -0
- simple_qa_test_finished_model (1).py +309 -0
- simple_qa_test_finished_model.py +309 -0
- test.py +428 -0
- top_p +0 -0
- train.py +859 -0
- train2.py +852 -0
- train_aramix_h100_full.py +1055 -0
- train_nlp_h100_maxvram_v6.py +1046 -0
- train_nlp_h100_maxvram_v7.py +805 -0
- upload.py +189 -0
- wikipedia_ar_h100/config.json +10 -0
- wikipedia_ar_h100/tokenizer_32k/tokenizer.json +0 -0
- wikipedia_ar_h100/tokenizer_32k/tokenizer_config.json +9 -0
- wikipedia_ar_h100/train_state.pt +3 -0
- wikipedia_ar_h100_agri_30gb/config.json +10 -0
- wikipedia_ar_h100_codealpaca/config.json +10 -0
- wikipedia_ar_h100_env_fr_ar_77gb/config.json +10 -0
- wikipedia_ar_h100_env_fr_ar_77gb/model_epoch_03.pt +3 -0
- wikipedia_ar_h100_multicode/config.json +10 -0
- wikipedia_ar_h100_multicode/train_state.pt +3 -0
- wikipedia_ar_h100_multicode_10x2000/config.json +10 -0
- wikipedia_ar_h100_multicode_10x2000/model_round_06.pt +3 -0
__pycache__/train_aramix_h100_full.cpython-313.pyc
ADDED
|
Binary file (48.4 kB). View file
|
|
|
aramix_h100/config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"vocab_size": 32000,
|
| 3 |
+
"block_size": 1024,
|
| 4 |
+
"d_model": 1024,
|
| 5 |
+
"n_heads": 16,
|
| 6 |
+
"n_layers": 24,
|
| 7 |
+
"d_ff": 4096,
|
| 8 |
+
"dropout": 0.0,
|
| 9 |
+
"use_checkpointing": false
|
| 10 |
+
}
|
aramix_h100/model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:78d6b75f8a8a079e5ba8b85def0c733dd808a075993fcb8fbe21a07350e0d8dc
|
| 3 |
+
size 5225851307
|
aramix_h100/model_best.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b1300c7623453b9d5a03fa6b2142ebe0f840676233ec122678429b8f2b7cd2ce
|
| 3 |
+
size 5225876625
|
aramix_h100/qa_test_report_simple.json
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"repo_dir": "/workspace/FirstChat/aramix_h100",
|
| 3 |
+
"train_script": "/workspace/FirstChat/train_aramix_h100_full.py",
|
| 4 |
+
"checkpoint": "/workspace/FirstChat/aramix_h100/model_best.pt",
|
| 5 |
+
"config_path": "/workspace/FirstChat/aramix_h100/config.json",
|
| 6 |
+
"tokenizer_dir": "/workspace/FirstChat/aramix_h100/tokenizer_32k",
|
| 7 |
+
"device": "cuda",
|
| 8 |
+
"total_questions": 10,
|
| 9 |
+
"avg_overlap_score": 0.1778,
|
| 10 |
+
"exact_match_rate": 0.0,
|
| 11 |
+
"avg_latency_s": 1.195,
|
| 12 |
+
"avg_words_generated": 28.6,
|
| 13 |
+
"scores_by_category": {
|
| 14 |
+
"Géographie": 0.25,
|
| 15 |
+
"Science": 0.05,
|
| 16 |
+
"Histoire": 0.0,
|
| 17 |
+
"Mathématiques": 0.0,
|
| 18 |
+
"Langage": 0.0,
|
| 19 |
+
"Raisonnement": 1.0,
|
| 20 |
+
"Dialogue": 0.0
|
| 21 |
+
},
|
| 22 |
+
"results": [
|
| 23 |
+
{
|
| 24 |
+
"id": 1,
|
| 25 |
+
"category": "Géographie",
|
| 26 |
+
"question": "Quelle est la capitale de la France ?",
|
| 27 |
+
"answer": "Au total, le personnage d'une société financière (dont le but à une institution) est une commission de la musique. Il a également considéré plusieurs sociétés, notamment les p",
|
| 28 |
+
"reference": "Paris",
|
| 29 |
+
"latency_s": 1.6,
|
| 30 |
+
"tokens_generated_approx": 28,
|
| 31 |
+
"overlap_score": 0.0,
|
| 32 |
+
"exact_match": false
|
| 33 |
+
},
|
| 34 |
+
{
|
| 35 |
+
"id": 2,
|
| 36 |
+
"category": "Géographie",
|
| 37 |
+
"question": "Quel est le plus long fleuve d'Afrique ?",
|
| 38 |
+
"answer": "Les petits chemins de fer français sont un type de classe moyenne de . Le groupe est le plus long, de poids lourds. Historique Le , le groupe est également connu comme secondaire d'une man",
|
| 39 |
+
"reference": "Le Nil",
|
| 40 |
+
"latency_s": 1.25,
|
| 41 |
+
"tokens_generated_approx": 35,
|
| 42 |
+
"overlap_score": 0.5,
|
| 43 |
+
"exact_match": false
|
| 44 |
+
},
|
| 45 |
+
{
|
| 46 |
+
"id": 3,
|
| 47 |
+
"category": "Science",
|
| 48 |
+
"question": "Qu'est-ce que la photosynthèse ?",
|
| 49 |
+
"answer": "Die verlore, die wat gevra het nie beslis nie! Chefsel en sy kragtige tempo van die opwindende akademiese medium. Hierdie photographies is gebaseer, asook hulle nie eens uitges",
|
| 50 |
+
"reference": "Processus par lequel les plantes convertissent la lumière en énergie",
|
| 51 |
+
"latency_s": 1.242,
|
| 52 |
+
"tokens_generated_approx": 28,
|
| 53 |
+
"overlap_score": 0.1,
|
| 54 |
+
"exact_match": false
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"id": 4,
|
| 58 |
+
"category": "Science",
|
| 59 |
+
"question": "Combien d'os compte le corps humain adulte ?",
|
| 60 |
+
"answer": "Die volgende plek is available, sal ons beslis nie opgelewing te koop nie.<eos>",
|
| 61 |
+
"reference": "206",
|
| 62 |
+
"latency_s": 0.518,
|
| 63 |
+
"tokens_generated_approx": 13,
|
| 64 |
+
"overlap_score": 0.0,
|
| 65 |
+
"exact_match": false
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"id": 5,
|
| 69 |
+
"category": "Histoire",
|
| 70 |
+
"question": "En quelle année a eu lieu la Révolution française ?",
|
| 71 |
+
"answer": "Les premiers élèves sont des trois fils de Jacques Salamine et Julien Fernandez. Leur nombre d'exemplaires est installée, confirmant la principale section du groupe de musique",
|
| 72 |
+
"reference": "1789",
|
| 73 |
+
"latency_s": 1.229,
|
| 74 |
+
"tokens_generated_approx": 26,
|
| 75 |
+
"overlap_score": 0.0,
|
| 76 |
+
"exact_match": false
|
| 77 |
+
},
|
| 78 |
+
{
|
| 79 |
+
"id": 6,
|
| 80 |
+
"category": "Histoire",
|
| 81 |
+
"question": "Qui a écrit Les Misérables ?",
|
| 82 |
+
"answer": "Historia et soutenue, L'Advention d'une partie de la mise en place des règnes du terme de cette matière. Le phénomène ou la présence d'un système de variable est écrit : «",
|
| 83 |
+
"reference": "Victor Hugo",
|
| 84 |
+
"latency_s": 1.22,
|
| 85 |
+
"tokens_generated_approx": 31,
|
| 86 |
+
"overlap_score": 0.0,
|
| 87 |
+
"exact_match": false
|
| 88 |
+
},
|
| 89 |
+
{
|
| 90 |
+
"id": 7,
|
| 91 |
+
"category": "Mathématiques",
|
| 92 |
+
"question": "Quelle est la formule de l'aire d'un cercle ?",
|
| 93 |
+
"answer": "Remarque : évoquant le premier au sens interdisant les règles différentes, on utilise les mots de type latin. Le cercle est un facteur plus ou moins électrique et une sorte de poids (",
|
| 94 |
+
"reference": "pi r carre",
|
| 95 |
+
"latency_s": 1.224,
|
| 96 |
+
"tokens_generated_approx": 33,
|
| 97 |
+
"overlap_score": 0.0,
|
| 98 |
+
"exact_match": false
|
| 99 |
+
},
|
| 100 |
+
{
|
| 101 |
+
"id": 8,
|
| 102 |
+
"category": "Langage",
|
| 103 |
+
"question": "Donne un synonyme du mot heureux.",
|
| 104 |
+
"answer": "le premier à créer une fois la relation de l'homme, qui serait naturellement réduite. Son grand-père est d'une famille enveloppant le chevalier (ou encore mythique). Les barrag",
|
| 105 |
+
"reference": "joyeux",
|
| 106 |
+
"latency_s": 1.228,
|
| 107 |
+
"tokens_generated_approx": 27,
|
| 108 |
+
"overlap_score": 0.0,
|
| 109 |
+
"exact_match": false
|
| 110 |
+
},
|
| 111 |
+
{
|
| 112 |
+
"id": 9,
|
| 113 |
+
"category": "Raisonnement",
|
| 114 |
+
"question": "Si j'ai 5 pommes et j'en donne 2, combien m'en reste-t-il ?",
|
| 115 |
+
"answer": "Loi de la section 1 à 3 par exemple : Dans les différentes factions qui vont à cette occasion, une section de terminaisons n'est pas connue. En janvier 1940, le site est introduit dans la m",
|
| 116 |
+
"reference": "3",
|
| 117 |
+
"latency_s": 1.226,
|
| 118 |
+
"tokens_generated_approx": 36,
|
| 119 |
+
"overlap_score": 1.0,
|
| 120 |
+
"exact_match": false
|
| 121 |
+
},
|
| 122 |
+
{
|
| 123 |
+
"id": 10,
|
| 124 |
+
"category": "Dialogue",
|
| 125 |
+
"question": "Comment vas-tu aujourd'hui ?",
|
| 126 |
+
"answer": "Die verskaffinge en maandelikse digitale oor hierdie raadslede is, is nie gebruik van die Vrystaatse kombuisies nie. Dit was die swart-regering in die omgewing aan die beste metode te",
|
| 127 |
+
"reference": null,
|
| 128 |
+
"latency_s": 1.212,
|
| 129 |
+
"tokens_generated_approx": 29,
|
| 130 |
+
"overlap_score": null,
|
| 131 |
+
"exact_match": null
|
| 132 |
+
}
|
| 133 |
+
]
|
| 134 |
+
}
|
aramix_h100/qa_test_report_simple.txt
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TEST QA SIMPLE — MODÈLE DÉJÀ ENTRAÎNÉ
|
| 2 |
+
============================================================
|
| 3 |
+
|
| 4 |
+
[01] Géographie
|
| 5 |
+
User : Quelle est la capitale de la France ?
|
| 6 |
+
Assistant : Au total, le personnage d'une société financière (dont le but à une institution) est une commission de la musique. Il a également considéré plusieurs sociétés, notamment les p
|
| 7 |
+
Référence : Paris
|
| 8 |
+
Overlap : 0%
|
| 9 |
+
ExactMatch: False
|
| 10 |
+
Latence : 1.6s
|
| 11 |
+
|
| 12 |
+
[02] Géographie
|
| 13 |
+
User : Quel est le plus long fleuve d'Afrique ?
|
| 14 |
+
Assistant : Les petits chemins de fer français sont un type de classe moyenne de . Le groupe est le plus long, de poids lourds. Historique Le , le groupe est également connu comme secondaire d'une man
|
| 15 |
+
Référence : Le Nil
|
| 16 |
+
Overlap : 50%
|
| 17 |
+
ExactMatch: False
|
| 18 |
+
Latence : 1.25s
|
| 19 |
+
|
| 20 |
+
[03] Science
|
| 21 |
+
User : Qu'est-ce que la photosynthèse ?
|
| 22 |
+
Assistant : Die verlore, die wat gevra het nie beslis nie! Chefsel en sy kragtige tempo van die opwindende akademiese medium. Hierdie photographies is gebaseer, asook hulle nie eens uitges
|
| 23 |
+
Référence : Processus par lequel les plantes convertissent la lumière en énergie
|
| 24 |
+
Overlap : 10%
|
| 25 |
+
ExactMatch: False
|
| 26 |
+
Latence : 1.242s
|
| 27 |
+
|
| 28 |
+
[04] Science
|
| 29 |
+
User : Combien d'os compte le corps humain adulte ?
|
| 30 |
+
Assistant : Die volgende plek is available, sal ons beslis nie opgelewing te koop nie.<eos>
|
| 31 |
+
Référence : 206
|
| 32 |
+
Overlap : 0%
|
| 33 |
+
ExactMatch: False
|
| 34 |
+
Latence : 0.518s
|
| 35 |
+
|
| 36 |
+
[05] Histoire
|
| 37 |
+
User : En quelle année a eu lieu la Révolution française ?
|
| 38 |
+
Assistant : Les premiers élèves sont des trois fils de Jacques Salamine et Julien Fernandez. Leur nombre d'exemplaires est installée, confirmant la principale section du groupe de musique
|
| 39 |
+
Référence : 1789
|
| 40 |
+
Overlap : 0%
|
| 41 |
+
ExactMatch: False
|
| 42 |
+
Latence : 1.229s
|
| 43 |
+
|
| 44 |
+
[06] Histoire
|
| 45 |
+
User : Qui a écrit Les Misérables ?
|
| 46 |
+
Assistant : Historia et soutenue, L'Advention d'une partie de la mise en place des règnes du terme de cette matière. Le phénomène ou la présence d'un système de variable est écrit : «
|
| 47 |
+
Référence : Victor Hugo
|
| 48 |
+
Overlap : 0%
|
| 49 |
+
ExactMatch: False
|
| 50 |
+
Latence : 1.22s
|
| 51 |
+
|
| 52 |
+
[07] Mathématiques
|
| 53 |
+
User : Quelle est la formule de l'aire d'un cercle ?
|
| 54 |
+
Assistant : Remarque : évoquant le premier au sens interdisant les règles différentes, on utilise les mots de type latin. Le cercle est un facteur plus ou moins électrique et une sorte de poids (
|
| 55 |
+
Référence : pi r carre
|
| 56 |
+
Overlap : 0%
|
| 57 |
+
ExactMatch: False
|
| 58 |
+
Latence : 1.224s
|
| 59 |
+
|
| 60 |
+
[08] Langage
|
| 61 |
+
User : Donne un synonyme du mot heureux.
|
| 62 |
+
Assistant : le premier à créer une fois la relation de l'homme, qui serait naturellement réduite. Son grand-père est d'une famille enveloppant le chevalier (ou encore mythique). Les barrag
|
| 63 |
+
Référence : joyeux
|
| 64 |
+
Overlap : 0%
|
| 65 |
+
ExactMatch: False
|
| 66 |
+
Latence : 1.228s
|
| 67 |
+
|
| 68 |
+
[09] Raisonnement
|
| 69 |
+
User : Si j'ai 5 pommes et j'en donne 2, combien m'en reste-t-il ?
|
| 70 |
+
Assistant : Loi de la section 1 à 3 par exemple : Dans les différentes factions qui vont à cette occasion, une section de terminaisons n'est pas connue. En janvier 1940, le site est introduit dans la m
|
| 71 |
+
Référence : 3
|
| 72 |
+
Overlap : 100%
|
| 73 |
+
ExactMatch: False
|
| 74 |
+
Latence : 1.226s
|
| 75 |
+
|
| 76 |
+
[10] Dialogue
|
| 77 |
+
User : Comment vas-tu aujourd'hui ?
|
| 78 |
+
Assistant : Die verskaffinge en maandelikse digitale oor hierdie raadslede is, is nie gebruik van die Vrystaatse kombuisies nie. Dit was die swart-regering in die omgewing aan die beste metode te
|
| 79 |
+
Latence : 1.212s
|
| 80 |
+
|
aramix_h100/tokenizer_32k/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
aramix_h100/tokenizer_32k/tokenizer_config.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"backend": "tokenizers",
|
| 3 |
+
"bos_token": "<bos>",
|
| 4 |
+
"eos_token": "<eos>",
|
| 5 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 6 |
+
"pad_token": "<pad>",
|
| 7 |
+
"tokenizer_class": "TokenizersBackend",
|
| 8 |
+
"unk_token": "<unk>"
|
| 9 |
+
}
|
aramix_h100/train_log.jsonl
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"step": 3050, "loss": 3.571237907409668, "ppl": 35.5605866297653, "lr": 0.00012589612961614406, "tok_s": 11144.981072882043, "time": 147.00788259506226}
|
| 2 |
+
{"step": 3100, "loss": 2.8578320932388306, "ppl": 17.423712977652347, "lr": 0.00012168934965573969, "tok_s": 57222.53834418063, "time": 175.63995742797852}
|
| 3 |
+
{"step": 3150, "loss": 2.83519070148468, "ppl": 17.03364833271062, "lr": 0.00011752894782453746, "tok_s": 57001.54441684658, "time": 204.3830382823944}
|
| 4 |
+
{"step": 3200, "loss": 2.9635899543762205, "ppl": 19.36737509653468, "lr": 0.00011341937918502494, "tok_s": 56964.73242461735, "time": 233.14469361305237}
|
| 5 |
+
{"step": 3250, "loss": 2.8826186561584475, "ppl": 17.860983768575373, "lr": 0.0001093650443662347, "tok_s": 57001.06640251769, "time": 261.88801550865173}
|
| 6 |
+
{"step": 3300, "loss": 2.703218548297882, "ppl": 14.927700012805772, "lr": 0.00010537028485144083, "tok_s": 60419.98038551336, "time": 289.00487303733826}
|
| 7 |
+
{"step": 3350, "loss": 2.745184621810913, "ppl": 15.56748776467281, "lr": 0.00010143937832918955, "tok_s": 57034.40571613923, "time": 317.7313930988312}
|
| 8 |
+
{"step": 3400, "loss": 2.7973192358016967, "ppl": 16.400621587712635, "lr": 9.757653411264333e-05, "tok_s": 61073.53656045087, "time": 344.55806946754456}
|
| 9 |
+
{"step": 3450, "loss": 2.7036573982238767, "ppl": 14.934252470519295, "lr": 9.378588863214297e-05, "tok_s": 56988.909283348396, "time": 373.30752301216125}
|
| 10 |
+
{"step": 3500, "loss": 2.5318048357963563, "ppl": 12.576183612317518, "lr": 9.007150100581427e-05, "tok_s": 57049.3650586449, "time": 402.02651047706604}
|
| 11 |
+
{"step": 3500, "val_loss": 2.7343945503234863, "val_ppl": 15.400416435550476, "per_domain": {"arabic_aramix": 2.896586099727042, "french_wiki": 2.638401048920083, "arabic_wiki": 2.9592110232303015, "math_stackexchange": 2.6955147483132103, "multilingual_cc": 2.75966284356334, "stories": 2.5633833775153527, "medical_pubmed": 2.8006927967071533, "medical_flashcards": 2.939271628856659}}
|
| 12 |
+
{"step": 3550, "loss": 2.622337305545807, "ppl": 13.767865716347135, "lr": 8.643734869296278e-05, "tok_s": 11633.872778271438, "time": 542.8566563129425}
|
| 13 |
+
{"step": 3600, "loss": 2.5885059988498687, "ppl": 13.30987177797193, "lr": 8.288732323491074e-05, "tok_s": 56913.92629532449, "time": 571.6439867019653}
|
| 14 |
+
{"step": 3650, "loss": 2.6797654819488526, "ppl": 14.581673229274589, "lr": 7.942522608783706e-05, "tok_s": 57007.48882356021, "time": 600.3840703964233}
|
| 15 |
+
{"step": 3700, "loss": 2.530478653907776, "ppl": 12.559516359730464, "lr": 7.605476455208276e-05, "tok_s": 56725.416636375594, "time": 629.2670667171478}
|
| 16 |
+
{"step": 3750, "loss": 2.569008586406708, "ppl": 13.052877224060135, "lr": 7.277954780228142e-05, "tok_s": 56930.38215870419, "time": 658.0460760593414}
|
| 17 |
+
{"step": 3800, "loss": 2.564538803100586, "ppl": 12.994663888764459, "lr": 6.960308302256383e-05, "tok_s": 56738.58728395681, "time": 686.922367811203}
|
| 18 |
+
{"step": 3850, "loss": 2.656132850646973, "ppl": 14.241109975260919, "lr": 6.652877165097785e-05, "tok_s": 56904.04773033927, "time": 715.7146956920624}
|
| 19 |
+
{"step": 3900, "loss": 2.603713195323944, "ppl": 13.51382445690133, "lr": 6.355990573714333e-05, "tok_s": 56984.61454772562, "time": 744.466315984726}
|
| 20 |
+
{"step": 3950, "loss": 2.4559333181381224, "ppl": 11.657308450637194, "lr": 6.069966441704281e-05, "tok_s": 56748.50220511147, "time": 773.3375625610352}
|
| 21 |
+
{"step": 4000, "loss": 2.4442277646064756, "ppl": 11.521648737554232, "lr": 5.795111050872301e-05, "tok_s": 59304.739097463025, "time": 800.9643597602844}
|
| 22 |
+
{"step": 4000, "val_loss": 2.6130046784877776, "val_ppl": 13.639973075654948, "per_domain": {"arabic_aramix": 2.794638446513438, "french_wiki": 2.522582699096084, "arabic_wiki": 2.849381112424951, "math_stackexchange": 2.5690312363884664, "multilingual_cc": 2.6081621836532247, "stories": 2.4209399079228495, "medical_pubmed": 2.6950340270996094, "medical_flashcards": 2.819118231534958}}
|
| 23 |
+
{"step": 4050, "loss": 2.616236004829407, "ppl": 13.684119567401414, "lr": 5.531718723255281e-05, "tok_s": 11053.809939264273, "time": 949.1847479343414}
|
| 24 |
+
{"step": 4100, "loss": 2.470917589664459, "ppl": 11.833299985286175, "lr": 5.280071505954885e-05, "tok_s": 56891.6055540174, "time": 977.9833726882935}
|
| 25 |
+
{"step": 4150, "loss": 2.51880806684494, "ppl": 12.413791432319185, "lr": 5.0404388691144755e-05, "tok_s": 56854.72699648875, "time": 1006.8006775379181}
|
| 26 |
+
{"step": 4200, "loss": 2.4818945503234864, "ppl": 11.96390918827006, "lr": 4.813077417363728e-05, "tok_s": 56728.39765714226, "time": 1035.682156085968}
|
| 27 |
+
{"step": 4250, "loss": 2.491965615749359, "ppl": 12.085007270264622, "lr": 4.5982306150399575e-05, "tok_s": 56775.80030404367, "time": 1064.5395212173462}
|
| 28 |
+
{"step": 4300, "loss": 2.5507710337638856, "ppl": 12.81698229994284, "lr": 4.3961285254804134e-05, "tok_s": 56859.33243706391, "time": 1093.3544919490814}
|
| 29 |
+
{"step": 4350, "loss": 2.5808962631225585, "ppl": 13.208971570046705, "lr": 4.206987564664711e-05, "tok_s": 58036.5193491192, "time": 1121.584992647171}
|
| 30 |
+
{"step": 4400, "loss": 2.7079242062568665, "ppl": 14.998110196375103, "lr": 4.031010269471151e-05, "tok_s": 57002.2214993296, "time": 1150.3277320861816}
|
| 31 |
+
{"step": 4450, "loss": 2.5953755283355715, "ppl": 13.401619104283894, "lr": 3.868385080795177e-05, "tok_s": 56785.49267694181, "time": 1179.1801717281342}
|
| 32 |
+
{"step": 4500, "loss": 2.5976397037506103, "ppl": 13.43199709835749, "lr": 3.71928614176214e-05, "tok_s": 56872.48957171817, "time": 1207.9884762763977}
|
| 33 |
+
{"step": 4500, "val_loss": 2.547382290661335, "val_ppl": 12.773622348939783, "per_domain": {"arabic_aramix": 2.7358595652868285, "french_wiki": 2.4578250470351537, "arabic_wiki": 2.7882657515375238, "math_stackexchange": 2.504125434702093, "multilingual_cc": 2.5334029245105656, "stories": 2.348106792994908, "medical_pubmed": 2.6415319442749023, "medical_flashcards": 2.757194072008133}}
|
| 34 |
+
{"step": 4550, "loss": 2.638249659538269, "ppl": 13.988697184009707, "lr": 3.583873111250479e-05, "tok_s": 11385.546618796197, "time": 1351.8902189731598}
|
| 35 |
+
{"step": 4600, "loss": 2.582155728340149, "ppl": 13.225618291081892, "lr": 3.462290992924992e-05, "tok_s": 56835.088159816194, "time": 1380.7174813747406}
|
| 36 |
+
{"step": 4650, "loss": 2.5022823572158814, "ppl": 12.210330518102197, "lr": 3.354669979963281e-05, "tok_s": 56918.493716463534, "time": 1409.5025017261505}
|
| 37 |
+
{"step": 4700, "loss": 2.4462553191185, "ppl": 11.545033207070272, "lr": 3.261125315641639e-05, "tok_s": 56923.2943292039, "time": 1438.285094499588}
|
| 38 |
+
{"step": 4750, "loss": 2.3780724477767943, "ppl": 10.784095909082684, "lr": 3.1817571699296604e-05, "tok_s": 57047.299242243054, "time": 1467.0051219463348}
|
| 39 |
+
{"step": 4800, "loss": 2.583693208694458, "ppl": 13.245968059053615, "lr": 3.116650532225727e-05, "tok_s": 56903.56946504383, "time": 1495.797691822052}
|
| 40 |
+
{"step": 4850, "loss": 2.5454819345474244, "ppl": 12.749370968040534, "lr": 3.065875120348237e-05, "tok_s": 56909.14470005539, "time": 1524.5874409675598}
|
| 41 |
+
{"step": 4900, "loss": 2.5480848932266236, "ppl": 12.782600282360375, "lr": 3.029485305880013e-05, "tok_s": 57001.02054004117, "time": 1553.3307859897614}
|
| 42 |
+
{"step": 4950, "loss": 2.5183346843719483, "ppl": 12.407916351721552, "lr": 3.007520055945856e-05, "tok_s": 56954.47416190712, "time": 1582.097621679306}
|
| 43 |
+
{"step": 5000, "loss": 2.488962812423706, "ppl": 12.048772799963553, "lr": 3.0000028914855615e-05, "tok_s": 56955.4739533805, "time": 1610.8639523983002}
|
| 44 |
+
{"step": 5000, "val_loss": 2.5137671560049055, "val_ppl": 12.351372073909298, "per_domain": {"arabic_aramix": 2.7070358647596118, "french_wiki": 2.427249958942895, "arabic_wiki": 2.755421937766828, "math_stackexchange": 2.469518039443276, "multilingual_cc": 2.489755442873998, "stories": 2.308292391536, "medical_pubmed": 2.6124343872070312, "medical_flashcards": 2.7211887538433075}}
|
aramix_h100/train_state.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dc13a38fb25354acd0eba8d965216a73cd9fc292a9fa767ad6f27924eb855ac9
|
| 3 |
+
size 5225877311
|
donner
ADDED
|
File without changes
|
nlp_1b_h100_maxvram/config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"vocab_size": 32000,
|
| 3 |
+
"block_size": 2048,
|
| 4 |
+
"d_model": 1536,
|
| 5 |
+
"n_heads": 24,
|
| 6 |
+
"n_layers": 24,
|
| 7 |
+
"d_ff": 6144,
|
| 8 |
+
"dropout": 0.0,
|
| 9 |
+
"use_checkpointing": false
|
| 10 |
+
}
|
nlp_1b_h100_maxvram/tokenizer_32k/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
nlp_1b_h100_maxvram/tokenizer_32k/tokenizer_config.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"backend": "tokenizers",
|
| 3 |
+
"bos_token": "<bos>",
|
| 4 |
+
"eos_token": "<eos>",
|
| 5 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 6 |
+
"pad_token": "<pad>",
|
| 7 |
+
"tokenizer_class": "TokenizersBackend",
|
| 8 |
+
"unk_token": "<unk>"
|
| 9 |
+
}
|
nlp_1b_h100_opt/config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"vocab_size": 32000,
|
| 3 |
+
"block_size": 1024,
|
| 4 |
+
"d_model": 1536,
|
| 5 |
+
"n_heads": 24,
|
| 6 |
+
"n_layers": 24,
|
| 7 |
+
"d_ff": 6144,
|
| 8 |
+
"dropout": 0.0,
|
| 9 |
+
"use_checkpointing": false
|
| 10 |
+
}
|
nlp_1b_h100_opt/model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:404b4fea027e28fbcabe19077a7ffebb8830de27ca94db6908992d55dcd85e6d
|
| 3 |
+
size 4415622541
|
nlp_1b_h100_opt/model_best.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:77b5bc51912e146de6f8855909bc84564dc6c30daa361613c88a11cc41ceb049
|
| 3 |
+
size 4415675901
|
nlp_1b_h100_opt/tokenizer_32k/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
nlp_1b_h100_opt/tokenizer_32k/tokenizer_config.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"backend": "tokenizers",
|
| 3 |
+
"bos_token": "<bos>",
|
| 4 |
+
"eos_token": "<eos>",
|
| 5 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 6 |
+
"pad_token": "<pad>",
|
| 7 |
+
"tokenizer_class": "TokenizersBackend",
|
| 8 |
+
"unk_token": "<unk>"
|
| 9 |
+
}
|
nlp_1b_h100_opt/train_state.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a749a6a3258c1a8d1bd9d27d4f1bf0e45454c70bf598e88741b471b8f2afa088
|
| 3 |
+
size 4415677037
|
nlp_1b_wiki_en_fr_ar/config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"vocab_size": 32000,
|
| 3 |
+
"block_size": 1024,
|
| 4 |
+
"d_model": 1536,
|
| 5 |
+
"n_heads": 24,
|
| 6 |
+
"n_layers": 24,
|
| 7 |
+
"d_ff": 6144,
|
| 8 |
+
"dropout": 0.0,
|
| 9 |
+
"use_checkpointing": false
|
| 10 |
+
}
|
nlp_1b_wiki_en_fr_ar/model_best.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dc91b83ddaf6edf5e20142465b16b89b4505fc04eb5ef26680dae6839c030118
|
| 3 |
+
size 11462571709
|
nlp_1b_wiki_en_fr_ar/model_epoch_02.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:243fcb18b86fdcc825307bc552f4392aee2a996217a48e56e650c3fd00257fd3
|
| 3 |
+
size 11462574453
|
nlp_1b_wiki_en_fr_ar/tokenizer_32k/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
nlp_1b_wiki_en_fr_ar/tokenizer_32k/tokenizer_config.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"backend": "tokenizers",
|
| 3 |
+
"bos_token": "<bos>",
|
| 4 |
+
"eos_token": "<eos>",
|
| 5 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 6 |
+
"pad_token": "<pad>",
|
| 7 |
+
"tokenizer_class": "TokenizersBackend",
|
| 8 |
+
"unk_token": "<unk>"
|
| 9 |
+
}
|
simple_qa_test_aramix.py
ADDED
|
@@ -0,0 +1,504 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
simple_qa_test_aramix.py
|
| 6 |
+
|
| 7 |
+
Test QA simple pour un modèle déjà entraîné dans une repo de type :
|
| 8 |
+
- train_aramix_h100_full.py
|
| 9 |
+
- aramix_h100/
|
| 10 |
+
- config.json
|
| 11 |
+
- model_best.pt
|
| 12 |
+
- model.pt
|
| 13 |
+
- tokenizer_32k/
|
| 14 |
+
|
| 15 |
+
Hypothèses alignées avec ton repo :
|
| 16 |
+
- le module d'entraînement expose : GPT, GPTConfig, train_or_load_tokenizer,
|
| 17 |
+
load_checkpoint, DOMAINS
|
| 18 |
+
- le tokenizer est géré par train_or_load_tokenizer(DOMAINS)
|
| 19 |
+
- le checkpoint se recharge avec load_checkpoint(model, opt, ckpt_path, device)
|
| 20 |
+
|
| 21 |
+
Usage
|
| 22 |
+
-----
|
| 23 |
+
python simple_qa_test_aramix.py
|
| 24 |
+
python simple_qa_test_aramix.py --repo_dir ./aramix_h100
|
| 25 |
+
python simple_qa_test_aramix.py --ckpt ./aramix_h100/model.pt
|
| 26 |
+
python simple_qa_test_aramix.py --questions qa_questions.json
|
| 27 |
+
python simple_qa_test_aramix.py --max_new_tokens 96 --temperature 0.4 --top_k 40
|
| 28 |
+
python simple_qa_test_aramix.py --save_report
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
from __future__ import annotations
|
| 32 |
+
|
| 33 |
+
import argparse
|
| 34 |
+
import importlib.util
|
| 35 |
+
import json
|
| 36 |
+
import os
|
| 37 |
+
import re
|
| 38 |
+
import sys
|
| 39 |
+
import time
|
| 40 |
+
import unicodedata
|
| 41 |
+
from pathlib import Path
|
| 42 |
+
from typing import Any, Dict, List, Optional
|
| 43 |
+
|
| 44 |
+
import torch
|
| 45 |
+
import torch.nn.functional as F
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
DEFAULT_QUESTIONS = [
|
| 49 |
+
{
|
| 50 |
+
"category": "Géographie",
|
| 51 |
+
"question": "Quelle est la capitale de la France ?",
|
| 52 |
+
"reference": "Paris",
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"category": "Géographie",
|
| 56 |
+
"question": "Quel est le plus long fleuve d'Afrique ?",
|
| 57 |
+
"reference": "Le Nil",
|
| 58 |
+
},
|
| 59 |
+
{
|
| 60 |
+
"category": "Science",
|
| 61 |
+
"question": "Qu'est-ce que la photosynthèse ?",
|
| 62 |
+
"reference": "Processus par lequel les plantes convertissent la lumière en énergie",
|
| 63 |
+
},
|
| 64 |
+
{
|
| 65 |
+
"category": "Science",
|
| 66 |
+
"question": "Combien d'os compte le corps humain adulte ?",
|
| 67 |
+
"reference": "206",
|
| 68 |
+
},
|
| 69 |
+
{
|
| 70 |
+
"category": "Histoire",
|
| 71 |
+
"question": "En quelle année a eu lieu la Révolution française ?",
|
| 72 |
+
"reference": "1789",
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"category": "Histoire",
|
| 76 |
+
"question": "Qui a écrit Les Misérables ?",
|
| 77 |
+
"reference": "Victor Hugo",
|
| 78 |
+
},
|
| 79 |
+
{
|
| 80 |
+
"category": "Mathématiques",
|
| 81 |
+
"question": "Quelle est la formule de l'aire d'un cercle ?",
|
| 82 |
+
"reference": "pi r carre",
|
| 83 |
+
},
|
| 84 |
+
{
|
| 85 |
+
"category": "Langage",
|
| 86 |
+
"question": "Donne un synonyme du mot heureux.",
|
| 87 |
+
"reference": "joyeux",
|
| 88 |
+
},
|
| 89 |
+
{
|
| 90 |
+
"category": "Raisonnement",
|
| 91 |
+
"question": "Si j'ai 5 pommes et j'en donne 2, combien m'en reste-t-il ?",
|
| 92 |
+
"reference": "3",
|
| 93 |
+
},
|
| 94 |
+
{
|
| 95 |
+
"category": "Dialogue",
|
| 96 |
+
"question": "Comment vas-tu aujourd'hui ?",
|
| 97 |
+
"reference": None,
|
| 98 |
+
},
|
| 99 |
+
]
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def load_module_from_file(py_path: Path):
|
| 103 |
+
spec = importlib.util.spec_from_file_location(py_path.stem, py_path)
|
| 104 |
+
if spec is None or spec.loader is None:
|
| 105 |
+
raise RuntimeError(f"Impossible de charger le module: {py_path}")
|
| 106 |
+
module = importlib.util.module_from_spec(spec)
|
| 107 |
+
sys.modules[py_path.stem] = module
|
| 108 |
+
spec.loader.exec_module(module)
|
| 109 |
+
return module
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def normalize_text(text: str) -> str:
|
| 113 |
+
text = (text or "").strip().lower()
|
| 114 |
+
text = unicodedata.normalize("NFKD", text)
|
| 115 |
+
text = "".join(ch for ch in text if not unicodedata.combining(ch))
|
| 116 |
+
text = text.replace("π", "pi")
|
| 117 |
+
text = re.sub(r"[\W_]+", " ", text, flags=re.UNICODE)
|
| 118 |
+
text = re.sub(r"\s+", " ", text).strip()
|
| 119 |
+
return text
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def lexical_overlap(reference: Optional[str], answer: str) -> Optional[float]:
|
| 123 |
+
if not reference:
|
| 124 |
+
return None
|
| 125 |
+
ref = set(normalize_text(reference).split())
|
| 126 |
+
ans = set(normalize_text(answer).split())
|
| 127 |
+
if not ref:
|
| 128 |
+
return None
|
| 129 |
+
return len(ref & ans) / len(ref)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def exact_match(reference: Optional[str], answer: str) -> Optional[bool]:
|
| 133 |
+
if not reference:
|
| 134 |
+
return None
|
| 135 |
+
return normalize_text(reference) == normalize_text(answer)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def infer_repo_defaults(repo_dir: Path):
|
| 139 |
+
train_script = repo_dir.parent / "train_aramix_h100_full.py"
|
| 140 |
+
if not train_script.exists():
|
| 141 |
+
train_script = repo_dir / "train_aramix_h100_full.py"
|
| 142 |
+
|
| 143 |
+
ckpt = repo_dir / "model_best.pt"
|
| 144 |
+
if not ckpt.exists():
|
| 145 |
+
ckpt = repo_dir / "model.pt"
|
| 146 |
+
|
| 147 |
+
config = repo_dir / "config.json"
|
| 148 |
+
tokenizer_dir = repo_dir / "tokenizer_32k"
|
| 149 |
+
return train_script, ckpt, config, tokenizer_dir
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def safe_get(cfg: Dict[str, Any], *names: str, default=None):
|
| 153 |
+
for name in names:
|
| 154 |
+
if name in cfg:
|
| 155 |
+
return cfg[name]
|
| 156 |
+
return default
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def build_model_config_dict(cfg_json: Dict[str, Any], vocab_size: int) -> Dict[str, Any]:
|
| 160 |
+
block_size = safe_get(cfg_json, "block_size", "max_seq_len", "seq_len", default=512)
|
| 161 |
+
d_model = safe_get(cfg_json, "d_model", "n_embd", "dim", default=768)
|
| 162 |
+
n_heads = safe_get(cfg_json, "n_heads", "n_head", "num_heads", default=12)
|
| 163 |
+
n_layers = safe_get(cfg_json, "n_layers", "n_layer", "num_layers", default=12)
|
| 164 |
+
d_ff = safe_get(cfg_json, "d_ff", "ffn_dim", "intermediate_size", default=d_model * 4)
|
| 165 |
+
|
| 166 |
+
return {
|
| 167 |
+
"vocab_size": vocab_size,
|
| 168 |
+
"block_size": int(block_size),
|
| 169 |
+
"d_model": int(d_model),
|
| 170 |
+
"n_heads": int(n_heads),
|
| 171 |
+
"n_layers": int(n_layers),
|
| 172 |
+
"d_ff": int(d_ff),
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class AramixChatTester:
|
| 177 |
+
def __init__(
|
| 178 |
+
self,
|
| 179 |
+
repo_dir: Path,
|
| 180 |
+
train_script: Path,
|
| 181 |
+
ckpt_path: Path,
|
| 182 |
+
config_path: Path,
|
| 183 |
+
device: Optional[str] = None,
|
| 184 |
+
):
|
| 185 |
+
self.repo_dir = repo_dir
|
| 186 |
+
self.train_script = train_script
|
| 187 |
+
self.ckpt_path = ckpt_path
|
| 188 |
+
self.config_path = config_path
|
| 189 |
+
|
| 190 |
+
self.device = torch.device(
|
| 191 |
+
device if device is not None else ("cuda" if torch.cuda.is_available() else "cpu")
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
self.M = load_module_from_file(self.train_script)
|
| 195 |
+
|
| 196 |
+
required = ["GPT", "GPTConfig", "train_or_load_tokenizer", "load_checkpoint", "DOMAINS"]
|
| 197 |
+
missing = [x for x in required if not hasattr(self.M, x)]
|
| 198 |
+
if missing:
|
| 199 |
+
raise RuntimeError(
|
| 200 |
+
f"Le fichier {self.train_script.name} ne contient pas les symboles attendus: {missing}"
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
self.cfg_json: Dict[str, Any] = {}
|
| 204 |
+
if self.config_path.exists():
|
| 205 |
+
with open(self.config_path, "r", encoding="utf-8") as f:
|
| 206 |
+
self.cfg_json = json.load(f)
|
| 207 |
+
|
| 208 |
+
self.tokenizer = self._load_tokenizer()
|
| 209 |
+
self.model = self._load_model()
|
| 210 |
+
|
| 211 |
+
def _load_tokenizer(self):
|
| 212 |
+
old_cwd = Path.cwd()
|
| 213 |
+
try:
|
| 214 |
+
os.chdir(self.repo_dir.parent)
|
| 215 |
+
tok = self.M.train_or_load_tokenizer(self.M.DOMAINS)
|
| 216 |
+
finally:
|
| 217 |
+
os.chdir(old_cwd)
|
| 218 |
+
return tok
|
| 219 |
+
|
| 220 |
+
def _make_gpt_config(self):
|
| 221 |
+
kwargs = build_model_config_dict(self.cfg_json, vocab_size=len(self.tokenizer))
|
| 222 |
+
try:
|
| 223 |
+
return self.M.GPTConfig(**kwargs)
|
| 224 |
+
except TypeError:
|
| 225 |
+
return self.M.GPTConfig(vocab_size=len(self.tokenizer))
|
| 226 |
+
|
| 227 |
+
def _load_model(self):
|
| 228 |
+
cfg = self._make_gpt_config()
|
| 229 |
+
model = self.M.GPT(cfg).to(self.device)
|
| 230 |
+
|
| 231 |
+
try:
|
| 232 |
+
self.M.load_checkpoint(model, None, self.ckpt_path, self.device)
|
| 233 |
+
except TypeError:
|
| 234 |
+
try:
|
| 235 |
+
self.M.load_checkpoint(model, self.ckpt_path, self.device)
|
| 236 |
+
except TypeError:
|
| 237 |
+
ckpt = torch.load(self.ckpt_path, map_location=self.device)
|
| 238 |
+
state = ckpt["model"] if isinstance(ckpt, dict) and "model" in ckpt else ckpt
|
| 239 |
+
if any(k.startswith("_orig_mod.") for k in state):
|
| 240 |
+
state = {k.replace("_orig_mod.", ""): v for k, v in state.items()}
|
| 241 |
+
model.load_state_dict(state, strict=False)
|
| 242 |
+
|
| 243 |
+
model.eval()
|
| 244 |
+
return model
|
| 245 |
+
|
| 246 |
+
def encode_prompt(self, question: str) -> List[int]:
|
| 247 |
+
bos = getattr(self.tokenizer, "bos_token_id", None)
|
| 248 |
+
eos = getattr(self.tokenizer, "eos_token_id", None)
|
| 249 |
+
|
| 250 |
+
prompt = f"Question: {question}\nRéponse:"
|
| 251 |
+
ids = self.tokenizer.encode(prompt, add_special_tokens=False)
|
| 252 |
+
|
| 253 |
+
if bos is not None:
|
| 254 |
+
ids = [bos] + ids
|
| 255 |
+
if eos is not None and len(ids) > 0 and ids[-1] == eos:
|
| 256 |
+
ids = ids[:-1]
|
| 257 |
+
return ids
|
| 258 |
+
|
| 259 |
+
@torch.no_grad()
|
| 260 |
+
def generate(
|
| 261 |
+
self,
|
| 262 |
+
question: str,
|
| 263 |
+
max_new_tokens: int = 96,
|
| 264 |
+
temperature: float = 0.4,
|
| 265 |
+
top_k: int = 40,
|
| 266 |
+
repetition_penalty: float = 1.12,
|
| 267 |
+
) -> str:
|
| 268 |
+
ids = self.encode_prompt(question)
|
| 269 |
+
x = torch.tensor([ids], dtype=torch.long, device=self.device)
|
| 270 |
+
|
| 271 |
+
eos_id = getattr(self.tokenizer, "eos_token_id", None)
|
| 272 |
+
block_size = getattr(getattr(self.model, "cfg", None), "block_size", None)
|
| 273 |
+
if block_size is None:
|
| 274 |
+
block_size = safe_get(self.cfg_json, "block_size", "max_seq_len", default=512)
|
| 275 |
+
|
| 276 |
+
for step in range(max_new_tokens):
|
| 277 |
+
x_ctx = x[:, -int(block_size):]
|
| 278 |
+
|
| 279 |
+
try:
|
| 280 |
+
logits, _ = self.model(x_ctx)
|
| 281 |
+
except TypeError:
|
| 282 |
+
out = self.model(x_ctx)
|
| 283 |
+
logits = out[0] if isinstance(out, tuple) else out
|
| 284 |
+
|
| 285 |
+
logits = logits[:, -1, :]
|
| 286 |
+
|
| 287 |
+
recent = x[0, -64:].tolist()
|
| 288 |
+
for tok in set(recent):
|
| 289 |
+
logits[0, tok] /= repetition_penalty
|
| 290 |
+
|
| 291 |
+
if temperature <= 0:
|
| 292 |
+
next_tok = torch.argmax(logits, dim=-1, keepdim=True)
|
| 293 |
+
else:
|
| 294 |
+
logits = logits / max(temperature, 1e-5)
|
| 295 |
+
|
| 296 |
+
if top_k is not None and top_k > 0:
|
| 297 |
+
values, _ = torch.topk(logits, k=min(top_k, logits.size(-1)))
|
| 298 |
+
kth = values[:, -1].unsqueeze(-1)
|
| 299 |
+
logits = torch.where(logits < kth, torch.full_like(logits, float("-inf")), logits)
|
| 300 |
+
|
| 301 |
+
probs = F.softmax(logits, dim=-1)
|
| 302 |
+
next_tok = torch.multinomial(probs, num_samples=1)
|
| 303 |
+
|
| 304 |
+
x = torch.cat([x, next_tok], dim=1)
|
| 305 |
+
|
| 306 |
+
if eos_id is not None and next_tok.item() == eos_id and step >= 2:
|
| 307 |
+
break
|
| 308 |
+
|
| 309 |
+
new_ids = x[0, len(ids):].tolist()
|
| 310 |
+
text = self.tokenizer.decode(new_ids).strip()
|
| 311 |
+
text = re.sub(r"\s+", " ", text).strip()
|
| 312 |
+
text = text.replace("Réponse :", "").replace("Réponse:", "").strip()
|
| 313 |
+
return text
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def load_questions(path: Optional[str]) -> List[Dict[str, Any]]:
|
| 317 |
+
if not path:
|
| 318 |
+
return DEFAULT_QUESTIONS
|
| 319 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 320 |
+
data = json.load(f)
|
| 321 |
+
if not isinstance(data, list):
|
| 322 |
+
raise ValueError("Le fichier questions doit contenir une liste JSON.")
|
| 323 |
+
return data
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def format_bar(score: float, width: int = 20) -> str:
|
| 327 |
+
n = max(0, min(width, int(round(score * width))))
|
| 328 |
+
return "█" * n + "░" * (width - n)
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def save_reports(output_dir: Path, summary: Dict[str, Any]) -> None:
|
| 332 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 333 |
+
|
| 334 |
+
json_path = output_dir / "qa_test_report_simple.json"
|
| 335 |
+
txt_path = output_dir / "qa_test_report_simple.txt"
|
| 336 |
+
|
| 337 |
+
with open(json_path, "w", encoding="utf-8") as f:
|
| 338 |
+
json.dump(summary, f, ensure_ascii=False, indent=2)
|
| 339 |
+
|
| 340 |
+
with open(txt_path, "w", encoding="utf-8") as f:
|
| 341 |
+
f.write("TEST QA SIMPLE — MODÈLE DÉJÀ ENTRAÎNÉ\n")
|
| 342 |
+
f.write("=" * 60 + "\n\n")
|
| 343 |
+
for r in summary["results"]:
|
| 344 |
+
f.write(f"[{r['id']:02d}] {r['category']}\n")
|
| 345 |
+
f.write(f" User : {r['question']}\n")
|
| 346 |
+
f.write(f" Assistant : {r['answer']}\n")
|
| 347 |
+
if r["reference"]:
|
| 348 |
+
f.write(f" Référence : {r['reference']}\n")
|
| 349 |
+
if r["overlap_score"] is not None:
|
| 350 |
+
f.write(f" Overlap : {r['overlap_score']:.0%}\n")
|
| 351 |
+
if r["exact_match"] is not None:
|
| 352 |
+
f.write(f" ExactMatch: {r['exact_match']}\n")
|
| 353 |
+
f.write(f" Latence : {r['latency_s']}s\n\n")
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def main():
|
| 357 |
+
parser = argparse.ArgumentParser("Test QA simple pour modèle Aramix déjà entraîné")
|
| 358 |
+
parser.add_argument("--repo_dir", type=str, default="./aramix_h100")
|
| 359 |
+
parser.add_argument("--train_script", type=str, default=None)
|
| 360 |
+
parser.add_argument("--ckpt", type=str, default=None)
|
| 361 |
+
parser.add_argument("--config", type=str, default=None)
|
| 362 |
+
parser.add_argument("--questions", type=str, default=None)
|
| 363 |
+
parser.add_argument("--max_new_tokens", type=int, default=96)
|
| 364 |
+
parser.add_argument("--temperature", type=float, default=0.4)
|
| 365 |
+
parser.add_argument("--top_k", type=int, default=40)
|
| 366 |
+
parser.add_argument("--repetition_penalty", type=float, default=1.12)
|
| 367 |
+
parser.add_argument("--device", type=str, default=None)
|
| 368 |
+
parser.add_argument("--save_report", action="store_true")
|
| 369 |
+
args = parser.parse_args()
|
| 370 |
+
|
| 371 |
+
repo_dir = Path(args.repo_dir).resolve()
|
| 372 |
+
train_script, ckpt_path, config_path, tokenizer_dir = infer_repo_defaults(repo_dir)
|
| 373 |
+
|
| 374 |
+
if args.train_script:
|
| 375 |
+
train_script = Path(args.train_script).resolve()
|
| 376 |
+
if args.ckpt:
|
| 377 |
+
ckpt_path = Path(args.ckpt).resolve()
|
| 378 |
+
if args.config:
|
| 379 |
+
config_path = Path(args.config).resolve()
|
| 380 |
+
|
| 381 |
+
if not train_script.exists():
|
| 382 |
+
raise FileNotFoundError(f"Script train introuvable: {train_script}")
|
| 383 |
+
if not ckpt_path.exists():
|
| 384 |
+
raise FileNotFoundError(f"Checkpoint introuvable: {ckpt_path}")
|
| 385 |
+
if not config_path.exists():
|
| 386 |
+
print(f"[WARN] config.json introuvable: {config_path} — fallback sur GPTConfig(vocab_size=...).")
|
| 387 |
+
|
| 388 |
+
questions = load_questions(args.questions)
|
| 389 |
+
tester = AramixChatTester(
|
| 390 |
+
repo_dir=repo_dir,
|
| 391 |
+
train_script=train_script,
|
| 392 |
+
ckpt_path=ckpt_path,
|
| 393 |
+
config_path=config_path,
|
| 394 |
+
device=args.device,
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
results: List[Dict[str, Any]] = []
|
| 398 |
+
categories: Dict[str, List[Dict[str, Any]]] = {}
|
| 399 |
+
|
| 400 |
+
print("\n" + "═" * 70)
|
| 401 |
+
print(" TEST QA SIMPLE — MODÈLE DÉJÀ ENTRAÎNÉ")
|
| 402 |
+
print("═" * 70)
|
| 403 |
+
print(f"Repo : {repo_dir}")
|
| 404 |
+
print(f"Train script: {train_script}")
|
| 405 |
+
print(f"Checkpoint : {ckpt_path}")
|
| 406 |
+
print(f"Config : {config_path}")
|
| 407 |
+
print(f"Tokenizer : {tokenizer_dir}")
|
| 408 |
+
print(f"Device : {tester.device}")
|
| 409 |
+
print(f"Questions : {len(questions)}")
|
| 410 |
+
print("═" * 70 + "\n")
|
| 411 |
+
|
| 412 |
+
for i, item in enumerate(questions, 1):
|
| 413 |
+
q = item["question"]
|
| 414 |
+
ref = item.get("reference")
|
| 415 |
+
cat = item.get("category", "Général")
|
| 416 |
+
|
| 417 |
+
t0 = time.time()
|
| 418 |
+
ans = tester.generate(
|
| 419 |
+
q,
|
| 420 |
+
max_new_tokens=args.max_new_tokens,
|
| 421 |
+
temperature=args.temperature,
|
| 422 |
+
top_k=args.top_k,
|
| 423 |
+
repetition_penalty=args.repetition_penalty,
|
| 424 |
+
)
|
| 425 |
+
latency = time.time() - t0
|
| 426 |
+
|
| 427 |
+
overlap = lexical_overlap(ref, ans)
|
| 428 |
+
em = exact_match(ref, ans)
|
| 429 |
+
|
| 430 |
+
entry = {
|
| 431 |
+
"id": i,
|
| 432 |
+
"category": cat,
|
| 433 |
+
"question": q,
|
| 434 |
+
"answer": ans,
|
| 435 |
+
"reference": ref,
|
| 436 |
+
"latency_s": round(latency, 3),
|
| 437 |
+
"tokens_generated_approx": len(ans.split()),
|
| 438 |
+
"overlap_score": None if overlap is None else round(overlap, 4),
|
| 439 |
+
"exact_match": em,
|
| 440 |
+
}
|
| 441 |
+
results.append(entry)
|
| 442 |
+
categories.setdefault(cat, []).append(entry)
|
| 443 |
+
|
| 444 |
+
overlap_str = f"{overlap:.0%}" if overlap is not None else "n/a"
|
| 445 |
+
em_str = "✓" if em else ("✗" if em is not None else "n/a")
|
| 446 |
+
|
| 447 |
+
print("─" * 70)
|
| 448 |
+
print(f"[{i:02d}] [{cat}] overlap={overlap_str} | EM={em_str}")
|
| 449 |
+
print(f" User : {q}")
|
| 450 |
+
print(f" Assistant : {ans}")
|
| 451 |
+
if ref:
|
| 452 |
+
print(f" Référence : {ref}")
|
| 453 |
+
print(f" ⏱ {latency:.2f}s | ~{entry['tokens_generated_approx']} mots\n")
|
| 454 |
+
|
| 455 |
+
scored_overlap = [r["overlap_score"] for r in results if r["overlap_score"] is not None]
|
| 456 |
+
scored_em = [r["exact_match"] for r in results if r["exact_match"] is not None]
|
| 457 |
+
|
| 458 |
+
avg_overlap = sum(scored_overlap) / len(scored_overlap) if scored_overlap else 0.0
|
| 459 |
+
em_rate = sum(1 for x in scored_em if x) / len(scored_em) if scored_em else 0.0
|
| 460 |
+
avg_latency = sum(r["latency_s"] for r in results) / len(results) if results else 0.0
|
| 461 |
+
avg_words = sum(r["tokens_generated_approx"] for r in results) / len(results) if results else 0.0
|
| 462 |
+
|
| 463 |
+
cat_scores: Dict[str, float] = {}
|
| 464 |
+
for cat, items in categories.items():
|
| 465 |
+
vals = [r["overlap_score"] for r in items if r["overlap_score"] is not None]
|
| 466 |
+
cat_scores[cat] = (sum(vals) / len(vals)) if vals else 0.0
|
| 467 |
+
|
| 468 |
+
summary = {
|
| 469 |
+
"repo_dir": str(repo_dir),
|
| 470 |
+
"train_script": str(train_script),
|
| 471 |
+
"checkpoint": str(ckpt_path),
|
| 472 |
+
"config_path": str(config_path),
|
| 473 |
+
"tokenizer_dir": str(tokenizer_dir),
|
| 474 |
+
"device": str(tester.device),
|
| 475 |
+
"total_questions": len(results),
|
| 476 |
+
"avg_overlap_score": round(avg_overlap, 4),
|
| 477 |
+
"exact_match_rate": round(em_rate, 4),
|
| 478 |
+
"avg_latency_s": round(avg_latency, 3),
|
| 479 |
+
"avg_words_generated": round(avg_words, 1),
|
| 480 |
+
"scores_by_category": {k: round(v, 4) for k, v in cat_scores.items()},
|
| 481 |
+
"results": results,
|
| 482 |
+
}
|
| 483 |
+
|
| 484 |
+
print("═" * 70)
|
| 485 |
+
print(" RÉSUMÉ")
|
| 486 |
+
print("═" * 70)
|
| 487 |
+
print(f"Questions testées : {len(results)}")
|
| 488 |
+
print(f"Overlap moyen : {avg_overlap:.1%}")
|
| 489 |
+
print(f"Exact match : {em_rate:.1%}")
|
| 490 |
+
print(f"Latence moyenne : {avg_latency:.2f}s")
|
| 491 |
+
print(f"Mots moyens : {avg_words:.1f}")
|
| 492 |
+
print("Scores / catégorie:")
|
| 493 |
+
for cat, score in sorted(cat_scores.items()):
|
| 494 |
+
print(f" {cat:<15} {format_bar(score)} {score:.0%}")
|
| 495 |
+
print("═" * 70)
|
| 496 |
+
|
| 497 |
+
if args.save_report:
|
| 498 |
+
save_reports(repo_dir, summary)
|
| 499 |
+
print(f"Rapports sauvegardés dans : {repo_dir / 'qa_test_report_simple.json'}")
|
| 500 |
+
print(f" {repo_dir / 'qa_test_report_simple.txt'}")
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
if __name__ == "__main__":
|
| 504 |
+
main()
|
simple_qa_test_aramix_v2.py
ADDED
|
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
simple_qa_test_aramix_v2.py
|
| 6 |
+
|
| 7 |
+
Version corrigée du test QA simple pour repo Aramix.
|
| 8 |
+
Correction principale :
|
| 9 |
+
- gère les checkpoints sauvés depuis torch.compile() avec préfixe "_orig_mod."
|
| 10 |
+
- contourne load_checkpoint() du script train si celui-ci échoue sur ce cas
|
| 11 |
+
|
| 12 |
+
Usage
|
| 13 |
+
-----
|
| 14 |
+
python simple_qa_test_aramix_v2.py --repo_dir ./aramix_h100 --save_report
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
import importlib.util
|
| 21 |
+
import json
|
| 22 |
+
import os
|
| 23 |
+
import re
|
| 24 |
+
import sys
|
| 25 |
+
import time
|
| 26 |
+
import unicodedata
|
| 27 |
+
from pathlib import Path
|
| 28 |
+
from typing import Any, Dict, List, Optional
|
| 29 |
+
|
| 30 |
+
import torch
|
| 31 |
+
import torch.nn.functional as F
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
DEFAULT_QUESTIONS = [
|
| 35 |
+
{"category": "Géographie", "question": "Quelle est la capitale de la France ?", "reference": "Paris"},
|
| 36 |
+
{"category": "Géographie", "question": "Quel est le plus long fleuve d'Afrique ?", "reference": "Le Nil"},
|
| 37 |
+
{"category": "Science", "question": "Qu'est-ce que la photosynthèse ?", "reference": "Processus par lequel les plantes convertissent la lumière en énergie"},
|
| 38 |
+
{"category": "Science", "question": "Combien d'os compte le corps humain adulte ?", "reference": "206"},
|
| 39 |
+
{"category": "Histoire", "question": "En quelle année a eu lieu la Révolution française ?", "reference": "1789"},
|
| 40 |
+
{"category": "Histoire", "question": "Qui a écrit Les Misérables ?", "reference": "Victor Hugo"},
|
| 41 |
+
{"category": "Mathématiques", "question": "Quelle est la formule de l'aire d'un cercle ?", "reference": "pi r carre"},
|
| 42 |
+
{"category": "Langage", "question": "Donne un synonyme du mot heureux.", "reference": "joyeux"},
|
| 43 |
+
{"category": "Raisonnement", "question": "Si j'ai 5 pommes et j'en donne 2, combien m'en reste-t-il ?", "reference": "3"},
|
| 44 |
+
{"category": "Dialogue", "question": "Comment vas-tu aujourd'hui ?", "reference": None},
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def load_module_from_file(py_path: Path):
|
| 49 |
+
spec = importlib.util.spec_from_file_location(py_path.stem, py_path)
|
| 50 |
+
if spec is None or spec.loader is None:
|
| 51 |
+
raise RuntimeError(f"Impossible de charger le module: {py_path}")
|
| 52 |
+
module = importlib.util.module_from_spec(spec)
|
| 53 |
+
sys.modules[py_path.stem] = module
|
| 54 |
+
spec.loader.exec_module(module)
|
| 55 |
+
return module
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def normalize_text(text: str) -> str:
|
| 59 |
+
text = (text or "").strip().lower()
|
| 60 |
+
text = unicodedata.normalize("NFKD", text)
|
| 61 |
+
text = "".join(ch for ch in text if not unicodedata.combining(ch))
|
| 62 |
+
text = text.replace("π", "pi")
|
| 63 |
+
text = re.sub(r"[\W_]+", " ", text, flags=re.UNICODE)
|
| 64 |
+
text = re.sub(r"\s+", " ", text).strip()
|
| 65 |
+
return text
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def lexical_overlap(reference: Optional[str], answer: str) -> Optional[float]:
|
| 69 |
+
if not reference:
|
| 70 |
+
return None
|
| 71 |
+
ref = set(normalize_text(reference).split())
|
| 72 |
+
ans = set(normalize_text(answer).split())
|
| 73 |
+
if not ref:
|
| 74 |
+
return None
|
| 75 |
+
return len(ref & ans) / len(ref)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def exact_match(reference: Optional[str], answer: str) -> Optional[bool]:
|
| 79 |
+
if not reference:
|
| 80 |
+
return None
|
| 81 |
+
return normalize_text(reference) == normalize_text(answer)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def infer_repo_defaults(repo_dir: Path):
|
| 85 |
+
train_script = repo_dir.parent / "train_aramix_h100_full.py"
|
| 86 |
+
if not train_script.exists():
|
| 87 |
+
train_script = repo_dir / "train_aramix_h100_full.py"
|
| 88 |
+
|
| 89 |
+
ckpt = repo_dir / "model_best.pt"
|
| 90 |
+
if not ckpt.exists():
|
| 91 |
+
ckpt = repo_dir / "model.pt"
|
| 92 |
+
|
| 93 |
+
config = repo_dir / "config.json"
|
| 94 |
+
tokenizer_dir = repo_dir / "tokenizer_32k"
|
| 95 |
+
return train_script, ckpt, config, tokenizer_dir
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def safe_get(cfg: Dict[str, Any], *names: str, default=None):
|
| 99 |
+
for name in names:
|
| 100 |
+
if name in cfg:
|
| 101 |
+
return cfg[name]
|
| 102 |
+
return default
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def build_model_config_dict(cfg_json: Dict[str, Any], vocab_size: int) -> Dict[str, Any]:
|
| 106 |
+
block_size = safe_get(cfg_json, "block_size", "max_seq_len", "seq_len", default=512)
|
| 107 |
+
d_model = safe_get(cfg_json, "d_model", "n_embd", "dim", default=768)
|
| 108 |
+
n_heads = safe_get(cfg_json, "n_heads", "n_head", "num_heads", default=12)
|
| 109 |
+
n_layers = safe_get(cfg_json, "n_layers", "n_layer", "num_layers", default=12)
|
| 110 |
+
d_ff = safe_get(cfg_json, "d_ff", "ffn_dim", "intermediate_size", default=d_model * 4)
|
| 111 |
+
return {
|
| 112 |
+
"vocab_size": vocab_size,
|
| 113 |
+
"block_size": int(block_size),
|
| 114 |
+
"d_model": int(d_model),
|
| 115 |
+
"n_heads": int(n_heads),
|
| 116 |
+
"n_layers": int(n_layers),
|
| 117 |
+
"d_ff": int(d_ff),
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def strip_orig_mod_prefix(state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
| 122 |
+
if any(k.startswith("_orig_mod.") for k in state.keys()):
|
| 123 |
+
return {k.replace("_orig_mod.", "", 1): v for k, v in state.items()}
|
| 124 |
+
return state
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class AramixChatTester:
|
| 128 |
+
def __init__(
|
| 129 |
+
self,
|
| 130 |
+
repo_dir: Path,
|
| 131 |
+
train_script: Path,
|
| 132 |
+
ckpt_path: Path,
|
| 133 |
+
config_path: Path,
|
| 134 |
+
device: Optional[str] = None,
|
| 135 |
+
):
|
| 136 |
+
self.repo_dir = repo_dir
|
| 137 |
+
self.train_script = train_script
|
| 138 |
+
self.ckpt_path = ckpt_path
|
| 139 |
+
self.config_path = config_path
|
| 140 |
+
self.device = torch.device(device if device is not None else ("cuda" if torch.cuda.is_available() else "cpu"))
|
| 141 |
+
|
| 142 |
+
self.M = load_module_from_file(self.train_script)
|
| 143 |
+
|
| 144 |
+
required = ["GPT", "GPTConfig", "train_or_load_tokenizer", "DOMAINS"]
|
| 145 |
+
missing = [x for x in required if not hasattr(self.M, x)]
|
| 146 |
+
if missing:
|
| 147 |
+
raise RuntimeError(f"Le fichier {self.train_script.name} ne contient pas les symboles attendus: {missing}")
|
| 148 |
+
|
| 149 |
+
self.cfg_json: Dict[str, Any] = {}
|
| 150 |
+
if self.config_path.exists():
|
| 151 |
+
with open(self.config_path, "r", encoding="utf-8") as f:
|
| 152 |
+
self.cfg_json = json.load(f)
|
| 153 |
+
|
| 154 |
+
self.tokenizer = self._load_tokenizer()
|
| 155 |
+
self.model = self._load_model()
|
| 156 |
+
|
| 157 |
+
def _load_tokenizer(self):
|
| 158 |
+
old_cwd = Path.cwd()
|
| 159 |
+
try:
|
| 160 |
+
os.chdir(self.repo_dir.parent)
|
| 161 |
+
tok = self.M.train_or_load_tokenizer(self.M.DOMAINS)
|
| 162 |
+
finally:
|
| 163 |
+
os.chdir(old_cwd)
|
| 164 |
+
return tok
|
| 165 |
+
|
| 166 |
+
def _make_gpt_config(self):
|
| 167 |
+
kwargs = build_model_config_dict(self.cfg_json, vocab_size=len(self.tokenizer))
|
| 168 |
+
try:
|
| 169 |
+
return self.M.GPTConfig(**kwargs)
|
| 170 |
+
except TypeError:
|
| 171 |
+
return self.M.GPTConfig(vocab_size=len(self.tokenizer))
|
| 172 |
+
|
| 173 |
+
def _manual_load_state(self, model: torch.nn.Module):
|
| 174 |
+
ckpt = torch.load(self.ckpt_path, map_location=self.device)
|
| 175 |
+
if isinstance(ckpt, dict) and "model" in ckpt:
|
| 176 |
+
state = ckpt["model"]
|
| 177 |
+
else:
|
| 178 |
+
state = ckpt
|
| 179 |
+
|
| 180 |
+
if not isinstance(state, dict):
|
| 181 |
+
raise RuntimeError("Checkpoint non reconnu: pas de state_dict exploitable.")
|
| 182 |
+
|
| 183 |
+
state = strip_orig_mod_prefix(state)
|
| 184 |
+
missing, unexpected = model.load_state_dict(state, strict=False)
|
| 185 |
+
|
| 186 |
+
# tolérance uniquement sur lm_head/tied weights éventuels, sinon on échoue
|
| 187 |
+
hard_missing = [k for k in missing if not k.endswith("lm_head.weight")]
|
| 188 |
+
hard_unexpected = [k for k in unexpected if not k.startswith("_orig_mod.")]
|
| 189 |
+
if hard_missing or hard_unexpected:
|
| 190 |
+
raise RuntimeError(
|
| 191 |
+
"Chargement manuel incomplet.\n"
|
| 192 |
+
f"Missing: {hard_missing[:20]}\n"
|
| 193 |
+
f"Unexpected: {hard_unexpected[:20]}"
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
def _load_model(self):
|
| 197 |
+
cfg = self._make_gpt_config()
|
| 198 |
+
model = self.M.GPT(cfg).to(self.device)
|
| 199 |
+
|
| 200 |
+
# 1) tentative via load_checkpoint du script train
|
| 201 |
+
if hasattr(self.M, "load_checkpoint"):
|
| 202 |
+
try:
|
| 203 |
+
try:
|
| 204 |
+
self.M.load_checkpoint(model, None, self.ckpt_path, self.device)
|
| 205 |
+
model.eval()
|
| 206 |
+
return model
|
| 207 |
+
except TypeError:
|
| 208 |
+
self.M.load_checkpoint(model, self.ckpt_path, self.device)
|
| 209 |
+
model.eval()
|
| 210 |
+
return model
|
| 211 |
+
except RuntimeError as e:
|
| 212 |
+
msg = str(e)
|
| 213 |
+
if "_orig_mod." not in msg:
|
| 214 |
+
raise
|
| 215 |
+
|
| 216 |
+
# 2) fallback robuste
|
| 217 |
+
self._manual_load_state(model)
|
| 218 |
+
model.eval()
|
| 219 |
+
return model
|
| 220 |
+
|
| 221 |
+
def encode_prompt(self, question: str) -> List[int]:
|
| 222 |
+
bos = getattr(self.tokenizer, "bos_token_id", None)
|
| 223 |
+
eos = getattr(self.tokenizer, "eos_token_id", None)
|
| 224 |
+
|
| 225 |
+
prompt = f"Question: {question}\nRéponse:"
|
| 226 |
+
ids = self.tokenizer.encode(prompt, add_special_tokens=False)
|
| 227 |
+
|
| 228 |
+
if bos is not None:
|
| 229 |
+
ids = [bos] + ids
|
| 230 |
+
if eos is not None and ids and ids[-1] == eos:
|
| 231 |
+
ids = ids[:-1]
|
| 232 |
+
return ids
|
| 233 |
+
|
| 234 |
+
@torch.no_grad()
|
| 235 |
+
def generate(
|
| 236 |
+
self,
|
| 237 |
+
question: str,
|
| 238 |
+
max_new_tokens: int = 96,
|
| 239 |
+
temperature: float = 0.4,
|
| 240 |
+
top_k: int = 40,
|
| 241 |
+
repetition_penalty: float = 1.12,
|
| 242 |
+
) -> str:
|
| 243 |
+
ids = self.encode_prompt(question)
|
| 244 |
+
x = torch.tensor([ids], dtype=torch.long, device=self.device)
|
| 245 |
+
|
| 246 |
+
eos_id = getattr(self.tokenizer, "eos_token_id", None)
|
| 247 |
+
block_size = getattr(getattr(self.model, "cfg", None), "block_size", None)
|
| 248 |
+
if block_size is None:
|
| 249 |
+
block_size = safe_get(self.cfg_json, "block_size", "max_seq_len", default=512)
|
| 250 |
+
|
| 251 |
+
for step in range(max_new_tokens):
|
| 252 |
+
x_ctx = x[:, -int(block_size):]
|
| 253 |
+
|
| 254 |
+
out = self.model(x_ctx)
|
| 255 |
+
logits = out[0] if isinstance(out, tuple) else out
|
| 256 |
+
logits = logits[:, -1, :]
|
| 257 |
+
|
| 258 |
+
recent = x[0, -64:].tolist()
|
| 259 |
+
for tok in set(recent):
|
| 260 |
+
logits[0, tok] /= repetition_penalty
|
| 261 |
+
|
| 262 |
+
if temperature <= 0:
|
| 263 |
+
next_tok = torch.argmax(logits, dim=-1, keepdim=True)
|
| 264 |
+
else:
|
| 265 |
+
logits = logits / max(temperature, 1e-5)
|
| 266 |
+
if top_k is not None and top_k > 0:
|
| 267 |
+
values, _ = torch.topk(logits, k=min(top_k, logits.size(-1)))
|
| 268 |
+
kth = values[:, -1].unsqueeze(-1)
|
| 269 |
+
logits = torch.where(logits < kth, torch.full_like(logits, float("-inf")), logits)
|
| 270 |
+
probs = F.softmax(logits, dim=-1)
|
| 271 |
+
next_tok = torch.multinomial(probs, num_samples=1)
|
| 272 |
+
|
| 273 |
+
x = torch.cat([x, next_tok], dim=1)
|
| 274 |
+
if eos_id is not None and next_tok.item() == eos_id and step >= 2:
|
| 275 |
+
break
|
| 276 |
+
|
| 277 |
+
new_ids = x[0, len(ids):].tolist()
|
| 278 |
+
text = self.tokenizer.decode(new_ids).strip()
|
| 279 |
+
text = re.sub(r"\s+", " ", text).strip()
|
| 280 |
+
text = text.replace("Réponse :", "").replace("Réponse:", "").strip()
|
| 281 |
+
return text
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def load_questions(path: Optional[str]) -> List[Dict[str, Any]]:
|
| 285 |
+
if not path:
|
| 286 |
+
return DEFAULT_QUESTIONS
|
| 287 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 288 |
+
data = json.load(f)
|
| 289 |
+
if not isinstance(data, list):
|
| 290 |
+
raise ValueError("Le fichier questions doit contenir une liste JSON.")
|
| 291 |
+
return data
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def format_bar(score: float, width: int = 20) -> str:
|
| 295 |
+
n = max(0, min(width, int(round(score * width))))
|
| 296 |
+
return "█" * n + "░" * (width - n)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def save_reports(output_dir: Path, summary: Dict[str, Any]) -> None:
|
| 300 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 301 |
+
|
| 302 |
+
json_path = output_dir / "qa_test_report_simple.json"
|
| 303 |
+
txt_path = output_dir / "qa_test_report_simple.txt"
|
| 304 |
+
|
| 305 |
+
with open(json_path, "w", encoding="utf-8") as f:
|
| 306 |
+
json.dump(summary, f, ensure_ascii=False, indent=2)
|
| 307 |
+
|
| 308 |
+
with open(txt_path, "w", encoding="utf-8") as f:
|
| 309 |
+
f.write("TEST QA SIMPLE — MODÈLE DÉJÀ ENTRAÎNÉ\n")
|
| 310 |
+
f.write("=" * 60 + "\n\n")
|
| 311 |
+
for r in summary["results"]:
|
| 312 |
+
f.write(f"[{r['id']:02d}] {r['category']}\n")
|
| 313 |
+
f.write(f" User : {r['question']}\n")
|
| 314 |
+
f.write(f" Assistant : {r['answer']}\n")
|
| 315 |
+
if r["reference"]:
|
| 316 |
+
f.write(f" Référence : {r['reference']}\n")
|
| 317 |
+
if r["overlap_score"] is not None:
|
| 318 |
+
f.write(f" Overlap : {r['overlap_score']:.0%}\n")
|
| 319 |
+
if r["exact_match"] is not None:
|
| 320 |
+
f.write(f" ExactMatch: {r['exact_match']}\n")
|
| 321 |
+
f.write(f" Latence : {r['latency_s']}s\n\n")
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def main():
|
| 325 |
+
parser = argparse.ArgumentParser("Test QA simple pour modèle Aramix déjà entraîné")
|
| 326 |
+
parser.add_argument("--repo_dir", type=str, default="./aramix_h100")
|
| 327 |
+
parser.add_argument("--train_script", type=str, default=None)
|
| 328 |
+
parser.add_argument("--ckpt", type=str, default=None)
|
| 329 |
+
parser.add_argument("--config", type=str, default=None)
|
| 330 |
+
parser.add_argument("--questions", type=str, default=None)
|
| 331 |
+
parser.add_argument("--max_new_tokens", type=int, default=96)
|
| 332 |
+
parser.add_argument("--temperature", type=float, default=0.4)
|
| 333 |
+
parser.add_argument("--top_k", type=int, default=40)
|
| 334 |
+
parser.add_argument("--repetition_penalty", type=float, default=1.12)
|
| 335 |
+
parser.add_argument("--device", type=str, default=None)
|
| 336 |
+
parser.add_argument("--save_report", action="store_true")
|
| 337 |
+
args = parser.parse_args()
|
| 338 |
+
|
| 339 |
+
repo_dir = Path(args.repo_dir).resolve()
|
| 340 |
+
train_script, ckpt_path, config_path, tokenizer_dir = infer_repo_defaults(repo_dir)
|
| 341 |
+
|
| 342 |
+
if args.train_script:
|
| 343 |
+
train_script = Path(args.train_script).resolve()
|
| 344 |
+
if args.ckpt:
|
| 345 |
+
ckpt_path = Path(args.ckpt).resolve()
|
| 346 |
+
if args.config:
|
| 347 |
+
config_path = Path(args.config).resolve()
|
| 348 |
+
|
| 349 |
+
if not train_script.exists():
|
| 350 |
+
raise FileNotFoundError(f"Script train introuvable: {train_script}")
|
| 351 |
+
if not ckpt_path.exists():
|
| 352 |
+
raise FileNotFoundError(f"Checkpoint introuvable: {ckpt_path}")
|
| 353 |
+
if not config_path.exists():
|
| 354 |
+
print(f"[WARN] config.json introuvable: {config_path} — fallback sur GPTConfig(vocab_size=...).")
|
| 355 |
+
|
| 356 |
+
questions = load_questions(args.questions)
|
| 357 |
+
tester = AramixChatTester(
|
| 358 |
+
repo_dir=repo_dir,
|
| 359 |
+
train_script=train_script,
|
| 360 |
+
ckpt_path=ckpt_path,
|
| 361 |
+
config_path=config_path,
|
| 362 |
+
device=args.device,
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
results: List[Dict[str, Any]] = []
|
| 366 |
+
categories: Dict[str, List[Dict[str, Any]]] = {}
|
| 367 |
+
|
| 368 |
+
print("\n" + "═" * 70)
|
| 369 |
+
print(" TEST QA SIMPLE — MODÈLE DÉJÀ ENTRAÎNÉ")
|
| 370 |
+
print("═" * 70)
|
| 371 |
+
print(f"Repo : {repo_dir}")
|
| 372 |
+
print(f"Train script: {train_script}")
|
| 373 |
+
print(f"Checkpoint : {ckpt_path}")
|
| 374 |
+
print(f"Config : {config_path}")
|
| 375 |
+
print(f"Tokenizer : {tokenizer_dir}")
|
| 376 |
+
print(f"Device : {tester.device}")
|
| 377 |
+
print(f"Questions : {len(questions)}")
|
| 378 |
+
print("═" * 70 + "\n")
|
| 379 |
+
|
| 380 |
+
for i, item in enumerate(questions, 1):
|
| 381 |
+
q = item["question"]
|
| 382 |
+
ref = item.get("reference")
|
| 383 |
+
cat = item.get("category", "Général")
|
| 384 |
+
|
| 385 |
+
t0 = time.time()
|
| 386 |
+
ans = tester.generate(
|
| 387 |
+
q,
|
| 388 |
+
max_new_tokens=args.max_new_tokens,
|
| 389 |
+
temperature=args.temperature,
|
| 390 |
+
top_k=args.top_k,
|
| 391 |
+
repetition_penalty=args.repetition_penalty,
|
| 392 |
+
)
|
| 393 |
+
latency = time.time() - t0
|
| 394 |
+
|
| 395 |
+
overlap = lexical_overlap(ref, ans)
|
| 396 |
+
em = exact_match(ref, ans)
|
| 397 |
+
|
| 398 |
+
entry = {
|
| 399 |
+
"id": i,
|
| 400 |
+
"category": cat,
|
| 401 |
+
"question": q,
|
| 402 |
+
"answer": ans,
|
| 403 |
+
"reference": ref,
|
| 404 |
+
"latency_s": round(latency, 3),
|
| 405 |
+
"tokens_generated_approx": len(ans.split()),
|
| 406 |
+
"overlap_score": None if overlap is None else round(overlap, 4),
|
| 407 |
+
"exact_match": em,
|
| 408 |
+
}
|
| 409 |
+
results.append(entry)
|
| 410 |
+
categories.setdefault(cat, []).append(entry)
|
| 411 |
+
|
| 412 |
+
overlap_str = f"{overlap:.0%}" if overlap is not None else "n/a"
|
| 413 |
+
em_str = "✓" if em else ("✗" if em is not None else "n/a")
|
| 414 |
+
|
| 415 |
+
print("─" * 70)
|
| 416 |
+
print(f"[{i:02d}] [{cat}] overlap={overlap_str} | EM={em_str}")
|
| 417 |
+
print(f" User : {q}")
|
| 418 |
+
print(f" Assistant : {ans}")
|
| 419 |
+
if ref:
|
| 420 |
+
print(f" Référence : {ref}")
|
| 421 |
+
print(f" ⏱ {latency:.2f}s | ~{entry['tokens_generated_approx']} mots\n")
|
| 422 |
+
|
| 423 |
+
scored_overlap = [r["overlap_score"] for r in results if r["overlap_score"] is not None]
|
| 424 |
+
scored_em = [r["exact_match"] for r in results if r["exact_match"] is not None]
|
| 425 |
+
|
| 426 |
+
avg_overlap = sum(scored_overlap) / len(scored_overlap) if scored_overlap else 0.0
|
| 427 |
+
em_rate = sum(1 for x in scored_em if x) / len(scored_em) if scored_em else 0.0
|
| 428 |
+
avg_latency = sum(r["latency_s"] for r in results) / len(results) if results else 0.0
|
| 429 |
+
avg_words = sum(r["tokens_generated_approx"] for r in results) / len(results) if results else 0.0
|
| 430 |
+
|
| 431 |
+
cat_scores: Dict[str, float] = {}
|
| 432 |
+
for cat, items in categories.items():
|
| 433 |
+
vals = [r["overlap_score"] for r in items if r["overlap_score"] is not None]
|
| 434 |
+
cat_scores[cat] = (sum(vals) / len(vals)) if vals else 0.0
|
| 435 |
+
|
| 436 |
+
summary = {
|
| 437 |
+
"repo_dir": str(repo_dir),
|
| 438 |
+
"train_script": str(train_script),
|
| 439 |
+
"checkpoint": str(ckpt_path),
|
| 440 |
+
"config_path": str(config_path),
|
| 441 |
+
"tokenizer_dir": str(tokenizer_dir),
|
| 442 |
+
"device": str(tester.device),
|
| 443 |
+
"total_questions": len(results),
|
| 444 |
+
"avg_overlap_score": round(avg_overlap, 4),
|
| 445 |
+
"exact_match_rate": round(em_rate, 4),
|
| 446 |
+
"avg_latency_s": round(avg_latency, 3),
|
| 447 |
+
"avg_words_generated": round(avg_words, 1),
|
| 448 |
+
"scores_by_category": {k: round(v, 4) for k, v in cat_scores.items()},
|
| 449 |
+
"results": results,
|
| 450 |
+
}
|
| 451 |
+
|
| 452 |
+
print("═" * 70)
|
| 453 |
+
print(" RÉSUMÉ")
|
| 454 |
+
print("═" * 70)
|
| 455 |
+
print(f"Questions testées : {len(results)}")
|
| 456 |
+
print(f"Overlap moyen : {avg_overlap:.1%}")
|
| 457 |
+
print(f"Exact match : {em_rate:.1%}")
|
| 458 |
+
print(f"Latence moyenne : {avg_latency:.2f}s")
|
| 459 |
+
print(f"Mots moyens : {avg_words:.1f}")
|
| 460 |
+
print("Scores / catégorie:")
|
| 461 |
+
for cat, score in sorted(cat_scores.items()):
|
| 462 |
+
print(f" {cat:<15} {format_bar(score)} {score:.0%}")
|
| 463 |
+
print("═" * 70)
|
| 464 |
+
|
| 465 |
+
if args.save_report:
|
| 466 |
+
save_reports(repo_dir, summary)
|
| 467 |
+
print(f"Rapports sauvegardés dans : {repo_dir / 'qa_test_report_simple.json'}")
|
| 468 |
+
print(f" {repo_dir / 'qa_test_report_simple.txt'}")
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
if __name__ == "__main__":
|
| 472 |
+
main()
|
simple_qa_test_aramix_v3.py
ADDED
|
@@ -0,0 +1,583 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
simple_qa_test_aramix_v3.py
|
| 6 |
+
|
| 7 |
+
Test QA simple et strict pour un modèle déjà entraîné dans une repo de type :
|
| 8 |
+
- train_aramix_h100_full.py
|
| 9 |
+
- aramix_h100/
|
| 10 |
+
- config.json
|
| 11 |
+
- model_best.pt
|
| 12 |
+
- model.pt
|
| 13 |
+
- tokenizer_32k/
|
| 14 |
+
|
| 15 |
+
Cette version améliore le test QA en :
|
| 16 |
+
- gérant les checkpoints torch.compile() avec préfixe "_orig_mod."
|
| 17 |
+
- utilisant un prompt plus directif pour des réponses courtes
|
| 18 |
+
- privilégiant une génération greedy / quasi-greedy
|
| 19 |
+
- tronquant proprement les réponses trop longues
|
| 20 |
+
- ajoutant une métrique "contains_reference"
|
| 21 |
+
|
| 22 |
+
Usage
|
| 23 |
+
-----
|
| 24 |
+
python simple_qa_test_aramix_v3.py --repo_dir ./aramix_h100 --save_report
|
| 25 |
+
python simple_qa_test_aramix_v3.py --repo_dir ./aramix_h100 --temperature 0 --max_new_tokens 16
|
| 26 |
+
python simple_qa_test_aramix_v3.py --repo_dir ./aramix_h100 --ckpt ./aramix_h100/model_best.pt
|
| 27 |
+
python simple_qa_test_aramix_v3.py --repo_dir ./aramix_h100 --questions qa_questions.json
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
from __future__ import annotations
|
| 31 |
+
|
| 32 |
+
import argparse
|
| 33 |
+
import importlib.util
|
| 34 |
+
import json
|
| 35 |
+
import os
|
| 36 |
+
import re
|
| 37 |
+
import sys
|
| 38 |
+
import time
|
| 39 |
+
import unicodedata
|
| 40 |
+
from pathlib import Path
|
| 41 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 42 |
+
|
| 43 |
+
import torch
|
| 44 |
+
import torch.nn.functional as F
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
DEFAULT_QUESTIONS = [
|
| 48 |
+
{"category": "Géographie", "question": "Quelle est la capitale de la France ?", "reference": "Paris"},
|
| 49 |
+
{"category": "Géographie", "question": "Quel est le plus long fleuve d'Afrique ?", "reference": "Le Nil"},
|
| 50 |
+
{"category": "Science", "question": "Qu'est-ce que la photosynthèse ?", "reference": "Processus par lequel les plantes convertissent la lumière en énergie"},
|
| 51 |
+
{"category": "Science", "question": "Combien d'os compte le corps humain adulte ?", "reference": "206"},
|
| 52 |
+
{"category": "Histoire", "question": "En quelle année a eu lieu la Révolution française ?", "reference": "1789"},
|
| 53 |
+
{"category": "Histoire", "question": "Qui a écrit Les Misérables ?", "reference": "Victor Hugo"},
|
| 54 |
+
{"category": "Mathématiques", "question": "Quelle est la formule de l'aire d'un cercle ?", "reference": "pi r carre"},
|
| 55 |
+
{"category": "Langage", "question": "Donne un synonyme du mot heureux.", "reference": "joyeux"},
|
| 56 |
+
{"category": "Raisonnement", "question": "Si j'ai 5 pommes et j'en donne 2, combien m'en reste-t-il ?", "reference": "3"},
|
| 57 |
+
{"category": "Dialogue", "question": "Comment vas-tu aujourd'hui ?", "reference": None},
|
| 58 |
+
]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def load_module_from_file(py_path: Path):
|
| 62 |
+
spec = importlib.util.spec_from_file_location(py_path.stem, py_path)
|
| 63 |
+
if spec is None or spec.loader is None:
|
| 64 |
+
raise RuntimeError(f"Impossible de charger le module: {py_path}")
|
| 65 |
+
module = importlib.util.module_from_spec(spec)
|
| 66 |
+
sys.modules[py_path.stem] = module
|
| 67 |
+
spec.loader.exec_module(module)
|
| 68 |
+
return module
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def normalize_text(text: str) -> str:
|
| 72 |
+
text = (text or "").strip().lower()
|
| 73 |
+
text = unicodedata.normalize("NFKD", text)
|
| 74 |
+
text = "".join(ch for ch in text if not unicodedata.combining(ch))
|
| 75 |
+
text = text.replace("π", "pi")
|
| 76 |
+
text = re.sub(r"[\W_]+", " ", text, flags=re.UNICODE)
|
| 77 |
+
text = re.sub(r"\s+", " ", text).strip()
|
| 78 |
+
return text
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def lexical_overlap(reference: Optional[str], answer: str) -> Optional[float]:
|
| 82 |
+
if not reference:
|
| 83 |
+
return None
|
| 84 |
+
ref = set(normalize_text(reference).split())
|
| 85 |
+
ans = set(normalize_text(answer).split())
|
| 86 |
+
if not ref:
|
| 87 |
+
return None
|
| 88 |
+
return len(ref & ans) / len(ref)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def exact_match(reference: Optional[str], answer: str) -> Optional[bool]:
|
| 92 |
+
if not reference:
|
| 93 |
+
return None
|
| 94 |
+
return normalize_text(reference) == normalize_text(answer)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def contains_reference(reference: Optional[str], answer: str) -> Optional[bool]:
|
| 98 |
+
if not reference:
|
| 99 |
+
return None
|
| 100 |
+
ref = normalize_text(reference)
|
| 101 |
+
ans = normalize_text(answer)
|
| 102 |
+
if not ref:
|
| 103 |
+
return None
|
| 104 |
+
return ref in ans
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def infer_repo_defaults(repo_dir: Path):
|
| 108 |
+
train_script = repo_dir.parent / "train_aramix_h100_full.py"
|
| 109 |
+
if not train_script.exists():
|
| 110 |
+
train_script = repo_dir / "train_aramix_h100_full.py"
|
| 111 |
+
|
| 112 |
+
ckpt = repo_dir / "model_best.pt"
|
| 113 |
+
if not ckpt.exists():
|
| 114 |
+
ckpt = repo_dir / "model.pt"
|
| 115 |
+
|
| 116 |
+
config = repo_dir / "config.json"
|
| 117 |
+
tokenizer_dir = repo_dir / "tokenizer_32k"
|
| 118 |
+
return train_script, ckpt, config, tokenizer_dir
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def safe_get(cfg: Dict[str, Any], *names: str, default=None):
|
| 122 |
+
for name in names:
|
| 123 |
+
if name in cfg:
|
| 124 |
+
return cfg[name]
|
| 125 |
+
return default
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def build_model_config_dict(cfg_json: Dict[str, Any], vocab_size: int) -> Dict[str, Any]:
|
| 129 |
+
block_size = safe_get(cfg_json, "block_size", "max_seq_len", "seq_len", default=512)
|
| 130 |
+
d_model = safe_get(cfg_json, "d_model", "n_embd", "dim", default=768)
|
| 131 |
+
n_heads = safe_get(cfg_json, "n_heads", "n_head", "num_heads", default=12)
|
| 132 |
+
n_layers = safe_get(cfg_json, "n_layers", "n_layer", "num_layers", default=12)
|
| 133 |
+
d_ff = safe_get(cfg_json, "d_ff", "ffn_dim", "intermediate_size", default=d_model * 4)
|
| 134 |
+
return {
|
| 135 |
+
"vocab_size": vocab_size,
|
| 136 |
+
"block_size": int(block_size),
|
| 137 |
+
"d_model": int(d_model),
|
| 138 |
+
"n_heads": int(n_heads),
|
| 139 |
+
"n_layers": int(n_layers),
|
| 140 |
+
"d_ff": int(d_ff),
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def strip_orig_mod_prefix(state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
| 145 |
+
if any(k.startswith("_orig_mod.") for k in state.keys()):
|
| 146 |
+
return {k.replace("_orig_mod.", "", 1): v for k, v in state.items()}
|
| 147 |
+
return state
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def clean_answer_text(text: str, max_words: int = 16) -> str:
|
| 151 |
+
text = (text or "").strip()
|
| 152 |
+
|
| 153 |
+
# Retire quelques marqueurs fréquents
|
| 154 |
+
text = text.replace("<eos>", " ")
|
| 155 |
+
text = text.replace("</s>", " ")
|
| 156 |
+
text = text.replace("<pad>", " ")
|
| 157 |
+
text = text.replace("Réponse :", " ")
|
| 158 |
+
text = text.replace("Réponse:", " ")
|
| 159 |
+
text = text.replace("Answer:", " ")
|
| 160 |
+
|
| 161 |
+
# Garde la première ligne
|
| 162 |
+
text = text.split("\n")[0].strip()
|
| 163 |
+
|
| 164 |
+
# Coupe à la première vraie fin de phrase courte
|
| 165 |
+
m = re.search(r"([.!?])", text)
|
| 166 |
+
if m and m.start() < 120:
|
| 167 |
+
text = text[: m.start() + 1]
|
| 168 |
+
|
| 169 |
+
# Compacte les espaces
|
| 170 |
+
text = re.sub(r"\s+", " ", text).strip()
|
| 171 |
+
|
| 172 |
+
# Tronque au nombre de mots voulu
|
| 173 |
+
words = text.split()
|
| 174 |
+
if len(words) > max_words:
|
| 175 |
+
text = " ".join(words[:max_words]).strip()
|
| 176 |
+
|
| 177 |
+
# Supprime ponctuation finale excessive
|
| 178 |
+
text = re.sub(r"[,\s;:]+$", "", text).strip()
|
| 179 |
+
|
| 180 |
+
return text
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def safe_top_k_logits(logits: torch.Tensor, top_k: int) -> torch.Tensor:
|
| 184 |
+
if top_k is None or top_k <= 0:
|
| 185 |
+
return logits
|
| 186 |
+
k = min(int(top_k), logits.size(-1))
|
| 187 |
+
values, _ = torch.topk(logits, k=k)
|
| 188 |
+
kth = values[:, -1].unsqueeze(-1)
|
| 189 |
+
return torch.where(logits < kth, torch.full_like(logits, float("-inf")), logits)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class AramixChatTester:
|
| 193 |
+
def __init__(
|
| 194 |
+
self,
|
| 195 |
+
repo_dir: Path,
|
| 196 |
+
train_script: Path,
|
| 197 |
+
ckpt_path: Path,
|
| 198 |
+
config_path: Path,
|
| 199 |
+
device: Optional[str] = None,
|
| 200 |
+
prompt_style: str = "strict_qa",
|
| 201 |
+
):
|
| 202 |
+
self.repo_dir = repo_dir
|
| 203 |
+
self.train_script = train_script
|
| 204 |
+
self.ckpt_path = ckpt_path
|
| 205 |
+
self.config_path = config_path
|
| 206 |
+
self.prompt_style = prompt_style
|
| 207 |
+
self.device = torch.device(device if device is not None else ("cuda" if torch.cuda.is_available() else "cpu"))
|
| 208 |
+
|
| 209 |
+
self.M = load_module_from_file(self.train_script)
|
| 210 |
+
|
| 211 |
+
required = ["GPT", "GPTConfig", "train_or_load_tokenizer", "DOMAINS"]
|
| 212 |
+
missing = [x for x in required if not hasattr(self.M, x)]
|
| 213 |
+
if missing:
|
| 214 |
+
raise RuntimeError(f"Le fichier {self.train_script.name} ne contient pas les symboles attendus: {missing}")
|
| 215 |
+
|
| 216 |
+
self.cfg_json: Dict[str, Any] = {}
|
| 217 |
+
if self.config_path.exists():
|
| 218 |
+
with open(self.config_path, "r", encoding="utf-8") as f:
|
| 219 |
+
self.cfg_json = json.load(f)
|
| 220 |
+
|
| 221 |
+
self.tokenizer = self._load_tokenizer()
|
| 222 |
+
self.model = self._load_model()
|
| 223 |
+
|
| 224 |
+
def _load_tokenizer(self):
|
| 225 |
+
old_cwd = Path.cwd()
|
| 226 |
+
try:
|
| 227 |
+
os.chdir(self.repo_dir.parent)
|
| 228 |
+
tok = self.M.train_or_load_tokenizer(self.M.DOMAINS)
|
| 229 |
+
finally:
|
| 230 |
+
os.chdir(old_cwd)
|
| 231 |
+
return tok
|
| 232 |
+
|
| 233 |
+
def _make_gpt_config(self):
|
| 234 |
+
kwargs = build_model_config_dict(self.cfg_json, vocab_size=len(self.tokenizer))
|
| 235 |
+
try:
|
| 236 |
+
return self.M.GPTConfig(**kwargs)
|
| 237 |
+
except TypeError:
|
| 238 |
+
return self.M.GPTConfig(vocab_size=len(self.tokenizer))
|
| 239 |
+
|
| 240 |
+
def _manual_load_state(self, model: torch.nn.Module):
|
| 241 |
+
ckpt = torch.load(self.ckpt_path, map_location=self.device)
|
| 242 |
+
if isinstance(ckpt, dict) and "model" in ckpt:
|
| 243 |
+
state = ckpt["model"]
|
| 244 |
+
else:
|
| 245 |
+
state = ckpt
|
| 246 |
+
|
| 247 |
+
if not isinstance(state, dict):
|
| 248 |
+
raise RuntimeError("Checkpoint non reconnu: pas de state_dict exploitable.")
|
| 249 |
+
|
| 250 |
+
state = strip_orig_mod_prefix(state)
|
| 251 |
+
missing, unexpected = model.load_state_dict(state, strict=False)
|
| 252 |
+
|
| 253 |
+
# tolérance minimale uniquement sur le tying éventuel
|
| 254 |
+
hard_missing = [k for k in missing if not k.endswith("lm_head.weight")]
|
| 255 |
+
hard_unexpected = [k for k in unexpected if not k.startswith("_orig_mod.")]
|
| 256 |
+
if hard_missing or hard_unexpected:
|
| 257 |
+
raise RuntimeError(
|
| 258 |
+
"Chargement manuel incomplet.\n"
|
| 259 |
+
f"Missing: {hard_missing[:20]}\n"
|
| 260 |
+
f"Unexpected: {hard_unexpected[:20]}"
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
def _load_model(self):
|
| 264 |
+
cfg = self._make_gpt_config()
|
| 265 |
+
model = self.M.GPT(cfg).to(self.device)
|
| 266 |
+
|
| 267 |
+
if hasattr(self.M, "load_checkpoint"):
|
| 268 |
+
try:
|
| 269 |
+
try:
|
| 270 |
+
self.M.load_checkpoint(model, None, self.ckpt_path, self.device)
|
| 271 |
+
model.eval()
|
| 272 |
+
return model
|
| 273 |
+
except TypeError:
|
| 274 |
+
self.M.load_checkpoint(model, self.ckpt_path, self.device)
|
| 275 |
+
model.eval()
|
| 276 |
+
return model
|
| 277 |
+
except RuntimeError as e:
|
| 278 |
+
if "_orig_mod." not in str(e):
|
| 279 |
+
raise
|
| 280 |
+
|
| 281 |
+
self._manual_load_state(model)
|
| 282 |
+
model.eval()
|
| 283 |
+
return model
|
| 284 |
+
|
| 285 |
+
def build_prompt(self, question: str) -> str:
|
| 286 |
+
if self.prompt_style == "strict_qa":
|
| 287 |
+
return (
|
| 288 |
+
"Réponds très brièvement et uniquement en français.\n"
|
| 289 |
+
"Donne seulement la réponse finale, sans explication.\n\n"
|
| 290 |
+
f"Question : {question}\n"
|
| 291 |
+
"Réponse :"
|
| 292 |
+
)
|
| 293 |
+
if self.prompt_style == "qa":
|
| 294 |
+
return f"Question: {question}\nRéponse:"
|
| 295 |
+
return question
|
| 296 |
+
|
| 297 |
+
def encode_prompt(self, question: str) -> List[int]:
|
| 298 |
+
bos = getattr(self.tokenizer, "bos_token_id", None)
|
| 299 |
+
eos = getattr(self.tokenizer, "eos_token_id", None)
|
| 300 |
+
|
| 301 |
+
prompt = self.build_prompt(question)
|
| 302 |
+
ids = self.tokenizer.encode(prompt, add_special_tokens=False)
|
| 303 |
+
|
| 304 |
+
if bos is not None:
|
| 305 |
+
ids = [bos] + ids
|
| 306 |
+
if eos is not None and ids and ids[-1] == eos:
|
| 307 |
+
ids = ids[:-1]
|
| 308 |
+
return ids
|
| 309 |
+
|
| 310 |
+
@torch.no_grad()
|
| 311 |
+
def generate(
|
| 312 |
+
self,
|
| 313 |
+
question: str,
|
| 314 |
+
max_new_tokens: int = 16,
|
| 315 |
+
temperature: float = 0.0,
|
| 316 |
+
top_k: int = 1,
|
| 317 |
+
repetition_penalty: float = 1.10,
|
| 318 |
+
max_answer_words: int = 16,
|
| 319 |
+
) -> str:
|
| 320 |
+
ids = self.encode_prompt(question)
|
| 321 |
+
x = torch.tensor([ids], dtype=torch.long, device=self.device)
|
| 322 |
+
|
| 323 |
+
eos_id = getattr(self.tokenizer, "eos_token_id", None)
|
| 324 |
+
block_size = getattr(getattr(self.model, "cfg", None), "block_size", None)
|
| 325 |
+
if block_size is None:
|
| 326 |
+
block_size = safe_get(self.cfg_json, "block_size", "max_seq_len", default=512)
|
| 327 |
+
|
| 328 |
+
generated_word_budget_hit = False
|
| 329 |
+
|
| 330 |
+
for step in range(max_new_tokens):
|
| 331 |
+
x_ctx = x[:, -int(block_size):]
|
| 332 |
+
out = self.model(x_ctx)
|
| 333 |
+
logits = out[0] if isinstance(out, tuple) else out
|
| 334 |
+
logits = logits[:, -1, :]
|
| 335 |
+
|
| 336 |
+
# Pénalité légère de répétition
|
| 337 |
+
recent = x[0, -48:].tolist()
|
| 338 |
+
for tok in set(recent):
|
| 339 |
+
logits[0, tok] /= repetition_penalty
|
| 340 |
+
|
| 341 |
+
# Greedy par défaut pour QA courte
|
| 342 |
+
if temperature <= 0:
|
| 343 |
+
if top_k is not None and top_k > 1:
|
| 344 |
+
masked = safe_top_k_logits(logits, top_k)
|
| 345 |
+
probs = F.softmax(masked, dim=-1)
|
| 346 |
+
next_tok = torch.multinomial(probs, num_samples=1)
|
| 347 |
+
else:
|
| 348 |
+
next_tok = torch.argmax(logits, dim=-1, keepdim=True)
|
| 349 |
+
else:
|
| 350 |
+
logits = logits / max(temperature, 1e-5)
|
| 351 |
+
logits = safe_top_k_logits(logits, top_k)
|
| 352 |
+
probs = F.softmax(logits, dim=-1)
|
| 353 |
+
next_tok = torch.multinomial(probs, num_samples=1)
|
| 354 |
+
|
| 355 |
+
x = torch.cat([x, next_tok], dim=1)
|
| 356 |
+
|
| 357 |
+
if eos_id is not None and next_tok.item() == eos_id and step >= 1:
|
| 358 |
+
break
|
| 359 |
+
|
| 360 |
+
# arrêt anticipé si la réponse devient déjà trop longue
|
| 361 |
+
partial = self.tokenizer.decode(x[0, len(ids):].tolist()).strip()
|
| 362 |
+
partial = clean_answer_text(partial, max_words=max_answer_words)
|
| 363 |
+
if len(partial.split()) >= max_answer_words:
|
| 364 |
+
generated_word_budget_hit = True
|
| 365 |
+
break
|
| 366 |
+
|
| 367 |
+
new_ids = x[0, len(ids):].tolist()
|
| 368 |
+
text = self.tokenizer.decode(new_ids).strip()
|
| 369 |
+
text = clean_answer_text(text, max_words=max_answer_words)
|
| 370 |
+
|
| 371 |
+
if generated_word_budget_hit:
|
| 372 |
+
text = clean_answer_text(text, max_words=max_answer_words)
|
| 373 |
+
|
| 374 |
+
return text
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
def load_questions(path: Optional[str]) -> List[Dict[str, Any]]:
|
| 378 |
+
if not path:
|
| 379 |
+
return DEFAULT_QUESTIONS
|
| 380 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 381 |
+
data = json.load(f)
|
| 382 |
+
if not isinstance(data, list):
|
| 383 |
+
raise ValueError("Le fichier questions doit contenir une liste JSON.")
|
| 384 |
+
return data
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def format_bar(score: float, width: int = 20) -> str:
|
| 388 |
+
n = max(0, min(width, int(round(score * width))))
|
| 389 |
+
return "█" * n + "░" * (width - n)
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def save_reports(output_dir: Path, summary: Dict[str, Any]) -> Tuple[Path, Path]:
|
| 393 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 394 |
+
|
| 395 |
+
json_path = output_dir / "qa_test_report_simple_v3.json"
|
| 396 |
+
txt_path = output_dir / "qa_test_report_simple_v3.txt"
|
| 397 |
+
|
| 398 |
+
with open(json_path, "w", encoding="utf-8") as f:
|
| 399 |
+
json.dump(summary, f, ensure_ascii=False, indent=2)
|
| 400 |
+
|
| 401 |
+
with open(txt_path, "w", encoding="utf-8") as f:
|
| 402 |
+
f.write("TEST QA SIMPLE — MODÈLE DÉJÀ ENTRAÎNÉ (V3)\n")
|
| 403 |
+
f.write("=" * 60 + "\n\n")
|
| 404 |
+
for r in summary["results"]:
|
| 405 |
+
f.write(f"[{r['id']:02d}] {r['category']}\n")
|
| 406 |
+
f.write(f" User : {r['question']}\n")
|
| 407 |
+
f.write(f" Assistant : {r['answer']}\n")
|
| 408 |
+
if r["reference"]:
|
| 409 |
+
f.write(f" Référence : {r['reference']}\n")
|
| 410 |
+
if r["overlap_score"] is not None:
|
| 411 |
+
f.write(f" Overlap : {r['overlap_score']:.0%}\n")
|
| 412 |
+
if r["exact_match"] is not None:
|
| 413 |
+
f.write(f" ExactMatch: {r['exact_match']}\n")
|
| 414 |
+
if r["contains_reference"] is not None:
|
| 415 |
+
f.write(f" Contains : {r['contains_reference']}\n")
|
| 416 |
+
f.write(f" Latence : {r['latency_s']}s\n\n")
|
| 417 |
+
return json_path, txt_path
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def main():
|
| 421 |
+
parser = argparse.ArgumentParser("Test QA simple et strict pour modèle Aramix déjà entraîné")
|
| 422 |
+
parser.add_argument("--repo_dir", type=str, default="./aramix_h100")
|
| 423 |
+
parser.add_argument("--train_script", type=str, default=None)
|
| 424 |
+
parser.add_argument("--ckpt", type=str, default=None)
|
| 425 |
+
parser.add_argument("--config", type=str, default=None)
|
| 426 |
+
parser.add_argument("--questions", type=str, default=None)
|
| 427 |
+
parser.add_argument("--max_new_tokens", type=int, default=16)
|
| 428 |
+
parser.add_argument("--temperature", type=float, default=0.0)
|
| 429 |
+
parser.add_argument("--top_k", type=int, default=1)
|
| 430 |
+
parser.add_argument("--repetition_penalty", type=float, default=1.10)
|
| 431 |
+
parser.add_argument("--max_answer_words", type=int, default=16)
|
| 432 |
+
parser.add_argument("--prompt_style", type=str, default="strict_qa", choices=["strict_qa", "qa", "raw"])
|
| 433 |
+
parser.add_argument("--device", type=str, default=None)
|
| 434 |
+
parser.add_argument("--save_report", action="store_true")
|
| 435 |
+
args = parser.parse_args()
|
| 436 |
+
|
| 437 |
+
repo_dir = Path(args.repo_dir).resolve()
|
| 438 |
+
train_script, ckpt_path, config_path, tokenizer_dir = infer_repo_defaults(repo_dir)
|
| 439 |
+
|
| 440 |
+
if args.train_script:
|
| 441 |
+
train_script = Path(args.train_script).resolve()
|
| 442 |
+
if args.ckpt:
|
| 443 |
+
ckpt_path = Path(args.ckpt).resolve()
|
| 444 |
+
if args.config:
|
| 445 |
+
config_path = Path(args.config).resolve()
|
| 446 |
+
|
| 447 |
+
if not train_script.exists():
|
| 448 |
+
raise FileNotFoundError(f"Script train introuvable: {train_script}")
|
| 449 |
+
if not ckpt_path.exists():
|
| 450 |
+
raise FileNotFoundError(f"Checkpoint introuvable: {ckpt_path}")
|
| 451 |
+
if not config_path.exists():
|
| 452 |
+
print(f"[WARN] config.json introuvable: {config_path} — fallback sur GPTConfig(vocab_size=...).")
|
| 453 |
+
|
| 454 |
+
questions = load_questions(args.questions)
|
| 455 |
+
tester = AramixChatTester(
|
| 456 |
+
repo_dir=repo_dir,
|
| 457 |
+
train_script=train_script,
|
| 458 |
+
ckpt_path=ckpt_path,
|
| 459 |
+
config_path=config_path,
|
| 460 |
+
device=args.device,
|
| 461 |
+
prompt_style=args.prompt_style,
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
results: List[Dict[str, Any]] = []
|
| 465 |
+
categories: Dict[str, List[Dict[str, Any]]] = {}
|
| 466 |
+
|
| 467 |
+
print("\n" + "=" * 60)
|
| 468 |
+
print("TEST QA SIMPLE — MODÈLE DÉJÀ ENTRAÎNÉ")
|
| 469 |
+
print("=" * 60)
|
| 470 |
+
print(f"Repo : {repo_dir}")
|
| 471 |
+
print(f"Train script: {train_script}")
|
| 472 |
+
print(f"Checkpoint : {ckpt_path}")
|
| 473 |
+
print(f"Config : {config_path}")
|
| 474 |
+
print(f"Tokenizer : {tokenizer_dir}")
|
| 475 |
+
print(f"Device : {tester.device}")
|
| 476 |
+
print(f"Prompt : {args.prompt_style}")
|
| 477 |
+
print(f"Questions : {len(questions)}")
|
| 478 |
+
print("=" * 60 + "\n")
|
| 479 |
+
|
| 480 |
+
for i, item in enumerate(questions, 1):
|
| 481 |
+
q = item["question"]
|
| 482 |
+
ref = item.get("reference")
|
| 483 |
+
cat = item.get("category", "Général")
|
| 484 |
+
|
| 485 |
+
t0 = time.time()
|
| 486 |
+
ans = tester.generate(
|
| 487 |
+
q,
|
| 488 |
+
max_new_tokens=args.max_new_tokens,
|
| 489 |
+
temperature=args.temperature,
|
| 490 |
+
top_k=args.top_k,
|
| 491 |
+
repetition_penalty=args.repetition_penalty,
|
| 492 |
+
max_answer_words=args.max_answer_words,
|
| 493 |
+
)
|
| 494 |
+
latency = time.time() - t0
|
| 495 |
+
|
| 496 |
+
overlap = lexical_overlap(ref, ans)
|
| 497 |
+
em = exact_match(ref, ans)
|
| 498 |
+
contains_ref = contains_reference(ref, ans)
|
| 499 |
+
|
| 500 |
+
entry = {
|
| 501 |
+
"id": i,
|
| 502 |
+
"category": cat,
|
| 503 |
+
"question": q,
|
| 504 |
+
"answer": ans,
|
| 505 |
+
"reference": ref,
|
| 506 |
+
"latency_s": round(latency, 3),
|
| 507 |
+
"tokens_generated_approx": len(ans.split()),
|
| 508 |
+
"overlap_score": None if overlap is None else round(overlap, 4),
|
| 509 |
+
"exact_match": em,
|
| 510 |
+
"contains_reference": contains_ref,
|
| 511 |
+
}
|
| 512 |
+
results.append(entry)
|
| 513 |
+
categories.setdefault(cat, []).append(entry)
|
| 514 |
+
|
| 515 |
+
overlap_str = f"{overlap:.0%}" if overlap is not None else "n/a"
|
| 516 |
+
em_str = "✓" if em else ("✗" if em is not None else "n/a")
|
| 517 |
+
contains_str = "✓" if contains_ref else ("✗" if contains_ref is not None else "n/a")
|
| 518 |
+
|
| 519 |
+
print(f"[{i:02d}] {cat}")
|
| 520 |
+
print(f" User : {q}")
|
| 521 |
+
print(f" Assistant : {ans}")
|
| 522 |
+
if ref:
|
| 523 |
+
print(f" Référence : {ref}")
|
| 524 |
+
print(f" Overlap : {overlap_str}")
|
| 525 |
+
print(f" ExactMatch: {em_str}")
|
| 526 |
+
print(f" Contains : {contains_str}")
|
| 527 |
+
print(f" Latence : {latency:.3f}s\n")
|
| 528 |
+
|
| 529 |
+
scored_overlap = [r["overlap_score"] for r in results if r["overlap_score"] is not None]
|
| 530 |
+
scored_em = [r["exact_match"] for r in results if r["exact_match"] is not None]
|
| 531 |
+
scored_contains = [r["contains_reference"] for r in results if r["contains_reference"] is not None]
|
| 532 |
+
|
| 533 |
+
avg_overlap = sum(scored_overlap) / len(scored_overlap) if scored_overlap else 0.0
|
| 534 |
+
em_rate = sum(1 for x in scored_em if x) / len(scored_em) if scored_em else 0.0
|
| 535 |
+
contains_rate = sum(1 for x in scored_contains if x) / len(scored_contains) if scored_contains else 0.0
|
| 536 |
+
avg_latency = sum(r["latency_s"] for r in results) / len(results) if results else 0.0
|
| 537 |
+
avg_words = sum(r["tokens_generated_approx"] for r in results) / len(results) if results else 0.0
|
| 538 |
+
|
| 539 |
+
cat_scores: Dict[str, float] = {}
|
| 540 |
+
for cat, items in categories.items():
|
| 541 |
+
vals = [r["overlap_score"] for r in items if r["overlap_score"] is not None]
|
| 542 |
+
cat_scores[cat] = (sum(vals) / len(vals)) if vals else 0.0
|
| 543 |
+
|
| 544 |
+
summary = {
|
| 545 |
+
"repo_dir": str(repo_dir),
|
| 546 |
+
"train_script": str(train_script),
|
| 547 |
+
"checkpoint": str(ckpt_path),
|
| 548 |
+
"config_path": str(config_path),
|
| 549 |
+
"tokenizer_dir": str(tokenizer_dir),
|
| 550 |
+
"device": str(tester.device),
|
| 551 |
+
"prompt_style": args.prompt_style,
|
| 552 |
+
"total_questions": len(results),
|
| 553 |
+
"avg_overlap_score": round(avg_overlap, 4),
|
| 554 |
+
"exact_match_rate": round(em_rate, 4),
|
| 555 |
+
"contains_reference_rate": round(contains_rate, 4),
|
| 556 |
+
"avg_latency_s": round(avg_latency, 3),
|
| 557 |
+
"avg_words_generated": round(avg_words, 1),
|
| 558 |
+
"scores_by_category": {k: round(v, 4) for k, v in cat_scores.items()},
|
| 559 |
+
"results": results,
|
| 560 |
+
}
|
| 561 |
+
|
| 562 |
+
print("=" * 60)
|
| 563 |
+
print("RÉSUMÉ")
|
| 564 |
+
print("=" * 60)
|
| 565 |
+
print(f"Questions testées : {len(results)}")
|
| 566 |
+
print(f"Overlap moyen : {avg_overlap:.1%}")
|
| 567 |
+
print(f"Exact match : {em_rate:.1%}")
|
| 568 |
+
print(f"Contains ref : {contains_rate:.1%}")
|
| 569 |
+
print(f"Latence moyenne : {avg_latency:.2f}s")
|
| 570 |
+
print(f"Mots moyens : {avg_words:.1f}")
|
| 571 |
+
print("Scores / catégorie:")
|
| 572 |
+
for cat, score in sorted(cat_scores.items()):
|
| 573 |
+
print(f" {cat:<15} {format_bar(score)} {score:.0%}")
|
| 574 |
+
print("=" * 60)
|
| 575 |
+
|
| 576 |
+
if args.save_report:
|
| 577 |
+
json_path, txt_path = save_reports(repo_dir, summary)
|
| 578 |
+
print(f"Rapports sauvegardés dans : {json_path}")
|
| 579 |
+
print(f" {txt_path}")
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
if __name__ == "__main__":
|
| 583 |
+
main()
|
simple_qa_test_finished_model (1).py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Test QA simple pour un modèle déjà entraîné.
|
| 6 |
+
|
| 7 |
+
Fonctionne avec le script : train_chatbot_100m_large.py
|
| 8 |
+
- charge la config depuis output_dir/train_config.json si disponible
|
| 9 |
+
- charge un checkpoint fini (par défaut: output_dir/sft_best.pt)
|
| 10 |
+
- pose une petite liste de questions QA
|
| 11 |
+
- calcule un score simple d'overlap lexical
|
| 12 |
+
- sauvegarde un rapport JSON + TXT
|
| 13 |
+
|
| 14 |
+
Exemples
|
| 15 |
+
--------
|
| 16 |
+
python simple_qa_test_finished_model.py --output_dir ./fr_100m
|
| 17 |
+
python simple_qa_test_finished_model.py --output_dir ./fr_100m --ckpt ./fr_100m/sft_final.pt
|
| 18 |
+
python simple_qa_test_finished_model.py --output_dir ./fr_100m --questions qa_questions.json
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
import argparse
|
| 24 |
+
import importlib.util
|
| 25 |
+
import json
|
| 26 |
+
import os
|
| 27 |
+
import sys
|
| 28 |
+
import time
|
| 29 |
+
import unicodedata
|
| 30 |
+
import re
|
| 31 |
+
from pathlib import Path
|
| 32 |
+
from typing import Dict, List, Optional
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
DEFAULT_QUESTIONS = [
|
| 36 |
+
{
|
| 37 |
+
"category": "Géographie",
|
| 38 |
+
"question": "Quelle est la capitale de la France ?",
|
| 39 |
+
"reference": "Paris",
|
| 40 |
+
},
|
| 41 |
+
{
|
| 42 |
+
"category": "Géographie",
|
| 43 |
+
"question": "Quel est le plus long fleuve d'Afrique ?",
|
| 44 |
+
"reference": "Le Nil",
|
| 45 |
+
},
|
| 46 |
+
{
|
| 47 |
+
"category": "Science",
|
| 48 |
+
"question": "Qu'est-ce que la photosynthèse ?",
|
| 49 |
+
"reference": "Processus par lequel les plantes convertissent la lumière en énergie",
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"category": "Science",
|
| 53 |
+
"question": "Combien d'os compte le corps humain adulte ?",
|
| 54 |
+
"reference": "206",
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"category": "Histoire",
|
| 58 |
+
"question": "En quelle année a eu lieu la Révolution française ?",
|
| 59 |
+
"reference": "1789",
|
| 60 |
+
},
|
| 61 |
+
{
|
| 62 |
+
"category": "Histoire",
|
| 63 |
+
"question": "Qui a écrit Les Misérables ?",
|
| 64 |
+
"reference": "Victor Hugo",
|
| 65 |
+
},
|
| 66 |
+
{
|
| 67 |
+
"category": "Mathématiques",
|
| 68 |
+
"question": "Quelle est la formule de l'aire d'un cercle ?",
|
| 69 |
+
"reference": "π × r²",
|
| 70 |
+
},
|
| 71 |
+
{
|
| 72 |
+
"category": "Langage",
|
| 73 |
+
"question": "Donne un synonyme du mot heureux.",
|
| 74 |
+
"reference": "joyeux",
|
| 75 |
+
},
|
| 76 |
+
{
|
| 77 |
+
"category": "Raisonnement",
|
| 78 |
+
"question": "Si j'ai 5 pommes et j'en donne 2, combien m'en reste-t-il ?",
|
| 79 |
+
"reference": "3",
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"category": "Dialogue",
|
| 83 |
+
"question": "Comment vas-tu aujourd'hui ?",
|
| 84 |
+
"reference": None,
|
| 85 |
+
},
|
| 86 |
+
]
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def normalize_text(s: str) -> str:
|
| 90 |
+
s = (s or "").strip().lower()
|
| 91 |
+
s = unicodedata.normalize("NFKD", s)
|
| 92 |
+
s = "".join(ch for ch in s if not unicodedata.combining(ch))
|
| 93 |
+
s = re.sub(r"[^\w\s]", " ", s, flags=re.UNICODE)
|
| 94 |
+
s = re.sub(r"\s+", " ", s).strip()
|
| 95 |
+
return s
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def lexical_overlap(reference: Optional[str], answer: str) -> Optional[float]:
|
| 99 |
+
if not reference:
|
| 100 |
+
return None
|
| 101 |
+
ref_tokens = set(normalize_text(reference).split())
|
| 102 |
+
ans_tokens = set(normalize_text(answer).split())
|
| 103 |
+
if not ref_tokens:
|
| 104 |
+
return 0.0
|
| 105 |
+
return len(ref_tokens & ans_tokens) / len(ref_tokens)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def exact_match(reference: Optional[str], answer: str) -> Optional[bool]:
|
| 109 |
+
if not reference:
|
| 110 |
+
return None
|
| 111 |
+
return normalize_text(reference) == normalize_text(answer)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def import_train_module(train_script_path: str):
|
| 115 |
+
path = Path(train_script_path)
|
| 116 |
+
if not path.exists():
|
| 117 |
+
raise FileNotFoundError(f"Script d'entraînement introuvable: {path}")
|
| 118 |
+
|
| 119 |
+
spec = importlib.util.spec_from_file_location("train_module", str(path))
|
| 120 |
+
if spec is None or spec.loader is None:
|
| 121 |
+
raise ImportError(f"Impossible de charger le module: {path}")
|
| 122 |
+
|
| 123 |
+
module = importlib.util.module_from_spec(spec)
|
| 124 |
+
sys.modules["train_module"] = module
|
| 125 |
+
spec.loader.exec_module(module)
|
| 126 |
+
return module
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def build_cfg(train_module, output_dir: str):
|
| 130 |
+
cfg_path = Path(output_dir) / "train_config.json"
|
| 131 |
+
if cfg_path.exists():
|
| 132 |
+
with open(cfg_path, "r", encoding="utf-8") as f:
|
| 133 |
+
saved = json.load(f)
|
| 134 |
+
cfg = train_module.TrainConfig(**saved)
|
| 135 |
+
else:
|
| 136 |
+
cfg = train_module.TrainConfig(output_dir=output_dir, tokenizer_prefix=f"{output_dir}/tokenizer")
|
| 137 |
+
|
| 138 |
+
cfg.output_dir = output_dir
|
| 139 |
+
cfg.tokenizer_prefix = f"{output_dir}/tokenizer"
|
| 140 |
+
return cfg
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def run_test(
|
| 144 |
+
train_script: str,
|
| 145 |
+
output_dir: str,
|
| 146 |
+
ckpt_path: str,
|
| 147 |
+
questions: List[Dict],
|
| 148 |
+
save_report: bool,
|
| 149 |
+
):
|
| 150 |
+
train_module = import_train_module(train_script)
|
| 151 |
+
cfg = build_cfg(train_module, output_dir)
|
| 152 |
+
bot = train_module.Chatbot(cfg, ckpt_path)
|
| 153 |
+
|
| 154 |
+
results = []
|
| 155 |
+
categories: Dict[str, List[Dict]] = {}
|
| 156 |
+
|
| 157 |
+
sep = "─" * 64
|
| 158 |
+
print(f"\n{'═'*64}")
|
| 159 |
+
print(" TEST QA SIMPLE — MODÈLE ENTRAÎNÉ")
|
| 160 |
+
print(f" Checkpoint : {ckpt_path}")
|
| 161 |
+
print(f" Questions : {len(questions)}")
|
| 162 |
+
print(f"{'═'*64}\n")
|
| 163 |
+
|
| 164 |
+
for i, item in enumerate(questions, 1):
|
| 165 |
+
q = item["question"]
|
| 166 |
+
ref = item.get("reference")
|
| 167 |
+
cat = item.get("category", "Général")
|
| 168 |
+
ctx = item.get("context", "")
|
| 169 |
+
|
| 170 |
+
t0 = time.time()
|
| 171 |
+
ans = bot.chat(q, context=ctx)
|
| 172 |
+
latency = time.time() - t0
|
| 173 |
+
|
| 174 |
+
overlap = lexical_overlap(ref, ans)
|
| 175 |
+
em = exact_match(ref, ans)
|
| 176 |
+
|
| 177 |
+
row = {
|
| 178 |
+
"id": i,
|
| 179 |
+
"category": cat,
|
| 180 |
+
"question": q,
|
| 181 |
+
"context": ctx,
|
| 182 |
+
"reference": ref,
|
| 183 |
+
"answer": ans,
|
| 184 |
+
"overlap_score": overlap,
|
| 185 |
+
"exact_match": em,
|
| 186 |
+
"latency_s": round(latency, 3),
|
| 187 |
+
"tokens_generated_approx": len(ans.split()),
|
| 188 |
+
}
|
| 189 |
+
results.append(row)
|
| 190 |
+
categories.setdefault(cat, []).append(row)
|
| 191 |
+
|
| 192 |
+
score_text = []
|
| 193 |
+
if overlap is not None:
|
| 194 |
+
score_text.append(f"overlap={overlap:.0%}")
|
| 195 |
+
if em is not None:
|
| 196 |
+
score_text.append(f"EM={'oui' if em else 'non'}")
|
| 197 |
+
score_str = f" [{' | '.join(score_text)}]" if score_text else ""
|
| 198 |
+
|
| 199 |
+
print(sep)
|
| 200 |
+
print(f"[{i:02d}] [{cat}]{score_str}")
|
| 201 |
+
if ctx:
|
| 202 |
+
print(f" Contexte : {ctx[:120]}{'...' if len(ctx) > 120 else ''}")
|
| 203 |
+
print(f" User : {q}")
|
| 204 |
+
print(f" Assistant : {ans}")
|
| 205 |
+
if ref:
|
| 206 |
+
print(f" Référence : {ref}")
|
| 207 |
+
print(f" ⏱ {latency:.2f}s | ~{row['tokens_generated_approx']} mots\n")
|
| 208 |
+
|
| 209 |
+
scored = [r for r in results if r["overlap_score"] is not None]
|
| 210 |
+
avg_overlap = sum(r["overlap_score"] for r in scored) / len(scored) if scored else 0.0
|
| 211 |
+
em_rows = [r for r in results if r["exact_match"] is not None]
|
| 212 |
+
em_rate = sum(1 for r in em_rows if r["exact_match"]) / len(em_rows) if em_rows else 0.0
|
| 213 |
+
avg_latency = sum(r["latency_s"] for r in results) / max(1, len(results))
|
| 214 |
+
avg_tokens = sum(r["tokens_generated_approx"] for r in results) / max(1, len(results))
|
| 215 |
+
|
| 216 |
+
scores_by_category = {}
|
| 217 |
+
for cat, items in categories.items():
|
| 218 |
+
cat_scored = [x for x in items if x["overlap_score"] is not None]
|
| 219 |
+
cat_em = [x for x in items if x["exact_match"] is not None]
|
| 220 |
+
scores_by_category[cat] = {
|
| 221 |
+
"avg_overlap": round(sum(x["overlap_score"] for x in cat_scored) / len(cat_scored), 4) if cat_scored else None,
|
| 222 |
+
"exact_match_rate": round(sum(1 for x in cat_em if x["exact_match"]) / len(cat_em), 4) if cat_em else None,
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
summary = {
|
| 226 |
+
"checkpoint": ckpt_path,
|
| 227 |
+
"total_questions": len(results),
|
| 228 |
+
"avg_overlap_score": round(avg_overlap, 4),
|
| 229 |
+
"exact_match_rate": round(em_rate, 4),
|
| 230 |
+
"avg_latency_s": round(avg_latency, 3),
|
| 231 |
+
"avg_tokens_generated_approx": round(avg_tokens, 1),
|
| 232 |
+
"scores_by_category": scores_by_category,
|
| 233 |
+
"results": results,
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
print(f"{'═'*64}")
|
| 237 |
+
print(" RÉSUMÉ")
|
| 238 |
+
print(f"{'═'*64}")
|
| 239 |
+
print(f" Questions testées : {len(results)}")
|
| 240 |
+
print(f" Overlap moyen : {avg_overlap:.1%}")
|
| 241 |
+
print(f" Exact match : {em_rate:.1%}")
|
| 242 |
+
print(f" Latence moyenne : {avg_latency:.2f}s")
|
| 243 |
+
print(f" Mots moyens : {avg_tokens:.0f}")
|
| 244 |
+
print(" Scores / catégorie :")
|
| 245 |
+
for cat, sc in scores_by_category.items():
|
| 246 |
+
ov = sc["avg_overlap"]
|
| 247 |
+
emc = sc["exact_match_rate"]
|
| 248 |
+
print(f" - {cat:<15} overlap={ov if ov is not None else 'n/a'} | EM={emc if emc is not None else 'n/a'}")
|
| 249 |
+
print(f"{'═'*64}\n")
|
| 250 |
+
|
| 251 |
+
if save_report:
|
| 252 |
+
report_json = Path(output_dir) / "qa_test_simple_report.json"
|
| 253 |
+
report_txt = Path(output_dir) / "qa_test_simple_report.txt"
|
| 254 |
+
|
| 255 |
+
with open(report_json, "w", encoding="utf-8") as f:
|
| 256 |
+
json.dump(summary, f, ensure_ascii=False, indent=2)
|
| 257 |
+
|
| 258 |
+
with open(report_txt, "w", encoding="utf-8") as f:
|
| 259 |
+
f.write("TEST QA SIMPLE — MODÈLE ENTRAÎNÉ\n")
|
| 260 |
+
f.write(f"Checkpoint : {ckpt_path}\n\n")
|
| 261 |
+
for r in results:
|
| 262 |
+
f.write(f"[{r['id']:02d}] {r['category']}\n")
|
| 263 |
+
if r["context"]:
|
| 264 |
+
f.write(f" Contexte : {r['context']}\n")
|
| 265 |
+
f.write(f" User : {r['question']}\n")
|
| 266 |
+
f.write(f" Assistant : {r['answer']}\n")
|
| 267 |
+
if r["reference"]:
|
| 268 |
+
f.write(f" Référence : {r['reference']}\n")
|
| 269 |
+
if r["overlap_score"] is not None:
|
| 270 |
+
f.write(f" Overlap : {r['overlap_score']:.0%}\n")
|
| 271 |
+
if r["exact_match"] is not None:
|
| 272 |
+
f.write(f" EM : {'oui' if r['exact_match'] else 'non'}\n")
|
| 273 |
+
f.write(f" Latence : {r['latency_s']}s\n\n")
|
| 274 |
+
|
| 275 |
+
print(f"Rapport JSON -> {report_json}")
|
| 276 |
+
print(f"Rapport TXT -> {report_txt}")
|
| 277 |
+
|
| 278 |
+
return summary
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
if __name__ == "__main__":
|
| 282 |
+
parser = argparse.ArgumentParser("Test QA simple pour modèle déjà entraîné")
|
| 283 |
+
parser.add_argument("--train_script", type=str, default="./train_chatbot_100m_large.py")
|
| 284 |
+
parser.add_argument("--output_dir", type=str, default="./fr_100m")
|
| 285 |
+
parser.add_argument("--ckpt", type=str, default=None)
|
| 286 |
+
parser.add_argument("--questions", type=str, default=None, help="JSON optionnel [{question, reference, category, context?}]")
|
| 287 |
+
parser.add_argument("--no_save", action="store_true")
|
| 288 |
+
args = parser.parse_args()
|
| 289 |
+
|
| 290 |
+
ckpt_path = args.ckpt or os.path.join(args.output_dir, "sft_best.pt")
|
| 291 |
+
if not Path(ckpt_path).exists():
|
| 292 |
+
raise FileNotFoundError(
|
| 293 |
+
f"Checkpoint introuvable: {ckpt_path}\n"
|
| 294 |
+
f"Exemple: --ckpt {args.output_dir}/sft_final.pt"
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
if args.questions:
|
| 298 |
+
with open(args.questions, "r", encoding="utf-8") as f:
|
| 299 |
+
questions = json.load(f)
|
| 300 |
+
else:
|
| 301 |
+
questions = DEFAULT_QUESTIONS
|
| 302 |
+
|
| 303 |
+
run_test(
|
| 304 |
+
train_script=args.train_script,
|
| 305 |
+
output_dir=args.output_dir,
|
| 306 |
+
ckpt_path=ckpt_path,
|
| 307 |
+
questions=questions,
|
| 308 |
+
save_report=not args.no_save,
|
| 309 |
+
)
|
simple_qa_test_finished_model.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Test QA simple pour un modèle déjà entraîné.
|
| 6 |
+
|
| 7 |
+
Fonctionne avec le script : train_chatbot_100m_large.py
|
| 8 |
+
- charge la config depuis output_dir/train_config.json si disponible
|
| 9 |
+
- charge un checkpoint fini (par défaut: output_dir/sft_best.pt)
|
| 10 |
+
- pose une petite liste de questions QA
|
| 11 |
+
- calcule un score simple d'overlap lexical
|
| 12 |
+
- sauvegarde un rapport JSON + TXT
|
| 13 |
+
|
| 14 |
+
Exemples
|
| 15 |
+
--------
|
| 16 |
+
python simple_qa_test_finished_model.py --output_dir ./fr_100m
|
| 17 |
+
python simple_qa_test_finished_model.py --output_dir ./fr_100m --ckpt ./fr_100m/sft_final.pt
|
| 18 |
+
python simple_qa_test_finished_model.py --output_dir ./fr_100m --questions qa_questions.json
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
import argparse
|
| 24 |
+
import importlib.util
|
| 25 |
+
import json
|
| 26 |
+
import os
|
| 27 |
+
import sys
|
| 28 |
+
import time
|
| 29 |
+
import unicodedata
|
| 30 |
+
import re
|
| 31 |
+
from pathlib import Path
|
| 32 |
+
from typing import Dict, List, Optional
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
DEFAULT_QUESTIONS = [
|
| 36 |
+
{
|
| 37 |
+
"category": "Géographie",
|
| 38 |
+
"question": "Quelle est la capitale de la France ?",
|
| 39 |
+
"reference": "Paris",
|
| 40 |
+
},
|
| 41 |
+
{
|
| 42 |
+
"category": "Géographie",
|
| 43 |
+
"question": "Quel est le plus long fleuve d'Afrique ?",
|
| 44 |
+
"reference": "Le Nil",
|
| 45 |
+
},
|
| 46 |
+
{
|
| 47 |
+
"category": "Science",
|
| 48 |
+
"question": "Qu'est-ce que la photosynthèse ?",
|
| 49 |
+
"reference": "Processus par lequel les plantes convertissent la lumière en énergie",
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"category": "Science",
|
| 53 |
+
"question": "Combien d'os compte le corps humain adulte ?",
|
| 54 |
+
"reference": "206",
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"category": "Histoire",
|
| 58 |
+
"question": "En quelle année a eu lieu la Révolution française ?",
|
| 59 |
+
"reference": "1789",
|
| 60 |
+
},
|
| 61 |
+
{
|
| 62 |
+
"category": "Histoire",
|
| 63 |
+
"question": "Qui a écrit Les Misérables ?",
|
| 64 |
+
"reference": "Victor Hugo",
|
| 65 |
+
},
|
| 66 |
+
{
|
| 67 |
+
"category": "Mathématiques",
|
| 68 |
+
"question": "Quelle est la formule de l'aire d'un cercle ?",
|
| 69 |
+
"reference": "π × r²",
|
| 70 |
+
},
|
| 71 |
+
{
|
| 72 |
+
"category": "Langage",
|
| 73 |
+
"question": "Donne un synonyme du mot heureux.",
|
| 74 |
+
"reference": "joyeux",
|
| 75 |
+
},
|
| 76 |
+
{
|
| 77 |
+
"category": "Raisonnement",
|
| 78 |
+
"question": "Si j'ai 5 pommes et j'en donne 2, combien m'en reste-t-il ?",
|
| 79 |
+
"reference": "3",
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"category": "Dialogue",
|
| 83 |
+
"question": "Comment vas-tu aujourd'hui ?",
|
| 84 |
+
"reference": None,
|
| 85 |
+
},
|
| 86 |
+
]
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def normalize_text(s: str) -> str:
|
| 90 |
+
s = (s or "").strip().lower()
|
| 91 |
+
s = unicodedata.normalize("NFKD", s)
|
| 92 |
+
s = "".join(ch for ch in s if not unicodedata.combining(ch))
|
| 93 |
+
s = re.sub(r"[^\w\s]", " ", s, flags=re.UNICODE)
|
| 94 |
+
s = re.sub(r"\s+", " ", s).strip()
|
| 95 |
+
return s
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def lexical_overlap(reference: Optional[str], answer: str) -> Optional[float]:
|
| 99 |
+
if not reference:
|
| 100 |
+
return None
|
| 101 |
+
ref_tokens = set(normalize_text(reference).split())
|
| 102 |
+
ans_tokens = set(normalize_text(answer).split())
|
| 103 |
+
if not ref_tokens:
|
| 104 |
+
return 0.0
|
| 105 |
+
return len(ref_tokens & ans_tokens) / len(ref_tokens)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def exact_match(reference: Optional[str], answer: str) -> Optional[bool]:
|
| 109 |
+
if not reference:
|
| 110 |
+
return None
|
| 111 |
+
return normalize_text(reference) == normalize_text(answer)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def import_train_module(train_script_path: str):
|
| 115 |
+
path = Path(train_script_path)
|
| 116 |
+
if not path.exists():
|
| 117 |
+
raise FileNotFoundError(f"Script d'entraînement introuvable: {path}")
|
| 118 |
+
|
| 119 |
+
spec = importlib.util.spec_from_file_location("train_module", str(path))
|
| 120 |
+
if spec is None or spec.loader is None:
|
| 121 |
+
raise ImportError(f"Impossible de charger le module: {path}")
|
| 122 |
+
|
| 123 |
+
module = importlib.util.module_from_spec(spec)
|
| 124 |
+
sys.modules["train_module"] = module
|
| 125 |
+
spec.loader.exec_module(module)
|
| 126 |
+
return module
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def build_cfg(train_module, output_dir: str):
|
| 130 |
+
cfg_path = Path(output_dir) / "train_config.json"
|
| 131 |
+
if cfg_path.exists():
|
| 132 |
+
with open(cfg_path, "r", encoding="utf-8") as f:
|
| 133 |
+
saved = json.load(f)
|
| 134 |
+
cfg = train_module.TrainConfig(**saved)
|
| 135 |
+
else:
|
| 136 |
+
cfg = train_module.TrainConfig(output_dir=output_dir, tokenizer_prefix=f"{output_dir}/tokenizer")
|
| 137 |
+
|
| 138 |
+
cfg.output_dir = output_dir
|
| 139 |
+
cfg.tokenizer_prefix = f"{output_dir}/tokenizer"
|
| 140 |
+
return cfg
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def run_test(
|
| 144 |
+
train_script: str,
|
| 145 |
+
output_dir: str,
|
| 146 |
+
ckpt_path: str,
|
| 147 |
+
questions: List[Dict],
|
| 148 |
+
save_report: bool,
|
| 149 |
+
):
|
| 150 |
+
train_module = import_train_module(train_script)
|
| 151 |
+
cfg = build_cfg(train_module, output_dir)
|
| 152 |
+
bot = train_module.Chatbot(cfg, ckpt_path)
|
| 153 |
+
|
| 154 |
+
results = []
|
| 155 |
+
categories: Dict[str, List[Dict]] = {}
|
| 156 |
+
|
| 157 |
+
sep = "─" * 64
|
| 158 |
+
print(f"\n{'═'*64}")
|
| 159 |
+
print(" TEST QA SIMPLE — MODÈLE ENTRAÎNÉ")
|
| 160 |
+
print(f" Checkpoint : {ckpt_path}")
|
| 161 |
+
print(f" Questions : {len(questions)}")
|
| 162 |
+
print(f"{'═'*64}\n")
|
| 163 |
+
|
| 164 |
+
for i, item in enumerate(questions, 1):
|
| 165 |
+
q = item["question"]
|
| 166 |
+
ref = item.get("reference")
|
| 167 |
+
cat = item.get("category", "Général")
|
| 168 |
+
ctx = item.get("context", "")
|
| 169 |
+
|
| 170 |
+
t0 = time.time()
|
| 171 |
+
ans = bot.chat(q, context=ctx)
|
| 172 |
+
latency = time.time() - t0
|
| 173 |
+
|
| 174 |
+
overlap = lexical_overlap(ref, ans)
|
| 175 |
+
em = exact_match(ref, ans)
|
| 176 |
+
|
| 177 |
+
row = {
|
| 178 |
+
"id": i,
|
| 179 |
+
"category": cat,
|
| 180 |
+
"question": q,
|
| 181 |
+
"context": ctx,
|
| 182 |
+
"reference": ref,
|
| 183 |
+
"answer": ans,
|
| 184 |
+
"overlap_score": overlap,
|
| 185 |
+
"exact_match": em,
|
| 186 |
+
"latency_s": round(latency, 3),
|
| 187 |
+
"tokens_generated_approx": len(ans.split()),
|
| 188 |
+
}
|
| 189 |
+
results.append(row)
|
| 190 |
+
categories.setdefault(cat, []).append(row)
|
| 191 |
+
|
| 192 |
+
score_text = []
|
| 193 |
+
if overlap is not None:
|
| 194 |
+
score_text.append(f"overlap={overlap:.0%}")
|
| 195 |
+
if em is not None:
|
| 196 |
+
score_text.append(f"EM={'oui' if em else 'non'}")
|
| 197 |
+
score_str = f" [{' | '.join(score_text)}]" if score_text else ""
|
| 198 |
+
|
| 199 |
+
print(sep)
|
| 200 |
+
print(f"[{i:02d}] [{cat}]{score_str}")
|
| 201 |
+
if ctx:
|
| 202 |
+
print(f" Contexte : {ctx[:120]}{'...' if len(ctx) > 120 else ''}")
|
| 203 |
+
print(f" User : {q}")
|
| 204 |
+
print(f" Assistant : {ans}")
|
| 205 |
+
if ref:
|
| 206 |
+
print(f" Référence : {ref}")
|
| 207 |
+
print(f" ⏱ {latency:.2f}s | ~{row['tokens_generated_approx']} mots\n")
|
| 208 |
+
|
| 209 |
+
scored = [r for r in results if r["overlap_score"] is not None]
|
| 210 |
+
avg_overlap = sum(r["overlap_score"] for r in scored) / len(scored) if scored else 0.0
|
| 211 |
+
em_rows = [r for r in results if r["exact_match"] is not None]
|
| 212 |
+
em_rate = sum(1 for r in em_rows if r["exact_match"]) / len(em_rows) if em_rows else 0.0
|
| 213 |
+
avg_latency = sum(r["latency_s"] for r in results) / max(1, len(results))
|
| 214 |
+
avg_tokens = sum(r["tokens_generated_approx"] for r in results) / max(1, len(results))
|
| 215 |
+
|
| 216 |
+
scores_by_category = {}
|
| 217 |
+
for cat, items in categories.items():
|
| 218 |
+
cat_scored = [x for x in items if x["overlap_score"] is not None]
|
| 219 |
+
cat_em = [x for x in items if x["exact_match"] is not None]
|
| 220 |
+
scores_by_category[cat] = {
|
| 221 |
+
"avg_overlap": round(sum(x["overlap_score"] for x in cat_scored) / len(cat_scored), 4) if cat_scored else None,
|
| 222 |
+
"exact_match_rate": round(sum(1 for x in cat_em if x["exact_match"]) / len(cat_em), 4) if cat_em else None,
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
summary = {
|
| 226 |
+
"checkpoint": ckpt_path,
|
| 227 |
+
"total_questions": len(results),
|
| 228 |
+
"avg_overlap_score": round(avg_overlap, 4),
|
| 229 |
+
"exact_match_rate": round(em_rate, 4),
|
| 230 |
+
"avg_latency_s": round(avg_latency, 3),
|
| 231 |
+
"avg_tokens_generated_approx": round(avg_tokens, 1),
|
| 232 |
+
"scores_by_category": scores_by_category,
|
| 233 |
+
"results": results,
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
print(f"{'═'*64}")
|
| 237 |
+
print(" RÉSUMÉ")
|
| 238 |
+
print(f"{'═'*64}")
|
| 239 |
+
print(f" Questions testées : {len(results)}")
|
| 240 |
+
print(f" Overlap moyen : {avg_overlap:.1%}")
|
| 241 |
+
print(f" Exact match : {em_rate:.1%}")
|
| 242 |
+
print(f" Latence moyenne : {avg_latency:.2f}s")
|
| 243 |
+
print(f" Mots moyens : {avg_tokens:.0f}")
|
| 244 |
+
print(" Scores / catégorie :")
|
| 245 |
+
for cat, sc in scores_by_category.items():
|
| 246 |
+
ov = sc["avg_overlap"]
|
| 247 |
+
emc = sc["exact_match_rate"]
|
| 248 |
+
print(f" - {cat:<15} overlap={ov if ov is not None else 'n/a'} | EM={emc if emc is not None else 'n/a'}")
|
| 249 |
+
print(f"{'═'*64}\n")
|
| 250 |
+
|
| 251 |
+
if save_report:
|
| 252 |
+
report_json = Path(output_dir) / "qa_test_simple_report.json"
|
| 253 |
+
report_txt = Path(output_dir) / "qa_test_simple_report.txt"
|
| 254 |
+
|
| 255 |
+
with open(report_json, "w", encoding="utf-8") as f:
|
| 256 |
+
json.dump(summary, f, ensure_ascii=False, indent=2)
|
| 257 |
+
|
| 258 |
+
with open(report_txt, "w", encoding="utf-8") as f:
|
| 259 |
+
f.write("TEST QA SIMPLE — MODÈLE ENTRAÎNÉ\n")
|
| 260 |
+
f.write(f"Checkpoint : {ckpt_path}\n\n")
|
| 261 |
+
for r in results:
|
| 262 |
+
f.write(f"[{r['id']:02d}] {r['category']}\n")
|
| 263 |
+
if r["context"]:
|
| 264 |
+
f.write(f" Contexte : {r['context']}\n")
|
| 265 |
+
f.write(f" User : {r['question']}\n")
|
| 266 |
+
f.write(f" Assistant : {r['answer']}\n")
|
| 267 |
+
if r["reference"]:
|
| 268 |
+
f.write(f" Référence : {r['reference']}\n")
|
| 269 |
+
if r["overlap_score"] is not None:
|
| 270 |
+
f.write(f" Overlap : {r['overlap_score']:.0%}\n")
|
| 271 |
+
if r["exact_match"] is not None:
|
| 272 |
+
f.write(f" EM : {'oui' if r['exact_match'] else 'non'}\n")
|
| 273 |
+
f.write(f" Latence : {r['latency_s']}s\n\n")
|
| 274 |
+
|
| 275 |
+
print(f"Rapport JSON -> {report_json}")
|
| 276 |
+
print(f"Rapport TXT -> {report_txt}")
|
| 277 |
+
|
| 278 |
+
return summary
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
if __name__ == "__main__":
|
| 282 |
+
parser = argparse.ArgumentParser("Test QA simple pour modèle déjà entraîné")
|
| 283 |
+
parser.add_argument("--train_script", type=str, default="./train_chatbot_100m_large.py")
|
| 284 |
+
parser.add_argument("--output_dir", type=str, default="./fr_100m")
|
| 285 |
+
parser.add_argument("--ckpt", type=str, default=None)
|
| 286 |
+
parser.add_argument("--questions", type=str, default=None, help="JSON optionnel [{question, reference, category, context?}]")
|
| 287 |
+
parser.add_argument("--no_save", action="store_true")
|
| 288 |
+
args = parser.parse_args()
|
| 289 |
+
|
| 290 |
+
ckpt_path = args.ckpt or os.path.join(args.output_dir, "sft_best.pt")
|
| 291 |
+
if not Path(ckpt_path).exists():
|
| 292 |
+
raise FileNotFoundError(
|
| 293 |
+
f"Checkpoint introuvable: {ckpt_path}\n"
|
| 294 |
+
f"Exemple: --ckpt {args.output_dir}/sft_final.pt"
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
if args.questions:
|
| 298 |
+
with open(args.questions, "r", encoding="utf-8") as f:
|
| 299 |
+
questions = json.load(f)
|
| 300 |
+
else:
|
| 301 |
+
questions = DEFAULT_QUESTIONS
|
| 302 |
+
|
| 303 |
+
run_test(
|
| 304 |
+
train_script=args.train_script,
|
| 305 |
+
output_dir=args.output_dir,
|
| 306 |
+
ckpt_path=ckpt_path,
|
| 307 |
+
questions=questions,
|
| 308 |
+
save_report=not args.no_save,
|
| 309 |
+
)
|
test.py
ADDED
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import json
|
| 8 |
+
from collections import OrderedDict
|
| 9 |
+
from contextlib import nullcontext
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Optional
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
from transformers import PreTrainedTokenizerFast
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# ============================================================
|
| 21 |
+
# Paths par défaut
|
| 22 |
+
# ============================================================
|
| 23 |
+
|
| 24 |
+
MODEL_DIR = Path("./nlp_1b_wiki_en_fr_ar")
|
| 25 |
+
|
| 26 |
+
DEFAULT_CHECKPOINT = MODEL_DIR / "model_best.pt"
|
| 27 |
+
DEFAULT_CONFIG = MODEL_DIR / "config.json"
|
| 28 |
+
DEFAULT_TOKENIZER_DIR = MODEL_DIR / "tokenizer_32k"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# ============================================================
|
| 32 |
+
# Utils
|
| 33 |
+
# ============================================================
|
| 34 |
+
|
| 35 |
+
def get_device() -> torch.device:
|
| 36 |
+
if torch.cuda.is_available():
|
| 37 |
+
return torch.device(f"cuda:{torch.cuda.current_device()}")
|
| 38 |
+
return torch.device("cpu")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def autocast_context(device: torch.device):
|
| 42 |
+
if device.type == "cuda":
|
| 43 |
+
return torch.autocast("cuda", dtype=torch.bfloat16)
|
| 44 |
+
return nullcontext()
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def normalize_state_dict_keys(state_dict: dict) -> OrderedDict:
|
| 48 |
+
normalized = OrderedDict()
|
| 49 |
+
for k, v in state_dict.items():
|
| 50 |
+
nk = k
|
| 51 |
+
if nk.startswith("module._orig_mod."):
|
| 52 |
+
nk = nk[len("module._orig_mod."):]
|
| 53 |
+
elif nk.startswith("_orig_mod."):
|
| 54 |
+
nk = nk[len("_orig_mod."):]
|
| 55 |
+
elif nk.startswith("module."):
|
| 56 |
+
nk = nk[len("module."):]
|
| 57 |
+
normalized[nk] = v
|
| 58 |
+
return normalized
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def postprocess_text(text: str) -> str:
|
| 62 |
+
return text.strip()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# ============================================================
|
| 66 |
+
# Architecture
|
| 67 |
+
# ============================================================
|
| 68 |
+
|
| 69 |
+
@dataclass
|
| 70 |
+
class GPTConfig:
|
| 71 |
+
vocab_size: int
|
| 72 |
+
block_size: int
|
| 73 |
+
d_model: int
|
| 74 |
+
n_heads: int
|
| 75 |
+
n_layers: int
|
| 76 |
+
d_ff: int
|
| 77 |
+
dropout: float = 0.0
|
| 78 |
+
use_checkpointing: bool = False
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class RMSNorm(nn.Module):
|
| 82 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 85 |
+
self.eps = eps
|
| 86 |
+
|
| 87 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 88 |
+
return self.weight * x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class RotaryEmbedding(nn.Module):
|
| 92 |
+
def __init__(self, dim: int, base: int = 10000, max_seq: int = 4096):
|
| 93 |
+
super().__init__()
|
| 94 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
| 95 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 96 |
+
|
| 97 |
+
t = torch.arange(max_seq).float()
|
| 98 |
+
freqs = torch.outer(t, inv_freq)
|
| 99 |
+
|
| 100 |
+
self.register_buffer(
|
| 101 |
+
"cos_cache",
|
| 102 |
+
torch.repeat_interleave(freqs.cos(), 2, dim=-1),
|
| 103 |
+
persistent=False,
|
| 104 |
+
)
|
| 105 |
+
self.register_buffer(
|
| 106 |
+
"sin_cache",
|
| 107 |
+
torch.repeat_interleave(freqs.sin(), 2, dim=-1),
|
| 108 |
+
persistent=False,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
def forward(self, seq_len: int, dtype: torch.dtype):
|
| 112 |
+
return self.cos_cache[:seq_len].to(dtype), self.sin_cache[:seq_len].to(dtype)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| 116 |
+
x1, x2 = x[..., ::2], x[..., 1::2]
|
| 117 |
+
return torch.stack((-x2, x1), dim=-1).flatten(-2)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
|
| 121 |
+
cos = cos.unsqueeze(0).unsqueeze(0)
|
| 122 |
+
sin = sin.unsqueeze(0).unsqueeze(0)
|
| 123 |
+
return x * cos + rotate_half(x) * sin
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class CausalSelfAttention(nn.Module):
|
| 127 |
+
def __init__(self, cfg: GPTConfig):
|
| 128 |
+
super().__init__()
|
| 129 |
+
assert cfg.d_model % cfg.n_heads == 0
|
| 130 |
+
self.n_heads = cfg.n_heads
|
| 131 |
+
self.head_dim = cfg.d_model // cfg.n_heads
|
| 132 |
+
self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
|
| 133 |
+
self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
|
| 134 |
+
self.rope = RotaryEmbedding(self.head_dim)
|
| 135 |
+
|
| 136 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 137 |
+
b, t, c = x.shape
|
| 138 |
+
q, k, v = self.qkv(x).split(c, dim=-1)
|
| 139 |
+
|
| 140 |
+
q = q.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
|
| 141 |
+
k = k.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
|
| 142 |
+
v = v.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
|
| 143 |
+
|
| 144 |
+
cos, sin = self.rope(t, x.dtype)
|
| 145 |
+
q = apply_rope(q, cos, sin)
|
| 146 |
+
k = apply_rope(k, cos, sin)
|
| 147 |
+
|
| 148 |
+
y = F.scaled_dot_product_attention(
|
| 149 |
+
q, k, v,
|
| 150 |
+
dropout_p=0.0,
|
| 151 |
+
is_causal=True,
|
| 152 |
+
)
|
| 153 |
+
y = y.transpose(1, 2).contiguous().view(b, t, c)
|
| 154 |
+
return self.proj(y)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class SwiGLU(nn.Module):
|
| 158 |
+
def __init__(self, cfg: GPTConfig):
|
| 159 |
+
super().__init__()
|
| 160 |
+
self.w1 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
|
| 161 |
+
self.w2 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
|
| 162 |
+
self.w3 = nn.Linear(cfg.d_ff, cfg.d_model, bias=False)
|
| 163 |
+
|
| 164 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 165 |
+
return self.w3(F.silu(self.w1(x)) * self.w2(x))
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class Block(nn.Module):
|
| 169 |
+
def __init__(self, cfg: GPTConfig):
|
| 170 |
+
super().__init__()
|
| 171 |
+
self.ln1 = RMSNorm(cfg.d_model)
|
| 172 |
+
self.attn = CausalSelfAttention(cfg)
|
| 173 |
+
self.ln2 = RMSNorm(cfg.d_model)
|
| 174 |
+
self.ff = SwiGLU(cfg)
|
| 175 |
+
|
| 176 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 177 |
+
x = x + self.attn(self.ln1(x))
|
| 178 |
+
x = x + self.ff(self.ln2(x))
|
| 179 |
+
return x
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class GPT(nn.Module):
|
| 183 |
+
def __init__(self, cfg: GPTConfig):
|
| 184 |
+
super().__init__()
|
| 185 |
+
self.cfg = cfg
|
| 186 |
+
self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
|
| 187 |
+
self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)])
|
| 188 |
+
self.ln_f = RMSNorm(cfg.d_model)
|
| 189 |
+
self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
|
| 190 |
+
self.lm_head.weight = self.tok_emb.weight
|
| 191 |
+
|
| 192 |
+
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 193 |
+
x = self.tok_emb(input_ids)
|
| 194 |
+
for block in self.blocks:
|
| 195 |
+
x = block(x)
|
| 196 |
+
return self.lm_head(self.ln_f(x))
|
| 197 |
+
|
| 198 |
+
@torch.inference_mode()
|
| 199 |
+
def generate(
|
| 200 |
+
self,
|
| 201 |
+
input_ids: torch.Tensor,
|
| 202 |
+
max_new_tokens: int = 160,
|
| 203 |
+
temperature: float = 0.8,
|
| 204 |
+
top_k: int = 50,
|
| 205 |
+
top_p: float = 0.95,
|
| 206 |
+
repetition_penalty: float = 1.05,
|
| 207 |
+
eos_token_id: Optional[int] = None,
|
| 208 |
+
) -> torch.Tensor:
|
| 209 |
+
self.eval()
|
| 210 |
+
|
| 211 |
+
for _ in range(max_new_tokens):
|
| 212 |
+
idx_cond = input_ids[:, -self.cfg.block_size :]
|
| 213 |
+
logits = self(idx_cond)
|
| 214 |
+
logits = logits[:, -1, :]
|
| 215 |
+
|
| 216 |
+
if repetition_penalty != 1.0:
|
| 217 |
+
for b in range(input_ids.size(0)):
|
| 218 |
+
seen = torch.unique(input_ids[b])
|
| 219 |
+
logits[b, seen] /= repetition_penalty
|
| 220 |
+
|
| 221 |
+
if temperature <= 0:
|
| 222 |
+
next_token = torch.argmax(logits, dim=-1, keepdim=True)
|
| 223 |
+
else:
|
| 224 |
+
logits = logits / temperature
|
| 225 |
+
|
| 226 |
+
if top_k > 0:
|
| 227 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 228 |
+
logits[logits < v[:, [-1]]] = -float("inf")
|
| 229 |
+
|
| 230 |
+
if 0 < top_p < 1.0:
|
| 231 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 232 |
+
probs = F.softmax(sorted_logits, dim=-1)
|
| 233 |
+
cumulative_probs = torch.cumsum(probs, dim=-1)
|
| 234 |
+
|
| 235 |
+
sorted_mask = cumulative_probs > top_p
|
| 236 |
+
sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
|
| 237 |
+
sorted_mask[..., 0] = False
|
| 238 |
+
|
| 239 |
+
mask = torch.zeros_like(logits, dtype=torch.bool)
|
| 240 |
+
mask.scatter_(1, sorted_indices, sorted_mask)
|
| 241 |
+
logits = logits.masked_fill(mask, -float("inf"))
|
| 242 |
+
|
| 243 |
+
probs = F.softmax(logits, dim=-1)
|
| 244 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 245 |
+
|
| 246 |
+
input_ids = torch.cat([input_ids, next_token], dim=1)
|
| 247 |
+
|
| 248 |
+
if eos_token_id is not None and (next_token == eos_token_id).all():
|
| 249 |
+
break
|
| 250 |
+
|
| 251 |
+
return input_ids
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
# ============================================================
|
| 255 |
+
# Load / generate
|
| 256 |
+
# ============================================================
|
| 257 |
+
|
| 258 |
+
def load_model_and_tokenizer(
|
| 259 |
+
checkpoint_path: Path,
|
| 260 |
+
config_path: Path,
|
| 261 |
+
tokenizer_dir: Path,
|
| 262 |
+
device: torch.device,
|
| 263 |
+
use_compile: bool = False,
|
| 264 |
+
):
|
| 265 |
+
if not checkpoint_path.exists():
|
| 266 |
+
raise FileNotFoundError(f"Checkpoint introuvable: {checkpoint_path}")
|
| 267 |
+
if not config_path.exists():
|
| 268 |
+
raise FileNotFoundError(f"Config introuvable: {config_path}")
|
| 269 |
+
if not tokenizer_dir.exists():
|
| 270 |
+
raise FileNotFoundError(f"Tokenizer introuvable: {tokenizer_dir}")
|
| 271 |
+
|
| 272 |
+
cfg_dict = json.loads(config_path.read_text(encoding="utf-8"))
|
| 273 |
+
cfg = GPTConfig(**cfg_dict)
|
| 274 |
+
|
| 275 |
+
tokenizer = PreTrainedTokenizerFast.from_pretrained(str(tokenizer_dir))
|
| 276 |
+
model = GPT(cfg).to(device)
|
| 277 |
+
|
| 278 |
+
ckpt = torch.load(checkpoint_path, map_location=device)
|
| 279 |
+
state_dict = normalize_state_dict_keys(ckpt["model"])
|
| 280 |
+
model.load_state_dict(state_dict, strict=True)
|
| 281 |
+
model.eval()
|
| 282 |
+
|
| 283 |
+
if use_compile and hasattr(torch, "compile"):
|
| 284 |
+
model = torch.compile(model, mode="default")
|
| 285 |
+
|
| 286 |
+
return model, tokenizer, ckpt
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def generate_text(
|
| 290 |
+
model: GPT,
|
| 291 |
+
tokenizer: PreTrainedTokenizerFast,
|
| 292 |
+
prompt: str,
|
| 293 |
+
device: torch.device,
|
| 294 |
+
max_new_tokens: int,
|
| 295 |
+
temperature: float,
|
| 296 |
+
top_k: int,
|
| 297 |
+
top_p: float,
|
| 298 |
+
repetition_penalty: float,
|
| 299 |
+
) -> str:
|
| 300 |
+
encoded = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
|
| 301 |
+
input_ids = encoded["input_ids"].to(device)
|
| 302 |
+
|
| 303 |
+
if tokenizer.bos_token_id is not None:
|
| 304 |
+
bos = torch.tensor([[tokenizer.bos_token_id]], device=device, dtype=input_ids.dtype)
|
| 305 |
+
input_ids = torch.cat([bos, input_ids], dim=1)
|
| 306 |
+
|
| 307 |
+
prompt_len = input_ids.shape[1]
|
| 308 |
+
|
| 309 |
+
with autocast_context(device):
|
| 310 |
+
output_ids = model.generate(
|
| 311 |
+
input_ids=input_ids,
|
| 312 |
+
max_new_tokens=max_new_tokens,
|
| 313 |
+
temperature=temperature,
|
| 314 |
+
top_k=top_k,
|
| 315 |
+
top_p=top_p,
|
| 316 |
+
repetition_penalty=repetition_penalty,
|
| 317 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
generated_ids = output_ids[0][prompt_len:]
|
| 321 |
+
text = tokenizer.decode(generated_ids, skip_special_tokens=True)
|
| 322 |
+
return postprocess_text(text)
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def main():
|
| 326 |
+
parser = argparse.ArgumentParser()
|
| 327 |
+
parser.add_argument("--checkpoint", type=str, default=str(DEFAULT_CHECKPOINT))
|
| 328 |
+
parser.add_argument("--config", type=str, default=str(DEFAULT_CONFIG))
|
| 329 |
+
parser.add_argument("--tokenizer_dir", type=str, default=str(DEFAULT_TOKENIZER_DIR))
|
| 330 |
+
parser.add_argument("--prompt", type=str, default="Wikipedia is a free online encyclopedia")
|
| 331 |
+
parser.add_argument("--max_new_tokens", type=int, default=160)
|
| 332 |
+
parser.add_argument("--temperature", type=float, default=0.8)
|
| 333 |
+
parser.add_argument("--top_k", type=int, default=50)
|
| 334 |
+
parser.add_argument("--top_p", type=float, default=0.95)
|
| 335 |
+
parser.add_argument("--repetition_penalty", type=float, default=1.05)
|
| 336 |
+
parser.add_argument("--interactive", action="store_true")
|
| 337 |
+
parser.add_argument("--show_examples", action="store_true")
|
| 338 |
+
parser.add_argument("--compile", action="store_true")
|
| 339 |
+
args = parser.parse_args()
|
| 340 |
+
|
| 341 |
+
device = get_device()
|
| 342 |
+
|
| 343 |
+
if device.type == "cuda":
|
| 344 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 345 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 346 |
+
torch.set_float32_matmul_precision("high")
|
| 347 |
+
|
| 348 |
+
model, tokenizer, ckpt = load_model_and_tokenizer(
|
| 349 |
+
checkpoint_path=Path(args.checkpoint),
|
| 350 |
+
config_path=Path(args.config),
|
| 351 |
+
tokenizer_dir=Path(args.tokenizer_dir),
|
| 352 |
+
device=device,
|
| 353 |
+
use_compile=args.compile,
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
print(f"Device: {device}")
|
| 357 |
+
print(f"Checkpoint: {args.checkpoint}")
|
| 358 |
+
print(f"epoch={ckpt.get('epoch', 'N/A')} | step={ckpt.get('step', 'N/A')} | best_loss={ckpt.get('best_loss', 'N/A')}")
|
| 359 |
+
|
| 360 |
+
if args.show_examples:
|
| 361 |
+
examples = [
|
| 362 |
+
"Wikipedia is a free online encyclopedia",
|
| 363 |
+
"La France est un pays d'Europe",
|
| 364 |
+
"الزراعة من أهم القطاعات الاقتصادية",
|
| 365 |
+
"Machine learning is a field of artificial intelligence",
|
| 366 |
+
]
|
| 367 |
+
for ex in examples:
|
| 368 |
+
print("\n--- Prompt ---")
|
| 369 |
+
print(ex)
|
| 370 |
+
print("\n--- Output ---")
|
| 371 |
+
print(
|
| 372 |
+
generate_text(
|
| 373 |
+
model=model,
|
| 374 |
+
tokenizer=tokenizer,
|
| 375 |
+
prompt=ex,
|
| 376 |
+
device=device,
|
| 377 |
+
max_new_tokens=args.max_new_tokens,
|
| 378 |
+
temperature=args.temperature,
|
| 379 |
+
top_k=args.top_k,
|
| 380 |
+
top_p=args.top_p,
|
| 381 |
+
repetition_penalty=args.repetition_penalty,
|
| 382 |
+
)
|
| 383 |
+
)
|
| 384 |
+
return
|
| 385 |
+
|
| 386 |
+
if args.interactive:
|
| 387 |
+
print("Mode interactif. Tape 'exit' pour quitter.\n")
|
| 388 |
+
while True:
|
| 389 |
+
prompt = input("Prompt> ").strip()
|
| 390 |
+
if prompt.lower() in {"exit", "quit"}:
|
| 391 |
+
break
|
| 392 |
+
if not prompt:
|
| 393 |
+
continue
|
| 394 |
+
|
| 395 |
+
print("\n=== Output ===")
|
| 396 |
+
print(
|
| 397 |
+
generate_text(
|
| 398 |
+
model=model,
|
| 399 |
+
tokenizer=tokenizer,
|
| 400 |
+
prompt=prompt,
|
| 401 |
+
device=device,
|
| 402 |
+
max_new_tokens=args.max_new_tokens,
|
| 403 |
+
temperature=args.temperature,
|
| 404 |
+
top_k=args.top_k,
|
| 405 |
+
top_p=args.top_p,
|
| 406 |
+
repetition_penalty=args.repetition_penalty,
|
| 407 |
+
)
|
| 408 |
+
)
|
| 409 |
+
print()
|
| 410 |
+
return
|
| 411 |
+
|
| 412 |
+
print(
|
| 413 |
+
generate_text(
|
| 414 |
+
model=model,
|
| 415 |
+
tokenizer=tokenizer,
|
| 416 |
+
prompt=args.prompt,
|
| 417 |
+
device=device,
|
| 418 |
+
max_new_tokens=args.max_new_tokens,
|
| 419 |
+
temperature=args.temperature,
|
| 420 |
+
top_k=args.top_k,
|
| 421 |
+
top_p=args.top_p,
|
| 422 |
+
repetition_penalty=args.repetition_penalty,
|
| 423 |
+
)
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
if __name__ == "__main__":
|
| 428 |
+
main()
|
top_p
ADDED
|
File without changes
|
train.py
ADDED
|
@@ -0,0 +1,859 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import math
|
| 8 |
+
import os
|
| 9 |
+
import random
|
| 10 |
+
import time
|
| 11 |
+
from collections import OrderedDict
|
| 12 |
+
from contextlib import nullcontext
|
| 13 |
+
from dataclasses import asdict, dataclass
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import Iterator, Optional
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.distributed as dist
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 22 |
+
from datasets import load_dataset
|
| 23 |
+
from transformers import PreTrainedTokenizerFast
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# ============================================================
|
| 27 |
+
# Base model / tokenizer / config
|
| 28 |
+
# ============================================================
|
| 29 |
+
|
| 30 |
+
BASE_CHECKPOINT = Path("./wikipedia_ar_h100_codealpaca/model_best.pt")
|
| 31 |
+
BASE_TOKENIZER_DIR = Path("./wikipedia_ar_h100/tokenizer_32k")
|
| 32 |
+
BASE_CONFIG_FILE = Path("./wikipedia_ar_h100/config.json")
|
| 33 |
+
|
| 34 |
+
OUT_DIR = Path("./wikipedia_ar_h100_multicode_10x2000")
|
| 35 |
+
OUT_DIR.mkdir(parents=True, exist_ok=True)
|
| 36 |
+
|
| 37 |
+
MODEL_FILE = OUT_DIR / "model.pt"
|
| 38 |
+
BEST_MODEL_FILE = OUT_DIR / "model_best.pt"
|
| 39 |
+
STATE_FILE = OUT_DIR / "train_state.pt"
|
| 40 |
+
CONFIG_FILE = OUT_DIR / "config.json"
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# ============================================================
|
| 44 |
+
# Datasets
|
| 45 |
+
# ============================================================
|
| 46 |
+
|
| 47 |
+
TRAIN_SOURCES = [
|
| 48 |
+
{
|
| 49 |
+
"name": "HuggingFaceH4/CodeAlpaca_20K",
|
| 50 |
+
"subset": None,
|
| 51 |
+
"split": "train",
|
| 52 |
+
"kind": "codealpaca",
|
| 53 |
+
"weight": 0.45,
|
| 54 |
+
"streaming": False,
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"name": "open-r1/codeforces",
|
| 58 |
+
"subset": "verifiable-prompts",
|
| 59 |
+
"split": "train",
|
| 60 |
+
"kind": "codeforces_python",
|
| 61 |
+
"weight": 0.35,
|
| 62 |
+
"streaming": False,
|
| 63 |
+
},
|
| 64 |
+
{
|
| 65 |
+
"name": "wikimedia/wikipedia",
|
| 66 |
+
"subset": "20231101.ar",
|
| 67 |
+
"split": "train",
|
| 68 |
+
"kind": "wikipedia_ar",
|
| 69 |
+
"weight": 0.20,
|
| 70 |
+
"streaming": True,
|
| 71 |
+
},
|
| 72 |
+
]
|
| 73 |
+
|
| 74 |
+
EVAL_SOURCE = {
|
| 75 |
+
"name": "HuggingFaceH4/CodeAlpaca_20K",
|
| 76 |
+
"subset": None,
|
| 77 |
+
"split": "test",
|
| 78 |
+
"kind": "codealpaca",
|
| 79 |
+
"streaming": False,
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
CODEFORCES_LANGUAGE = "python"
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# ============================================================
|
| 86 |
+
# Hyperparamètres
|
| 87 |
+
# ============================================================
|
| 88 |
+
|
| 89 |
+
SEED = 42
|
| 90 |
+
TARGET_VRAM_GIB = 75.0
|
| 91 |
+
|
| 92 |
+
LEARNING_RATE = 5e-5
|
| 93 |
+
MIN_LR = 5e-6
|
| 94 |
+
WEIGHT_DECAY = 0.1
|
| 95 |
+
WARMUP_STEPS = 200
|
| 96 |
+
|
| 97 |
+
NUM_ROUNDS = 10
|
| 98 |
+
STEPS_PER_ROUND = 2000
|
| 99 |
+
MAX_STEPS = NUM_ROUNDS * STEPS_PER_ROUND # 20000
|
| 100 |
+
|
| 101 |
+
BATCH_SIZE = 24
|
| 102 |
+
GRAD_ACCUM_STEPS = 1
|
| 103 |
+
MAX_GRAD_NORM = 1.0
|
| 104 |
+
|
| 105 |
+
EVAL_EVERY = 250
|
| 106 |
+
SAVE_EVERY = 500
|
| 107 |
+
MAX_EVAL_EXAMPLES = 2000
|
| 108 |
+
TEXT_CHAR_LIMIT = 6000
|
| 109 |
+
|
| 110 |
+
DTYPE = torch.bfloat16
|
| 111 |
+
USE_COMPILE = True
|
| 112 |
+
COMPILE_MODE = "default"
|
| 113 |
+
USE_CHECKPOINTING = False
|
| 114 |
+
|
| 115 |
+
TRAIN_NUM_WORKERS = 0
|
| 116 |
+
EVAL_NUM_WORKERS = 0
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# ============================================================
|
| 120 |
+
# Helpers
|
| 121 |
+
# ============================================================
|
| 122 |
+
|
| 123 |
+
def is_distributed() -> bool:
|
| 124 |
+
return dist.is_available() and dist.is_initialized()
|
| 125 |
+
|
| 126 |
+
def get_rank() -> int:
|
| 127 |
+
return dist.get_rank() if is_distributed() else 0
|
| 128 |
+
|
| 129 |
+
def get_world_size() -> int:
|
| 130 |
+
return dist.get_world_size() if is_distributed() else 1
|
| 131 |
+
|
| 132 |
+
def is_main() -> bool:
|
| 133 |
+
return get_rank() == 0
|
| 134 |
+
|
| 135 |
+
def init_distributed() -> Optional[torch.device]:
|
| 136 |
+
local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
| 137 |
+
if local_rank == -1:
|
| 138 |
+
return None
|
| 139 |
+
dist.init_process_group("nccl")
|
| 140 |
+
torch.cuda.set_device(local_rank)
|
| 141 |
+
return torch.device(f"cuda:{local_rank}")
|
| 142 |
+
|
| 143 |
+
def set_seed(seed: int) -> None:
|
| 144 |
+
random.seed(seed)
|
| 145 |
+
torch.manual_seed(seed)
|
| 146 |
+
if torch.cuda.is_available():
|
| 147 |
+
torch.cuda.manual_seed_all(seed)
|
| 148 |
+
|
| 149 |
+
def get_device(ddp_device: Optional[torch.device] = None) -> torch.device:
|
| 150 |
+
if ddp_device is not None:
|
| 151 |
+
return ddp_device
|
| 152 |
+
if torch.cuda.is_available():
|
| 153 |
+
return torch.device(f"cuda:{torch.cuda.current_device()}")
|
| 154 |
+
return torch.device("cpu")
|
| 155 |
+
|
| 156 |
+
def current_cuda_index(device: torch.device) -> int:
|
| 157 |
+
if device.type != "cuda":
|
| 158 |
+
raise ValueError("Device non CUDA")
|
| 159 |
+
return device.index if device.index is not None else torch.cuda.current_device()
|
| 160 |
+
|
| 161 |
+
def autocast_context(device: torch.device):
|
| 162 |
+
if device.type == "cuda":
|
| 163 |
+
return torch.autocast("cuda", dtype=DTYPE)
|
| 164 |
+
return nullcontext()
|
| 165 |
+
|
| 166 |
+
def unwrap_model(model: nn.Module) -> nn.Module:
|
| 167 |
+
m = model.module if isinstance(model, DDP) else model
|
| 168 |
+
if hasattr(m, "_orig_mod"):
|
| 169 |
+
return m._orig_mod
|
| 170 |
+
return m
|
| 171 |
+
|
| 172 |
+
def count_parameters(model: nn.Module) -> int:
|
| 173 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 174 |
+
|
| 175 |
+
def normalize_state_dict_keys(state_dict: dict) -> OrderedDict:
|
| 176 |
+
normalized = OrderedDict()
|
| 177 |
+
for k, v in state_dict.items():
|
| 178 |
+
nk = k
|
| 179 |
+
if nk.startswith("module._orig_mod."):
|
| 180 |
+
nk = nk[len("module._orig_mod."):]
|
| 181 |
+
elif nk.startswith("_orig_mod."):
|
| 182 |
+
nk = nk[len("_orig_mod."):]
|
| 183 |
+
elif nk.startswith("module."):
|
| 184 |
+
nk = nk[len("module."):]
|
| 185 |
+
normalized[nk] = v
|
| 186 |
+
return normalized
|
| 187 |
+
|
| 188 |
+
def normalize_text(text: str) -> str:
|
| 189 |
+
return " ".join(text.strip().split())
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
# ============================================================
|
| 193 |
+
# Dataset loading / formatting
|
| 194 |
+
# ============================================================
|
| 195 |
+
|
| 196 |
+
def load_one_dataset(spec: dict):
|
| 197 |
+
kwargs = {
|
| 198 |
+
"path": spec["name"],
|
| 199 |
+
"split": spec["split"],
|
| 200 |
+
"streaming": spec["streaming"],
|
| 201 |
+
}
|
| 202 |
+
if spec["subset"] is not None:
|
| 203 |
+
kwargs["name"] = spec["subset"]
|
| 204 |
+
return load_dataset(**kwargs)
|
| 205 |
+
|
| 206 |
+
def format_record(row: dict, kind: str) -> str:
|
| 207 |
+
if kind == "codealpaca":
|
| 208 |
+
prompt = row.get("prompt", "")
|
| 209 |
+
completion = row.get("completion", "")
|
| 210 |
+
if not isinstance(prompt, str):
|
| 211 |
+
prompt = str(prompt)
|
| 212 |
+
if not isinstance(completion, str):
|
| 213 |
+
completion = str(completion)
|
| 214 |
+
text = (
|
| 215 |
+
"### Instruction\n"
|
| 216 |
+
f"{prompt.strip()}\n\n"
|
| 217 |
+
"### Response\n"
|
| 218 |
+
f"{completion.strip()}"
|
| 219 |
+
)
|
| 220 |
+
return normalize_text(text)
|
| 221 |
+
|
| 222 |
+
if kind == "codeforces_python":
|
| 223 |
+
language = row.get("language", "")
|
| 224 |
+
if language != CODEFORCES_LANGUAGE:
|
| 225 |
+
return ""
|
| 226 |
+
|
| 227 |
+
prompt = row.get("prompt", "")
|
| 228 |
+
title = row.get("title", "")
|
| 229 |
+
if not isinstance(prompt, str):
|
| 230 |
+
prompt = str(prompt)
|
| 231 |
+
if not isinstance(title, str):
|
| 232 |
+
title = str(title)
|
| 233 |
+
|
| 234 |
+
text = (
|
| 235 |
+
f"### Competitive Programming Problem ({language})\n"
|
| 236 |
+
f"{title.strip()}\n\n"
|
| 237 |
+
f"{prompt.strip()}"
|
| 238 |
+
)
|
| 239 |
+
return normalize_text(text)
|
| 240 |
+
|
| 241 |
+
if kind == "wikipedia_ar":
|
| 242 |
+
text = row.get("text", "")
|
| 243 |
+
if not isinstance(text, str):
|
| 244 |
+
text = str(text)
|
| 245 |
+
return normalize_text(text)
|
| 246 |
+
|
| 247 |
+
return ""
|
| 248 |
+
|
| 249 |
+
def example_text_iter(spec: dict, max_examples: Optional[int] = None) -> Iterator[str]:
|
| 250 |
+
ds = load_one_dataset(spec)
|
| 251 |
+
n = 0
|
| 252 |
+
for row in ds:
|
| 253 |
+
text = format_record(row, spec["kind"])
|
| 254 |
+
if not text or len(text) < 20:
|
| 255 |
+
continue
|
| 256 |
+
if TEXT_CHAR_LIMIT is not None:
|
| 257 |
+
text = text[:TEXT_CHAR_LIMIT]
|
| 258 |
+
yield text
|
| 259 |
+
n += 1
|
| 260 |
+
if max_examples is not None and n >= max_examples:
|
| 261 |
+
break
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
class MixedTextSource:
|
| 265 |
+
def __init__(self, specs: list[dict]):
|
| 266 |
+
self.specs = specs
|
| 267 |
+
self.weights = [s["weight"] for s in specs]
|
| 268 |
+
self.streams = [example_text_iter(s) for s in specs]
|
| 269 |
+
|
| 270 |
+
def next_text(self) -> str:
|
| 271 |
+
while True:
|
| 272 |
+
idx = random.choices(range(len(self.specs)), weights=self.weights, k=1)[0]
|
| 273 |
+
try:
|
| 274 |
+
return next(self.streams[idx])
|
| 275 |
+
except StopIteration:
|
| 276 |
+
self.streams[idx] = example_text_iter(self.specs[idx])
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def packed_block_stream_mixed(
|
| 280 |
+
tokenizer: PreTrainedTokenizerFast,
|
| 281 |
+
specs: list[dict],
|
| 282 |
+
block_size: int,
|
| 283 |
+
) -> Iterator[list[int]]:
|
| 284 |
+
bos, eos = tokenizer.bos_token_id, tokenizer.eos_token_id
|
| 285 |
+
buffer: list[int] = []
|
| 286 |
+
source = MixedTextSource(specs)
|
| 287 |
+
|
| 288 |
+
while True:
|
| 289 |
+
text = source.next_text()
|
| 290 |
+
ids = tokenizer.encode(text, add_special_tokens=False)
|
| 291 |
+
if not ids:
|
| 292 |
+
continue
|
| 293 |
+
|
| 294 |
+
buffer.extend([bos] + ids + [eos])
|
| 295 |
+
|
| 296 |
+
while len(buffer) >= block_size + 1:
|
| 297 |
+
yield buffer[: block_size + 1]
|
| 298 |
+
buffer = buffer[block_size + 1:]
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
class PackedMixedBlocks(torch.utils.data.IterableDataset):
|
| 302 |
+
def __init__(self, tokenizer, specs, block_size):
|
| 303 |
+
super().__init__()
|
| 304 |
+
self.tokenizer = tokenizer
|
| 305 |
+
self.specs = specs
|
| 306 |
+
self.block_size = block_size
|
| 307 |
+
|
| 308 |
+
def __iter__(self):
|
| 309 |
+
worker = torch.utils.data.get_worker_info()
|
| 310 |
+
rank = get_rank()
|
| 311 |
+
world_size = get_world_size()
|
| 312 |
+
|
| 313 |
+
if worker is None:
|
| 314 |
+
shard_mod = world_size
|
| 315 |
+
shard_id = rank
|
| 316 |
+
else:
|
| 317 |
+
shard_mod = worker.num_workers * world_size
|
| 318 |
+
shard_id = rank * worker.num_workers + worker.id
|
| 319 |
+
|
| 320 |
+
for idx, chunk in enumerate(
|
| 321 |
+
packed_block_stream_mixed(
|
| 322 |
+
tokenizer=self.tokenizer,
|
| 323 |
+
specs=self.specs,
|
| 324 |
+
block_size=self.block_size,
|
| 325 |
+
)
|
| 326 |
+
):
|
| 327 |
+
if idx % shard_mod != shard_id:
|
| 328 |
+
continue
|
| 329 |
+
|
| 330 |
+
yield {
|
| 331 |
+
"input_ids": torch.tensor(chunk[:-1], dtype=torch.long),
|
| 332 |
+
"labels": torch.tensor(chunk[1:], dtype=torch.long),
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
class PackedEvalBlocks(torch.utils.data.IterableDataset):
|
| 337 |
+
def __init__(self, tokenizer, spec, block_size, max_examples):
|
| 338 |
+
super().__init__()
|
| 339 |
+
self.tokenizer = tokenizer
|
| 340 |
+
self.spec = spec
|
| 341 |
+
self.block_size = block_size
|
| 342 |
+
self.max_examples = max_examples
|
| 343 |
+
|
| 344 |
+
def __iter__(self):
|
| 345 |
+
worker = torch.utils.data.get_worker_info()
|
| 346 |
+
rank = get_rank()
|
| 347 |
+
world_size = get_world_size()
|
| 348 |
+
|
| 349 |
+
if worker is None:
|
| 350 |
+
shard_mod = world_size
|
| 351 |
+
shard_id = rank
|
| 352 |
+
else:
|
| 353 |
+
shard_mod = worker.num_workers * world_size
|
| 354 |
+
shard_id = rank * worker.num_workers + worker.id
|
| 355 |
+
|
| 356 |
+
bos, eos = self.tokenizer.bos_token_id, self.tokenizer.eos_token_id
|
| 357 |
+
buffer: list[int] = []
|
| 358 |
+
|
| 359 |
+
for ex_idx, text in enumerate(example_text_iter(self.spec, max_examples=self.max_examples)):
|
| 360 |
+
if ex_idx % shard_mod != shard_id:
|
| 361 |
+
continue
|
| 362 |
+
|
| 363 |
+
ids = self.tokenizer.encode(text, add_special_tokens=False)
|
| 364 |
+
if not ids:
|
| 365 |
+
continue
|
| 366 |
+
|
| 367 |
+
buffer.extend([bos] + ids + [eos])
|
| 368 |
+
|
| 369 |
+
while len(buffer) >= self.block_size + 1:
|
| 370 |
+
chunk = buffer[: self.block_size + 1]
|
| 371 |
+
buffer = buffer[self.block_size + 1:]
|
| 372 |
+
yield {
|
| 373 |
+
"input_ids": torch.tensor(chunk[:-1], dtype=torch.long),
|
| 374 |
+
"labels": torch.tensor(chunk[1:], dtype=torch.long),
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
# ============================================================
|
| 379 |
+
# Architecture
|
| 380 |
+
# ============================================================
|
| 381 |
+
|
| 382 |
+
@dataclass
|
| 383 |
+
class GPTConfig:
|
| 384 |
+
vocab_size: int
|
| 385 |
+
block_size: int
|
| 386 |
+
d_model: int
|
| 387 |
+
n_heads: int
|
| 388 |
+
n_layers: int
|
| 389 |
+
d_ff: int
|
| 390 |
+
dropout: float = 0.0
|
| 391 |
+
use_checkpointing: bool = False
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
class RMSNorm(nn.Module):
|
| 395 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 396 |
+
super().__init__()
|
| 397 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 398 |
+
self.eps = eps
|
| 399 |
+
|
| 400 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 401 |
+
return self.weight * x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
class RotaryEmbedding(nn.Module):
|
| 405 |
+
def __init__(self, dim: int, base: int = 10000, max_seq: int = 4096):
|
| 406 |
+
super().__init__()
|
| 407 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
| 408 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 409 |
+
|
| 410 |
+
t = torch.arange(max_seq).float()
|
| 411 |
+
freqs = torch.outer(t, inv_freq)
|
| 412 |
+
|
| 413 |
+
self.register_buffer("cos_cache", torch.repeat_interleave(freqs.cos(), 2, dim=-1), persistent=False)
|
| 414 |
+
self.register_buffer("sin_cache", torch.repeat_interleave(freqs.sin(), 2, dim=-1), persistent=False)
|
| 415 |
+
|
| 416 |
+
def forward(self, seq_len: int, dtype: torch.dtype):
|
| 417 |
+
return self.cos_cache[:seq_len].to(dtype), self.sin_cache[:seq_len].to(dtype)
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| 421 |
+
x1, x2 = x[..., ::2], x[..., 1::2]
|
| 422 |
+
return torch.stack((-x2, x1), dim=-1).flatten(-2)
|
| 423 |
+
|
| 424 |
+
def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
|
| 425 |
+
cos = cos.unsqueeze(0).unsqueeze(0)
|
| 426 |
+
sin = sin.unsqueeze(0).unsqueeze(0)
|
| 427 |
+
return x * cos + rotate_half(x) * sin
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
class CausalSelfAttention(nn.Module):
|
| 431 |
+
def __init__(self, cfg: GPTConfig):
|
| 432 |
+
super().__init__()
|
| 433 |
+
assert cfg.d_model % cfg.n_heads == 0
|
| 434 |
+
self.n_heads = cfg.n_heads
|
| 435 |
+
self.head_dim = cfg.d_model // cfg.n_heads
|
| 436 |
+
self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
|
| 437 |
+
self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
|
| 438 |
+
self.dropout_p = cfg.dropout
|
| 439 |
+
self.rope = RotaryEmbedding(self.head_dim)
|
| 440 |
+
|
| 441 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 442 |
+
b, t, c = x.shape
|
| 443 |
+
q, k, v = self.qkv(x).split(c, dim=-1)
|
| 444 |
+
|
| 445 |
+
q = q.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
|
| 446 |
+
k = k.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
|
| 447 |
+
v = v.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
|
| 448 |
+
|
| 449 |
+
cos, sin = self.rope(t, x.dtype)
|
| 450 |
+
q = apply_rope(q, cos, sin)
|
| 451 |
+
k = apply_rope(k, cos, sin)
|
| 452 |
+
|
| 453 |
+
y = F.scaled_dot_product_attention(
|
| 454 |
+
q, k, v,
|
| 455 |
+
dropout_p=self.dropout_p if self.training else 0.0,
|
| 456 |
+
is_causal=True,
|
| 457 |
+
)
|
| 458 |
+
y = y.transpose(1, 2).contiguous().view(b, t, c)
|
| 459 |
+
return self.proj(y)
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
class SwiGLU(nn.Module):
|
| 463 |
+
def __init__(self, cfg: GPTConfig):
|
| 464 |
+
super().__init__()
|
| 465 |
+
self.w1 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
|
| 466 |
+
self.w2 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
|
| 467 |
+
self.w3 = nn.Linear(cfg.d_ff, cfg.d_model, bias=False)
|
| 468 |
+
|
| 469 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 470 |
+
return self.w3(F.silu(self.w1(x)) * self.w2(x))
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
class Block(nn.Module):
|
| 474 |
+
def __init__(self, cfg: GPTConfig):
|
| 475 |
+
super().__init__()
|
| 476 |
+
self.ln1 = RMSNorm(cfg.d_model)
|
| 477 |
+
self.attn = CausalSelfAttention(cfg)
|
| 478 |
+
self.ln2 = RMSNorm(cfg.d_model)
|
| 479 |
+
self.ff = SwiGLU(cfg)
|
| 480 |
+
|
| 481 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 482 |
+
x = x + self.attn(self.ln1(x))
|
| 483 |
+
x = x + self.ff(self.ln2(x))
|
| 484 |
+
return x
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
class GPT(nn.Module):
|
| 488 |
+
def __init__(self, cfg: GPTConfig):
|
| 489 |
+
super().__init__()
|
| 490 |
+
self.cfg = cfg
|
| 491 |
+
self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
|
| 492 |
+
self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)])
|
| 493 |
+
self.ln_f = RMSNorm(cfg.d_model)
|
| 494 |
+
self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
|
| 495 |
+
self.lm_head.weight = self.tok_emb.weight
|
| 496 |
+
self.apply(self._init_weights)
|
| 497 |
+
|
| 498 |
+
@staticmethod
|
| 499 |
+
def _init_weights(m: nn.Module) -> None:
|
| 500 |
+
if isinstance(m, (nn.Linear, nn.Embedding)):
|
| 501 |
+
nn.init.normal_(m.weight, mean=0.0, std=0.02)
|
| 502 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 503 |
+
nn.init.zeros_(m.bias)
|
| 504 |
+
|
| 505 |
+
def forward(self, input_ids: torch.Tensor, labels: Optional[torch.Tensor] = None):
|
| 506 |
+
x = self.tok_emb(input_ids)
|
| 507 |
+
|
| 508 |
+
for block in self.blocks:
|
| 509 |
+
if self.cfg.use_checkpointing and self.training:
|
| 510 |
+
x = torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False)
|
| 511 |
+
else:
|
| 512 |
+
x = block(x)
|
| 513 |
+
|
| 514 |
+
logits = self.lm_head(self.ln_f(x))
|
| 515 |
+
loss = None
|
| 516 |
+
|
| 517 |
+
if labels is not None:
|
| 518 |
+
loss = F.cross_entropy(
|
| 519 |
+
logits.reshape(-1, logits.size(-1)),
|
| 520 |
+
labels.reshape(-1),
|
| 521 |
+
ignore_index=-100,
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
return logits, loss
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
# ============================================================
|
| 528 |
+
# Optimizer / LR
|
| 529 |
+
# ============================================================
|
| 530 |
+
|
| 531 |
+
def build_optimizer(model: nn.Module) -> torch.optim.Optimizer:
|
| 532 |
+
decay, no_decay = [], []
|
| 533 |
+
for name, p in unwrap_model(model).named_parameters():
|
| 534 |
+
if not p.requires_grad:
|
| 535 |
+
continue
|
| 536 |
+
(decay if p.ndim >= 2 and "weight" in name else no_decay).append(p)
|
| 537 |
+
|
| 538 |
+
return torch.optim.AdamW(
|
| 539 |
+
[
|
| 540 |
+
{"params": decay, "weight_decay": WEIGHT_DECAY},
|
| 541 |
+
{"params": no_decay, "weight_decay": 0.0},
|
| 542 |
+
],
|
| 543 |
+
lr=LEARNING_RATE,
|
| 544 |
+
betas=(0.9, 0.95),
|
| 545 |
+
eps=1e-8,
|
| 546 |
+
fused=torch.cuda.is_available(),
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
def cosine_lr(step: int) -> float:
|
| 550 |
+
if step < WARMUP_STEPS:
|
| 551 |
+
return LEARNING_RATE * step / max(1, WARMUP_STEPS)
|
| 552 |
+
p = min(1.0, (step - WARMUP_STEPS) / max(1, MAX_STEPS - WARMUP_STEPS))
|
| 553 |
+
return MIN_LR + 0.5 * (LEARNING_RATE - MIN_LR) * (1 + math.cos(math.pi * p))
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
# ============================================================
|
| 557 |
+
# Checkpoints
|
| 558 |
+
# ============================================================
|
| 559 |
+
|
| 560 |
+
def load_base_config() -> GPTConfig:
|
| 561 |
+
cfg_dict = json.loads(BASE_CONFIG_FILE.read_text(encoding="utf-8"))
|
| 562 |
+
cfg_dict["use_checkpointing"] = USE_CHECKPOINTING
|
| 563 |
+
return GPTConfig(**cfg_dict)
|
| 564 |
+
|
| 565 |
+
def initialize_model_from_base(model: nn.Module, device: torch.device) -> None:
|
| 566 |
+
if not BASE_CHECKPOINT.exists():
|
| 567 |
+
raise FileNotFoundError(f"Checkpoint de base introuvable: {BASE_CHECKPOINT}")
|
| 568 |
+
ckpt = torch.load(BASE_CHECKPOINT, map_location=device)
|
| 569 |
+
state_dict = normalize_state_dict_keys(ckpt["model"])
|
| 570 |
+
unwrap_model(model).load_state_dict(state_dict, strict=True)
|
| 571 |
+
|
| 572 |
+
def save_checkpoint(model, optimizer, step, best_loss, path):
|
| 573 |
+
raw = unwrap_model(model)
|
| 574 |
+
model_state = normalize_state_dict_keys(raw.state_dict())
|
| 575 |
+
torch.save(
|
| 576 |
+
{
|
| 577 |
+
"model": model_state,
|
| 578 |
+
"optimizer": optimizer.state_dict(),
|
| 579 |
+
"step": step,
|
| 580 |
+
"best_loss": best_loss,
|
| 581 |
+
"config": asdict(raw.cfg),
|
| 582 |
+
},
|
| 583 |
+
path,
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
def load_resume_checkpoint(model, optimizer, path, device) -> tuple[int, float]:
|
| 587 |
+
ckpt = torch.load(path, map_location=device)
|
| 588 |
+
raw = unwrap_model(model)
|
| 589 |
+
model_state = normalize_state_dict_keys(ckpt["model"])
|
| 590 |
+
raw.load_state_dict(model_state, strict=True)
|
| 591 |
+
|
| 592 |
+
try:
|
| 593 |
+
optimizer.load_state_dict(ckpt["optimizer"])
|
| 594 |
+
except Exception as e:
|
| 595 |
+
print(f"[warn] Optimizer state non repris: {e}")
|
| 596 |
+
|
| 597 |
+
return int(ckpt.get("step", 0)), float(ckpt.get("best_loss", 1e9))
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
# ============================================================
|
| 601 |
+
# Evaluation
|
| 602 |
+
# ============================================================
|
| 603 |
+
|
| 604 |
+
@torch.no_grad()
|
| 605 |
+
def evaluate(model, loader, device, max_batches: int = 100) -> float:
|
| 606 |
+
model.eval()
|
| 607 |
+
losses = []
|
| 608 |
+
|
| 609 |
+
for i, batch in enumerate(loader):
|
| 610 |
+
if i >= max_batches:
|
| 611 |
+
break
|
| 612 |
+
|
| 613 |
+
inp = batch["input_ids"].to(device, non_blocking=True)
|
| 614 |
+
lbl = batch["labels"].to(device, non_blocking=True)
|
| 615 |
+
|
| 616 |
+
with autocast_context(device):
|
| 617 |
+
_, loss = model(inp, lbl)
|
| 618 |
+
|
| 619 |
+
losses.append(loss.item())
|
| 620 |
+
|
| 621 |
+
model.train()
|
| 622 |
+
return sum(losses) / max(1, len(losses))
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
# ============================================================
|
| 626 |
+
# Main
|
| 627 |
+
# ============================================================
|
| 628 |
+
|
| 629 |
+
def main() -> None:
|
| 630 |
+
ddp_device = init_distributed()
|
| 631 |
+
set_seed(SEED + get_rank())
|
| 632 |
+
device = get_device(ddp_device)
|
| 633 |
+
|
| 634 |
+
cuda_device_index = None
|
| 635 |
+
vram_fraction = None
|
| 636 |
+
|
| 637 |
+
if device.type == "cuda":
|
| 638 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 639 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 640 |
+
torch.set_float32_matmul_precision("high")
|
| 641 |
+
|
| 642 |
+
cuda_device_index = current_cuda_index(device)
|
| 643 |
+
_, total_mem_bytes = torch.cuda.mem_get_info(cuda_device_index)
|
| 644 |
+
target_bytes = int(TARGET_VRAM_GIB * (1024 ** 3))
|
| 645 |
+
vram_fraction = min(target_bytes / total_mem_bytes, 0.999)
|
| 646 |
+
|
| 647 |
+
torch.cuda.memory.set_per_process_memory_fraction(
|
| 648 |
+
vram_fraction,
|
| 649 |
+
device=cuda_device_index,
|
| 650 |
+
)
|
| 651 |
+
|
| 652 |
+
if is_main():
|
| 653 |
+
print("=" * 60)
|
| 654 |
+
print(" Re-train même modèle | 10 x 2000 steps")
|
| 655 |
+
print("=" * 60)
|
| 656 |
+
print(f"Device: {device} | World: {get_world_size()} GPU(s)")
|
| 657 |
+
if device.type == "cuda":
|
| 658 |
+
free_mem, total_mem = torch.cuda.mem_get_info(cuda_device_index)
|
| 659 |
+
print(f"GPU: {torch.cuda.get_device_name(cuda_device_index)}")
|
| 660 |
+
print(f"VRAM cible: {TARGET_VRAM_GIB:.1f} GiB")
|
| 661 |
+
print(f"Fraction PyTorch: {vram_fraction:.4f}")
|
| 662 |
+
print(f"GPU total: {total_mem / 1024**3:.2f} GiB | libre: {free_mem / 1024**3:.2f} GiB")
|
| 663 |
+
print(f"Rounds: {NUM_ROUNDS} | Steps/round: {STEPS_PER_ROUND} | MAX_STEPS: {MAX_STEPS}")
|
| 664 |
+
|
| 665 |
+
tokenizer = PreTrainedTokenizerFast.from_pretrained(str(BASE_TOKENIZER_DIR))
|
| 666 |
+
cfg = load_base_config()
|
| 667 |
+
cfg.vocab_size = len(tokenizer)
|
| 668 |
+
|
| 669 |
+
if is_main():
|
| 670 |
+
CONFIG_FILE.write_text(
|
| 671 |
+
json.dumps(asdict(cfg), indent=2, ensure_ascii=False),
|
| 672 |
+
encoding="utf-8",
|
| 673 |
+
)
|
| 674 |
+
print(f"Base checkpoint: {BASE_CHECKPOINT}")
|
| 675 |
+
print(f"Tokenizer: {BASE_TOKENIZER_DIR}")
|
| 676 |
+
|
| 677 |
+
model = GPT(cfg).to(device)
|
| 678 |
+
initialize_model_from_base(model, device)
|
| 679 |
+
|
| 680 |
+
if USE_COMPILE and hasattr(torch, "compile"):
|
| 681 |
+
model = torch.compile(model, mode=COMPILE_MODE)
|
| 682 |
+
if is_main():
|
| 683 |
+
print(f"torch.compile activé ({COMPILE_MODE})")
|
| 684 |
+
|
| 685 |
+
if is_distributed():
|
| 686 |
+
model = DDP(model, device_ids=[device.index])
|
| 687 |
+
|
| 688 |
+
optimizer = build_optimizer(model)
|
| 689 |
+
|
| 690 |
+
start_step, best_eval = 0, 1e9
|
| 691 |
+
if STATE_FILE.exists():
|
| 692 |
+
try:
|
| 693 |
+
if is_main():
|
| 694 |
+
print(f"Reprise depuis {STATE_FILE}")
|
| 695 |
+
start_step, best_eval = load_resume_checkpoint(model, optimizer, STATE_FILE, device)
|
| 696 |
+
except Exception as e:
|
| 697 |
+
if is_main():
|
| 698 |
+
bad_path = STATE_FILE.with_suffix(".corrupt.pt")
|
| 699 |
+
print(f"[warn] Checkpoint illisible: {e}")
|
| 700 |
+
try:
|
| 701 |
+
STATE_FILE.rename(bad_path)
|
| 702 |
+
print(f"[warn] Checkpoint corrompu renommé vers {bad_path}")
|
| 703 |
+
except Exception:
|
| 704 |
+
pass
|
| 705 |
+
print("[warn] Reprise ignorée, démarrage depuis le checkpoint de base.")
|
| 706 |
+
start_step, best_eval = 0, 1e9
|
| 707 |
+
|
| 708 |
+
if start_step >= MAX_STEPS:
|
| 709 |
+
if is_main():
|
| 710 |
+
print(f"[warn] start_step={start_step} >= MAX_STEPS={MAX_STEPS}")
|
| 711 |
+
print("[warn] Rien à entraîner.")
|
| 712 |
+
return
|
| 713 |
+
|
| 714 |
+
train_ds = PackedMixedBlocks(
|
| 715 |
+
tokenizer=tokenizer,
|
| 716 |
+
specs=TRAIN_SOURCES,
|
| 717 |
+
block_size=cfg.block_size,
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
eval_ds = PackedEvalBlocks(
|
| 721 |
+
tokenizer=tokenizer,
|
| 722 |
+
spec=EVAL_SOURCE,
|
| 723 |
+
block_size=cfg.block_size,
|
| 724 |
+
max_examples=MAX_EVAL_EXAMPLES,
|
| 725 |
+
)
|
| 726 |
+
|
| 727 |
+
train_loader = torch.utils.data.DataLoader(
|
| 728 |
+
train_ds,
|
| 729 |
+
batch_size=BATCH_SIZE,
|
| 730 |
+
num_workers=TRAIN_NUM_WORKERS,
|
| 731 |
+
pin_memory=(device.type == "cuda"),
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
eval_loader = torch.utils.data.DataLoader(
|
| 735 |
+
eval_ds,
|
| 736 |
+
batch_size=BATCH_SIZE,
|
| 737 |
+
num_workers=EVAL_NUM_WORKERS,
|
| 738 |
+
pin_memory=(device.type == "cuda"),
|
| 739 |
+
)
|
| 740 |
+
|
| 741 |
+
if is_main():
|
| 742 |
+
raw_model = unwrap_model(model)
|
| 743 |
+
n_params = count_parameters(raw_model)
|
| 744 |
+
print(f"Paramètres: {n_params / 1e6:.1f}M")
|
| 745 |
+
print(f"Architecture: d={cfg.d_model} | heads={cfg.n_heads} | layers={cfg.n_layers} | block={cfg.block_size}")
|
| 746 |
+
print(f"Batch size: {BATCH_SIZE} | Grad accum: {GRAD_ACCUM_STEPS}")
|
| 747 |
+
print(f"Dtype: {DTYPE} | Compile: {USE_COMPILE} ({COMPILE_MODE if USE_COMPILE else 'off'})")
|
| 748 |
+
|
| 749 |
+
model.train()
|
| 750 |
+
optimizer.zero_grad(set_to_none=True)
|
| 751 |
+
|
| 752 |
+
train_iter = iter(train_loader)
|
| 753 |
+
step = start_step
|
| 754 |
+
t0 = time.time()
|
| 755 |
+
log_loss_sum = 0.0
|
| 756 |
+
log_loss_count = 0
|
| 757 |
+
tokens_since_log = 0
|
| 758 |
+
last_log = time.time()
|
| 759 |
+
|
| 760 |
+
if device.type == "cuda":
|
| 761 |
+
torch.cuda.reset_peak_memory_stats(cuda_device_index)
|
| 762 |
+
|
| 763 |
+
current_round = (step // STEPS_PER_ROUND) + 1
|
| 764 |
+
|
| 765 |
+
while step < MAX_STEPS:
|
| 766 |
+
for _ in range(GRAD_ACCUM_STEPS):
|
| 767 |
+
batch = next(train_iter)
|
| 768 |
+
|
| 769 |
+
inp = batch["input_ids"].to(device, non_blocking=True)
|
| 770 |
+
lbl = batch["labels"].to(device, non_blocking=True)
|
| 771 |
+
|
| 772 |
+
with autocast_context(device):
|
| 773 |
+
_, loss = model(inp, lbl)
|
| 774 |
+
|
| 775 |
+
(loss / GRAD_ACCUM_STEPS).backward()
|
| 776 |
+
|
| 777 |
+
log_loss_sum += loss.item()
|
| 778 |
+
log_loss_count += 1
|
| 779 |
+
tokens_since_log += inp.numel()
|
| 780 |
+
|
| 781 |
+
lr = cosine_lr(step)
|
| 782 |
+
for group in optimizer.param_groups:
|
| 783 |
+
group["lr"] = lr
|
| 784 |
+
|
| 785 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
|
| 786 |
+
optimizer.step()
|
| 787 |
+
optimizer.zero_grad(set_to_none=True)
|
| 788 |
+
step += 1
|
| 789 |
+
|
| 790 |
+
new_round = ((step - 1) // STEPS_PER_ROUND) + 1
|
| 791 |
+
if new_round != current_round and is_main():
|
| 792 |
+
current_round = new_round
|
| 793 |
+
print(f"\n===== Round {current_round}/{NUM_ROUNDS} =====")
|
| 794 |
+
|
| 795 |
+
if step % 50 == 0 and is_main():
|
| 796 |
+
now = time.time()
|
| 797 |
+
elapsed = max(1e-6, now - last_log)
|
| 798 |
+
tok_s = tokens_since_log / elapsed
|
| 799 |
+
avg_loss = log_loss_sum / max(1, log_loss_count)
|
| 800 |
+
round_idx = ((step - 1) // STEPS_PER_ROUND) + 1
|
| 801 |
+
step_in_round = ((step - 1) % STEPS_PER_ROUND) + 1
|
| 802 |
+
|
| 803 |
+
print(
|
| 804 |
+
f"round {round_idx:2d}/{NUM_ROUNDS} | "
|
| 805 |
+
f"step {step_in_round:4d}/{STEPS_PER_ROUND} | "
|
| 806 |
+
f"global {step:5d}/{MAX_STEPS} | "
|
| 807 |
+
f"loss={avg_loss:.4f} | lr={lr:.2e} | {tok_s:,.0f} tok/s"
|
| 808 |
+
)
|
| 809 |
+
|
| 810 |
+
if device.type == "cuda":
|
| 811 |
+
allocated = torch.cuda.memory_allocated(cuda_device_index) / 1024**3
|
| 812 |
+
reserved = torch.cuda.memory_reserved(cuda_device_index) / 1024**3
|
| 813 |
+
max_alloc = torch.cuda.max_memory_allocated(cuda_device_index) / 1024**3
|
| 814 |
+
max_reserved = torch.cuda.max_memory_reserved(cuda_device_index) / 1024**3
|
| 815 |
+
print(
|
| 816 |
+
f"GPU mem | alloc={allocated:.2f} GiB | reserved={reserved:.2f} GiB | "
|
| 817 |
+
f"max_alloc={max_alloc:.2f} GiB | max_reserved={max_reserved:.2f} GiB"
|
| 818 |
+
)
|
| 819 |
+
|
| 820 |
+
last_log = now
|
| 821 |
+
tokens_since_log = 0
|
| 822 |
+
log_loss_sum = 0.0
|
| 823 |
+
log_loss_count = 0
|
| 824 |
+
|
| 825 |
+
if step % EVAL_EVERY == 0 and is_main():
|
| 826 |
+
val_loss = evaluate(model, eval_loader, device)
|
| 827 |
+
print(f"[eval] global step {step:5d} | val_loss={val_loss:.4f}")
|
| 828 |
+
|
| 829 |
+
if val_loss < best_eval:
|
| 830 |
+
best_eval = val_loss
|
| 831 |
+
save_checkpoint(model, optimizer, step, best_eval, BEST_MODEL_FILE)
|
| 832 |
+
print(f"✓ Meilleur modèle → {BEST_MODEL_FILE}")
|
| 833 |
+
|
| 834 |
+
if step % SAVE_EVERY == 0 and is_main():
|
| 835 |
+
save_checkpoint(model, optimizer, step, best_eval, STATE_FILE)
|
| 836 |
+
save_checkpoint(model, optimizer, step, best_eval, MODEL_FILE)
|
| 837 |
+
print(f"✓ Checkpoint → {MODEL_FILE}")
|
| 838 |
+
|
| 839 |
+
if step % STEPS_PER_ROUND == 0 and is_main():
|
| 840 |
+
round_no = step // STEPS_PER_ROUND
|
| 841 |
+
round_ckpt = OUT_DIR / f"model_round_{round_no:02d}.pt"
|
| 842 |
+
save_checkpoint(model, optimizer, step, best_eval, round_ckpt)
|
| 843 |
+
print(f"✓ Fin round {round_no}/{NUM_ROUNDS} → {round_ckpt}")
|
| 844 |
+
|
| 845 |
+
if is_main():
|
| 846 |
+
save_checkpoint(model, optimizer, step, best_eval, MODEL_FILE)
|
| 847 |
+
save_checkpoint(model, optimizer, step, best_eval, STATE_FILE)
|
| 848 |
+
total = (time.time() - t0) / 60
|
| 849 |
+
print(f"\nModèle final → {MODEL_FILE}")
|
| 850 |
+
print(f"Meilleur modèle → {BEST_MODEL_FILE}")
|
| 851 |
+
print(f"Temps total : {total:.1f} min")
|
| 852 |
+
print(f"Steps effectués : {step}")
|
| 853 |
+
|
| 854 |
+
if is_distributed():
|
| 855 |
+
dist.destroy_process_group()
|
| 856 |
+
|
| 857 |
+
|
| 858 |
+
if __name__ == "__main__":
|
| 859 |
+
main()
|
train2.py
ADDED
|
@@ -0,0 +1,852 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
train_nlp_h100_optimized.py — v2 (bugfix device mismatch)
|
| 5 |
+
===========================================================
|
| 6 |
+
Corrections vs v1 :
|
| 7 |
+
• apply_qlora() appelé APRÈS model.to(device) → lora_A/lora_B naissent sur CUDA
|
| 8 |
+
• LoRALinear.__init__ : move explicite des adaptateurs sur le device du base_layer
|
| 9 |
+
• torch.compile désactivé quand USE_CHECKPOINTING=True (conflict dynamo+checkpoint
|
| 10 |
+
avec sous-modules custom) — on utilise COMPILE_AFTER_CKPT pour les cas où on
|
| 11 |
+
veut quand même compiler (USE_CHECKPOINTING=False)
|
| 12 |
+
• Ajout d'un fallback propre : si compile crash, on continue sans compile
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import itertools
|
| 18 |
+
import json
|
| 19 |
+
import math
|
| 20 |
+
import os
|
| 21 |
+
import random
|
| 22 |
+
import time
|
| 23 |
+
from collections import OrderedDict
|
| 24 |
+
from contextlib import nullcontext
|
| 25 |
+
from dataclasses import asdict, dataclass
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
from typing import Iterator, Optional
|
| 28 |
+
|
| 29 |
+
import torch
|
| 30 |
+
import torch.distributed as dist
|
| 31 |
+
import torch.nn as nn
|
| 32 |
+
import torch.nn.functional as F
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
import bitsandbytes as bnb
|
| 36 |
+
from bitsandbytes.nn import Params4bit
|
| 37 |
+
HAS_BNB = True
|
| 38 |
+
except ImportError:
|
| 39 |
+
HAS_BNB = False
|
| 40 |
+
print("[warn] bitsandbytes non disponible – quantification 4-bit désactivée")
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
from flash_attn import flash_attn_func
|
| 44 |
+
HAS_FLASH = True
|
| 45 |
+
except ImportError:
|
| 46 |
+
HAS_FLASH = False
|
| 47 |
+
print("[warn] flash-attn non disponible – fallback F.scaled_dot_product_attention")
|
| 48 |
+
|
| 49 |
+
from datasets import load_dataset
|
| 50 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 51 |
+
from tokenizers import (
|
| 52 |
+
Tokenizer, decoders, models, normalizers,
|
| 53 |
+
pre_tokenizers, processors, trainers,
|
| 54 |
+
)
|
| 55 |
+
from transformers import PreTrainedTokenizerFast
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# ╔══════════════════════════════════════════════════════════════════════════════╗
|
| 59 |
+
# ║ CHEMINS ║
|
| 60 |
+
# ╚══════════════════════════════════════════════════════════════════════════════╝
|
| 61 |
+
|
| 62 |
+
OUT_DIR = Path("./nlp_1b_h100_opt")
|
| 63 |
+
OUT_DIR.mkdir(parents=True, exist_ok=True)
|
| 64 |
+
TOKENIZER_DIR = OUT_DIR / "tokenizer_32k"
|
| 65 |
+
CONFIG_FILE = OUT_DIR / "config.json"
|
| 66 |
+
MODEL_FILE = OUT_DIR / "model.pt"
|
| 67 |
+
BEST_MODEL_FILE= OUT_DIR / "model_best.pt"
|
| 68 |
+
STATE_FILE = OUT_DIR / "train_state.pt"
|
| 69 |
+
BASE_CHECKPOINT: Optional[Path] = None
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# ╔══════════════════════════════════════════════════════════════════════════════╗
|
| 73 |
+
# ║ HYPERPARAMÈTRES ║
|
| 74 |
+
# ╚══════════════════════════════════════════════════════════════════════════════╝
|
| 75 |
+
|
| 76 |
+
SEED = 42
|
| 77 |
+
TARGET_VRAM_GIB= 78.0
|
| 78 |
+
|
| 79 |
+
BLOCK_SIZE = 1024
|
| 80 |
+
VOCAB_SIZE = 32_000
|
| 81 |
+
D_MODEL = 1536
|
| 82 |
+
N_HEADS = 24
|
| 83 |
+
N_LAYERS = 24
|
| 84 |
+
D_FF = 6144
|
| 85 |
+
DROPOUT = 0.0
|
| 86 |
+
|
| 87 |
+
USE_QLORA = True
|
| 88 |
+
LORA_R = 64
|
| 89 |
+
LORA_ALPHA = 128
|
| 90 |
+
LORA_DROPOUT = 0.05
|
| 91 |
+
LORA_TARGET_MODULES = ["qkv", "proj", "w1", "w2", "w3"]
|
| 92 |
+
|
| 93 |
+
NUM_EPOCHS = 10
|
| 94 |
+
LEARNING_RATE = 3e-4
|
| 95 |
+
MIN_LR = 3e-5
|
| 96 |
+
WEIGHT_DECAY = 0.1
|
| 97 |
+
WARMUP_STEPS = 500
|
| 98 |
+
|
| 99 |
+
# ┌─────────────────────────────────────────────────────────────────────────────┐
|
| 100 |
+
# │ RÉGLAGE BATCH SIZE → 78 Go VRAM │
|
| 101 |
+
# │ Démarrer : BATCH_SIZE=8, GRAD_ACCUM_STEPS=2 │
|
| 102 |
+
# │ Augmenter BATCH_SIZE par +2 jusqu'à max_reserved ≈ 77 Go dans les logs │
|
| 103 |
+
# │ Si OOM : BATCH_SIZE -= 1 ou USE_CHECKPOINTING=True │
|
| 104 |
+
# └─────────────────────────────────────────────────────────────────────────────┘
|
| 105 |
+
BATCH_SIZE = 16
|
| 106 |
+
GRAD_ACCUM_STEPS = 1
|
| 107 |
+
MAX_GRAD_NORM = 1.0
|
| 108 |
+
EVAL_EVERY = 500
|
| 109 |
+
SAVE_EVERY = 1_000
|
| 110 |
+
|
| 111 |
+
DTYPE = torch.bfloat16
|
| 112 |
+
|
| 113 |
+
# ── Compile : désactivé quand USE_CHECKPOINTING=True pour éviter le conflict
|
| 114 |
+
# dynamo ↔ checkpoint ↔ sous-modules custom (LoRALinear).
|
| 115 |
+
# Mettre USE_CHECKPOINTING=False ET USE_COMPILE=True pour vitesse max.
|
| 116 |
+
USE_CHECKPOINTING = False # économise ~8× activations VRAM
|
| 117 |
+
USE_COMPILE = True # ← mettre True seulement si USE_CHECKPOINTING=False
|
| 118 |
+
COMPILE_MODE = "reduce-overhead"
|
| 119 |
+
|
| 120 |
+
TRAIN_NUM_WORKERS = 4
|
| 121 |
+
EVAL_NUM_WORKERS = 2
|
| 122 |
+
PREFETCH_FACTOR = 2
|
| 123 |
+
|
| 124 |
+
TOKENIZER_SAMPLE_DOCS_PER_SOURCE = 15_000
|
| 125 |
+
TOKENIZER_CHAR_LIMIT = 2_000
|
| 126 |
+
TEXT_CHAR_LIMIT = 4_000
|
| 127 |
+
|
| 128 |
+
SPECIAL_TOKENS = ["<pad>", "<bos>", "<eos>", "<unk>"]
|
| 129 |
+
PAD_TOKEN, BOS_TOKEN, EOS_TOKEN, UNK_TOKEN = SPECIAL_TOKENS
|
| 130 |
+
|
| 131 |
+
WIKI_CONFIGS = ["20231101.en", "20231101.fr", "20231101.ar"]
|
| 132 |
+
FINEWEB_CONFIG = "sample-10BT"
|
| 133 |
+
DEV_DOCS_PER_WIKI_CONFIG = 1_500
|
| 134 |
+
DEV_DOCS_FINEWEB = 3_000
|
| 135 |
+
TRAIN_DOCS_PER_WIKI_CONFIG = 30_000
|
| 136 |
+
TRAIN_DOCS_FINEWEB = 60_000
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# ╔══════════════════════════════════════════════════════════════════════════════╗
|
| 140 |
+
# ║ DISTRIBUTED ║
|
| 141 |
+
# ╚══════════════════════════════════════════════════════════════════════════════╝
|
| 142 |
+
|
| 143 |
+
def is_distributed() -> bool:
|
| 144 |
+
return dist.is_available() and dist.is_initialized()
|
| 145 |
+
|
| 146 |
+
def get_rank() -> int:
|
| 147 |
+
return dist.get_rank() if is_distributed() else 0
|
| 148 |
+
|
| 149 |
+
def get_world_size() -> int:
|
| 150 |
+
return dist.get_world_size() if is_distributed() else 1
|
| 151 |
+
|
| 152 |
+
def is_main() -> bool:
|
| 153 |
+
return get_rank() == 0
|
| 154 |
+
|
| 155 |
+
def init_distributed() -> Optional[torch.device]:
|
| 156 |
+
local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
| 157 |
+
if local_rank == -1:
|
| 158 |
+
return None
|
| 159 |
+
dist.init_process_group("nccl")
|
| 160 |
+
torch.cuda.set_device(local_rank)
|
| 161 |
+
return torch.device(f"cuda:{local_rank}")
|
| 162 |
+
|
| 163 |
+
def set_seed(seed: int) -> None:
|
| 164 |
+
random.seed(seed)
|
| 165 |
+
torch.manual_seed(seed)
|
| 166 |
+
if torch.cuda.is_available():
|
| 167 |
+
torch.cuda.manual_seed_all(seed)
|
| 168 |
+
|
| 169 |
+
def get_device(ddp_device: Optional[torch.device] = None) -> torch.device:
|
| 170 |
+
if ddp_device is not None:
|
| 171 |
+
return ddp_device
|
| 172 |
+
if torch.cuda.is_available():
|
| 173 |
+
return torch.device(f"cuda:{torch.cuda.current_device()}")
|
| 174 |
+
return torch.device("cpu")
|
| 175 |
+
|
| 176 |
+
def current_cuda_index(device: torch.device) -> int:
|
| 177 |
+
return device.index if device.index is not None else torch.cuda.current_device()
|
| 178 |
+
|
| 179 |
+
def autocast_context(device: torch.device):
|
| 180 |
+
if device.type == "cuda":
|
| 181 |
+
return torch.autocast("cuda", dtype=DTYPE)
|
| 182 |
+
return nullcontext()
|
| 183 |
+
|
| 184 |
+
def unwrap_model(model: nn.Module) -> nn.Module:
|
| 185 |
+
m = model.module if isinstance(model, DDP) else model
|
| 186 |
+
return m._orig_mod if hasattr(m, "_orig_mod") else m
|
| 187 |
+
|
| 188 |
+
def count_parameters(model: nn.Module, trainable_only: bool = True) -> int:
|
| 189 |
+
return sum(p.numel() for p in model.parameters() if not trainable_only or p.requires_grad)
|
| 190 |
+
|
| 191 |
+
def normalize_state_dict_keys(sd: dict) -> OrderedDict:
|
| 192 |
+
out = OrderedDict()
|
| 193 |
+
for k, v in sd.items():
|
| 194 |
+
for prefix in ("module._orig_mod.", "_orig_mod.", "module."):
|
| 195 |
+
if k.startswith(prefix):
|
| 196 |
+
k = k[len(prefix):]
|
| 197 |
+
break
|
| 198 |
+
out[k] = v
|
| 199 |
+
return out
|
| 200 |
+
|
| 201 |
+
def normalize_text(t: str) -> str:
|
| 202 |
+
return " ".join(t.strip().split())
|
| 203 |
+
|
| 204 |
+
def safe_str(x) -> str:
|
| 205 |
+
return x if isinstance(x, str) else ("" if x is None else str(x))
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
# ╔══════════════════════════════════════════════════════════════════════════════╗
|
| 209 |
+
# ║ DATASETS ║
|
| 210 |
+
# ╚══════════════════════════════════════════════════════════════════════════════╝
|
| 211 |
+
|
| 212 |
+
def load_wiki_stream(cfg_name: str):
|
| 213 |
+
return load_dataset("wikimedia/wikipedia", cfg_name, split="train", streaming=True)
|
| 214 |
+
|
| 215 |
+
def load_fineweb_stream():
|
| 216 |
+
return load_dataset("HuggingFaceFW/fineweb-edu", FINEWEB_CONFIG, split="train", streaming=True)
|
| 217 |
+
|
| 218 |
+
def stream_texts(ds, start: int, count: int, char_limit: int) -> Iterator[str]:
|
| 219 |
+
for row in itertools.islice(ds, start, start + count):
|
| 220 |
+
text = normalize_text(safe_str(row.get("text", "")))
|
| 221 |
+
if len(text) >= 20:
|
| 222 |
+
yield text[:char_limit]
|
| 223 |
+
|
| 224 |
+
def tokenizer_training_iterator() -> Iterator[str]:
|
| 225 |
+
for c in WIKI_CONFIGS:
|
| 226 |
+
yield from stream_texts(load_wiki_stream(c), 0, TOKENIZER_SAMPLE_DOCS_PER_SOURCE, TOKENIZER_CHAR_LIMIT)
|
| 227 |
+
yield from stream_texts(load_fineweb_stream(), 0, TOKENIZER_SAMPLE_DOCS_PER_SOURCE, TOKENIZER_CHAR_LIMIT)
|
| 228 |
+
|
| 229 |
+
def build_epoch_train_texts(epoch: int) -> list[str]:
|
| 230 |
+
texts: list[str] = []
|
| 231 |
+
for c in WIKI_CONFIGS:
|
| 232 |
+
start = DEV_DOCS_PER_WIKI_CONFIG + epoch * TRAIN_DOCS_PER_WIKI_CONFIG
|
| 233 |
+
texts.extend(stream_texts(load_wiki_stream(c), start, TRAIN_DOCS_PER_WIKI_CONFIG, TEXT_CHAR_LIMIT))
|
| 234 |
+
start = DEV_DOCS_FINEWEB + epoch * TRAIN_DOCS_FINEWEB
|
| 235 |
+
texts.extend(stream_texts(load_fineweb_stream(), start, TRAIN_DOCS_FINEWEB, TEXT_CHAR_LIMIT))
|
| 236 |
+
random.Random(SEED + epoch).shuffle(texts)
|
| 237 |
+
return texts
|
| 238 |
+
|
| 239 |
+
def build_eval_texts() -> list[str]:
|
| 240 |
+
texts: list[str] = []
|
| 241 |
+
for c in WIKI_CONFIGS:
|
| 242 |
+
texts.extend(stream_texts(load_wiki_stream(c), 0, DEV_DOCS_PER_WIKI_CONFIG, TEXT_CHAR_LIMIT))
|
| 243 |
+
texts.extend(stream_texts(load_fineweb_stream(), 0, DEV_DOCS_FINEWEB, TEXT_CHAR_LIMIT))
|
| 244 |
+
return texts
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
# ╔══════════════════════════════════════════════════════════════════════════════╗
|
| 248 |
+
# ║ PACKED DATASET ║
|
| 249 |
+
# ╚══════════════════════════════════════════════════════════════════════════════╝
|
| 250 |
+
|
| 251 |
+
class PackedTextList(torch.utils.data.IterableDataset):
|
| 252 |
+
def __init__(self, texts, tokenizer, block_size, epoch_seed=0):
|
| 253 |
+
super().__init__()
|
| 254 |
+
self.texts = texts
|
| 255 |
+
self.tokenizer = tokenizer
|
| 256 |
+
self.block_size = block_size
|
| 257 |
+
self.epoch_seed = epoch_seed
|
| 258 |
+
|
| 259 |
+
def __iter__(self):
|
| 260 |
+
worker = torch.utils.data.get_worker_info()
|
| 261 |
+
rank, ws = get_rank(), get_world_size()
|
| 262 |
+
if worker is None:
|
| 263 |
+
shard_mod, shard_id = ws, rank
|
| 264 |
+
else:
|
| 265 |
+
shard_mod = worker.num_workers * ws
|
| 266 |
+
shard_id = rank * worker.num_workers + worker.id
|
| 267 |
+
|
| 268 |
+
rng = random.Random(self.epoch_seed)
|
| 269 |
+
indices = list(range(len(self.texts)))
|
| 270 |
+
rng.shuffle(indices)
|
| 271 |
+
|
| 272 |
+
bos, eos = self.tokenizer.bos_token_id, self.tokenizer.eos_token_id
|
| 273 |
+
buf: list[int] = []
|
| 274 |
+
|
| 275 |
+
for li, ti in enumerate(indices):
|
| 276 |
+
if li % shard_mod != shard_id:
|
| 277 |
+
continue
|
| 278 |
+
ids = self.tokenizer.encode(self.texts[ti], add_special_tokens=False)
|
| 279 |
+
if not ids:
|
| 280 |
+
continue
|
| 281 |
+
buf.extend([bos] + ids + [eos])
|
| 282 |
+
while len(buf) >= self.block_size + 1:
|
| 283 |
+
chunk = buf[: self.block_size + 1]
|
| 284 |
+
buf = buf[self.block_size + 1 :]
|
| 285 |
+
yield {
|
| 286 |
+
"input_ids": torch.tensor(chunk[:-1], dtype=torch.long),
|
| 287 |
+
"labels": torch.tensor(chunk[1:], dtype=torch.long),
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
# ╔══════════════════════════════════════════════════════════════════════════════╗
|
| 292 |
+
# ║ TOKENIZER ║
|
| 293 |
+
# ╚══════════════════════════════════════════════════════════════════════════════╝
|
| 294 |
+
|
| 295 |
+
def tokenizer_ready() -> bool:
|
| 296 |
+
return (TOKENIZER_DIR / "tokenizer.json").exists() and (TOKENIZER_DIR / "tokenizer_config.json").exists()
|
| 297 |
+
|
| 298 |
+
def train_tokenizer_once() -> None:
|
| 299 |
+
TOKENIZER_DIR.mkdir(parents=True, exist_ok=True)
|
| 300 |
+
tok = Tokenizer(models.BPE(unk_token=UNK_TOKEN))
|
| 301 |
+
tok.normalizer = normalizers.Sequence([normalizers.NFKC()])
|
| 302 |
+
tok.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
|
| 303 |
+
tok.decoder = decoders.ByteLevel()
|
| 304 |
+
trainer = trainers.BpeTrainer(
|
| 305 |
+
vocab_size=VOCAB_SIZE, min_frequency=2, show_progress=is_main(),
|
| 306 |
+
special_tokens=SPECIAL_TOKENS, initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
|
| 307 |
+
)
|
| 308 |
+
tok.train_from_iterator(tokenizer_training_iterator(), trainer=trainer)
|
| 309 |
+
bos_id, eos_id = tok.token_to_id(BOS_TOKEN), tok.token_to_id(EOS_TOKEN)
|
| 310 |
+
tok.post_processor = processors.TemplateProcessing(
|
| 311 |
+
single=f"{BOS_TOKEN} $A {EOS_TOKEN}",
|
| 312 |
+
pair=f"{BOS_TOKEN} $A {EOS_TOKEN} $B:1 {EOS_TOKEN}:1",
|
| 313 |
+
special_tokens=[(BOS_TOKEN, bos_id), (EOS_TOKEN, eos_id)],
|
| 314 |
+
)
|
| 315 |
+
tok.save(str(TOKENIZER_DIR / "tokenizer.json"))
|
| 316 |
+
fast = PreTrainedTokenizerFast(
|
| 317 |
+
tokenizer_file=str(TOKENIZER_DIR / "tokenizer.json"),
|
| 318 |
+
bos_token=BOS_TOKEN, eos_token=EOS_TOKEN, unk_token=UNK_TOKEN, pad_token=PAD_TOKEN,
|
| 319 |
+
)
|
| 320 |
+
fast.save_pretrained(str(TOKENIZER_DIR))
|
| 321 |
+
|
| 322 |
+
def train_or_load_tokenizer() -> PreTrainedTokenizerFast:
|
| 323 |
+
TOKENIZER_DIR.mkdir(parents=True, exist_ok=True)
|
| 324 |
+
if not tokenizer_ready():
|
| 325 |
+
if is_distributed():
|
| 326 |
+
if is_main():
|
| 327 |
+
print("Entraînement tokenizer 32k…"); train_tokenizer_once()
|
| 328 |
+
dist.barrier()
|
| 329 |
+
else:
|
| 330 |
+
print("Entraînement tokenizer 32k…"); train_tokenizer_once()
|
| 331 |
+
return PreTrainedTokenizerFast.from_pretrained(str(TOKENIZER_DIR))
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
# ╔══════════════════════════════════════════════════════════════════════════════╗
|
| 335 |
+
# ║ MODÈLE ║
|
| 336 |
+
# ╚══════════════════════════════════════════════════════════════════════════════╝
|
| 337 |
+
|
| 338 |
+
@dataclass
|
| 339 |
+
class GPTConfig:
|
| 340 |
+
vocab_size: int = VOCAB_SIZE
|
| 341 |
+
block_size: int = BLOCK_SIZE
|
| 342 |
+
d_model: int = D_MODEL
|
| 343 |
+
n_heads: int = N_HEADS
|
| 344 |
+
n_layers: int = N_LAYERS
|
| 345 |
+
d_ff: int = D_FF
|
| 346 |
+
dropout: float = DROPOUT
|
| 347 |
+
use_checkpointing: bool = USE_CHECKPOINTING
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
class RMSNorm(nn.Module):
|
| 351 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 352 |
+
super().__init__()
|
| 353 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 354 |
+
self.eps = eps
|
| 355 |
+
def forward(self, x):
|
| 356 |
+
return self.weight * x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
class RotaryEmbedding(nn.Module):
|
| 360 |
+
def __init__(self, dim: int, base: int = 10_000, max_seq: int = 4_096):
|
| 361 |
+
super().__init__()
|
| 362 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
| 363 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 364 |
+
t = torch.arange(max_seq).float()
|
| 365 |
+
freqs = torch.outer(t, inv_freq)
|
| 366 |
+
self.register_buffer("cos_cache", torch.repeat_interleave(freqs.cos(), 2, dim=-1), persistent=False)
|
| 367 |
+
self.register_buffer("sin_cache", torch.repeat_interleave(freqs.sin(), 2, dim=-1), persistent=False)
|
| 368 |
+
def forward(self, seq_len: int, dtype: torch.dtype):
|
| 369 |
+
return self.cos_cache[:seq_len].to(dtype), self.sin_cache[:seq_len].to(dtype)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def rotate_half(x):
|
| 373 |
+
x1, x2 = x[..., ::2], x[..., 1::2]
|
| 374 |
+
return torch.stack((-x2, x1), dim=-1).flatten(-2)
|
| 375 |
+
|
| 376 |
+
def apply_rope(x, cos, sin):
|
| 377 |
+
return x * cos.unsqueeze(0).unsqueeze(0) + rotate_half(x) * sin.unsqueeze(0).unsqueeze(0)
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
class CausalSelfAttention(nn.Module):
|
| 381 |
+
def __init__(self, cfg: GPTConfig):
|
| 382 |
+
super().__init__()
|
| 383 |
+
assert cfg.d_model % cfg.n_heads == 0
|
| 384 |
+
self.n_heads = cfg.n_heads
|
| 385 |
+
self.head_dim = cfg.d_model // cfg.n_heads
|
| 386 |
+
self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
|
| 387 |
+
self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
|
| 388 |
+
self.dropout_p = cfg.dropout
|
| 389 |
+
self.rope = RotaryEmbedding(self.head_dim)
|
| 390 |
+
|
| 391 |
+
def forward(self, x):
|
| 392 |
+
b, t, c = x.shape
|
| 393 |
+
q, k, v = self.qkv(x).split(c, dim=-1)
|
| 394 |
+
q = q.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
|
| 395 |
+
k = k.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
|
| 396 |
+
v = v.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
|
| 397 |
+
cos, sin = self.rope(t, x.dtype)
|
| 398 |
+
q, k = apply_rope(q, cos, sin), apply_rope(k, cos, sin)
|
| 399 |
+
|
| 400 |
+
if HAS_FLASH:
|
| 401 |
+
# Flash Attention 2 attend (b, t, nh, hd)
|
| 402 |
+
q = q.transpose(1, 2)
|
| 403 |
+
k = k.transpose(1, 2)
|
| 404 |
+
v = v.transpose(1, 2)
|
| 405 |
+
y = flash_attn_func(q, k, v, dropout_p=self.dropout_p if self.training else 0.0, causal=True)
|
| 406 |
+
y = y.reshape(b, t, c)
|
| 407 |
+
else:
|
| 408 |
+
y = F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout_p if self.training else 0.0, is_causal=True)
|
| 409 |
+
y = y.transpose(1, 2).contiguous().view(b, t, c)
|
| 410 |
+
|
| 411 |
+
return self.proj(y)
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
class SwiGLU(nn.Module):
|
| 415 |
+
def __init__(self, cfg: GPTConfig):
|
| 416 |
+
super().__init__()
|
| 417 |
+
self.w1 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
|
| 418 |
+
self.w2 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
|
| 419 |
+
self.w3 = nn.Linear(cfg.d_ff, cfg.d_model, bias=False)
|
| 420 |
+
def forward(self, x):
|
| 421 |
+
return self.w3(F.silu(self.w1(x)) * self.w2(x))
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
class Block(nn.Module):
|
| 425 |
+
def __init__(self, cfg: GPTConfig):
|
| 426 |
+
super().__init__()
|
| 427 |
+
self.ln1 = RMSNorm(cfg.d_model)
|
| 428 |
+
self.attn = CausalSelfAttention(cfg)
|
| 429 |
+
self.ln2 = RMSNorm(cfg.d_model)
|
| 430 |
+
self.ff = SwiGLU(cfg)
|
| 431 |
+
def forward(self, x):
|
| 432 |
+
x = x + self.attn(self.ln1(x))
|
| 433 |
+
x = x + self.ff(self.ln2(x))
|
| 434 |
+
return x
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
class GPT(nn.Module):
|
| 438 |
+
def __init__(self, cfg: GPTConfig):
|
| 439 |
+
super().__init__()
|
| 440 |
+
self.cfg = cfg
|
| 441 |
+
self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
|
| 442 |
+
self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)])
|
| 443 |
+
self.ln_f = RMSNorm(cfg.d_model)
|
| 444 |
+
self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
|
| 445 |
+
self.lm_head.weight = self.tok_emb.weight # weight tying
|
| 446 |
+
self.apply(self._init_weights)
|
| 447 |
+
|
| 448 |
+
@staticmethod
|
| 449 |
+
def _init_weights(m):
|
| 450 |
+
if isinstance(m, (nn.Linear, nn.Embedding)):
|
| 451 |
+
nn.init.normal_(m.weight, 0.0, 0.02)
|
| 452 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 453 |
+
nn.init.zeros_(m.bias)
|
| 454 |
+
|
| 455 |
+
def forward(self, input_ids, labels=None):
|
| 456 |
+
x = self.tok_emb(input_ids)
|
| 457 |
+
for block in self.blocks:
|
| 458 |
+
if self.cfg.use_checkpointing and self.training:
|
| 459 |
+
x = torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False)
|
| 460 |
+
else:
|
| 461 |
+
x = block(x)
|
| 462 |
+
logits = self.lm_head(self.ln_f(x))
|
| 463 |
+
loss = None
|
| 464 |
+
if labels is not None:
|
| 465 |
+
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), labels.reshape(-1), ignore_index=-100)
|
| 466 |
+
return logits, loss
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
# ╔══════════════════════════════════════════════════════════════════════════════╗
|
| 470 |
+
# ║ QLORA ║
|
| 471 |
+
# ║ ║
|
| 472 |
+
# ║ CORRECTIF CLÉ : apply_qlora() DOIT être appelé APRÈS model.to(device). ║
|
| 473 |
+
# ║ LoRALinear détecte automatiquement le device du base_layer et y crée ║
|
| 474 |
+
# ║ lora_A / lora_B directement, sans besoin de .to() séparé. ║
|
| 475 |
+
# ╚══════════════════════════════════════════════════════════════════════════════╝
|
| 476 |
+
|
| 477 |
+
class LoRALinear(nn.Module):
|
| 478 |
+
"""
|
| 479 |
+
Adaptateur LoRA autour d'un nn.Linear existant.
|
| 480 |
+
|
| 481 |
+
IMPORTANT : les sous-modules lora_A et lora_B sont créés sur le MÊME
|
| 482 |
+
device que base_layer.weight via le move explicite ci-dessous.
|
| 483 |
+
C'est la correction du bug 'cuda:0 vs cpu' de la v1.
|
| 484 |
+
"""
|
| 485 |
+
def __init__(self, base_layer: nn.Linear, r: int = LORA_R, alpha: int = LORA_ALPHA, dropout: float = LORA_DROPOUT):
|
| 486 |
+
super().__init__()
|
| 487 |
+
self.base = base_layer
|
| 488 |
+
self.r = r
|
| 489 |
+
self.scale = alpha / r
|
| 490 |
+
in_f, out_f = base_layer.in_features, base_layer.out_features
|
| 491 |
+
|
| 492 |
+
# ── Détecter le device du base_layer ──────────────────────────────────
|
| 493 |
+
# base_layer.weight peut être un Params4bit (pas de .device direct)
|
| 494 |
+
try:
|
| 495 |
+
dev = next(base_layer.parameters()).device
|
| 496 |
+
except StopIteration:
|
| 497 |
+
dev = torch.device("cpu")
|
| 498 |
+
|
| 499 |
+
# Créer les adaptateurs DIRECTEMENT sur le bon device
|
| 500 |
+
self.lora_A = nn.Linear(in_f, r, bias=False, device=dev)
|
| 501 |
+
self.lora_B = nn.Linear(r, out_f, bias=False, device=dev)
|
| 502 |
+
self.drop = nn.Dropout(dropout)
|
| 503 |
+
|
| 504 |
+
# Initialisation standard LoRA
|
| 505 |
+
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
|
| 506 |
+
nn.init.zeros_(self.lora_B.weight)
|
| 507 |
+
|
| 508 |
+
# Geler les poids de base
|
| 509 |
+
for p in self.base.parameters():
|
| 510 |
+
p.requires_grad = False
|
| 511 |
+
|
| 512 |
+
def forward(self, x):
|
| 513 |
+
return self.base(x) + self.lora_B(self.lora_A(self.drop(x))) * self.scale
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
def apply_qlora(model: GPT, device: torch.device) -> GPT:
|
| 517 |
+
"""
|
| 518 |
+
Remplace les couches cibles par LoRALinear.
|
| 519 |
+
À appeler IMPÉRATIVEMENT après model.to(device).
|
| 520 |
+
"""
|
| 521 |
+
if not USE_QLORA:
|
| 522 |
+
return model
|
| 523 |
+
|
| 524 |
+
replaced = 0
|
| 525 |
+
# Collecter d'abord pour éviter de modifier le dict pendant l'itération
|
| 526 |
+
targets = []
|
| 527 |
+
for name, module in model.named_modules():
|
| 528 |
+
parts = name.split(".")
|
| 529 |
+
if parts[-1] in LORA_TARGET_MODULES and isinstance(module, nn.Linear):
|
| 530 |
+
targets.append((name, module))
|
| 531 |
+
|
| 532 |
+
for name, module in targets:
|
| 533 |
+
parts = name.split(".")
|
| 534 |
+
parent = model
|
| 535 |
+
for part in parts[:-1]:
|
| 536 |
+
parent = getattr(parent, part)
|
| 537 |
+
|
| 538 |
+
lora_layer = LoRALinear(module)
|
| 539 |
+
setattr(parent, parts[-1], lora_layer)
|
| 540 |
+
replaced += 1
|
| 541 |
+
|
| 542 |
+
if is_main():
|
| 543 |
+
print(f"QLoRA : {replaced} couches remplacées (device={device}, NF4={HAS_BNB})")
|
| 544 |
+
return model
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
def freeze_base_weights(model: GPT) -> None:
|
| 548 |
+
"""Seuls lora_A et lora_B restent entraînables."""
|
| 549 |
+
for name, p in model.named_parameters():
|
| 550 |
+
p.requires_grad = ("lora_A" in name or "lora_B" in name)
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
# ╔══════════════════════════════════════════════════════════════════════════════╗
|
| 554 |
+
# ║ OPTIMIZER ║
|
| 555 |
+
# ╚══════════════════════════════════════════════════════════════════════════════╝
|
| 556 |
+
|
| 557 |
+
def build_optimizer(model: nn.Module) -> torch.optim.Optimizer:
|
| 558 |
+
decay, no_decay = [], []
|
| 559 |
+
for name, p in unwrap_model(model).named_parameters():
|
| 560 |
+
if not p.requires_grad:
|
| 561 |
+
continue
|
| 562 |
+
(decay if p.ndim >= 2 and "weight" in name else no_decay).append(p)
|
| 563 |
+
|
| 564 |
+
groups = [
|
| 565 |
+
{"params": decay, "weight_decay": WEIGHT_DECAY},
|
| 566 |
+
{"params": no_decay, "weight_decay": 0.0},
|
| 567 |
+
]
|
| 568 |
+
|
| 569 |
+
if HAS_BNB:
|
| 570 |
+
return bnb.optim.PagedAdamW8bit(groups, lr=LEARNING_RATE, betas=(0.9, 0.95), eps=1e-8)
|
| 571 |
+
|
| 572 |
+
return torch.optim.AdamW(groups, lr=LEARNING_RATE, betas=(0.9, 0.95), eps=1e-8, fused=torch.cuda.is_available())
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
def cosine_lr(step: int, total_steps: int) -> float:
|
| 576 |
+
if step < WARMUP_STEPS:
|
| 577 |
+
return LEARNING_RATE * step / max(1, WARMUP_STEPS)
|
| 578 |
+
p = min(1.0, (step - WARMUP_STEPS) / max(1, total_steps - WARMUP_STEPS))
|
| 579 |
+
return MIN_LR + 0.5 * (LEARNING_RATE - MIN_LR) * (1 + math.cos(math.pi * p))
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
# ╔══════════════════════════════════════════════════════════════════════════════╗
|
| 583 |
+
# ║ CHECKPOINT ║
|
| 584 |
+
# ╚══════════════════════════════════════════════════════════════════════════════╝
|
| 585 |
+
|
| 586 |
+
def save_checkpoint(model, optimizer, epoch, step, best_loss, path):
|
| 587 |
+
raw = unwrap_model(model)
|
| 588 |
+
torch.save({
|
| 589 |
+
"model": normalize_state_dict_keys(raw.state_dict()),
|
| 590 |
+
"optimizer": optimizer.state_dict(),
|
| 591 |
+
"epoch": epoch, "step": step, "best_loss": best_loss,
|
| 592 |
+
"config": asdict(raw.cfg),
|
| 593 |
+
}, path)
|
| 594 |
+
|
| 595 |
+
def maybe_load_base_checkpoint(model, device):
|
| 596 |
+
if BASE_CHECKPOINT is None or not Path(BASE_CHECKPOINT).exists():
|
| 597 |
+
return
|
| 598 |
+
ckpt = torch.load(BASE_CHECKPOINT, map_location=device)
|
| 599 |
+
unwrap_model(model).load_state_dict(normalize_state_dict_keys(ckpt["model"]), strict=False)
|
| 600 |
+
|
| 601 |
+
def load_resume_checkpoint(model, optimizer, path, device):
|
| 602 |
+
ckpt = torch.load(path, map_location=device)
|
| 603 |
+
unwrap_model(model).load_state_dict(normalize_state_dict_keys(ckpt["model"]), strict=True)
|
| 604 |
+
try:
|
| 605 |
+
optimizer.load_state_dict(ckpt["optimizer"])
|
| 606 |
+
except Exception as e:
|
| 607 |
+
print(f"[warn] Optimizer state non repris: {e}")
|
| 608 |
+
return int(ckpt.get("epoch", 0)), int(ckpt.get("step", 0)), float(ckpt.get("best_loss", 1e9))
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
# ╔══════════════════════════════════════════════════════════════════════════════╗
|
| 612 |
+
# ║ ÉVALUATION ║
|
| 613 |
+
# ╚══════════════════════════════════════════════════════════════════════════════╝
|
| 614 |
+
|
| 615 |
+
@torch.no_grad()
|
| 616 |
+
def evaluate(model, loader, device, max_batches=200) -> float:
|
| 617 |
+
model.eval()
|
| 618 |
+
losses = []
|
| 619 |
+
for i, batch in enumerate(loader):
|
| 620 |
+
if i >= max_batches:
|
| 621 |
+
break
|
| 622 |
+
inp = batch["input_ids"].to(device, non_blocking=True)
|
| 623 |
+
lbl = batch["labels"].to(device, non_blocking=True)
|
| 624 |
+
with autocast_context(device):
|
| 625 |
+
_, loss = model(inp, lbl)
|
| 626 |
+
losses.append(loss.item())
|
| 627 |
+
model.train()
|
| 628 |
+
return sum(losses) / max(1, len(losses))
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
# ╔══════════════════════════════════════════════════════════════════════════════╗
|
| 632 |
+
# ║ DATALOADER ║
|
| 633 |
+
# ╚══════════════════════════════════════════════════════════════════════════════╝
|
| 634 |
+
|
| 635 |
+
def make_loader(dataset, batch_size, num_workers, is_cuda):
|
| 636 |
+
kwargs = dict(batch_size=batch_size, num_workers=num_workers, pin_memory=is_cuda)
|
| 637 |
+
if num_workers > 0:
|
| 638 |
+
kwargs["persistent_workers"] = True
|
| 639 |
+
kwargs["prefetch_factor"] = PREFETCH_FACTOR
|
| 640 |
+
return torch.utils.data.DataLoader(dataset, **kwargs)
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
# ╔══════════════════════════════════════════════════════════════════════════════╗
|
| 644 |
+
# ║ MAIN ║
|
| 645 |
+
# ╚══════════════════════════════════════════════════════════════════════════════╝
|
| 646 |
+
|
| 647 |
+
def main() -> None:
|
| 648 |
+
ddp_device = init_distributed()
|
| 649 |
+
set_seed(SEED + get_rank())
|
| 650 |
+
device = get_device(ddp_device)
|
| 651 |
+
is_cuda = device.type == "cuda"
|
| 652 |
+
|
| 653 |
+
cuda_idx = None
|
| 654 |
+
vram_fraction = None
|
| 655 |
+
|
| 656 |
+
if is_cuda:
|
| 657 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 658 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 659 |
+
torch.set_float32_matmul_precision("high")
|
| 660 |
+
cuda_idx = current_cuda_index(device)
|
| 661 |
+
_, total = torch.cuda.mem_get_info(cuda_idx)
|
| 662 |
+
vram_fraction = min(TARGET_VRAM_GIB * (1024**3) / total, 0.999)
|
| 663 |
+
torch.cuda.memory.set_per_process_memory_fraction(vram_fraction, device=cuda_idx)
|
| 664 |
+
|
| 665 |
+
if is_main():
|
| 666 |
+
print("=" * 72)
|
| 667 |
+
print(" GPT ~1B | H100 80 Go | QLoRA + BF16 + TF32 | v2 (device fix)")
|
| 668 |
+
print("=" * 72)
|
| 669 |
+
print(f"Device : {device} | World: {get_world_size()} GPU(s)")
|
| 670 |
+
print(f"Flash-2 : {HAS_FLASH} | BNB 4-bit: {HAS_BNB} | QLoRA: {USE_QLORA}")
|
| 671 |
+
print(f"Grad ckpt: {USE_CHECKPOINTING} | Compile: {USE_COMPILE} ({COMPILE_MODE})")
|
| 672 |
+
if is_cuda:
|
| 673 |
+
free, total = torch.cuda.mem_get_info(cuda_idx)
|
| 674 |
+
print(f"GPU : {torch.cuda.get_device_name(cuda_idx)}")
|
| 675 |
+
print(f"VRAM : {total/1024**3:.1f} GiB | libre: {free/1024**3:.1f} GiB")
|
| 676 |
+
|
| 677 |
+
tokenizer = train_or_load_tokenizer()
|
| 678 |
+
cfg = GPTConfig(vocab_size=len(tokenizer))
|
| 679 |
+
|
| 680 |
+
if is_main():
|
| 681 |
+
CONFIG_FILE.write_text(json.dumps(asdict(cfg), indent=2, ensure_ascii=False), encoding="utf-8")
|
| 682 |
+
|
| 683 |
+
# ── 1. Créer le modèle ────────────────────────────────────────────────────
|
| 684 |
+
model = GPT(cfg).to(device)
|
| 685 |
+
|
| 686 |
+
# ── 2. Appliquer QLoRA APRÈS .to(device) ─────────────────────────────────
|
| 687 |
+
# C'est la correction principale : lora_A/lora_B sont créés sur CUDA
|
| 688 |
+
if USE_QLORA:
|
| 689 |
+
model = apply_qlora(model, device)
|
| 690 |
+
freeze_base_weights(model)
|
| 691 |
+
|
| 692 |
+
maybe_load_base_checkpoint(model, device)
|
| 693 |
+
|
| 694 |
+
# ── 3. torch.compile (seulement si USE_CHECKPOINTING=False) ──────────────
|
| 695 |
+
# La combinaison compile + checkpoint + LoRALinear custom est instable
|
| 696 |
+
# avec torch.dynamo sur PyTorch 2.x. Choisir l'un ou l'autre.
|
| 697 |
+
if USE_COMPILE and not USE_CHECKPOINTING and hasattr(torch, "compile"):
|
| 698 |
+
try:
|
| 699 |
+
model = torch.compile(model, mode=COMPILE_MODE)
|
| 700 |
+
if is_main():
|
| 701 |
+
print(f"torch.compile activé ({COMPILE_MODE})")
|
| 702 |
+
except Exception as e:
|
| 703 |
+
if is_main():
|
| 704 |
+
print(f"[warn] torch.compile échoué ({e}) — poursuite sans compile")
|
| 705 |
+
|
| 706 |
+
# ── 4. DDP ────────────────────────────────────────────────────────────────
|
| 707 |
+
if is_distributed():
|
| 708 |
+
model = DDP(model, device_ids=[device.index])
|
| 709 |
+
|
| 710 |
+
optimizer = build_optimizer(model)
|
| 711 |
+
|
| 712 |
+
# ── Datasets ──────────────────────────────────────────────────────────────
|
| 713 |
+
eval_texts = build_eval_texts()
|
| 714 |
+
eval_ds = PackedTextList(eval_texts, tokenizer, cfg.block_size, SEED + 999)
|
| 715 |
+
eval_loader = make_loader(eval_ds, BATCH_SIZE, EVAL_NUM_WORKERS, is_cuda)
|
| 716 |
+
|
| 717 |
+
init_texts = build_epoch_train_texts(0)
|
| 718 |
+
steps_per_epoch = max(1, len(init_texts) // BATCH_SIZE)
|
| 719 |
+
total_steps_est = steps_per_epoch * NUM_EPOCHS
|
| 720 |
+
|
| 721 |
+
# ── Reprise ───────────────────────────────────────────────────────────────
|
| 722 |
+
start_epoch, start_step, best_eval = 0, 0, 1e9
|
| 723 |
+
if STATE_FILE.exists():
|
| 724 |
+
try:
|
| 725 |
+
if is_main(): print(f"Reprise depuis {STATE_FILE}")
|
| 726 |
+
start_epoch, start_step, best_eval = load_resume_checkpoint(model, optimizer, STATE_FILE, device)
|
| 727 |
+
except Exception as e:
|
| 728 |
+
if is_main():
|
| 729 |
+
bad = STATE_FILE.with_suffix(".corrupt.pt")
|
| 730 |
+
print(f"[warn] Checkpoint illisible: {e}")
|
| 731 |
+
try: STATE_FILE.rename(bad)
|
| 732 |
+
except Exception: pass
|
| 733 |
+
start_epoch, start_step, best_eval = 0, 0, 1e9
|
| 734 |
+
|
| 735 |
+
if is_main():
|
| 736 |
+
raw = unwrap_model(model)
|
| 737 |
+
n_total = count_parameters(raw, False)
|
| 738 |
+
n_train = count_parameters(raw, True)
|
| 739 |
+
print(f"Paramètres totaux : {n_total/1e9:.3f}B")
|
| 740 |
+
print(f"Paramètres entraînés : {n_train/1e6:.1f}M ({100*n_train/max(1,n_total):.2f}%)")
|
| 741 |
+
print(f"Batch size : {BATCH_SIZE} | Grad accum: {GRAD_ACCUM_STEPS} | Effective: {BATCH_SIZE*GRAD_ACCUM_STEPS}")
|
| 742 |
+
print(f"Steps estimés: {total_steps_est} | Eval texts: {len(eval_texts)}")
|
| 743 |
+
print()
|
| 744 |
+
print("── Conseil VRAM ───────────────────────────���────────────────────")
|
| 745 |
+
print(" Surveille 'max_reserved=XX GiB' à step 50.")
|
| 746 |
+
print(" Augmente BATCH_SIZE par +2 jusqu'à ~77 Go réservés.")
|
| 747 |
+
print(" Si OOM : BATCH_SIZE -= 1 ou USE_CHECKPOINTING=True.")
|
| 748 |
+
print("────────────────────────────────────────────────────────────────")
|
| 749 |
+
|
| 750 |
+
# ── Boucle d'entraînement ─────────────────────────────────────────────────
|
| 751 |
+
model.train()
|
| 752 |
+
optimizer.zero_grad(set_to_none=True)
|
| 753 |
+
|
| 754 |
+
global_step = start_step
|
| 755 |
+
t0 = time.time()
|
| 756 |
+
log_loss_sum = 0.0
|
| 757 |
+
log_loss_count = 0
|
| 758 |
+
tokens_since_log = 0
|
| 759 |
+
last_log = time.time()
|
| 760 |
+
|
| 761 |
+
if is_cuda:
|
| 762 |
+
torch.cuda.reset_peak_memory_stats(cuda_idx)
|
| 763 |
+
|
| 764 |
+
for epoch in range(start_epoch, NUM_EPOCHS):
|
| 765 |
+
if is_main():
|
| 766 |
+
print(f"\n{'='*20} Epoch {epoch+1}/{NUM_EPOCHS} {'='*20}")
|
| 767 |
+
|
| 768 |
+
train_texts = build_epoch_train_texts(epoch)
|
| 769 |
+
train_ds = PackedTextList(train_texts, tokenizer, cfg.block_size, SEED + epoch)
|
| 770 |
+
train_loader = make_loader(train_ds, BATCH_SIZE, TRAIN_NUM_WORKERS, is_cuda)
|
| 771 |
+
|
| 772 |
+
for micro_step, batch in enumerate(train_loader):
|
| 773 |
+
inp = batch["input_ids"].to(device, non_blocking=True)
|
| 774 |
+
lbl = batch["labels"].to(device, non_blocking=True)
|
| 775 |
+
|
| 776 |
+
with autocast_context(device):
|
| 777 |
+
_, loss = model(inp, lbl)
|
| 778 |
+
|
| 779 |
+
(loss / GRAD_ACCUM_STEPS).backward()
|
| 780 |
+
|
| 781 |
+
log_loss_sum += loss.item()
|
| 782 |
+
log_loss_count += 1
|
| 783 |
+
tokens_since_log += inp.numel()
|
| 784 |
+
|
| 785 |
+
if (micro_step + 1) % GRAD_ACCUM_STEPS != 0:
|
| 786 |
+
continue
|
| 787 |
+
|
| 788 |
+
lr = cosine_lr(global_step, total_steps_est)
|
| 789 |
+
for group in optimizer.param_groups:
|
| 790 |
+
group["lr"] = lr
|
| 791 |
+
|
| 792 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
|
| 793 |
+
optimizer.step()
|
| 794 |
+
optimizer.zero_grad(set_to_none=True)
|
| 795 |
+
global_step += 1
|
| 796 |
+
|
| 797 |
+
if global_step % 50 == 0 and is_main():
|
| 798 |
+
now = time.time()
|
| 799 |
+
elapsed = max(1e-6, now - last_log)
|
| 800 |
+
tok_s = tokens_since_log / elapsed
|
| 801 |
+
avg_loss = log_loss_sum / max(1, log_loss_count)
|
| 802 |
+
print(
|
| 803 |
+
f"ep {epoch+1}/{NUM_EPOCHS} | step={global_step:5d} | "
|
| 804 |
+
f"loss={avg_loss:.4f} | lr={lr:.2e} | {tok_s:,.0f} tok/s"
|
| 805 |
+
)
|
| 806 |
+
if is_cuda:
|
| 807 |
+
alloc = torch.cuda.memory_allocated(cuda_idx) / 1024**3
|
| 808 |
+
reserved = torch.cuda.memory_reserved(cuda_idx) / 1024**3
|
| 809 |
+
max_alloc = torch.cuda.max_memory_allocated(cuda_idx) / 1024**3
|
| 810 |
+
max_res = torch.cuda.max_memory_reserved(cuda_idx) / 1024**3
|
| 811 |
+
print(
|
| 812 |
+
f"GPU mem | alloc={alloc:.2f} | reserved={reserved:.2f} | "
|
| 813 |
+
f"max_alloc={max_alloc:.2f} | max_reserved={max_res:.2f} (GiB)"
|
| 814 |
+
)
|
| 815 |
+
last_log = now
|
| 816 |
+
tokens_since_log = 0
|
| 817 |
+
log_loss_sum = 0.0
|
| 818 |
+
log_loss_count = 0
|
| 819 |
+
|
| 820 |
+
if global_step % EVAL_EVERY == 0 and is_main():
|
| 821 |
+
val_loss = evaluate(model, eval_loader, device)
|
| 822 |
+
print(f"[eval] step {global_step:5d} | val_loss={val_loss:.4f}")
|
| 823 |
+
if val_loss < best_eval:
|
| 824 |
+
best_eval = val_loss
|
| 825 |
+
save_checkpoint(model, optimizer, epoch, global_step, best_eval, BEST_MODEL_FILE)
|
| 826 |
+
print(f"✓ Meilleur modèle → {BEST_MODEL_FILE}")
|
| 827 |
+
|
| 828 |
+
if global_step % SAVE_EVERY == 0 and is_main():
|
| 829 |
+
save_checkpoint(model, optimizer, epoch, global_step, best_eval, STATE_FILE)
|
| 830 |
+
save_checkpoint(model, optimizer, epoch, global_step, best_eval, MODEL_FILE)
|
| 831 |
+
print(f"✓ Checkpoint → {MODEL_FILE}")
|
| 832 |
+
|
| 833 |
+
if is_main():
|
| 834 |
+
save_checkpoint(model, optimizer, epoch + 1, global_step, best_eval, STATE_FILE)
|
| 835 |
+
ckpt = OUT_DIR / f"model_epoch_{epoch+1:02d}.pt"
|
| 836 |
+
save_checkpoint(model, optimizer, epoch + 1, global_step, best_eval, ckpt)
|
| 837 |
+
print(f"✓ Fin epoch {epoch+1}/{NUM_EPOCHS} → {ckpt}")
|
| 838 |
+
|
| 839 |
+
if is_main():
|
| 840 |
+
save_checkpoint(model, optimizer, NUM_EPOCHS, global_step, best_eval, MODEL_FILE)
|
| 841 |
+
save_checkpoint(model, optimizer, NUM_EPOCHS, global_step, best_eval, STATE_FILE)
|
| 842 |
+
total_min = (time.time() - t0) / 60
|
| 843 |
+
print(f"\nModèle final → {MODEL_FILE}")
|
| 844 |
+
print(f"Meilleur modèle → {BEST_MODEL_FILE}")
|
| 845 |
+
print(f"Temps total : {total_min:.1f} min | Steps: {global_step}")
|
| 846 |
+
|
| 847 |
+
if is_distributed():
|
| 848 |
+
dist.destroy_process_group()
|
| 849 |
+
|
| 850 |
+
|
| 851 |
+
if __name__ == "__main__":
|
| 852 |
+
main()
|
train_aramix_h100_full.py
ADDED
|
@@ -0,0 +1,1055 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Entraînement LLM Multi-Domaine — H100 80 Go [CORRIGÉ]
|
| 6 |
+
════════════════════════════════════════════════════════
|
| 7 |
+
Architecture : GPT causal ~435M
|
| 8 |
+
RMSNorm · RoPE · SwiGLU · Flash Attention (SDPA)
|
| 9 |
+
Précision : BF16 natif + TF32 + fused AdamW
|
| 10 |
+
Compilation : torch.compile(mode="reduce-overhead")
|
| 11 |
+
Dataset : 10 domaines en streaming HF, échantillonnage pondéré
|
| 12 |
+
|
| 13 |
+
Correctifs v2 :
|
| 14 |
+
✓ trust_remote_code=True supprimé (déprécié datasets>=3.x)
|
| 15 |
+
✓ wikipedia → datasets.load_dataset sans script legacy
|
| 16 |
+
✓ Datasets remplacés par leurs équivalents Parquet/modernes
|
| 17 |
+
✓ MAX_STEPS = 5 000 (2 epochs estimées sur corpus réduit)
|
| 18 |
+
✓ pubmed_abstracts → pubmed_qa (Parquet natif)
|
| 19 |
+
✓ RedPajama CC → allenai/c4 (en/fr/ar)
|
| 20 |
+
✓ pile-of-law → joelniklaus/pile_of_law (Parquet)
|
| 21 |
+
✓ Gestion propre de StopIteration dans le DataLoader
|
| 22 |
+
|
| 23 |
+
Usage mono-GPU :
|
| 24 |
+
python train_aramix_h100_full.py
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
from __future__ import annotations
|
| 28 |
+
|
| 29 |
+
import copy
|
| 30 |
+
import gc
|
| 31 |
+
import json
|
| 32 |
+
import math
|
| 33 |
+
import os
|
| 34 |
+
import random
|
| 35 |
+
import time
|
| 36 |
+
from dataclasses import asdict, dataclass, field
|
| 37 |
+
from pathlib import Path
|
| 38 |
+
from typing import Dict, Iterator, List, Optional, Tuple
|
| 39 |
+
|
| 40 |
+
import torch
|
| 41 |
+
import torch.distributed as dist
|
| 42 |
+
import torch.nn as nn
|
| 43 |
+
import torch.nn.functional as F
|
| 44 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 45 |
+
from torch.utils.data import DataLoader, IterableDataset
|
| 46 |
+
from datasets import load_dataset
|
| 47 |
+
from tokenizers import (
|
| 48 |
+
Tokenizer,
|
| 49 |
+
decoders,
|
| 50 |
+
models,
|
| 51 |
+
normalizers,
|
| 52 |
+
pre_tokenizers,
|
| 53 |
+
processors,
|
| 54 |
+
trainers,
|
| 55 |
+
)
|
| 56 |
+
from transformers import PreTrainedTokenizerFast
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# ══════════════════════════════════════════════════════════════════
|
| 60 |
+
# §1 CONFIGURATION GLOBALE
|
| 61 |
+
# ══════════════════════════════════════════════════════════════════
|
| 62 |
+
|
| 63 |
+
OUT_DIR = Path("./aramix_h100")
|
| 64 |
+
OUT_DIR.mkdir(parents=True, exist_ok=True)
|
| 65 |
+
|
| 66 |
+
SEED = 42
|
| 67 |
+
|
| 68 |
+
# ── Tokenizer ─────────────────────────────────────────────────────
|
| 69 |
+
TOKENIZER_DIR = OUT_DIR / "tokenizer_32k"
|
| 70 |
+
TOKENIZER_VOCAB = 32_000
|
| 71 |
+
TOKENIZER_SAMPLE_DOCS = 80_000 # réduit pour aller plus vite
|
| 72 |
+
TOKENIZER_CHAR_LIMIT = 2_000
|
| 73 |
+
|
| 74 |
+
SPECIAL_TOKENS = ["<pad>", "<bos>", "<eos>", "<unk>"]
|
| 75 |
+
PAD_TOKEN, BOS_TOKEN, EOS_TOKEN, UNK_TOKEN = SPECIAL_TOKENS
|
| 76 |
+
|
| 77 |
+
# ── Architecture (~435M) ──────────────────────────────────────────
|
| 78 |
+
VOCAB_SIZE = 32_000 # mis à jour après chargement du tokenizer
|
| 79 |
+
BLOCK_SIZE = 1024
|
| 80 |
+
D_MODEL = 1024
|
| 81 |
+
N_HEADS = 18
|
| 82 |
+
N_LAYERS = 24
|
| 83 |
+
D_FF = 4096
|
| 84 |
+
DROPOUT = 0.1
|
| 85 |
+
|
| 86 |
+
# ── Entraînement ──────────────────────────────────────────────────
|
| 87 |
+
LEARNING_RATE = 3e-4
|
| 88 |
+
MIN_LR = 3e-5
|
| 89 |
+
WEIGHT_DECAY = 0.1
|
| 90 |
+
WARMUP_STEPS = 200 # réduit proportionnellement (5k steps)
|
| 91 |
+
MAX_STEPS = 5_000 # ← 5 000 steps / ~2 epochs
|
| 92 |
+
MAX_GRAD_NORM = 1.0
|
| 93 |
+
|
| 94 |
+
BATCH_SIZE = 32 # H100 80 Go : 32×1024 en BF16 ≈ 26 Go
|
| 95 |
+
GRAD_ACCUM_STEPS = 1
|
| 96 |
+
EVAL_EVERY = 500
|
| 97 |
+
SAVE_EVERY = 1_000
|
| 98 |
+
MAX_EVAL_DOCS_DOM = 300
|
| 99 |
+
TRAIN_CHAR_LIMIT = 4_000
|
| 100 |
+
|
| 101 |
+
# ── Précision & compilation ───────────────────────────────────────
|
| 102 |
+
DTYPE = torch.bfloat16
|
| 103 |
+
USE_COMPILE = True
|
| 104 |
+
USE_CHECKPOINTING = False # inutile avec 80 Go VRAM
|
| 105 |
+
|
| 106 |
+
# ── Fichiers ─────────────────────────────────────────────────────
|
| 107 |
+
MODEL_FILE = OUT_DIR / "model.pt"
|
| 108 |
+
BEST_MODEL_FILE = OUT_DIR / "model_best.pt"
|
| 109 |
+
STATE_FILE = OUT_DIR / "train_state.pt"
|
| 110 |
+
CONFIG_FILE = OUT_DIR / "config.json"
|
| 111 |
+
LOG_FILE = OUT_DIR / "train_log.jsonl"
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# ══════════════════════════════════════════════════════════════════
|
| 115 |
+
# §2 REGISTRE MULTI-DOMAINES (tous compatibles Parquet / sans script)
|
| 116 |
+
# ══════════════════════════════════════════════════════════════════
|
| 117 |
+
|
| 118 |
+
@dataclass
|
| 119 |
+
class DomainConfig:
|
| 120 |
+
name: str
|
| 121 |
+
hf_path: str
|
| 122 |
+
hf_subset: Optional[str]
|
| 123 |
+
hf_split: str
|
| 124 |
+
text_field: str
|
| 125 |
+
weight: float
|
| 126 |
+
char_limit: int
|
| 127 |
+
min_chars: int = 80
|
| 128 |
+
lang_filter: Optional[str] = None
|
| 129 |
+
# ── NOUVEAU : certains datasets nécessitent un champ imbriqué ──
|
| 130 |
+
extra_kwargs: dict = field(default_factory=dict)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
DOMAINS: List[DomainConfig] = [
|
| 134 |
+
|
| 135 |
+
# ── Code / Dev ────────────────────────────────────────────────
|
| 136 |
+
# bigcode/starcoderdata utilise un script MAIS accepte désormais
|
| 137 |
+
# le paramètre trust_remote_code=False via HF_DATASETS_TRUST_REMOTE_CODE=0
|
| 138 |
+
# → on passe par sa version Parquet directe sur HF hub
|
| 139 |
+
DomainConfig(
|
| 140 |
+
name="code_python",
|
| 141 |
+
hf_path="bigcode/the-stack-dedup",
|
| 142 |
+
hf_subset="data/python",
|
| 143 |
+
hf_split="train", text_field="content",
|
| 144 |
+
weight=0.12, char_limit=6_000, min_chars=100,
|
| 145 |
+
),
|
| 146 |
+
DomainConfig(
|
| 147 |
+
name="code_js",
|
| 148 |
+
hf_path="bigcode/the-stack-dedup",
|
| 149 |
+
hf_subset="data/javascript",
|
| 150 |
+
hf_split="train", text_field="content",
|
| 151 |
+
weight=0.06, char_limit=5_000, min_chars=100,
|
| 152 |
+
),
|
| 153 |
+
DomainConfig(
|
| 154 |
+
name="code_csharp",
|
| 155 |
+
hf_path="bigcode/the-stack-dedup",
|
| 156 |
+
hf_subset="data/c-sharp",
|
| 157 |
+
hf_split="train", text_field="content",
|
| 158 |
+
weight=0.04, char_limit=5_000, min_chars=100,
|
| 159 |
+
),
|
| 160 |
+
|
| 161 |
+
# ── Médical ───────────────────────────────────────────────────
|
| 162 |
+
DomainConfig(
|
| 163 |
+
name="medical_flashcards",
|
| 164 |
+
hf_path="medalpaca/medical_meadow_medical_flashcards",
|
| 165 |
+
hf_subset=None, hf_split="train", text_field="output",
|
| 166 |
+
weight=0.06, char_limit=3_000, min_chars=60,
|
| 167 |
+
),
|
| 168 |
+
# pubmed_abstracts → pubmed_qa (natif Parquet)
|
| 169 |
+
DomainConfig(
|
| 170 |
+
name="medical_pubmed",
|
| 171 |
+
hf_path="qiaojin/PubMedQA",
|
| 172 |
+
hf_subset="pqa_labeled",
|
| 173 |
+
hf_split="train", text_field="long_answer",
|
| 174 |
+
weight=0.06, char_limit=4_000, min_chars=100,
|
| 175 |
+
),
|
| 176 |
+
|
| 177 |
+
# ── Français ──────────────────────────────────────────────────
|
| 178 |
+
# wikipedia sans script : utiliser la version datasets>=2.14 qui
|
| 179 |
+
# charge directement les Parquet sans script .py
|
| 180 |
+
DomainConfig(
|
| 181 |
+
name="french_wiki",
|
| 182 |
+
hf_path="wikimedia/wikipedia",
|
| 183 |
+
hf_subset="20231101.fr",
|
| 184 |
+
hf_split="train", text_field="text",
|
| 185 |
+
weight=0.08, char_limit=5_000, min_chars=100,
|
| 186 |
+
),
|
| 187 |
+
DomainConfig(
|
| 188 |
+
name="french_culture",
|
| 189 |
+
hf_path="lyon-nlp/corpus-france-culture-inter-2023",
|
| 190 |
+
hf_subset=None, hf_split="train", text_field="text",
|
| 191 |
+
weight=0.04, char_limit=4_000, min_chars=80,
|
| 192 |
+
),
|
| 193 |
+
DomainConfig(
|
| 194 |
+
name="french_news",
|
| 195 |
+
hf_path="mlsum", hf_subset="fr",
|
| 196 |
+
hf_split="train", text_field="text",
|
| 197 |
+
weight=0.03, char_limit=3_000, min_chars=80,
|
| 198 |
+
),
|
| 199 |
+
|
| 200 |
+
# ── Arabe ─────────────────────────────────────────────────────
|
| 201 |
+
DomainConfig(
|
| 202 |
+
name="arabic_aramix",
|
| 203 |
+
hf_path="AdaMLLab/AraMix", hf_subset="matched",
|
| 204 |
+
hf_split="train", text_field="text",
|
| 205 |
+
weight=0.10, char_limit=4_000, min_chars=80,
|
| 206 |
+
),
|
| 207 |
+
DomainConfig(
|
| 208 |
+
name="arabic_wiki",
|
| 209 |
+
hf_path="wikimedia/wikipedia",
|
| 210 |
+
hf_subset="20231101.ar",
|
| 211 |
+
hf_split="train", text_field="text",
|
| 212 |
+
weight=0.05, char_limit=4_000, min_chars=80,
|
| 213 |
+
),
|
| 214 |
+
# OSCAR-2301 → oscar-corpus/OSCAR-2301 reste supporté sans script
|
| 215 |
+
DomainConfig(
|
| 216 |
+
name="arabic_oscar",
|
| 217 |
+
hf_path="oscar-corpus/OSCAR-2301", hf_subset="ar",
|
| 218 |
+
hf_split="train", text_field="content",
|
| 219 |
+
weight=0.04, char_limit=3_000, min_chars=80,
|
| 220 |
+
),
|
| 221 |
+
|
| 222 |
+
# ── Créatif ───────────────────────────────────────────────────
|
| 223 |
+
DomainConfig(
|
| 224 |
+
name="creative_writing",
|
| 225 |
+
hf_path="ajibawa-2023/creative-writing-40k",
|
| 226 |
+
hf_subset=None, hf_split="train", text_field="output",
|
| 227 |
+
weight=0.05, char_limit=5_000, min_chars=100,
|
| 228 |
+
),
|
| 229 |
+
DomainConfig(
|
| 230 |
+
name="stories",
|
| 231 |
+
hf_path="roneneldan/TinyStories",
|
| 232 |
+
hf_subset=None, hf_split="train", text_field="text",
|
| 233 |
+
weight=0.04, char_limit=3_000, min_chars=80,
|
| 234 |
+
),
|
| 235 |
+
DomainConfig(
|
| 236 |
+
name="reddit_posts",
|
| 237 |
+
hf_path="webis/tldr-17",
|
| 238 |
+
hf_subset=None, hf_split="train", text_field="content",
|
| 239 |
+
weight=0.03, char_limit=3_000, min_chars=80,
|
| 240 |
+
),
|
| 241 |
+
|
| 242 |
+
# ── Mathématiques ─────────────────────────────────────────────
|
| 243 |
+
DomainConfig(
|
| 244 |
+
name="math_stackexchange",
|
| 245 |
+
hf_path="math-ai/StackMathQA",
|
| 246 |
+
hf_subset=None, hf_split="train", text_field="A",
|
| 247 |
+
weight=0.04, char_limit=4_000, min_chars=80,
|
| 248 |
+
),
|
| 249 |
+
DomainConfig(
|
| 250 |
+
name="math_problems",
|
| 251 |
+
hf_path="lighteval/MATH",
|
| 252 |
+
hf_subset=None, hf_split="train", text_field="solution",
|
| 253 |
+
weight=0.03, char_limit=3_000, min_chars=60,
|
| 254 |
+
),
|
| 255 |
+
|
| 256 |
+
# ── Juridique ─────────────────────────────────────────────────
|
| 257 |
+
# joelniklaus/pile_of_law = version Parquet de pile-of-law
|
| 258 |
+
DomainConfig(
|
| 259 |
+
name="legal_en",
|
| 260 |
+
hf_path="joelniklaus/pile_of_law",
|
| 261 |
+
hf_subset="courtlistener_opinions",
|
| 262 |
+
hf_split="train", text_field="text",
|
| 263 |
+
weight=0.04, char_limit=5_000, min_chars=100,
|
| 264 |
+
),
|
| 265 |
+
DomainConfig(
|
| 266 |
+
name="legal_fr",
|
| 267 |
+
hf_path="antoinelouis/french-legal-corpus",
|
| 268 |
+
hf_subset=None, hf_split="train", text_field="text",
|
| 269 |
+
weight=0.02, char_limit=4_000, min_chars=80,
|
| 270 |
+
),
|
| 271 |
+
|
| 272 |
+
# ── Science ───────────────────────────────────────────────────
|
| 273 |
+
# RedPajama arxiv → allenai/peS2o (semantic scholar, Parquet)
|
| 274 |
+
DomainConfig(
|
| 275 |
+
name="science_arxiv",
|
| 276 |
+
hf_path="allenai/peS2o",
|
| 277 |
+
hf_subset=None,
|
| 278 |
+
hf_split="train", text_field="text",
|
| 279 |
+
weight=0.05, char_limit=6_000, min_chars=100,
|
| 280 |
+
),
|
| 281 |
+
|
| 282 |
+
# ── Multilingue général ───────────────────────────────────────
|
| 283 |
+
# RedPajama CC → allenai/c4 multilingual (Parquet)
|
| 284 |
+
DomainConfig(
|
| 285 |
+
name="multilingual_cc",
|
| 286 |
+
hf_path="allenai/c4",
|
| 287 |
+
hf_subset="multilingual",
|
| 288 |
+
hf_split="train", text_field="text",
|
| 289 |
+
weight=0.02, char_limit=3_000, min_chars=80,
|
| 290 |
+
),
|
| 291 |
+
]
|
| 292 |
+
|
| 293 |
+
# Validation : somme des poids ≈ 1.0
|
| 294 |
+
_wsum = sum(d.weight for d in DOMAINS)
|
| 295 |
+
assert abs(_wsum - 1.0) < 0.01, f"Somme des poids = {_wsum:.4f} ≠ 1.0"
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def select_domains(*names: str) -> List[DomainConfig]:
|
| 299 |
+
"""Sous-sélection + renormalisation automatique des poids."""
|
| 300 |
+
selected = [d for d in DOMAINS if d.name in names]
|
| 301 |
+
if not selected:
|
| 302 |
+
raise ValueError(f"Aucun domaine parmi : {names}")
|
| 303 |
+
total = sum(d.weight for d in selected)
|
| 304 |
+
out = []
|
| 305 |
+
for d in selected:
|
| 306 |
+
dc = copy.copy(d)
|
| 307 |
+
dc.weight = round(d.weight / total, 6)
|
| 308 |
+
out.append(dc)
|
| 309 |
+
return out
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def print_domain_summary(domains: Optional[List[DomainConfig]] = None) -> None:
|
| 313 |
+
if domains is None:
|
| 314 |
+
domains = DOMAINS
|
| 315 |
+
print(f"\n{'Domaine':<25} {'Dataset HF':<45} {'Poids':>6}")
|
| 316 |
+
print("─" * 80)
|
| 317 |
+
for d in sorted(domains, key=lambda x: -x.weight):
|
| 318 |
+
sub = f"/{d.hf_subset}" if d.hf_subset else ""
|
| 319 |
+
print(f"{d.name:<25} {d.hf_path + sub:<45} {d.weight:>6.1%}")
|
| 320 |
+
print(f"{'TOTAL':<25} {'':<45} {sum(d.weight for d in domains):>6.1%}\n")
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
# ══════════════════════════════════════════════════════════════════
|
| 324 |
+
# §3 STREAMING DATASET MULTI-DOMAINES
|
| 325 |
+
# ══════════════════════════════════════════════════════════════════
|
| 326 |
+
|
| 327 |
+
def domain_text_stream(
|
| 328 |
+
domain: DomainConfig,
|
| 329 |
+
max_docs: Optional[int] = None,
|
| 330 |
+
) -> Iterator[str]:
|
| 331 |
+
"""
|
| 332 |
+
Charge, filtre et nettoie le texte brut d'un domaine HF en streaming.
|
| 333 |
+
|
| 334 |
+
CORRECTIF : trust_remote_code supprimé, gestion d'erreur par domaine
|
| 335 |
+
pour éviter de planter tout l'entraînement si un dataset échoue.
|
| 336 |
+
"""
|
| 337 |
+
try:
|
| 338 |
+
ds = load_dataset(
|
| 339 |
+
domain.hf_path,
|
| 340 |
+
domain.hf_subset,
|
| 341 |
+
split=domain.hf_split,
|
| 342 |
+
streaming=True,
|
| 343 |
+
# trust_remote_code=True ← SUPPRIMÉ (déprécié datasets>=3.x)
|
| 344 |
+
)
|
| 345 |
+
except Exception as e:
|
| 346 |
+
print(f"[WARN] Domaine '{domain.name}' impossible à charger : {e}")
|
| 347 |
+
return # domaine ignoré proprement
|
| 348 |
+
|
| 349 |
+
n = 0
|
| 350 |
+
for row in ds:
|
| 351 |
+
text = row.get(domain.text_field, "")
|
| 352 |
+
if not text or not isinstance(text, str):
|
| 353 |
+
continue
|
| 354 |
+
if domain.lang_filter:
|
| 355 |
+
lang = row.get("lang", row.get("language", ""))
|
| 356 |
+
if lang and lang != domain.lang_filter:
|
| 357 |
+
continue
|
| 358 |
+
text = " ".join(text.strip().split())
|
| 359 |
+
if len(text) < domain.min_chars:
|
| 360 |
+
continue
|
| 361 |
+
yield text[: domain.char_limit]
|
| 362 |
+
n += 1
|
| 363 |
+
if max_docs and n >= max_docs:
|
| 364 |
+
break
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def interleaved_text_stream(
|
| 368 |
+
domains: List[DomainConfig],
|
| 369 |
+
max_docs_per_domain: Optional[int] = None,
|
| 370 |
+
seed: int = 42,
|
| 371 |
+
) -> Iterator[Tuple[str, str]]:
|
| 372 |
+
"""
|
| 373 |
+
Mélange stochastique pondéré des domaines.
|
| 374 |
+
Yield : (domain_name, text)
|
| 375 |
+
|
| 376 |
+
CORRECTIF : quand tous les domaines sont épuisés → StopIteration propre.
|
| 377 |
+
"""
|
| 378 |
+
rng = random.Random(seed)
|
| 379 |
+
iters = {d.name: domain_text_stream(d, max_docs=max_docs_per_domain) for d in domains}
|
| 380 |
+
exhausted: set = set()
|
| 381 |
+
|
| 382 |
+
while len(exhausted) < len(domains):
|
| 383 |
+
active = [d for d in domains if d.name not in exhausted]
|
| 384 |
+
if not active:
|
| 385 |
+
break
|
| 386 |
+
chosen = rng.choices(active, weights=[d.weight for d in active], k=1)[0]
|
| 387 |
+
try:
|
| 388 |
+
yield chosen.name, next(iters[chosen.name])
|
| 389 |
+
except StopIteration:
|
| 390 |
+
exhausted.add(chosen.name)
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def packed_block_stream(
|
| 394 |
+
tokenizer: PreTrainedTokenizerFast,
|
| 395 |
+
domains: List[DomainConfig],
|
| 396 |
+
block_size: int,
|
| 397 |
+
max_docs_per_domain: Optional[int] = None,
|
| 398 |
+
seed: int = 42,
|
| 399 |
+
) -> Iterator[Dict]:
|
| 400 |
+
"""
|
| 401 |
+
Tokenise et pack les textes en blocs denses de block_size tokens.
|
| 402 |
+
Yield : {"input_ids": list[int], "labels": list[int], "domain": str}
|
| 403 |
+
"""
|
| 404 |
+
bos, eos = tokenizer.bos_token_id, tokenizer.eos_token_id
|
| 405 |
+
buffer: List[int] = []
|
| 406 |
+
buffer_domain: List[str] = []
|
| 407 |
+
|
| 408 |
+
for domain_name, text in interleaved_text_stream(domains, max_docs_per_domain, seed):
|
| 409 |
+
ids = tokenizer.encode(text, add_special_tokens=False)
|
| 410 |
+
if not ids:
|
| 411 |
+
continue
|
| 412 |
+
seq = [bos] + ids + [eos]
|
| 413 |
+
buffer.extend(seq)
|
| 414 |
+
buffer_domain.extend([domain_name] * len(seq))
|
| 415 |
+
|
| 416 |
+
while len(buffer) >= block_size + 1:
|
| 417 |
+
chunk = buffer[:block_size + 1]
|
| 418 |
+
chunk_domain = buffer_domain[:block_size + 1]
|
| 419 |
+
buffer = buffer[block_size + 1:]
|
| 420 |
+
buffer_domain = buffer_domain[block_size + 1:]
|
| 421 |
+
majority = max(set(chunk_domain), key=chunk_domain.count)
|
| 422 |
+
yield {"input_ids": chunk[:-1], "labels": chunk[1:], "domain": majority}
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
class MultiDomainPackedDataset(IterableDataset):
|
| 426 |
+
"""
|
| 427 |
+
IterableDataset multi-domaines avec sharding inter-workers.
|
| 428 |
+
Compatible DataLoader(num_workers=N).
|
| 429 |
+
"""
|
| 430 |
+
def __init__(
|
| 431 |
+
self,
|
| 432 |
+
tokenizer: PreTrainedTokenizerFast,
|
| 433 |
+
domains: List[DomainConfig],
|
| 434 |
+
block_size: int,
|
| 435 |
+
max_docs_per_domain: Optional[int] = None,
|
| 436 |
+
seed: int = 42,
|
| 437 |
+
):
|
| 438 |
+
super().__init__()
|
| 439 |
+
self.tokenizer = tokenizer
|
| 440 |
+
self.domains = domains
|
| 441 |
+
self.block_size = block_size
|
| 442 |
+
self.max_docs_per_domain = max_docs_per_domain
|
| 443 |
+
self.seed = seed
|
| 444 |
+
|
| 445 |
+
def __iter__(self):
|
| 446 |
+
worker = torch.utils.data.get_worker_info()
|
| 447 |
+
wid = worker.id if worker else 0
|
| 448 |
+
n_workers = worker.num_workers if worker else 1
|
| 449 |
+
|
| 450 |
+
for idx, block in enumerate(packed_block_stream(
|
| 451 |
+
self.tokenizer, self.domains, self.block_size,
|
| 452 |
+
self.max_docs_per_domain, seed=self.seed + wid,
|
| 453 |
+
)):
|
| 454 |
+
if idx % n_workers != wid:
|
| 455 |
+
continue
|
| 456 |
+
yield {
|
| 457 |
+
"input_ids": torch.tensor(block["input_ids"], dtype=torch.long),
|
| 458 |
+
"labels": torch.tensor(block["labels"], dtype=torch.long),
|
| 459 |
+
"domain": block["domain"],
|
| 460 |
+
}
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
def build_dataloaders(
|
| 464 |
+
tokenizer: PreTrainedTokenizerFast,
|
| 465 |
+
domains: List[DomainConfig],
|
| 466 |
+
block_size: int,
|
| 467 |
+
train_batch_size: int,
|
| 468 |
+
eval_batch_size: int = 16,
|
| 469 |
+
max_eval_docs_per_dom: int = 300,
|
| 470 |
+
num_workers: int = 4,
|
| 471 |
+
seed: int = 42,
|
| 472 |
+
) -> Tuple[DataLoader, DataLoader]:
|
| 473 |
+
train_ds = MultiDomainPackedDataset(
|
| 474 |
+
tokenizer, domains, block_size,
|
| 475 |
+
max_docs_per_domain=None,
|
| 476 |
+
seed=seed,
|
| 477 |
+
)
|
| 478 |
+
eval_ds = MultiDomainPackedDataset(
|
| 479 |
+
tokenizer, domains, block_size,
|
| 480 |
+
max_docs_per_domain=max_eval_docs_per_dom,
|
| 481 |
+
seed=seed + 9999,
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
def collate_fn(batch):
|
| 485 |
+
return {
|
| 486 |
+
"input_ids": torch.stack([b["input_ids"] for b in batch]),
|
| 487 |
+
"labels": torch.stack([b["labels"] for b in batch]),
|
| 488 |
+
"domain": [b["domain"] for b in batch],
|
| 489 |
+
}
|
| 490 |
+
|
| 491 |
+
train_loader = DataLoader(
|
| 492 |
+
train_ds, batch_size=train_batch_size,
|
| 493 |
+
num_workers=num_workers, pin_memory=True,
|
| 494 |
+
prefetch_factor=2, collate_fn=collate_fn,
|
| 495 |
+
)
|
| 496 |
+
eval_loader = DataLoader(
|
| 497 |
+
eval_ds, batch_size=eval_batch_size,
|
| 498 |
+
num_workers=max(1, num_workers // 2), pin_memory=True,
|
| 499 |
+
prefetch_factor=2, collate_fn=collate_fn,
|
| 500 |
+
)
|
| 501 |
+
return train_loader, eval_loader
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
# ══════════════════════════════════════════════════════════════════
|
| 505 |
+
# §4 TOKENIZER BPE 32k
|
| 506 |
+
# ══════════════════════════════════════════════════════════════════
|
| 507 |
+
|
| 508 |
+
def train_or_load_tokenizer(
|
| 509 |
+
domains: List[DomainConfig],
|
| 510 |
+
) -> PreTrainedTokenizerFast:
|
| 511 |
+
TOKENIZER_DIR.mkdir(parents=True, exist_ok=True)
|
| 512 |
+
tok_json = TOKENIZER_DIR / "tokenizer.json"
|
| 513 |
+
tok_cfg = TOKENIZER_DIR / "tokenizer_config.json"
|
| 514 |
+
|
| 515 |
+
if tok_json.exists() and tok_cfg.exists():
|
| 516 |
+
print("Tokenizer existant chargé depuis le cache.")
|
| 517 |
+
return PreTrainedTokenizerFast.from_pretrained(str(TOKENIZER_DIR))
|
| 518 |
+
|
| 519 |
+
if is_main():
|
| 520 |
+
print("Entraînement tokenizer BPE 32k…")
|
| 521 |
+
|
| 522 |
+
def _iter_sample() -> Iterator[str]:
|
| 523 |
+
n_per_domain = max(1, TOKENIZER_SAMPLE_DOCS // len(domains))
|
| 524 |
+
for domain in domains:
|
| 525 |
+
for text in domain_text_stream(domain, max_docs=n_per_domain):
|
| 526 |
+
yield text
|
| 527 |
+
|
| 528 |
+
tok = Tokenizer(models.BPE(unk_token=UNK_TOKEN))
|
| 529 |
+
tok.normalizer = normalizers.Sequence([normalizers.NFKC()])
|
| 530 |
+
tok.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
|
| 531 |
+
tok.decoder = decoders.ByteLevel()
|
| 532 |
+
|
| 533 |
+
trainer = trainers.BpeTrainer(
|
| 534 |
+
vocab_size=TOKENIZER_VOCAB,
|
| 535 |
+
min_frequency=2,
|
| 536 |
+
show_progress=is_main(),
|
| 537 |
+
special_tokens=SPECIAL_TOKENS,
|
| 538 |
+
initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
|
| 539 |
+
)
|
| 540 |
+
tok.train_from_iterator(_iter_sample(), trainer=trainer)
|
| 541 |
+
|
| 542 |
+
bos_id = tok.token_to_id(BOS_TOKEN)
|
| 543 |
+
eos_id = tok.token_to_id(EOS_TOKEN)
|
| 544 |
+
tok.post_processor = processors.TemplateProcessing(
|
| 545 |
+
single=f"{BOS_TOKEN} $A {EOS_TOKEN}",
|
| 546 |
+
pair=f"{BOS_TOKEN} $A {EOS_TOKEN} $B:1 {EOS_TOKEN}:1",
|
| 547 |
+
special_tokens=[(BOS_TOKEN, bos_id), (EOS_TOKEN, eos_id)],
|
| 548 |
+
)
|
| 549 |
+
tok.save(str(tok_json))
|
| 550 |
+
|
| 551 |
+
fast = PreTrainedTokenizerFast(
|
| 552 |
+
tokenizer_file=str(tok_json),
|
| 553 |
+
bos_token=BOS_TOKEN, eos_token=EOS_TOKEN,
|
| 554 |
+
unk_token=UNK_TOKEN, pad_token=PAD_TOKEN,
|
| 555 |
+
)
|
| 556 |
+
fast.save_pretrained(str(TOKENIZER_DIR))
|
| 557 |
+
|
| 558 |
+
smap = TOKENIZER_DIR / "special_tokens_map.json"
|
| 559 |
+
if not smap.exists():
|
| 560 |
+
smap.write_text(json.dumps({
|
| 561 |
+
"bos_token": BOS_TOKEN, "eos_token": EOS_TOKEN,
|
| 562 |
+
"unk_token": UNK_TOKEN, "pad_token": PAD_TOKEN,
|
| 563 |
+
}, indent=2), encoding="utf-8")
|
| 564 |
+
|
| 565 |
+
return PreTrainedTokenizerFast.from_pretrained(str(TOKENIZER_DIR))
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
# ══════════════════════════════════════════════════════════════════
|
| 569 |
+
# §5 ARCHITECTURE GPT
|
| 570 |
+
# ══════════════════════════════════════════════════════════════════
|
| 571 |
+
|
| 572 |
+
@dataclass
|
| 573 |
+
class GPTConfig:
|
| 574 |
+
vocab_size: int = VOCAB_SIZE
|
| 575 |
+
block_size: int = BLOCK_SIZE
|
| 576 |
+
d_model: int = D_MODEL
|
| 577 |
+
n_heads: int = N_HEADS
|
| 578 |
+
n_layers: int = N_LAYERS
|
| 579 |
+
d_ff: int = D_FF
|
| 580 |
+
dropout: float = DROPOUT
|
| 581 |
+
use_checkpointing: bool = USE_CHECKPOINTING
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
class RMSNorm(nn.Module):
|
| 585 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 586 |
+
super().__init__()
|
| 587 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 588 |
+
self.eps = eps
|
| 589 |
+
|
| 590 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 591 |
+
return self.weight * x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
class RotaryEmbedding(nn.Module):
|
| 595 |
+
def __init__(self, dim: int, base: int = 10_000, max_seq: int = 4096):
|
| 596 |
+
super().__init__()
|
| 597 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
| 598 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 599 |
+
t = torch.arange(max_seq).float()
|
| 600 |
+
freqs = torch.outer(t, inv_freq)
|
| 601 |
+
self.register_buffer("cos_cache", torch.repeat_interleave(freqs.cos(), 2, dim=-1), persistent=False)
|
| 602 |
+
self.register_buffer("sin_cache", torch.repeat_interleave(freqs.sin(), 2, dim=-1), persistent=False)
|
| 603 |
+
|
| 604 |
+
def forward(self, seq_len: int, dtype: torch.dtype) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 605 |
+
return self.cos_cache[:seq_len].to(dtype), self.sin_cache[:seq_len].to(dtype)
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| 609 |
+
return torch.stack((-x[..., 1::2], x[..., ::2]), dim=-1).flatten(-2)
|
| 610 |
+
|
| 611 |
+
def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
|
| 612 |
+
return x * cos.unsqueeze(0).unsqueeze(0) + rotate_half(x) * sin.unsqueeze(0).unsqueeze(0)
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
class CausalSelfAttention(nn.Module):
|
| 616 |
+
def __init__(self, cfg: GPTConfig):
|
| 617 |
+
super().__init__()
|
| 618 |
+
assert cfg.d_model % cfg.n_heads == 0
|
| 619 |
+
self.n_heads = cfg.n_heads
|
| 620 |
+
self.head_dim = cfg.d_model // cfg.n_heads
|
| 621 |
+
self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
|
| 622 |
+
self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
|
| 623 |
+
self.dropout_p = cfg.dropout
|
| 624 |
+
self.rope = RotaryEmbedding(self.head_dim)
|
| 625 |
+
|
| 626 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 627 |
+
b, t, c = x.shape
|
| 628 |
+
q, k, v = self.qkv(x).split(c, dim=-1)
|
| 629 |
+
q = q.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
|
| 630 |
+
k = k.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
|
| 631 |
+
v = v.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
|
| 632 |
+
|
| 633 |
+
cos, sin = self.rope(t, x.dtype)
|
| 634 |
+
q, k = apply_rope(q, cos, sin), apply_rope(k, cos, sin)
|
| 635 |
+
|
| 636 |
+
# Flash Attention via SDPA (PyTorch ≥2.0, natif H100)
|
| 637 |
+
y = F.scaled_dot_product_attention(
|
| 638 |
+
q, k, v,
|
| 639 |
+
dropout_p=self.dropout_p if self.training else 0.0,
|
| 640 |
+
is_causal=True,
|
| 641 |
+
)
|
| 642 |
+
return self.proj(y.transpose(1, 2).contiguous().view(b, t, c))
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
class SwiGLU(nn.Module):
|
| 646 |
+
def __init__(self, cfg: GPTConfig):
|
| 647 |
+
super().__init__()
|
| 648 |
+
self.w1 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
|
| 649 |
+
self.w2 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
|
| 650 |
+
self.w3 = nn.Linear(cfg.d_ff, cfg.d_model, bias=False)
|
| 651 |
+
|
| 652 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 653 |
+
return self.w3(F.silu(self.w1(x)) * self.w2(x))
|
| 654 |
+
|
| 655 |
+
|
| 656 |
+
class Block(nn.Module):
|
| 657 |
+
def __init__(self, cfg: GPTConfig):
|
| 658 |
+
super().__init__()
|
| 659 |
+
self.ln1 = RMSNorm(cfg.d_model)
|
| 660 |
+
self.attn = CausalSelfAttention(cfg)
|
| 661 |
+
self.ln2 = RMSNorm(cfg.d_model)
|
| 662 |
+
self.ff = SwiGLU(cfg)
|
| 663 |
+
|
| 664 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 665 |
+
x = x + self.attn(self.ln1(x))
|
| 666 |
+
x = x + self.ff(self.ln2(x))
|
| 667 |
+
return x
|
| 668 |
+
|
| 669 |
+
|
| 670 |
+
class GPT(nn.Module):
|
| 671 |
+
def __init__(self, cfg: GPTConfig):
|
| 672 |
+
super().__init__()
|
| 673 |
+
self.cfg = cfg
|
| 674 |
+
self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
|
| 675 |
+
self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)])
|
| 676 |
+
self.ln_f = RMSNorm(cfg.d_model)
|
| 677 |
+
self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
|
| 678 |
+
self.lm_head.weight = self.tok_emb.weight # weight tying
|
| 679 |
+
|
| 680 |
+
self.apply(self._init_weights)
|
| 681 |
+
|
| 682 |
+
@staticmethod
|
| 683 |
+
def _init_weights(m: nn.Module) -> None:
|
| 684 |
+
if isinstance(m, (nn.Linear, nn.Embedding)):
|
| 685 |
+
nn.init.normal_(m.weight, mean=0.0, std=0.02)
|
| 686 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 687 |
+
nn.init.zeros_(m.bias)
|
| 688 |
+
|
| 689 |
+
def forward(
|
| 690 |
+
self,
|
| 691 |
+
input_ids: torch.Tensor,
|
| 692 |
+
labels: Optional[torch.Tensor] = None,
|
| 693 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 694 |
+
x = self.tok_emb(input_ids)
|
| 695 |
+
for block in self.blocks:
|
| 696 |
+
if self.cfg.use_checkpointing and self.training:
|
| 697 |
+
x = torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False)
|
| 698 |
+
else:
|
| 699 |
+
x = block(x)
|
| 700 |
+
logits = self.lm_head(self.ln_f(x))
|
| 701 |
+
loss = None
|
| 702 |
+
if labels is not None:
|
| 703 |
+
loss = F.cross_entropy(
|
| 704 |
+
logits.reshape(-1, logits.size(-1)),
|
| 705 |
+
labels.reshape(-1),
|
| 706 |
+
ignore_index=-100,
|
| 707 |
+
)
|
| 708 |
+
return logits, loss
|
| 709 |
+
|
| 710 |
+
|
| 711 |
+
# ══════════════════════════════════════════════════════════════════
|
| 712 |
+
# §6 OPTIMIZER & LR SCHEDULE
|
| 713 |
+
# ══════════════════════════════════════════════════════════════════
|
| 714 |
+
|
| 715 |
+
def build_optimizer(model: nn.Module) -> torch.optim.Optimizer:
|
| 716 |
+
decay, no_decay = [], []
|
| 717 |
+
for name, p in model.named_parameters():
|
| 718 |
+
if not p.requires_grad:
|
| 719 |
+
continue
|
| 720 |
+
(decay if p.ndim >= 2 and "weight" in name else no_decay).append(p)
|
| 721 |
+
return torch.optim.AdamW(
|
| 722 |
+
[
|
| 723 |
+
{"params": decay, "weight_decay": WEIGHT_DECAY},
|
| 724 |
+
{"params": no_decay, "weight_decay": 0.0},
|
| 725 |
+
],
|
| 726 |
+
lr=LEARNING_RATE, betas=(0.9, 0.95), eps=1e-8,
|
| 727 |
+
fused=True, # kernel unique GPU → +10-15% sur H100
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
|
| 731 |
+
def cosine_lr(step: int) -> float:
|
| 732 |
+
if step < WARMUP_STEPS:
|
| 733 |
+
return LEARNING_RATE * step / max(1, WARMUP_STEPS)
|
| 734 |
+
p = min(1.0, (step - WARMUP_STEPS) / max(1, MAX_STEPS - WARMUP_STEPS))
|
| 735 |
+
return MIN_LR + 0.5 * (LEARNING_RATE - MIN_LR) * (1.0 + math.cos(math.pi * p))
|
| 736 |
+
|
| 737 |
+
|
| 738 |
+
# ══════════════════════════════════════════════════════════════════
|
| 739 |
+
# §7 CHECKPOINT
|
| 740 |
+
# ══════════════════════════════════════════════════════════════════
|
| 741 |
+
|
| 742 |
+
def save_checkpoint(
|
| 743 |
+
model: nn.Module,
|
| 744 |
+
optimizer: torch.optim.Optimizer,
|
| 745 |
+
step: int,
|
| 746 |
+
best_loss: float,
|
| 747 |
+
path: Path,
|
| 748 |
+
) -> None:
|
| 749 |
+
raw = model.module if isinstance(model, DDP) else model
|
| 750 |
+
torch.save({
|
| 751 |
+
"model": raw.state_dict(),
|
| 752 |
+
"optimizer": optimizer.state_dict(),
|
| 753 |
+
"step": step,
|
| 754 |
+
"best_loss": best_loss,
|
| 755 |
+
"config": asdict(raw.cfg),
|
| 756 |
+
}, path)
|
| 757 |
+
|
| 758 |
+
|
| 759 |
+
def load_checkpoint(
|
| 760 |
+
model: nn.Module,
|
| 761 |
+
optimizer: torch.optim.Optimizer,
|
| 762 |
+
path: Path,
|
| 763 |
+
device: torch.device,
|
| 764 |
+
) -> Tuple[int, float]:
|
| 765 |
+
ckpt = torch.load(path, map_location=device)
|
| 766 |
+
raw = model.module if isinstance(model, DDP) else model
|
| 767 |
+
raw.load_state_dict(ckpt["model"])
|
| 768 |
+
optimizer.load_state_dict(ckpt["optimizer"])
|
| 769 |
+
return int(ckpt.get("step", 0)), float(ckpt.get("best_loss", 1e9))
|
| 770 |
+
|
| 771 |
+
|
| 772 |
+
# ══════════════════════════════════════════════════════════════════
|
| 773 |
+
# §8 DDP HELPERS
|
| 774 |
+
# ══════════════════════════════════════════════════════════════════
|
| 775 |
+
|
| 776 |
+
def init_distributed() -> Optional[torch.device]:
|
| 777 |
+
local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
| 778 |
+
if local_rank == -1:
|
| 779 |
+
return None
|
| 780 |
+
dist.init_process_group("nccl")
|
| 781 |
+
torch.cuda.set_device(local_rank)
|
| 782 |
+
return torch.device(f"cuda:{local_rank}")
|
| 783 |
+
|
| 784 |
+
def is_distributed() -> bool:
|
| 785 |
+
return dist.is_available() and dist.is_initialized()
|
| 786 |
+
|
| 787 |
+
def get_rank() -> int:
|
| 788 |
+
return dist.get_rank() if is_distributed() else 0
|
| 789 |
+
|
| 790 |
+
def get_world_size() -> int:
|
| 791 |
+
return dist.get_world_size() if is_distributed() else 1
|
| 792 |
+
|
| 793 |
+
def is_main() -> bool:
|
| 794 |
+
return get_rank() == 0
|
| 795 |
+
|
| 796 |
+
|
| 797 |
+
# ══════════════════════════════════════════════════════════════════
|
| 798 |
+
# §9 ÉVALUATION
|
| 799 |
+
# ══════════════════════════════════════════════════════════════════
|
| 800 |
+
|
| 801 |
+
@torch.no_grad()
|
| 802 |
+
def evaluate(
|
| 803 |
+
model: nn.Module,
|
| 804 |
+
loader: DataLoader,
|
| 805 |
+
device: torch.device,
|
| 806 |
+
max_batches: int = 80,
|
| 807 |
+
) -> Tuple[float, Dict[str, float]]:
|
| 808 |
+
model.eval()
|
| 809 |
+
total_loss, total_n = 0.0, 0
|
| 810 |
+
domain_losses: Dict[str, list] = {}
|
| 811 |
+
|
| 812 |
+
for i, batch in enumerate(loader):
|
| 813 |
+
if i >= max_batches:
|
| 814 |
+
break
|
| 815 |
+
inp = batch["input_ids"].to(device)
|
| 816 |
+
lbl = batch["labels"].to(device)
|
| 817 |
+
domains_batch = batch["domain"]
|
| 818 |
+
|
| 819 |
+
with torch.autocast("cuda", dtype=DTYPE):
|
| 820 |
+
_, loss = model(inp, lbl)
|
| 821 |
+
|
| 822 |
+
lv = loss.item()
|
| 823 |
+
total_loss += lv
|
| 824 |
+
total_n += 1
|
| 825 |
+
|
| 826 |
+
for dom in domains_batch:
|
| 827 |
+
domain_losses.setdefault(dom, []).append(lv)
|
| 828 |
+
|
| 829 |
+
model.train()
|
| 830 |
+
global_loss = total_loss / max(1, total_n)
|
| 831 |
+
per_domain = {k: sum(v) / len(v) for k, v in domain_losses.items()}
|
| 832 |
+
return global_loss, per_domain
|
| 833 |
+
|
| 834 |
+
|
| 835 |
+
# ══════════════════════════════════════════════════════════════════
|
| 836 |
+
# §10 LOGGING
|
| 837 |
+
# ══════════════════════════════════════════════════════════════════
|
| 838 |
+
|
| 839 |
+
def log_jsonl(path: Path, record: dict) -> None:
|
| 840 |
+
with open(path, "a", encoding="utf-8") as f:
|
| 841 |
+
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
| 842 |
+
|
| 843 |
+
|
| 844 |
+
# ══════════════════════════════════════════════════════════════════
|
| 845 |
+
# §11 MAIN
|
| 846 |
+
# ══════════════════════════════════════════════════════════════════
|
| 847 |
+
|
| 848 |
+
def main() -> None:
|
| 849 |
+
|
| 850 |
+
# ── DDP init ──────────────────────────────────────────────────
|
| 851 |
+
ddp_device = init_distributed()
|
| 852 |
+
set_seed_fn(SEED + get_rank())
|
| 853 |
+
device = ddp_device if ddp_device else torch.device(
|
| 854 |
+
"cuda" if torch.cuda.is_available() else "cpu"
|
| 855 |
+
)
|
| 856 |
+
|
| 857 |
+
# ── Optimisations H100 ────────────────────────────────────────
|
| 858 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 859 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 860 |
+
|
| 861 |
+
if is_main():
|
| 862 |
+
print("=" * 60)
|
| 863 |
+
print(" LLM Multi-Domaine — H100 Training [v2 CORRIGÉ]")
|
| 864 |
+
print("=" * 60)
|
| 865 |
+
gpu = torch.cuda.get_device_name(device) if device.type == "cuda" else "CPU"
|
| 866 |
+
print(f"Device : {device} ({gpu})")
|
| 867 |
+
print(f"GPUs : {get_world_size()}")
|
| 868 |
+
print(f"Steps : {MAX_STEPS} (~2 epochs sur corpus réduit)")
|
| 869 |
+
print_domain_summary()
|
| 870 |
+
|
| 871 |
+
# ── Tokenizer ─────────────────────────────────────────────────
|
| 872 |
+
tokenizer = train_or_load_tokenizer(DOMAINS)
|
| 873 |
+
vocab_size = len(tokenizer)
|
| 874 |
+
if is_main():
|
| 875 |
+
print(f"Tokenizer : {TOKENIZER_DIR} | vocab={vocab_size}")
|
| 876 |
+
|
| 877 |
+
# ── Modèle ────────────────────────────────────────────────────
|
| 878 |
+
cfg = GPTConfig(vocab_size=vocab_size)
|
| 879 |
+
if is_main():
|
| 880 |
+
CONFIG_FILE.write_text(
|
| 881 |
+
json.dumps(asdict(cfg), indent=2, ensure_ascii=False),
|
| 882 |
+
encoding="utf-8",
|
| 883 |
+
)
|
| 884 |
+
|
| 885 |
+
model = GPT(cfg).to(device)
|
| 886 |
+
|
| 887 |
+
if USE_COMPILE and device.type == "cuda":
|
| 888 |
+
model = torch.compile(model, mode="reduce-overhead")
|
| 889 |
+
if is_main():
|
| 890 |
+
print("torch.compile : reduce-overhead ✓")
|
| 891 |
+
|
| 892 |
+
if is_distributed():
|
| 893 |
+
model = DDP(model, device_ids=[device.index])
|
| 894 |
+
|
| 895 |
+
optimizer = build_optimizer(model)
|
| 896 |
+
|
| 897 |
+
# ── Reprise depuis checkpoint ──────────────────────────────────
|
| 898 |
+
start_step, best_eval = 0, 1e9
|
| 899 |
+
if STATE_FILE.exists():
|
| 900 |
+
if is_main():
|
| 901 |
+
print(f"Reprise depuis {STATE_FILE}")
|
| 902 |
+
start_step, best_eval = load_checkpoint(model, optimizer, STATE_FILE, device)
|
| 903 |
+
if is_main():
|
| 904 |
+
print(f" → reprise à step {start_step}, best_loss={best_eval:.4f}")
|
| 905 |
+
|
| 906 |
+
# ── DataLoaders ───────────────────────────────────────────────
|
| 907 |
+
train_loader, eval_loader = build_dataloaders(
|
| 908 |
+
tokenizer = tokenizer,
|
| 909 |
+
domains = DOMAINS,
|
| 910 |
+
block_size = BLOCK_SIZE,
|
| 911 |
+
train_batch_size = BATCH_SIZE,
|
| 912 |
+
eval_batch_size = max(1, BATCH_SIZE // 2),
|
| 913 |
+
max_eval_docs_per_dom = MAX_EVAL_DOCS_DOM,
|
| 914 |
+
num_workers = 4,
|
| 915 |
+
seed = SEED,
|
| 916 |
+
)
|
| 917 |
+
|
| 918 |
+
# ── Résumé ────────────────────────────────────────────────────
|
| 919 |
+
if is_main():
|
| 920 |
+
raw = model.module if isinstance(model, DDP) else model
|
| 921 |
+
# comptage sans modules compilés
|
| 922 |
+
try:
|
| 923 |
+
n_params = sum(p.numel() for p in raw.parameters() if p.requires_grad)
|
| 924 |
+
except Exception:
|
| 925 |
+
n_params = -1
|
| 926 |
+
eff_batch = BATCH_SIZE * GRAD_ACCUM_STEPS * get_world_size()
|
| 927 |
+
print(f"Paramètres : {n_params/1e6:.1f}M" if n_params > 0 else "Paramètres : N/A (compilé)")
|
| 928 |
+
print(f"Architecture : d={D_MODEL} | heads={N_HEADS} | layers={N_LAYERS} | ctx={BLOCK_SIZE}")
|
| 929 |
+
print(f"Batch effectif: {eff_batch} séq × {BLOCK_SIZE} tok = {eff_batch*BLOCK_SIZE:,} tok/step")
|
| 930 |
+
print(f"Dtype : {DTYPE} | Steps : {MAX_STEPS} | Warmup : {WARMUP_STEPS}")
|
| 931 |
+
print("=" * 60)
|
| 932 |
+
|
| 933 |
+
# ── Boucle d'entraînement ─────────────────────────────────────
|
| 934 |
+
model.train()
|
| 935 |
+
optimizer.zero_grad(set_to_none=True)
|
| 936 |
+
|
| 937 |
+
train_iter = iter(train_loader)
|
| 938 |
+
step = start_step
|
| 939 |
+
t0 = time.time()
|
| 940 |
+
log_loss_sum = 0.0
|
| 941 |
+
log_loss_n = 0
|
| 942 |
+
tokens_log = 0
|
| 943 |
+
last_log = time.time()
|
| 944 |
+
|
| 945 |
+
while step < MAX_STEPS:
|
| 946 |
+
|
| 947 |
+
# ── gradient accumulation ──────────────────────────────────
|
| 948 |
+
for micro in range(GRAD_ACCUM_STEPS):
|
| 949 |
+
try:
|
| 950 |
+
batch = next(train_iter)
|
| 951 |
+
except StopIteration:
|
| 952 |
+
# ← CORRECTIF : relance l'itérateur proprement
|
| 953 |
+
train_iter = iter(train_loader)
|
| 954 |
+
try:
|
| 955 |
+
batch = next(train_iter)
|
| 956 |
+
except StopIteration:
|
| 957 |
+
print("[WARN] Dataset entièrement épuisé avant MAX_STEPS.")
|
| 958 |
+
break
|
| 959 |
+
|
| 960 |
+
inp = batch["input_ids"].to(device, non_blocking=True)
|
| 961 |
+
lbl = batch["labels"].to(device, non_blocking=True)
|
| 962 |
+
|
| 963 |
+
with torch.autocast("cuda", dtype=DTYPE):
|
| 964 |
+
_, loss = model(inp, lbl)
|
| 965 |
+
|
| 966 |
+
(loss / GRAD_ACCUM_STEPS).backward()
|
| 967 |
+
log_loss_sum += loss.item()
|
| 968 |
+
log_loss_n += 1
|
| 969 |
+
tokens_log += inp.numel()
|
| 970 |
+
|
| 971 |
+
# ── optimizer step ────────────────────────────────────────
|
| 972 |
+
lr = cosine_lr(step)
|
| 973 |
+
for g in optimizer.param_groups:
|
| 974 |
+
g["lr"] = lr
|
| 975 |
+
|
| 976 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
|
| 977 |
+
optimizer.step()
|
| 978 |
+
optimizer.zero_grad(set_to_none=True)
|
| 979 |
+
step += 1
|
| 980 |
+
|
| 981 |
+
# ── logging toutes les 50 steps ───────────────────────────
|
| 982 |
+
if step % 50 == 0 and is_main():
|
| 983 |
+
now = time.time()
|
| 984 |
+
elapsed = max(1e-6, now - last_log)
|
| 985 |
+
tok_s = tokens_log / elapsed
|
| 986 |
+
avg_l = log_loss_sum / max(1, log_loss_n)
|
| 987 |
+
ppl = math.exp(min(avg_l, 20))
|
| 988 |
+
print(
|
| 989 |
+
f"step {step:5d}/{MAX_STEPS} | "
|
| 990 |
+
f"loss={avg_l:.4f} | ppl={ppl:.1f} | "
|
| 991 |
+
f"lr={lr:.2e} | {tok_s:,.0f} tok/s"
|
| 992 |
+
)
|
| 993 |
+
log_jsonl(LOG_FILE, {
|
| 994 |
+
"step": step, "loss": avg_l, "ppl": ppl,
|
| 995 |
+
"lr": lr, "tok_s": tok_s, "time": now - t0,
|
| 996 |
+
})
|
| 997 |
+
last_log = now
|
| 998 |
+
tokens_log = 0
|
| 999 |
+
log_loss_sum = 0.0
|
| 1000 |
+
log_loss_n = 0
|
| 1001 |
+
|
| 1002 |
+
# ── évaluation ────────────────────────────────────────────
|
| 1003 |
+
if step % EVAL_EVERY == 0 and is_main():
|
| 1004 |
+
val_loss, per_dom = evaluate(model, eval_loader, device)
|
| 1005 |
+
ppl_val = math.exp(min(val_loss, 20))
|
| 1006 |
+
print(f"\n[eval] step {step} | val_loss={val_loss:.4f} | ppl={ppl_val:.1f}")
|
| 1007 |
+
print(" Perplexité par domaine :")
|
| 1008 |
+
for dom, dl in sorted(per_dom.items(), key=lambda x: -x[1]):
|
| 1009 |
+
print(f" {dom:<25} loss={dl:.4f} ppl={math.exp(min(dl,20)):.1f}")
|
| 1010 |
+
print()
|
| 1011 |
+
|
| 1012 |
+
log_jsonl(LOG_FILE, {
|
| 1013 |
+
"step": step, "val_loss": val_loss, "val_ppl": ppl_val,
|
| 1014 |
+
"per_domain": per_dom,
|
| 1015 |
+
})
|
| 1016 |
+
|
| 1017 |
+
if val_loss < best_eval:
|
| 1018 |
+
best_eval = val_loss
|
| 1019 |
+
save_checkpoint(model, optimizer, step, best_eval, BEST_MODEL_FILE)
|
| 1020 |
+
print(f" ✓ Meilleur modèle → {BEST_MODEL_FILE}\n")
|
| 1021 |
+
|
| 1022 |
+
# ── checkpoint périodique ─────────────────────────────────
|
| 1023 |
+
if step % SAVE_EVERY == 0 and is_main():
|
| 1024 |
+
save_checkpoint(model, optimizer, step, best_eval, STATE_FILE)
|
| 1025 |
+
save_checkpoint(model, optimizer, step, best_eval, MODEL_FILE)
|
| 1026 |
+
print(f" ✓ Checkpoint step {step} → {MODEL_FILE}")
|
| 1027 |
+
|
| 1028 |
+
# ── Fin ───────────────────────────────────────────────────────
|
| 1029 |
+
if is_main():
|
| 1030 |
+
save_checkpoint(model, optimizer, step, best_eval, MODEL_FILE)
|
| 1031 |
+
save_checkpoint(model, optimizer, step, best_eval, STATE_FILE)
|
| 1032 |
+
total_min = (time.time() - t0) / 60
|
| 1033 |
+
print(f"\n{'='*60}")
|
| 1034 |
+
print(f"Modèle final → {MODEL_FILE}")
|
| 1035 |
+
print(f"Meilleur modèle→ {BEST_MODEL_FILE}")
|
| 1036 |
+
print(f"Steps réalisés : {step}")
|
| 1037 |
+
print(f"Temps total : {total_min:.1f} min")
|
| 1038 |
+
print(f"{'='*60}")
|
| 1039 |
+
|
| 1040 |
+
if is_distributed():
|
| 1041 |
+
dist.destroy_process_group()
|
| 1042 |
+
|
| 1043 |
+
|
| 1044 |
+
# ══════════════════════════════════════════════════════════════════
|
| 1045 |
+
# §12 UTILS
|
| 1046 |
+
# ══════════════════════════════════════════════════════════════════
|
| 1047 |
+
|
| 1048 |
+
def set_seed_fn(seed: int) -> None:
|
| 1049 |
+
random.seed(seed)
|
| 1050 |
+
torch.manual_seed(seed)
|
| 1051 |
+
torch.cuda.manual_seed_all(seed)
|
| 1052 |
+
|
| 1053 |
+
|
| 1054 |
+
if __name__ == "__main__":
|
| 1055 |
+
main()
|
train_nlp_h100_maxvram_v6.py
ADDED
|
@@ -0,0 +1,1046 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
train_nlp_h100_maxvram.py — v6 (démarrage sûr sur H100 80G)
|
| 5 |
+
============================================================
|
| 6 |
+
|
| 7 |
+
Corrections principales vs v5
|
| 8 |
+
-----------------------------
|
| 9 |
+
1. Le log v5 montre que, même sans cap VRAM logiciel,
|
| 10 |
+
BATCH_SIZE=32 provoque un vrai OOM matériel dès le premier
|
| 11 |
+
forward compilé sur H100 80G.
|
| 12 |
+
|
| 13 |
+
2. On abaisse donc le réglage de départ à :
|
| 14 |
+
BATCH_SIZE = 28
|
| 15 |
+
pour démarrer dans une zone sûre, puis remonter ensuite si le
|
| 16 |
+
log montre encore de la marge.
|
| 17 |
+
|
| 18 |
+
3. Le cap logiciel reste désactivé par défaut :
|
| 19 |
+
TARGET_VRAM_GIB = None
|
| 20 |
+
afin d'éviter tout faux OOM dû à PyTorch.
|
| 21 |
+
|
| 22 |
+
4. Le mode torch.compile reste sur :
|
| 23 |
+
COMPILE_MODE = "max-autotune-no-cudagraphs"
|
| 24 |
+
qui garde un bon compromis perf / mémoire sans surcoût CUDA graphs.
|
| 25 |
+
|
| 26 |
+
5. BASE_CHECKPOINT est chargé AVANT l'injection LoRA,
|
| 27 |
+
et l'estimation des steps reste corrigée pour le scheduler LR.
|
| 28 |
+
|
| 29 |
+
Conseils de réglage
|
| 30 |
+
-------------------
|
| 31 |
+
- Démarre avec :
|
| 32 |
+
BATCH_SIZE = 28
|
| 33 |
+
TARGET_VRAM_GIB = None
|
| 34 |
+
COMPILE_MODE = "max-autotune-no-cudagraphs"
|
| 35 |
+
- Si stable et max_reserved < 72 GiB après quelques logs :
|
| 36 |
+
BATCH_SIZE += 2
|
| 37 |
+
- Si vrai OOM matériel :
|
| 38 |
+
BATCH_SIZE -= 2 puis relance
|
| 39 |
+
- Si tu veux encore plus de marge au premier essai :
|
| 40 |
+
BATCH_SIZE = 24
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
from __future__ import annotations
|
| 44 |
+
|
| 45 |
+
import itertools
|
| 46 |
+
import json
|
| 47 |
+
import math
|
| 48 |
+
import os
|
| 49 |
+
import random
|
| 50 |
+
import time
|
| 51 |
+
from collections import OrderedDict
|
| 52 |
+
from contextlib import nullcontext
|
| 53 |
+
from dataclasses import asdict, dataclass
|
| 54 |
+
from pathlib import Path
|
| 55 |
+
from typing import Iterator, Optional
|
| 56 |
+
|
| 57 |
+
# A définir AVANT import torch
|
| 58 |
+
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
|
| 59 |
+
|
| 60 |
+
import torch
|
| 61 |
+
import torch.distributed as dist
|
| 62 |
+
import torch.nn as nn
|
| 63 |
+
import torch.nn.functional as F
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
import bitsandbytes as bnb
|
| 67 |
+
HAS_BNB = True
|
| 68 |
+
except ImportError:
|
| 69 |
+
HAS_BNB = False
|
| 70 |
+
print("[warn] bitsandbytes non disponible")
|
| 71 |
+
|
| 72 |
+
try:
|
| 73 |
+
from flash_attn import flash_attn_func
|
| 74 |
+
HAS_FLASH = True
|
| 75 |
+
except ImportError:
|
| 76 |
+
HAS_FLASH = False
|
| 77 |
+
print("[warn] flash-attn non disponible – fallback SDPA (toujours fusionné sur H100)")
|
| 78 |
+
|
| 79 |
+
from datasets import load_dataset
|
| 80 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 81 |
+
from tokenizers import (
|
| 82 |
+
Tokenizer, decoders, models, normalizers,
|
| 83 |
+
pre_tokenizers, processors, trainers,
|
| 84 |
+
)
|
| 85 |
+
from transformers import PreTrainedTokenizerFast
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# ╔══════════════════════════════════════════════════════════════════════════════╗
|
| 89 |
+
# ║ CHEMINS ║
|
| 90 |
+
# ╚══════════════════════════════════════════════════════════════════════════════╝
|
| 91 |
+
|
| 92 |
+
OUT_DIR = Path("./nlp_1b_h100_maxvram")
|
| 93 |
+
OUT_DIR.mkdir(parents=True, exist_ok=True)
|
| 94 |
+
TOKENIZER_DIR = OUT_DIR / "tokenizer_32k"
|
| 95 |
+
CONFIG_FILE = OUT_DIR / "config.json"
|
| 96 |
+
MODEL_FILE = OUT_DIR / "model.pt"
|
| 97 |
+
BEST_MODEL_FILE = OUT_DIR / "model_best.pt"
|
| 98 |
+
STATE_FILE = OUT_DIR / "train_state.pt"
|
| 99 |
+
BASE_CHECKPOINT: Optional[Path] = None
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# ╔══════════════════════════════════════════════════════════════════════════════╗
|
| 103 |
+
# ║ HYPERPARAMÈTRES — H100 ║
|
| 104 |
+
# ╚══════════════════════════════════════════════════════════════════════════════╝
|
| 105 |
+
|
| 106 |
+
SEED = 42
|
| 107 |
+
TARGET_VRAM_GIB = None # None = pas de cap logiciel ; évite les faux OOM dus au cap PyTorch
|
| 108 |
+
|
| 109 |
+
# ── Architecture ~1B ──────────────────────────────────────────────────────────
|
| 110 |
+
BLOCK_SIZE = 2048
|
| 111 |
+
VOCAB_SIZE = 32_000
|
| 112 |
+
D_MODEL = 1536
|
| 113 |
+
N_HEADS = 24
|
| 114 |
+
N_LAYERS = 24
|
| 115 |
+
D_FF = 6144
|
| 116 |
+
DROPOUT = 0.0
|
| 117 |
+
|
| 118 |
+
# ── LoRA / "QLoRA" si BNB dispo côté optimiseur ──────────────────────────────
|
| 119 |
+
USE_QLORA = True
|
| 120 |
+
LORA_R = 64
|
| 121 |
+
LORA_ALPHA = 128
|
| 122 |
+
LORA_DROPOUT = 0.05
|
| 123 |
+
LORA_TARGET_MODULES = ["qkv", "proj", "w1", "w2", "w3"]
|
| 124 |
+
|
| 125 |
+
# ── Entraînement ──────────────────────────────────────────────────────────────
|
| 126 |
+
NUM_EPOCHS = 10
|
| 127 |
+
LEARNING_RATE = 3e-4
|
| 128 |
+
MIN_LR = 3e-5
|
| 129 |
+
WEIGHT_DECAY = 0.1
|
| 130 |
+
WARMUP_STEPS = 500
|
| 131 |
+
|
| 132 |
+
# Réglage de départ conseillé
|
| 133 |
+
BATCH_SIZE = 28
|
| 134 |
+
GRAD_ACCUM_STEPS = 1
|
| 135 |
+
|
| 136 |
+
MAX_GRAD_NORM = 1.0
|
| 137 |
+
EVAL_EVERY = 500
|
| 138 |
+
SAVE_EVERY = 1_000
|
| 139 |
+
|
| 140 |
+
DTYPE = torch.bfloat16
|
| 141 |
+
|
| 142 |
+
# Compile : version plus robuste au démarrage
|
| 143 |
+
USE_CHECKPOINTING = False
|
| 144 |
+
USE_COMPILE = True
|
| 145 |
+
COMPILE_MODE = "max-autotune-no-cudagraphs"
|
| 146 |
+
|
| 147 |
+
# ── DataLoader ────────────────────────────────────────────────────────────────
|
| 148 |
+
TRAIN_NUM_WORKERS = 8
|
| 149 |
+
EVAL_NUM_WORKERS = 4
|
| 150 |
+
PREFETCH_FACTOR = 4
|
| 151 |
+
|
| 152 |
+
# ── Textes ────────────────────────────────────────────────────────────────────
|
| 153 |
+
TOKENIZER_SAMPLE_DOCS_PER_SOURCE = 15_000
|
| 154 |
+
TOKENIZER_CHAR_LIMIT = 2_000
|
| 155 |
+
TEXT_CHAR_LIMIT = 8_000
|
| 156 |
+
|
| 157 |
+
SPECIAL_TOKENS = ["<pad>", "<bos>", "<eos>", "<unk>"]
|
| 158 |
+
PAD_TOKEN, BOS_TOKEN, EOS_TOKEN, UNK_TOKEN = SPECIAL_TOKENS
|
| 159 |
+
|
| 160 |
+
WIKI_CONFIGS = ["20231101.en", "20231101.fr", "20231101.ar"]
|
| 161 |
+
FINEWEB_CONFIG = "sample-10BT"
|
| 162 |
+
DEV_DOCS_PER_WIKI_CONFIG = 1_500
|
| 163 |
+
DEV_DOCS_FINEWEB = 3_000
|
| 164 |
+
TRAIN_DOCS_PER_WIKI_CONFIG = 30_000
|
| 165 |
+
TRAIN_DOCS_FINEWEB = 60_000
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
# ╔══════════════════════════════════════════════════════════════════════════════╗
|
| 169 |
+
# ║ DISTRIBUTED ║
|
| 170 |
+
# ╚══════════════════════════════════════════════════════════════════════════════╝
|
| 171 |
+
|
| 172 |
+
def is_distributed() -> bool:
|
| 173 |
+
return dist.is_available() and dist.is_initialized()
|
| 174 |
+
|
| 175 |
+
def get_rank() -> int:
|
| 176 |
+
return dist.get_rank() if is_distributed() else 0
|
| 177 |
+
|
| 178 |
+
def get_world_size() -> int:
|
| 179 |
+
return dist.get_world_size() if is_distributed() else 1
|
| 180 |
+
|
| 181 |
+
def is_main() -> bool:
|
| 182 |
+
return get_rank() == 0
|
| 183 |
+
|
| 184 |
+
def init_distributed() -> Optional[torch.device]:
|
| 185 |
+
lr = int(os.environ.get("LOCAL_RANK", -1))
|
| 186 |
+
if lr == -1:
|
| 187 |
+
return None
|
| 188 |
+
dist.init_process_group("nccl")
|
| 189 |
+
torch.cuda.set_device(lr)
|
| 190 |
+
return torch.device(f"cuda:{lr}")
|
| 191 |
+
|
| 192 |
+
def set_seed(seed: int) -> None:
|
| 193 |
+
random.seed(seed)
|
| 194 |
+
torch.manual_seed(seed)
|
| 195 |
+
if torch.cuda.is_available():
|
| 196 |
+
torch.cuda.manual_seed_all(seed)
|
| 197 |
+
|
| 198 |
+
def get_device(ddp=None) -> torch.device:
|
| 199 |
+
if ddp is not None:
|
| 200 |
+
return ddp
|
| 201 |
+
if torch.cuda.is_available():
|
| 202 |
+
return torch.device(f"cuda:{torch.cuda.current_device()}")
|
| 203 |
+
return torch.device("cpu")
|
| 204 |
+
|
| 205 |
+
def current_cuda_index(device: torch.device) -> int:
|
| 206 |
+
return device.index if device.index is not None else torch.cuda.current_device()
|
| 207 |
+
|
| 208 |
+
def autocast_context(device: torch.device):
|
| 209 |
+
return torch.autocast("cuda", dtype=DTYPE) if device.type == "cuda" else nullcontext()
|
| 210 |
+
|
| 211 |
+
def unwrap_model(model: nn.Module) -> nn.Module:
|
| 212 |
+
m = model.module if isinstance(model, DDP) else model
|
| 213 |
+
return m._orig_mod if hasattr(m, "_orig_mod") else m
|
| 214 |
+
|
| 215 |
+
def count_parameters(model: nn.Module, trainable_only: bool = True) -> int:
|
| 216 |
+
return sum(p.numel() for p in model.parameters() if not trainable_only or p.requires_grad)
|
| 217 |
+
|
| 218 |
+
def normalize_state_dict_keys(sd: dict) -> OrderedDict:
|
| 219 |
+
out = OrderedDict()
|
| 220 |
+
for k, v in sd.items():
|
| 221 |
+
for prefix in ("module._orig_mod.", "_orig_mod.", "module."):
|
| 222 |
+
if k.startswith(prefix):
|
| 223 |
+
k = k[len(prefix):]
|
| 224 |
+
break
|
| 225 |
+
out[k] = v
|
| 226 |
+
return out
|
| 227 |
+
|
| 228 |
+
def normalize_text(t: str) -> str:
|
| 229 |
+
return " ".join(t.strip().split())
|
| 230 |
+
|
| 231 |
+
def safe_str(x) -> str:
|
| 232 |
+
return x if isinstance(x, str) else ("" if x is None else str(x))
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
# ╔══════════════════════════════════════════════════════════════════════════════╗
|
| 236 |
+
# ║ DATASETS ║
|
| 237 |
+
# ╚══════════════════════════════════════════════════════════════════════════════╝
|
| 238 |
+
|
| 239 |
+
def load_wiki_stream(cfg_name: str):
|
| 240 |
+
return load_dataset("wikimedia/wikipedia", cfg_name, split="train", streaming=True)
|
| 241 |
+
|
| 242 |
+
def load_fineweb_stream():
|
| 243 |
+
return load_dataset("HuggingFaceFW/fineweb-edu", FINEWEB_CONFIG, split="train", streaming=True)
|
| 244 |
+
|
| 245 |
+
def stream_texts(ds, start: int, count: int, char_limit: int) -> Iterator[str]:
|
| 246 |
+
for row in itertools.islice(ds, start, start + count):
|
| 247 |
+
text = normalize_text(safe_str(row.get("text", "")))
|
| 248 |
+
if len(text) >= 20:
|
| 249 |
+
yield text[:char_limit]
|
| 250 |
+
|
| 251 |
+
def tokenizer_training_iterator() -> Iterator[str]:
|
| 252 |
+
for c in WIKI_CONFIGS:
|
| 253 |
+
yield from stream_texts(
|
| 254 |
+
load_wiki_stream(c),
|
| 255 |
+
0,
|
| 256 |
+
TOKENIZER_SAMPLE_DOCS_PER_SOURCE,
|
| 257 |
+
TOKENIZER_CHAR_LIMIT,
|
| 258 |
+
)
|
| 259 |
+
yield from stream_texts(
|
| 260 |
+
load_fineweb_stream(),
|
| 261 |
+
0,
|
| 262 |
+
TOKENIZER_SAMPLE_DOCS_PER_SOURCE,
|
| 263 |
+
TOKENIZER_CHAR_LIMIT,
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
def build_epoch_train_texts(epoch: int) -> list[str]:
|
| 267 |
+
texts: list[str] = []
|
| 268 |
+
for c in WIKI_CONFIGS:
|
| 269 |
+
start = DEV_DOCS_PER_WIKI_CONFIG + epoch * TRAIN_DOCS_PER_WIKI_CONFIG
|
| 270 |
+
texts.extend(stream_texts(load_wiki_stream(c), start, TRAIN_DOCS_PER_WIKI_CONFIG, TEXT_CHAR_LIMIT))
|
| 271 |
+
start = DEV_DOCS_FINEWEB + epoch * TRAIN_DOCS_FINEWEB
|
| 272 |
+
texts.extend(stream_texts(load_fineweb_stream(), start, TRAIN_DOCS_FINEWEB, TEXT_CHAR_LIMIT))
|
| 273 |
+
random.Random(SEED + epoch).shuffle(texts)
|
| 274 |
+
return texts
|
| 275 |
+
|
| 276 |
+
def build_eval_texts() -> list[str]:
|
| 277 |
+
texts: list[str] = []
|
| 278 |
+
for c in WIKI_CONFIGS:
|
| 279 |
+
texts.extend(stream_texts(load_wiki_stream(c), 0, DEV_DOCS_PER_WIKI_CONFIG, TEXT_CHAR_LIMIT))
|
| 280 |
+
texts.extend(stream_texts(load_fineweb_stream(), 0, DEV_DOCS_FINEWEB, TEXT_CHAR_LIMIT))
|
| 281 |
+
return texts
|
| 282 |
+
|
| 283 |
+
def estimate_steps_per_epoch(
|
| 284 |
+
texts: list[str],
|
| 285 |
+
tokenizer: PreTrainedTokenizerFast,
|
| 286 |
+
block_size: int,
|
| 287 |
+
batch_size: int,
|
| 288 |
+
sample_size: int = 512,
|
| 289 |
+
) -> int:
|
| 290 |
+
"""
|
| 291 |
+
Estimation des steps réels pour le scheduler LR.
|
| 292 |
+
Chaque texte contribue environ len(ids)+2 tokens (BOS/EOS).
|
| 293 |
+
Un exemple packed consomme block_size+1 tokens.
|
| 294 |
+
"""
|
| 295 |
+
if not texts:
|
| 296 |
+
return 1
|
| 297 |
+
|
| 298 |
+
rng = random.Random(SEED)
|
| 299 |
+
if len(texts) > sample_size:
|
| 300 |
+
sample = rng.sample(texts, sample_size)
|
| 301 |
+
else:
|
| 302 |
+
sample = texts
|
| 303 |
+
|
| 304 |
+
total_tokens = 0
|
| 305 |
+
valid = 0
|
| 306 |
+
for txt in sample:
|
| 307 |
+
ids = tokenizer.encode(txt, add_special_tokens=False)
|
| 308 |
+
if ids:
|
| 309 |
+
total_tokens += len(ids) + 2
|
| 310 |
+
valid += 1
|
| 311 |
+
|
| 312 |
+
avg_tokens_per_text = total_tokens / max(1, valid)
|
| 313 |
+
est_epoch_tokens = avg_tokens_per_text * len(texts)
|
| 314 |
+
tokens_per_step = (block_size + 1) * batch_size * get_world_size()
|
| 315 |
+
return max(1, int(est_epoch_tokens // max(1, tokens_per_step)))
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
# ╔══════════════════════════════════════════════════════════════════════════════╗
|
| 319 |
+
# ║ PACKED DATASET ║
|
| 320 |
+
# ╚══════════════════════════════════════════════════════════════════════════════╝
|
| 321 |
+
|
| 322 |
+
class PackedTextList(torch.utils.data.IterableDataset):
|
| 323 |
+
"""
|
| 324 |
+
Packing dense sans padding.
|
| 325 |
+
drop_last=True dans le DataLoader → shapes constantes, utile pour compile.
|
| 326 |
+
"""
|
| 327 |
+
def __init__(self, texts, tokenizer, block_size, epoch_seed=0):
|
| 328 |
+
super().__init__()
|
| 329 |
+
self.texts = texts
|
| 330 |
+
self.tokenizer = tokenizer
|
| 331 |
+
self.block_size = block_size
|
| 332 |
+
self.epoch_seed = epoch_seed
|
| 333 |
+
|
| 334 |
+
def __iter__(self):
|
| 335 |
+
worker = torch.utils.data.get_worker_info()
|
| 336 |
+
rank, ws = get_rank(), get_world_size()
|
| 337 |
+
|
| 338 |
+
if worker is None:
|
| 339 |
+
shard_mod, shard_id = ws, rank
|
| 340 |
+
else:
|
| 341 |
+
shard_mod = worker.num_workers * ws
|
| 342 |
+
shard_id = rank * worker.num_workers + worker.id
|
| 343 |
+
|
| 344 |
+
rng = random.Random(self.epoch_seed)
|
| 345 |
+
indices = list(range(len(self.texts)))
|
| 346 |
+
rng.shuffle(indices)
|
| 347 |
+
|
| 348 |
+
bos, eos = self.tokenizer.bos_token_id, self.tokenizer.eos_token_id
|
| 349 |
+
buf: list[int] = []
|
| 350 |
+
|
| 351 |
+
for li, ti in enumerate(indices):
|
| 352 |
+
if li % shard_mod != shard_id:
|
| 353 |
+
continue
|
| 354 |
+
|
| 355 |
+
ids = self.tokenizer.encode(self.texts[ti], add_special_tokens=False)
|
| 356 |
+
if not ids:
|
| 357 |
+
continue
|
| 358 |
+
|
| 359 |
+
buf.extend([bos] + ids + [eos])
|
| 360 |
+
|
| 361 |
+
while len(buf) >= self.block_size + 1:
|
| 362 |
+
chunk = buf[: self.block_size + 1]
|
| 363 |
+
buf = buf[self.block_size + 1 :]
|
| 364 |
+
yield {
|
| 365 |
+
"input_ids": torch.tensor(chunk[:-1], dtype=torch.long),
|
| 366 |
+
"labels": torch.tensor(chunk[1:], dtype=torch.long),
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
# ╔══════════════════════════════════════════════════════════════════════════════╗
|
| 371 |
+
# ║ TOKENIZER ║
|
| 372 |
+
# ╚═══════════════════════════════════════════════════════════════════════════��══╝
|
| 373 |
+
|
| 374 |
+
def tokenizer_ready() -> bool:
|
| 375 |
+
return (
|
| 376 |
+
(TOKENIZER_DIR / "tokenizer.json").exists()
|
| 377 |
+
and (TOKENIZER_DIR / "tokenizer_config.json").exists()
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
def train_tokenizer_once() -> None:
|
| 381 |
+
TOKENIZER_DIR.mkdir(parents=True, exist_ok=True)
|
| 382 |
+
|
| 383 |
+
tok = Tokenizer(models.BPE(unk_token=UNK_TOKEN))
|
| 384 |
+
tok.normalizer = normalizers.Sequence([normalizers.NFKC()])
|
| 385 |
+
tok.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
|
| 386 |
+
tok.decoder = decoders.ByteLevel()
|
| 387 |
+
|
| 388 |
+
trainer = trainers.BpeTrainer(
|
| 389 |
+
vocab_size=VOCAB_SIZE,
|
| 390 |
+
min_frequency=2,
|
| 391 |
+
show_progress=is_main(),
|
| 392 |
+
special_tokens=SPECIAL_TOKENS,
|
| 393 |
+
initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
tok.train_from_iterator(tokenizer_training_iterator(), trainer=trainer)
|
| 397 |
+
|
| 398 |
+
bos_id = tok.token_to_id(BOS_TOKEN)
|
| 399 |
+
eos_id = tok.token_to_id(EOS_TOKEN)
|
| 400 |
+
|
| 401 |
+
tok.post_processor = processors.TemplateProcessing(
|
| 402 |
+
single=f"{BOS_TOKEN} $A {EOS_TOKEN}",
|
| 403 |
+
pair=f"{BOS_TOKEN} $A {EOS_TOKEN} $B:1 {EOS_TOKEN}:1",
|
| 404 |
+
special_tokens=[(BOS_TOKEN, bos_id), (EOS_TOKEN, eos_id)],
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
tok.save(str(TOKENIZER_DIR / "tokenizer.json"))
|
| 408 |
+
|
| 409 |
+
fast = PreTrainedTokenizerFast(
|
| 410 |
+
tokenizer_file=str(TOKENIZER_DIR / "tokenizer.json"),
|
| 411 |
+
bos_token=BOS_TOKEN,
|
| 412 |
+
eos_token=EOS_TOKEN,
|
| 413 |
+
unk_token=UNK_TOKEN,
|
| 414 |
+
pad_token=PAD_TOKEN,
|
| 415 |
+
)
|
| 416 |
+
fast.save_pretrained(str(TOKENIZER_DIR))
|
| 417 |
+
|
| 418 |
+
def train_or_load_tokenizer() -> PreTrainedTokenizerFast:
|
| 419 |
+
TOKENIZER_DIR.mkdir(parents=True, exist_ok=True)
|
| 420 |
+
|
| 421 |
+
if not tokenizer_ready():
|
| 422 |
+
if is_distributed():
|
| 423 |
+
if is_main():
|
| 424 |
+
print("Entraînement tokenizer 32k…")
|
| 425 |
+
train_tokenizer_once()
|
| 426 |
+
dist.barrier()
|
| 427 |
+
else:
|
| 428 |
+
print("Entraînement tokenizer 32k…")
|
| 429 |
+
train_tokenizer_once()
|
| 430 |
+
|
| 431 |
+
return PreTrainedTokenizerFast.from_pretrained(str(TOKENIZER_DIR))
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
# ╔══════════════════════════════════════════════════════════════════════════════╗
|
| 435 |
+
# ║ MODÈLE GPT ║
|
| 436 |
+
# ╚══════════════════════════════════════════════════════════════════════════════╝
|
| 437 |
+
|
| 438 |
+
@dataclass
|
| 439 |
+
class GPTConfig:
|
| 440 |
+
vocab_size: int = VOCAB_SIZE
|
| 441 |
+
block_size: int = BLOCK_SIZE
|
| 442 |
+
d_model: int = D_MODEL
|
| 443 |
+
n_heads: int = N_HEADS
|
| 444 |
+
n_layers: int = N_LAYERS
|
| 445 |
+
d_ff: int = D_FF
|
| 446 |
+
dropout: float = DROPOUT
|
| 447 |
+
use_checkpointing: bool = USE_CHECKPOINTING
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
class RMSNorm(nn.Module):
|
| 451 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 452 |
+
super().__init__()
|
| 453 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 454 |
+
self.eps = eps
|
| 455 |
+
|
| 456 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 457 |
+
return self.weight * x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
class RotaryEmbedding(nn.Module):
|
| 461 |
+
def __init__(self, dim: int, base: int = 10_000, max_seq: int = 8_192):
|
| 462 |
+
super().__init__()
|
| 463 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
| 464 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 465 |
+
|
| 466 |
+
t = torch.arange(max_seq).float()
|
| 467 |
+
freqs = torch.outer(t, inv_freq)
|
| 468 |
+
self.register_buffer("cos_cache", torch.repeat_interleave(freqs.cos(), 2, dim=-1), persistent=False)
|
| 469 |
+
self.register_buffer("sin_cache", torch.repeat_interleave(freqs.sin(), 2, dim=-1), persistent=False)
|
| 470 |
+
|
| 471 |
+
def forward(self, seq_len: int, dtype: torch.dtype):
|
| 472 |
+
return self.cos_cache[:seq_len].to(dtype), self.sin_cache[:seq_len].to(dtype)
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| 476 |
+
x1, x2 = x[..., ::2], x[..., 1::2]
|
| 477 |
+
return torch.stack((-x2, x1), dim=-1).flatten(-2)
|
| 478 |
+
|
| 479 |
+
def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
|
| 480 |
+
return x * cos.unsqueeze(0).unsqueeze(0) + rotate_half(x) * sin.unsqueeze(0).unsqueeze(0)
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
class CausalSelfAttention(nn.Module):
|
| 484 |
+
"""
|
| 485 |
+
Flash Attention 2 si disponible.
|
| 486 |
+
Sinon SDPA PyTorch, fusionné et performant sur H100.
|
| 487 |
+
"""
|
| 488 |
+
def __init__(self, cfg: GPTConfig):
|
| 489 |
+
super().__init__()
|
| 490 |
+
assert cfg.d_model % cfg.n_heads == 0
|
| 491 |
+
|
| 492 |
+
self.n_heads = cfg.n_heads
|
| 493 |
+
self.head_dim = cfg.d_model // cfg.n_heads
|
| 494 |
+
self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
|
| 495 |
+
self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
|
| 496 |
+
self.dropout_p = cfg.dropout
|
| 497 |
+
self.rope = RotaryEmbedding(self.head_dim)
|
| 498 |
+
|
| 499 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 500 |
+
b, t, c = x.shape
|
| 501 |
+
|
| 502 |
+
q, k, v = self.qkv(x).split(c, dim=-1)
|
| 503 |
+
q = q.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
|
| 504 |
+
k = k.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
|
| 505 |
+
v = v.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
|
| 506 |
+
|
| 507 |
+
cos, sin = self.rope(t, x.dtype)
|
| 508 |
+
q, k = apply_rope(q, cos, sin), apply_rope(k, cos, sin)
|
| 509 |
+
|
| 510 |
+
if HAS_FLASH:
|
| 511 |
+
q = q.transpose(1, 2)
|
| 512 |
+
k = k.transpose(1, 2)
|
| 513 |
+
v = v.transpose(1, 2)
|
| 514 |
+
y = flash_attn_func(
|
| 515 |
+
q,
|
| 516 |
+
k,
|
| 517 |
+
v,
|
| 518 |
+
dropout_p=self.dropout_p if self.training else 0.0,
|
| 519 |
+
causal=True,
|
| 520 |
+
)
|
| 521 |
+
y = y.reshape(b, t, c)
|
| 522 |
+
else:
|
| 523 |
+
y = F.scaled_dot_product_attention(
|
| 524 |
+
q,
|
| 525 |
+
k,
|
| 526 |
+
v,
|
| 527 |
+
dropout_p=self.dropout_p if self.training else 0.0,
|
| 528 |
+
is_causal=True,
|
| 529 |
+
)
|
| 530 |
+
y = y.transpose(1, 2).contiguous().view(b, t, c)
|
| 531 |
+
|
| 532 |
+
return self.proj(y)
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
class SwiGLU(nn.Module):
|
| 536 |
+
def __init__(self, cfg: GPTConfig):
|
| 537 |
+
super().__init__()
|
| 538 |
+
self.w1 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
|
| 539 |
+
self.w2 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
|
| 540 |
+
self.w3 = nn.Linear(cfg.d_ff, cfg.d_model, bias=False)
|
| 541 |
+
|
| 542 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 543 |
+
return self.w3(F.silu(self.w1(x)) * self.w2(x))
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
class Block(nn.Module):
|
| 547 |
+
def __init__(self, cfg: GPTConfig):
|
| 548 |
+
super().__init__()
|
| 549 |
+
self.ln1 = RMSNorm(cfg.d_model)
|
| 550 |
+
self.attn = CausalSelfAttention(cfg)
|
| 551 |
+
self.ln2 = RMSNorm(cfg.d_model)
|
| 552 |
+
self.ff = SwiGLU(cfg)
|
| 553 |
+
|
| 554 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 555 |
+
x = x + self.attn(self.ln1(x))
|
| 556 |
+
x = x + self.ff(self.ln2(x))
|
| 557 |
+
return x
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
class GPT(nn.Module):
|
| 561 |
+
def __init__(self, cfg: GPTConfig):
|
| 562 |
+
super().__init__()
|
| 563 |
+
self.cfg = cfg
|
| 564 |
+
self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
|
| 565 |
+
self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)])
|
| 566 |
+
self.ln_f = RMSNorm(cfg.d_model)
|
| 567 |
+
self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
|
| 568 |
+
self.lm_head.weight = self.tok_emb.weight
|
| 569 |
+
self.apply(self._init_weights)
|
| 570 |
+
|
| 571 |
+
@staticmethod
|
| 572 |
+
def _init_weights(m: nn.Module) -> None:
|
| 573 |
+
if isinstance(m, (nn.Linear, nn.Embedding)):
|
| 574 |
+
nn.init.normal_(m.weight, 0.0, 0.02)
|
| 575 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 576 |
+
nn.init.zeros_(m.bias)
|
| 577 |
+
|
| 578 |
+
def forward(self, input_ids: torch.Tensor, labels: Optional[torch.Tensor] = None):
|
| 579 |
+
x = self.tok_emb(input_ids)
|
| 580 |
+
|
| 581 |
+
for block in self.blocks:
|
| 582 |
+
if self.cfg.use_checkpointing and self.training:
|
| 583 |
+
x = torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False)
|
| 584 |
+
else:
|
| 585 |
+
x = block(x)
|
| 586 |
+
|
| 587 |
+
logits = self.lm_head(self.ln_f(x))
|
| 588 |
+
loss = None
|
| 589 |
+
|
| 590 |
+
if labels is not None:
|
| 591 |
+
loss = F.cross_entropy(
|
| 592 |
+
logits.reshape(-1, logits.size(-1)),
|
| 593 |
+
labels.reshape(-1),
|
| 594 |
+
ignore_index=-100,
|
| 595 |
+
)
|
| 596 |
+
|
| 597 |
+
return logits, loss
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
# ╔══════════════════════════════════════════════════════════════════════════════╗
|
| 601 |
+
# ║ LoRA ║
|
| 602 |
+
# ╚══════════════════════════════════════════════════════════════════════════════╝
|
| 603 |
+
|
| 604 |
+
class LoRALinear(nn.Module):
|
| 605 |
+
def __init__(self, base_layer: nn.Linear, r=LORA_R, alpha=LORA_ALPHA, dropout=LORA_DROPOUT):
|
| 606 |
+
super().__init__()
|
| 607 |
+
self.base = base_layer
|
| 608 |
+
self.r = r
|
| 609 |
+
self.scale = alpha / r
|
| 610 |
+
|
| 611 |
+
in_f, out_f = base_layer.in_features, base_layer.out_features
|
| 612 |
+
try:
|
| 613 |
+
dev = next(base_layer.parameters()).device
|
| 614 |
+
except StopIteration:
|
| 615 |
+
dev = torch.device("cpu")
|
| 616 |
+
|
| 617 |
+
self.lora_A = nn.Linear(in_f, r, bias=False, device=dev)
|
| 618 |
+
self.lora_B = nn.Linear(r, out_f, bias=False, device=dev)
|
| 619 |
+
self.drop = nn.Dropout(dropout)
|
| 620 |
+
|
| 621 |
+
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
|
| 622 |
+
nn.init.zeros_(self.lora_B.weight)
|
| 623 |
+
|
| 624 |
+
for p in self.base.parameters():
|
| 625 |
+
p.requires_grad = False
|
| 626 |
+
|
| 627 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 628 |
+
return self.base(x) + self.lora_B(self.lora_A(self.drop(x))) * self.scale
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
def apply_qlora(model: GPT, device: torch.device) -> GPT:
|
| 632 |
+
if not USE_QLORA:
|
| 633 |
+
return model
|
| 634 |
+
|
| 635 |
+
targets = [
|
| 636 |
+
(name, module)
|
| 637 |
+
for name, module in model.named_modules()
|
| 638 |
+
if name.split(".")[-1] in LORA_TARGET_MODULES and isinstance(module, nn.Linear)
|
| 639 |
+
]
|
| 640 |
+
|
| 641 |
+
for name, module in targets:
|
| 642 |
+
parts = name.split(".")
|
| 643 |
+
parent = model
|
| 644 |
+
for part in parts[:-1]:
|
| 645 |
+
parent = getattr(parent, part)
|
| 646 |
+
setattr(parent, parts[-1], LoRALinear(module))
|
| 647 |
+
|
| 648 |
+
if is_main():
|
| 649 |
+
print(f"LoRA : {len(targets)} couches remplacées (device={device})")
|
| 650 |
+
|
| 651 |
+
return model
|
| 652 |
+
|
| 653 |
+
def freeze_base_weights(model: GPT) -> None:
|
| 654 |
+
for name, p in model.named_parameters():
|
| 655 |
+
p.requires_grad = ("lora_A" in name or "lora_B" in name)
|
| 656 |
+
|
| 657 |
+
|
| 658 |
+
# ╔══════════════════════════════════════════════════════════════════════════════╗
|
| 659 |
+
# ║ OPTIMIZER ║
|
| 660 |
+
# ╚══════════════════════════════════════════════════════════════════════════════╝
|
| 661 |
+
|
| 662 |
+
def build_optimizer(model: nn.Module) -> torch.optim.Optimizer:
|
| 663 |
+
decay, no_decay = [], []
|
| 664 |
+
|
| 665 |
+
for name, p in unwrap_model(model).named_parameters():
|
| 666 |
+
if not p.requires_grad:
|
| 667 |
+
continue
|
| 668 |
+
if p.ndim >= 2 and "weight" in name:
|
| 669 |
+
decay.append(p)
|
| 670 |
+
else:
|
| 671 |
+
no_decay.append(p)
|
| 672 |
+
|
| 673 |
+
groups = [
|
| 674 |
+
{"params": decay, "weight_decay": WEIGHT_DECAY},
|
| 675 |
+
{"params": no_decay, "weight_decay": 0.0},
|
| 676 |
+
]
|
| 677 |
+
|
| 678 |
+
if HAS_BNB:
|
| 679 |
+
return bnb.optim.PagedAdamW8bit(
|
| 680 |
+
groups,
|
| 681 |
+
lr=LEARNING_RATE,
|
| 682 |
+
betas=(0.9, 0.95),
|
| 683 |
+
eps=1e-8,
|
| 684 |
+
)
|
| 685 |
+
|
| 686 |
+
return torch.optim.AdamW(
|
| 687 |
+
groups,
|
| 688 |
+
lr=LEARNING_RATE,
|
| 689 |
+
betas=(0.9, 0.95),
|
| 690 |
+
eps=1e-8,
|
| 691 |
+
fused=torch.cuda.is_available(),
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
def cosine_lr(step: int, total_steps: int) -> float:
|
| 695 |
+
if step < WARMUP_STEPS:
|
| 696 |
+
return LEARNING_RATE * step / max(1, WARMUP_STEPS)
|
| 697 |
+
|
| 698 |
+
p = min(1.0, (step - WARMUP_STEPS) / max(1, total_steps - WARMUP_STEPS))
|
| 699 |
+
return MIN_LR + 0.5 * (LEARNING_RATE - MIN_LR) * (1 + math.cos(math.pi * p))
|
| 700 |
+
|
| 701 |
+
|
| 702 |
+
# ╔══════════════════════════════════════════════════════════════════════════════╗
|
| 703 |
+
# ║ CHECKPOINT ║
|
| 704 |
+
# ╚══════════════════════════════════════════════════════════════════════════════╝
|
| 705 |
+
|
| 706 |
+
def save_checkpoint(model, optimizer, epoch, step, best_loss, path):
|
| 707 |
+
raw = unwrap_model(model)
|
| 708 |
+
torch.save(
|
| 709 |
+
{
|
| 710 |
+
"model": normalize_state_dict_keys(raw.state_dict()),
|
| 711 |
+
"optimizer": optimizer.state_dict(),
|
| 712 |
+
"epoch": epoch,
|
| 713 |
+
"step": step,
|
| 714 |
+
"best_loss": best_loss,
|
| 715 |
+
"config": asdict(raw.cfg),
|
| 716 |
+
},
|
| 717 |
+
path,
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
def maybe_load_base_checkpoint(model, device):
|
| 721 |
+
if BASE_CHECKPOINT is None or not Path(BASE_CHECKPOINT).exists():
|
| 722 |
+
return
|
| 723 |
+
|
| 724 |
+
ckpt = torch.load(BASE_CHECKPOINT, map_location=device)
|
| 725 |
+
missing, unexpected = model.load_state_dict(normalize_state_dict_keys(ckpt["model"]), strict=False)
|
| 726 |
+
|
| 727 |
+
if is_main():
|
| 728 |
+
print(f"Base checkpoint chargé depuis {BASE_CHECKPOINT}")
|
| 729 |
+
if missing:
|
| 730 |
+
print(f"[warn] missing keys base ckpt: {len(missing)}")
|
| 731 |
+
if unexpected:
|
| 732 |
+
print(f"[warn] unexpected keys base ckpt: {len(unexpected)}")
|
| 733 |
+
|
| 734 |
+
def load_resume_checkpoint(model, optimizer, path, device):
|
| 735 |
+
ckpt = torch.load(path, map_location=device)
|
| 736 |
+
unwrap_model(model).load_state_dict(normalize_state_dict_keys(ckpt["model"]), strict=True)
|
| 737 |
+
|
| 738 |
+
try:
|
| 739 |
+
optimizer.load_state_dict(ckpt["optimizer"])
|
| 740 |
+
except Exception as e:
|
| 741 |
+
print(f"[warn] Optimizer state non repris: {e}")
|
| 742 |
+
|
| 743 |
+
return (
|
| 744 |
+
int(ckpt.get("epoch", 0)),
|
| 745 |
+
int(ckpt.get("step", 0)),
|
| 746 |
+
float(ckpt.get("best_loss", 1e9)),
|
| 747 |
+
)
|
| 748 |
+
|
| 749 |
+
|
| 750 |
+
# ╔══════════════════════════════════════════════════════════════════════════════╗
|
| 751 |
+
# ║ ÉVALUATION ║
|
| 752 |
+
# ╚══════════════════════════════════════════════════════════════════════════════╝
|
| 753 |
+
|
| 754 |
+
@torch.no_grad()
|
| 755 |
+
def evaluate(model, loader, device, max_batches=100) -> float:
|
| 756 |
+
model.eval()
|
| 757 |
+
losses = []
|
| 758 |
+
|
| 759 |
+
for i, batch in enumerate(loader):
|
| 760 |
+
if i >= max_batches:
|
| 761 |
+
break
|
| 762 |
+
|
| 763 |
+
inp = batch["input_ids"].to(device, non_blocking=True)
|
| 764 |
+
lbl = batch["labels"].to(device, non_blocking=True)
|
| 765 |
+
|
| 766 |
+
with autocast_context(device):
|
| 767 |
+
_, loss = model(inp, lbl)
|
| 768 |
+
|
| 769 |
+
losses.append(loss.item())
|
| 770 |
+
|
| 771 |
+
model.train()
|
| 772 |
+
return sum(losses) / max(1, len(losses))
|
| 773 |
+
|
| 774 |
+
|
| 775 |
+
# ╔══════════════════════════════════════════════════════════════════════════════╗
|
| 776 |
+
# ║ DATALOADER ║
|
| 777 |
+
# ╚══════════════════════════════════════════════════════════════════════════════╝
|
| 778 |
+
|
| 779 |
+
def make_loader(dataset, batch_size, num_workers, is_cuda, drop_last=True):
|
| 780 |
+
kwargs = dict(
|
| 781 |
+
batch_size=batch_size,
|
| 782 |
+
num_workers=num_workers,
|
| 783 |
+
pin_memory=is_cuda,
|
| 784 |
+
drop_last=drop_last,
|
| 785 |
+
)
|
| 786 |
+
if num_workers > 0:
|
| 787 |
+
kwargs["persistent_workers"] = True
|
| 788 |
+
kwargs["prefetch_factor"] = PREFETCH_FACTOR
|
| 789 |
+
return torch.utils.data.DataLoader(dataset, **kwargs)
|
| 790 |
+
|
| 791 |
+
|
| 792 |
+
# ╔══════════════════════════════════════════════════════════════════════════════╗
|
| 793 |
+
# ║ CUDA / LOGGING ║
|
| 794 |
+
# ╚══════════════════════════════════════════════════════════════════════════════╝
|
| 795 |
+
|
| 796 |
+
def maybe_limit_process_memory(device: torch.device) -> tuple[Optional[int], Optional[float]]:
|
| 797 |
+
if device.type != "cuda":
|
| 798 |
+
return None, None
|
| 799 |
+
|
| 800 |
+
cuda_idx = current_cuda_index(device)
|
| 801 |
+
|
| 802 |
+
if TARGET_VRAM_GIB is None:
|
| 803 |
+
return cuda_idx, None
|
| 804 |
+
|
| 805 |
+
_, total = torch.cuda.mem_get_info(cuda_idx)
|
| 806 |
+
vram_fraction = min(TARGET_VRAM_GIB * (1024**3) / total, 0.98)
|
| 807 |
+
torch.cuda.memory.set_per_process_memory_fraction(vram_fraction, device=cuda_idx)
|
| 808 |
+
return cuda_idx, vram_fraction
|
| 809 |
+
|
| 810 |
+
def sync_if_cuda(device: torch.device) -> None:
|
| 811 |
+
if device.type == "cuda":
|
| 812 |
+
torch.cuda.synchronize(current_cuda_index(device))
|
| 813 |
+
|
| 814 |
+
|
| 815 |
+
# ╔══════════════════════════════════════════════════════════════════════════════╗
|
| 816 |
+
# ║ MAIN ║
|
| 817 |
+
# ╚══════════════════════════════════════════════════════════════════════════════╝
|
| 818 |
+
|
| 819 |
+
def main() -> None:
|
| 820 |
+
ddp_device = init_distributed()
|
| 821 |
+
set_seed(SEED + get_rank())
|
| 822 |
+
|
| 823 |
+
device = get_device(ddp_device)
|
| 824 |
+
is_cuda = device.type == "cuda"
|
| 825 |
+
|
| 826 |
+
if is_cuda:
|
| 827 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 828 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 829 |
+
torch.set_float32_matmul_precision("high")
|
| 830 |
+
torch.backends.cudnn.benchmark = True
|
| 831 |
+
|
| 832 |
+
cuda_idx, vram_fraction = maybe_limit_process_memory(device)
|
| 833 |
+
|
| 834 |
+
if is_main():
|
| 835 |
+
print("=" * 72)
|
| 836 |
+
print(" GPT ~1B | H100 MAX VRAM | LoRA + BF16 + TF32 + compile | v4")
|
| 837 |
+
print("=" * 72)
|
| 838 |
+
print(f"Device : {device} | World: {get_world_size()} GPU(s)")
|
| 839 |
+
print(f"Flash-2 : {HAS_FLASH} | BNB: {HAS_BNB} | LoRA: {USE_QLORA}")
|
| 840 |
+
print(f"Grad ckpt : {USE_CHECKPOINTING} | Compile: {USE_COMPILE} ({COMPILE_MODE})")
|
| 841 |
+
print(f"BLOCK_SIZE : {BLOCK_SIZE} | BATCH_SIZE: {BATCH_SIZE} | GRAD_ACCUM: {GRAD_ACCUM_STEPS}")
|
| 842 |
+
print(f"Tokens/step : {BLOCK_SIZE * BATCH_SIZE * GRAD_ACCUM_STEPS:,}")
|
| 843 |
+
if is_cuda:
|
| 844 |
+
free, total = torch.cuda.mem_get_info(cuda_idx)
|
| 845 |
+
print(f"GPU : {torch.cuda.get_device_name(cuda_idx)}")
|
| 846 |
+
print(f"VRAM : {total/1024**3:.1f} GiB | libre: {free/1024**3:.1f} GiB")
|
| 847 |
+
if vram_fraction is None:
|
| 848 |
+
print("Cap VRAM : désactivé")
|
| 849 |
+
else:
|
| 850 |
+
print(f"Cap VRAM : {TARGET_VRAM_GIB:.1f} GiB ({100*vram_fraction:.1f}% du device)")
|
| 851 |
+
|
| 852 |
+
tokenizer = train_or_load_tokenizer()
|
| 853 |
+
cfg = GPTConfig(vocab_size=len(tokenizer))
|
| 854 |
+
|
| 855 |
+
if is_main():
|
| 856 |
+
CONFIG_FILE.write_text(json.dumps(asdict(cfg), indent=2, ensure_ascii=False), encoding="utf-8")
|
| 857 |
+
|
| 858 |
+
# 1) Base model sur GPU
|
| 859 |
+
model = GPT(cfg).to(device)
|
| 860 |
+
|
| 861 |
+
# 2) Charger éventuel checkpoint base AVANT LoRA
|
| 862 |
+
maybe_load_base_checkpoint(model, device)
|
| 863 |
+
|
| 864 |
+
# 3) Injecter LoRA ensuite
|
| 865 |
+
if USE_QLORA:
|
| 866 |
+
model = apply_qlora(model, device)
|
| 867 |
+
freeze_base_weights(model)
|
| 868 |
+
|
| 869 |
+
# 4) Compiler le modèle ensuite
|
| 870 |
+
if USE_COMPILE and hasattr(torch, "compile"):
|
| 871 |
+
if is_main():
|
| 872 |
+
print(f"Compilation torch.compile({COMPILE_MODE})…")
|
| 873 |
+
try:
|
| 874 |
+
model = torch.compile(model, mode=COMPILE_MODE, fullgraph=False)
|
| 875 |
+
if is_main():
|
| 876 |
+
print("torch.compile : OK")
|
| 877 |
+
except Exception as e:
|
| 878 |
+
if is_main():
|
| 879 |
+
print(f"[warn] torch.compile échoué ({e}) — fallback eager")
|
| 880 |
+
|
| 881 |
+
# 5) DDP
|
| 882 |
+
if is_distributed():
|
| 883 |
+
model = DDP(model, device_ids=[device.index])
|
| 884 |
+
|
| 885 |
+
optimizer = build_optimizer(model)
|
| 886 |
+
|
| 887 |
+
# ── Datasets ──────────────────────────────────────────────────────────────
|
| 888 |
+
eval_texts = build_eval_texts()
|
| 889 |
+
eval_ds = PackedTextList(eval_texts, tokenizer, cfg.block_size, SEED + 999)
|
| 890 |
+
eval_loader = make_loader(eval_ds, BATCH_SIZE, EVAL_NUM_WORKERS, is_cuda, drop_last=False)
|
| 891 |
+
|
| 892 |
+
init_texts = build_epoch_train_texts(0)
|
| 893 |
+
steps_per_epoch = estimate_steps_per_epoch(init_texts, tokenizer, cfg.block_size, BATCH_SIZE * GRAD_ACCUM_STEPS)
|
| 894 |
+
total_steps_est = max(steps_per_epoch * NUM_EPOCHS, WARMUP_STEPS + 100)
|
| 895 |
+
|
| 896 |
+
# ── Reprise ───────────────────────────────────────────────────────────────
|
| 897 |
+
start_epoch, start_step, best_eval = 0, 0, 1e9
|
| 898 |
+
if STATE_FILE.exists():
|
| 899 |
+
try:
|
| 900 |
+
if is_main():
|
| 901 |
+
print(f"Reprise depuis {STATE_FILE}")
|
| 902 |
+
start_epoch, start_step, best_eval = load_resume_checkpoint(model, optimizer, STATE_FILE, device)
|
| 903 |
+
except Exception as e:
|
| 904 |
+
if is_main():
|
| 905 |
+
try:
|
| 906 |
+
STATE_FILE.rename(STATE_FILE.with_suffix(".corrupt.pt"))
|
| 907 |
+
except Exception:
|
| 908 |
+
pass
|
| 909 |
+
print(f"[warn] Checkpoint illisible ({e}) — reprise ignorée")
|
| 910 |
+
start_epoch, start_step, best_eval = 0, 0, 1e9
|
| 911 |
+
|
| 912 |
+
if is_main():
|
| 913 |
+
raw = unwrap_model(model)
|
| 914 |
+
n_total = count_parameters(raw, False)
|
| 915 |
+
n_train = count_parameters(raw, True)
|
| 916 |
+
effective_bs = BATCH_SIZE * GRAD_ACCUM_STEPS * get_world_size()
|
| 917 |
+
|
| 918 |
+
print(f"\nParamètres totaux : {n_total/1e9:.3f}B")
|
| 919 |
+
print(f"Paramètres entraînés : {n_train/1e6:.1f}M ({100*n_train/max(1, n_total):.2f}%)")
|
| 920 |
+
print(f"Batch effectif : {effective_bs} ({BATCH_SIZE}×{GRAD_ACCUM_STEPS}×{get_world_size()} GPU)")
|
| 921 |
+
print(f"Tokens/step : {BLOCK_SIZE * effective_bs:,}")
|
| 922 |
+
print(f"Steps estimés : {total_steps_est:,}")
|
| 923 |
+
print()
|
| 924 |
+
print("┌── Pilotage VRAM ──────────────────────────────────────────────┐")
|
| 925 |
+
print("│ Lis 'max_reserved' après quelques logs : │")
|
| 926 |
+
print("│ < 72 GiB → +2 BS | 72–77.5 GiB → zone cible │")
|
| 927 |
+
print("│ vrai OOM → -2 BS | puis relance │")
|
| 928 |
+
print("└───────────────────────────────────────────────────────────────┘")
|
| 929 |
+
|
| 930 |
+
# ── Boucle principale ─────────────────────────────────────────────────────
|
| 931 |
+
model.train()
|
| 932 |
+
optimizer.zero_grad(set_to_none=True)
|
| 933 |
+
|
| 934 |
+
global_step = start_step
|
| 935 |
+
t0 = time.time()
|
| 936 |
+
log_loss_sum = 0.0
|
| 937 |
+
log_loss_count = 0
|
| 938 |
+
tokens_since_log = 0
|
| 939 |
+
last_log = time.time()
|
| 940 |
+
|
| 941 |
+
if is_cuda:
|
| 942 |
+
torch.cuda.reset_peak_memory_stats(cuda_idx)
|
| 943 |
+
|
| 944 |
+
for epoch in range(start_epoch, NUM_EPOCHS):
|
| 945 |
+
if is_main():
|
| 946 |
+
print(f"\n{'='*20} Epoch {epoch+1}/{NUM_EPOCHS} {'='*20}")
|
| 947 |
+
|
| 948 |
+
train_texts = build_epoch_train_texts(epoch)
|
| 949 |
+
train_ds = PackedTextList(train_texts, tokenizer, cfg.block_size, SEED + epoch)
|
| 950 |
+
train_loader = make_loader(train_ds, BATCH_SIZE, TRAIN_NUM_WORKERS, is_cuda, drop_last=True)
|
| 951 |
+
|
| 952 |
+
for micro_step, batch in enumerate(train_loader):
|
| 953 |
+
inp = batch["input_ids"].to(device, non_blocking=True)
|
| 954 |
+
lbl = batch["labels"].to(device, non_blocking=True)
|
| 955 |
+
|
| 956 |
+
with autocast_context(device):
|
| 957 |
+
_, loss = model(inp, lbl)
|
| 958 |
+
|
| 959 |
+
(loss / GRAD_ACCUM_STEPS).backward()
|
| 960 |
+
|
| 961 |
+
log_loss_sum += loss.item()
|
| 962 |
+
log_loss_count += 1
|
| 963 |
+
tokens_since_log += inp.numel()
|
| 964 |
+
|
| 965 |
+
if (micro_step + 1) % GRAD_ACCUM_STEPS != 0:
|
| 966 |
+
continue
|
| 967 |
+
|
| 968 |
+
lr = cosine_lr(global_step, total_steps_est)
|
| 969 |
+
for group in optimizer.param_groups:
|
| 970 |
+
group["lr"] = lr
|
| 971 |
+
|
| 972 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
|
| 973 |
+
optimizer.step()
|
| 974 |
+
optimizer.zero_grad(set_to_none=True)
|
| 975 |
+
global_step += 1
|
| 976 |
+
|
| 977 |
+
if global_step % 50 == 0 and is_main():
|
| 978 |
+
sync_if_cuda(device)
|
| 979 |
+
now = time.time()
|
| 980 |
+
elapsed = max(1e-6, now - last_log)
|
| 981 |
+
tok_s = tokens_since_log / elapsed
|
| 982 |
+
avg_loss = log_loss_sum / max(1, log_loss_count)
|
| 983 |
+
|
| 984 |
+
print(
|
| 985 |
+
f"ep {epoch+1}/{NUM_EPOCHS} | step={global_step:5d} | "
|
| 986 |
+
f"loss={avg_loss:.4f} | lr={lr:.2e} | {tok_s:,.0f} tok/s"
|
| 987 |
+
)
|
| 988 |
+
|
| 989 |
+
if is_cuda:
|
| 990 |
+
alloc = torch.cuda.memory_allocated(cuda_idx) / 1024**3
|
| 991 |
+
reserved = torch.cuda.memory_reserved(cuda_idx) / 1024**3
|
| 992 |
+
max_res = torch.cuda.max_memory_reserved(cuda_idx) / 1024**3
|
| 993 |
+
|
| 994 |
+
status = (
|
| 995 |
+
"▲ OK"
|
| 996 |
+
if max_res < 75.0 else
|
| 997 |
+
"⚠ proche limite"
|
| 998 |
+
if max_res < 77.5 else
|
| 999 |
+
"🔴 DANGER OOM"
|
| 1000 |
+
)
|
| 1001 |
+
|
| 1002 |
+
print(
|
| 1003 |
+
f" GPU mem | alloc={alloc:.1f} | reserved={reserved:.1f} | "
|
| 1004 |
+
f"max_reserved={max_res:.1f} GiB {status}"
|
| 1005 |
+
)
|
| 1006 |
+
|
| 1007 |
+
last_log = now
|
| 1008 |
+
tokens_since_log = 0
|
| 1009 |
+
log_loss_sum = 0.0
|
| 1010 |
+
log_loss_count = 0
|
| 1011 |
+
|
| 1012 |
+
if global_step % EVAL_EVERY == 0 and is_main():
|
| 1013 |
+
val_loss = evaluate(model, eval_loader, device)
|
| 1014 |
+
print(f"[eval] step {global_step:5d} | val_loss={val_loss:.4f}")
|
| 1015 |
+
|
| 1016 |
+
if val_loss < best_eval:
|
| 1017 |
+
best_eval = val_loss
|
| 1018 |
+
save_checkpoint(model, optimizer, epoch, global_step, best_eval, BEST_MODEL_FILE)
|
| 1019 |
+
print(f"✓ Meilleur modèle → {BEST_MODEL_FILE}")
|
| 1020 |
+
|
| 1021 |
+
if global_step % SAVE_EVERY == 0 and is_main():
|
| 1022 |
+
save_checkpoint(model, optimizer, epoch, global_step, best_eval, STATE_FILE)
|
| 1023 |
+
save_checkpoint(model, optimizer, epoch, global_step, best_eval, MODEL_FILE)
|
| 1024 |
+
print(f"✓ Checkpoint → {MODEL_FILE}")
|
| 1025 |
+
|
| 1026 |
+
if is_main():
|
| 1027 |
+
save_checkpoint(model, optimizer, epoch + 1, global_step, best_eval, STATE_FILE)
|
| 1028 |
+
ckpt = OUT_DIR / f"model_epoch_{epoch+1:02d}.pt"
|
| 1029 |
+
save_checkpoint(model, optimizer, epoch + 1, global_step, best_eval, ckpt)
|
| 1030 |
+
print(f"✓ Fin epoch {epoch+1}/{NUM_EPOCHS} → {ckpt}")
|
| 1031 |
+
|
| 1032 |
+
if is_main():
|
| 1033 |
+
save_checkpoint(model, optimizer, NUM_EPOCHS, global_step, best_eval, MODEL_FILE)
|
| 1034 |
+
save_checkpoint(model, optimizer, NUM_EPOCHS, global_step, best_eval, STATE_FILE)
|
| 1035 |
+
total_min = (time.time() - t0) / 60
|
| 1036 |
+
|
| 1037 |
+
print(f"\nModèle final → {MODEL_FILE}")
|
| 1038 |
+
print(f"Meilleur modèle → {BEST_MODEL_FILE}")
|
| 1039 |
+
print(f"Temps total : {total_min:.1f} min | Steps: {global_step}")
|
| 1040 |
+
|
| 1041 |
+
if is_distributed():
|
| 1042 |
+
dist.destroy_process_group()
|
| 1043 |
+
|
| 1044 |
+
|
| 1045 |
+
if __name__ == "__main__":
|
| 1046 |
+
main()
|
train_nlp_h100_maxvram_v7.py
ADDED
|
@@ -0,0 +1,805 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
train_nlp_h100_maxvram_v7.py — v3 (fix gated OSCAR → public C4)
|
| 5 |
+
===========================================================
|
| 6 |
+
• Datasets publics seulement (plus de gated error)
|
| 7 |
+
• Toujours ~85 GB de données traitées sur 10 epochs
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import itertools
|
| 13 |
+
import json
|
| 14 |
+
import math
|
| 15 |
+
import os
|
| 16 |
+
import random
|
| 17 |
+
import time
|
| 18 |
+
from collections import OrderedDict
|
| 19 |
+
from contextlib import nullcontext
|
| 20 |
+
from dataclasses import asdict, dataclass
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from typing import Iterator, Optional
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
import torch.distributed as dist
|
| 26 |
+
import torch.nn as nn
|
| 27 |
+
import torch.nn.functional as F
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
import bitsandbytes as bnb
|
| 31 |
+
HAS_BNB = True
|
| 32 |
+
except ImportError:
|
| 33 |
+
HAS_BNB = False
|
| 34 |
+
print("[warn] bitsandbytes non disponible – quantification 4-bit désactivée")
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
from flash_attn import flash_attn_func
|
| 38 |
+
HAS_FLASH = True
|
| 39 |
+
except ImportError:
|
| 40 |
+
HAS_FLASH = False
|
| 41 |
+
print("[warn] flash-attn non disponible – fallback F.scaled_dot_product_attention")
|
| 42 |
+
|
| 43 |
+
from datasets import load_dataset
|
| 44 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 45 |
+
from tokenizers import (
|
| 46 |
+
Tokenizer, decoders, models, normalizers,
|
| 47 |
+
pre_tokenizers, processors, trainers,
|
| 48 |
+
)
|
| 49 |
+
from transformers import PreTrainedTokenizerFast
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# ╔══════════════════════════════════════════════════════════════════════════════╗
|
| 53 |
+
# ║ CHEMINS ║
|
| 54 |
+
# ╚══════════════════════════════════════════════════════════════════════════════╝
|
| 55 |
+
|
| 56 |
+
OUT_DIR = Path("./nlp_1b_h100_opt")
|
| 57 |
+
OUT_DIR.mkdir(parents=True, exist_ok=True)
|
| 58 |
+
TOKENIZER_DIR = OUT_DIR / "tokenizer_32k"
|
| 59 |
+
CONFIG_FILE = OUT_DIR / "config.json"
|
| 60 |
+
MODEL_FILE = OUT_DIR / "model.pt"
|
| 61 |
+
BEST_MODEL_FILE= OUT_DIR / "model_best.pt"
|
| 62 |
+
STATE_FILE = OUT_DIR / "train_state.pt"
|
| 63 |
+
BASE_CHECKPOINT: Optional[Path] = None
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# ╔══════════════════════════════════════════════════════════════════════════════╗
|
| 67 |
+
# ║ HYPERPARAMÈTRES ║
|
| 68 |
+
# ╚══════════════════════════════════════════════════════════════════════════════╝
|
| 69 |
+
|
| 70 |
+
SEED = 42
|
| 71 |
+
TARGET_VRAM_GIB= 78.0
|
| 72 |
+
|
| 73 |
+
BLOCK_SIZE = 1024
|
| 74 |
+
VOCAB_SIZE = 32_000
|
| 75 |
+
D_MODEL = 1536
|
| 76 |
+
N_HEADS = 24
|
| 77 |
+
N_LAYERS = 24
|
| 78 |
+
D_FF = 6144
|
| 79 |
+
DROPOUT = 0.0
|
| 80 |
+
|
| 81 |
+
USE_QLORA = True
|
| 82 |
+
LORA_R = 64
|
| 83 |
+
LORA_ALPHA = 128
|
| 84 |
+
LORA_DROPOUT = 0.05
|
| 85 |
+
LORA_TARGET_MODULES = ["qkv", "proj", "w1", "w2", "w3"]
|
| 86 |
+
|
| 87 |
+
NUM_EPOCHS = 3
|
| 88 |
+
LEARNING_RATE = 3e-4
|
| 89 |
+
MIN_LR = 3e-5
|
| 90 |
+
WEIGHT_DECAY = 0.1
|
| 91 |
+
WARMUP_STEPS = 500
|
| 92 |
+
|
| 93 |
+
BATCH_SIZE = 28
|
| 94 |
+
GRAD_ACCUM_STEPS = 1
|
| 95 |
+
MAX_GRAD_NORM = 1.0
|
| 96 |
+
EVAL_EVERY = 500
|
| 97 |
+
SAVE_EVERY = 1_000
|
| 98 |
+
|
| 99 |
+
DTYPE = torch.bfloat16
|
| 100 |
+
|
| 101 |
+
USE_CHECKPOINTING = False
|
| 102 |
+
USE_COMPILE = True
|
| 103 |
+
COMPILE_MODE = "reduce-overhead"
|
| 104 |
+
|
| 105 |
+
TRAIN_NUM_WORKERS = 4
|
| 106 |
+
EVAL_NUM_WORKERS = 2
|
| 107 |
+
PREFETCH_FACTOR = 2
|
| 108 |
+
|
| 109 |
+
TOKENIZER_SAMPLE_DOCS_PER_SOURCE = 15_000
|
| 110 |
+
TOKENIZER_CHAR_LIMIT = 2_000
|
| 111 |
+
TEXT_CHAR_LIMIT = 4_000
|
| 112 |
+
|
| 113 |
+
SPECIAL_TOKENS = ["<pad>", "<bos>", "<eos>", "<unk>"]
|
| 114 |
+
PAD_TOKEN, BOS_TOKEN, EOS_TOKEN, UNK_TOKEN = SPECIAL_TOKENS
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# ╔══════════════════════════════════════════════════════════════════════════════╗
|
| 118 |
+
# ║ DATASETS — PUBLIC + MAX 100 GB (fix gated OSCAR) ║
|
| 119 |
+
# ╚══════════════════════════════════════════════════════════════════════════════╝
|
| 120 |
+
|
| 121 |
+
DATA_SOURCES = [
|
| 122 |
+
# 1. FineWeb (anglais – très haute qualité)
|
| 123 |
+
{
|
| 124 |
+
"name": "HuggingFaceFW/fineweb",
|
| 125 |
+
"config": None,
|
| 126 |
+
"split": "train",
|
| 127 |
+
"text_column": "text",
|
| 128 |
+
"dev_docs": 10_000,
|
| 129 |
+
"train_docs_per_epoch": 1_200_000, # ~48 GB sur 10 epochs
|
| 130 |
+
"language_filter": None,
|
| 131 |
+
},
|
| 132 |
+
# 2. C4 multilingual → français
|
| 133 |
+
{
|
| 134 |
+
"name": "allenai/c4",
|
| 135 |
+
"config": "multilingual",
|
| 136 |
+
"split": "train",
|
| 137 |
+
"text_column": "text",
|
| 138 |
+
"dev_docs": 5_000,
|
| 139 |
+
"train_docs_per_epoch": 400_000, # ~16 GB sur 10 epochs
|
| 140 |
+
"language_filter": "fr",
|
| 141 |
+
},
|
| 142 |
+
# 3. C4 multilingual → arabe
|
| 143 |
+
{
|
| 144 |
+
"name": "allenai/c4",
|
| 145 |
+
"config": "multilingual",
|
| 146 |
+
"split": "train",
|
| 147 |
+
"text_column": "text",
|
| 148 |
+
"dev_docs": 5_000,
|
| 149 |
+
"train_docs_per_epoch": 300_000, # ~12 GB sur 10 epochs
|
| 150 |
+
"language_filter": "ar",
|
| 151 |
+
},
|
| 152 |
+
]
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
# ╔══════════════════════════════════════════════════════════════════════════════╗
|
| 156 |
+
# ║ DISTRIBUTED + UTILS ║
|
| 157 |
+
# ╚══════════════════════════════════════════════════════════════════════════════╝
|
| 158 |
+
|
| 159 |
+
def is_distributed() -> bool:
|
| 160 |
+
return dist.is_available() and dist.is_initialized()
|
| 161 |
+
|
| 162 |
+
def get_rank() -> int:
|
| 163 |
+
return dist.get_rank() if is_distributed() else 0
|
| 164 |
+
|
| 165 |
+
def get_world_size() -> int:
|
| 166 |
+
return dist.get_world_size() if is_distributed() else 1
|
| 167 |
+
|
| 168 |
+
def is_main() -> bool:
|
| 169 |
+
return get_rank() == 0
|
| 170 |
+
|
| 171 |
+
def init_distributed() -> Optional[torch.device]:
|
| 172 |
+
local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
| 173 |
+
if local_rank == -1:
|
| 174 |
+
return None
|
| 175 |
+
dist.init_process_group("nccl")
|
| 176 |
+
torch.cuda.set_device(local_rank)
|
| 177 |
+
return torch.device(f"cuda:{local_rank}")
|
| 178 |
+
|
| 179 |
+
def set_seed(seed: int) -> None:
|
| 180 |
+
random.seed(seed)
|
| 181 |
+
torch.manual_seed(seed)
|
| 182 |
+
if torch.cuda.is_available():
|
| 183 |
+
torch.cuda.manual_seed_all(seed)
|
| 184 |
+
|
| 185 |
+
def get_device(ddp_device: Optional[torch.device] = None) -> torch.device:
|
| 186 |
+
if ddp_device is not None:
|
| 187 |
+
return ddp_device
|
| 188 |
+
if torch.cuda.is_available():
|
| 189 |
+
return torch.device(f"cuda:{torch.cuda.current_device()}")
|
| 190 |
+
return torch.device("cpu")
|
| 191 |
+
|
| 192 |
+
def current_cuda_index(device: torch.device) -> int:
|
| 193 |
+
return device.index if device.index is not None else torch.cuda.current_device()
|
| 194 |
+
|
| 195 |
+
def autocast_context(device: torch.device):
|
| 196 |
+
if device.type == "cuda":
|
| 197 |
+
return torch.autocast("cuda", dtype=DTYPE)
|
| 198 |
+
return nullcontext()
|
| 199 |
+
|
| 200 |
+
def unwrap_model(model: nn.Module) -> nn.Module:
|
| 201 |
+
m = model.module if isinstance(model, DDP) else model
|
| 202 |
+
return m._orig_mod if hasattr(m, "_orig_mod") else m
|
| 203 |
+
|
| 204 |
+
def count_parameters(model: nn.Module, trainable_only: bool = True) -> int:
|
| 205 |
+
return sum(p.numel() for p in model.parameters() if not trainable_only or p.requires_grad)
|
| 206 |
+
|
| 207 |
+
def normalize_state_dict_keys(sd: dict) -> OrderedDict:
|
| 208 |
+
out = OrderedDict()
|
| 209 |
+
for k, v in sd.items():
|
| 210 |
+
for prefix in ("module._orig_mod.", "_orig_mod.", "module."):
|
| 211 |
+
if k.startswith(prefix):
|
| 212 |
+
k = k[len(prefix):]
|
| 213 |
+
break
|
| 214 |
+
out[k] = v
|
| 215 |
+
return out
|
| 216 |
+
|
| 217 |
+
def normalize_text(t: str) -> str:
|
| 218 |
+
return " ".join(t.strip().split())
|
| 219 |
+
|
| 220 |
+
def safe_str(x) -> str:
|
| 221 |
+
return x if isinstance(x, str) else ("" if x is None else str(x))
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
# ╔══════════════════════════════════════════════════════════════════════════════╗
|
| 225 |
+
# ║ DATA LOADING (streaming + language filter) ║
|
| 226 |
+
# ╚══════════════════════════════════════════════════════════════════════════════╝
|
| 227 |
+
|
| 228 |
+
def load_hf_stream(repo_id: str, config: str | None = None, split: str = "train"):
|
| 229 |
+
return load_dataset(repo_id, config, split=split, streaming=True)
|
| 230 |
+
|
| 231 |
+
def stream_texts_from_source(source: dict, start: int, count: int, char_limit: int) -> Iterator[str]:
|
| 232 |
+
ds = load_hf_stream(source["name"], source.get("config"), source.get("split", "train"))
|
| 233 |
+
col = source["text_column"]
|
| 234 |
+
|
| 235 |
+
for row in itertools.islice(ds, start, start + count):
|
| 236 |
+
text = normalize_text(safe_str(row.get(col, "")))
|
| 237 |
+
if len(text) < 20:
|
| 238 |
+
continue
|
| 239 |
+
|
| 240 |
+
# Filtre langue (pour C4 multilingual)
|
| 241 |
+
if source.get("language_filter"):
|
| 242 |
+
if row.get("language") != source["language_filter"]:
|
| 243 |
+
continue
|
| 244 |
+
|
| 245 |
+
yield text[:char_limit]
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def build_epoch_train_texts(epoch: int) -> list[str]:
|
| 249 |
+
texts: list[str] = []
|
| 250 |
+
rng = random.Random(SEED + epoch)
|
| 251 |
+
|
| 252 |
+
for src in DATA_SOURCES:
|
| 253 |
+
start = src["dev_docs"] + epoch * src["train_docs_per_epoch"]
|
| 254 |
+
texts.extend(stream_texts_from_source(
|
| 255 |
+
src, start, src["train_docs_per_epoch"], TEXT_CHAR_LIMIT
|
| 256 |
+
))
|
| 257 |
+
|
| 258 |
+
rng.shuffle(texts)
|
| 259 |
+
return texts
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def build_eval_texts() -> list[str]:
|
| 263 |
+
texts: list[str] = []
|
| 264 |
+
for src in DATA_SOURCES:
|
| 265 |
+
texts.extend(stream_texts_from_source(
|
| 266 |
+
src, 0, src["dev_docs"], TEXT_CHAR_LIMIT
|
| 267 |
+
))
|
| 268 |
+
return texts
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
# ╔═��════════════════════════════════════════════════════════════════════════════╗
|
| 272 |
+
# ║ TOKENIZER ║
|
| 273 |
+
# ╚══════════════════════════════════════════════════════════════════════════════╝
|
| 274 |
+
|
| 275 |
+
def tokenizer_ready() -> bool:
|
| 276 |
+
return (TOKENIZER_DIR / "tokenizer.json").exists() and (TOKENIZER_DIR / "tokenizer_config.json").exists()
|
| 277 |
+
|
| 278 |
+
def train_tokenizer_once() -> None:
|
| 279 |
+
TOKENIZER_DIR.mkdir(parents=True, exist_ok=True)
|
| 280 |
+
tok = Tokenizer(models.BPE(unk_token=UNK_TOKEN))
|
| 281 |
+
tok.normalizer = normalizers.Sequence([normalizers.NFKC()])
|
| 282 |
+
tok.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
|
| 283 |
+
tok.decoder = decoders.ByteLevel()
|
| 284 |
+
trainer = trainers.BpeTrainer(
|
| 285 |
+
vocab_size=VOCAB_SIZE, min_frequency=2, show_progress=is_main(),
|
| 286 |
+
special_tokens=SPECIAL_TOKENS, initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
|
| 287 |
+
)
|
| 288 |
+
tok.train_from_iterator(tokenizer_training_iterator(), trainer=trainer)
|
| 289 |
+
bos_id, eos_id = tok.token_to_id(BOS_TOKEN), tok.token_to_id(EOS_TOKEN)
|
| 290 |
+
tok.post_processor = processors.TemplateProcessing(
|
| 291 |
+
single=f"{BOS_TOKEN} $A {EOS_TOKEN}",
|
| 292 |
+
pair=f"{BOS_TOKEN} $A {EOS_TOKEN} $B:1 {EOS_TOKEN}:1",
|
| 293 |
+
special_tokens=[(BOS_TOKEN, bos_id), (EOS_TOKEN, eos_id)],
|
| 294 |
+
)
|
| 295 |
+
tok.save(str(TOKENIZER_DIR / "tokenizer.json"))
|
| 296 |
+
fast = PreTrainedTokenizerFast(
|
| 297 |
+
tokenizer_file=str(TOKENIZER_DIR / "tokenizer.json"),
|
| 298 |
+
bos_token=BOS_TOKEN, eos_token=EOS_TOKEN, unk_token=UNK_TOKEN, pad_token=PAD_TOKEN,
|
| 299 |
+
)
|
| 300 |
+
fast.save_pretrained(str(TOKENIZER_DIR))
|
| 301 |
+
|
| 302 |
+
def tokenizer_training_iterator() -> Iterator[str]:
|
| 303 |
+
for src in DATA_SOURCES:
|
| 304 |
+
yield from stream_texts_from_source(src, 0, TOKENIZER_SAMPLE_DOCS_PER_SOURCE, TOKENIZER_CHAR_LIMIT)
|
| 305 |
+
|
| 306 |
+
def train_or_load_tokenizer() -> PreTrainedTokenizerFast:
|
| 307 |
+
TOKENIZER_DIR.mkdir(parents=True, exist_ok=True)
|
| 308 |
+
if not tokenizer_ready():
|
| 309 |
+
if is_distributed():
|
| 310 |
+
if is_main():
|
| 311 |
+
print("Entraînement tokenizer 32k…")
|
| 312 |
+
train_tokenizer_once()
|
| 313 |
+
dist.barrier()
|
| 314 |
+
else:
|
| 315 |
+
print("Entraînement tokenizer 32k…")
|
| 316 |
+
train_tokenizer_once()
|
| 317 |
+
return PreTrainedTokenizerFast.from_pretrained(str(TOKENIZER_DIR))
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
# ╔══════════════════════════════════════════════════════════════════════════════╗
|
| 321 |
+
# ║ MODÈLE + QLORA + OPTIMIZER + CHECKPOINT + EVAL (inchangés) ║
|
| 322 |
+
# ╚══════════════════════════════════════════════════════════════════════════════╝
|
| 323 |
+
|
| 324 |
+
# (Tout le reste du code est identique à la v2 que je t’ai donnée précédemment)
|
| 325 |
+
# Je le garde complet pour que tu puisses copier-coller directement.
|
| 326 |
+
|
| 327 |
+
@dataclass
|
| 328 |
+
class GPTConfig:
|
| 329 |
+
vocab_size: int = VOCAB_SIZE
|
| 330 |
+
block_size: int = BLOCK_SIZE
|
| 331 |
+
d_model: int = D_MODEL
|
| 332 |
+
n_heads: int = N_HEADS
|
| 333 |
+
n_layers: int = N_LAYERS
|
| 334 |
+
d_ff: int = D_FF
|
| 335 |
+
dropout: float = DROPOUT
|
| 336 |
+
use_checkpointing: bool = USE_CHECKPOINTING
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
class RMSNorm(nn.Module):
|
| 340 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 341 |
+
super().__init__()
|
| 342 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 343 |
+
self.eps = eps
|
| 344 |
+
def forward(self, x):
|
| 345 |
+
return self.weight * x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
class RotaryEmbedding(nn.Module):
|
| 349 |
+
def __init__(self, dim: int, base: int = 10_000, max_seq: int = 4_096):
|
| 350 |
+
super().__init__()
|
| 351 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
| 352 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 353 |
+
t = torch.arange(max_seq).float()
|
| 354 |
+
freqs = torch.outer(t, inv_freq)
|
| 355 |
+
self.register_buffer("cos_cache", torch.repeat_interleave(freqs.cos(), 2, dim=-1), persistent=False)
|
| 356 |
+
self.register_buffer("sin_cache", torch.repeat_interleave(freqs.sin(), 2, dim=-1), persistent=False)
|
| 357 |
+
def forward(self, seq_len: int, dtype: torch.dtype):
|
| 358 |
+
return self.cos_cache[:seq_len].to(dtype), self.sin_cache[:seq_len].to(dtype)
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def rotate_half(x):
|
| 362 |
+
x1, x2 = x[..., ::2], x[..., 1::2]
|
| 363 |
+
return torch.stack((-x2, x1), dim=-1).flatten(-2)
|
| 364 |
+
|
| 365 |
+
def apply_rope(x, cos, sin):
|
| 366 |
+
return x * cos.unsqueeze(0).unsqueeze(0) + rotate_half(x) * sin.unsqueeze(0).unsqueeze(0)
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
class CausalSelfAttention(nn.Module):
|
| 370 |
+
def __init__(self, cfg: GPTConfig):
|
| 371 |
+
super().__init__()
|
| 372 |
+
assert cfg.d_model % cfg.n_heads == 0
|
| 373 |
+
self.n_heads = cfg.n_heads
|
| 374 |
+
self.head_dim = cfg.d_model // cfg.n_heads
|
| 375 |
+
self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
|
| 376 |
+
self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
|
| 377 |
+
self.dropout_p = cfg.dropout
|
| 378 |
+
self.rope = RotaryEmbedding(self.head_dim)
|
| 379 |
+
|
| 380 |
+
def forward(self, x):
|
| 381 |
+
b, t, c = x.shape
|
| 382 |
+
q, k, v = self.qkv(x).split(c, dim=-1)
|
| 383 |
+
q = q.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
|
| 384 |
+
k = k.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
|
| 385 |
+
v = v.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
|
| 386 |
+
cos, sin = self.rope(t, x.dtype)
|
| 387 |
+
q, k = apply_rope(q, cos, sin), apply_rope(k, cos, sin)
|
| 388 |
+
|
| 389 |
+
if HAS_FLASH:
|
| 390 |
+
q = q.transpose(1, 2)
|
| 391 |
+
k = k.transpose(1, 2)
|
| 392 |
+
v = v.transpose(1, 2)
|
| 393 |
+
y = flash_attn_func(q, k, v, dropout_p=self.dropout_p if self.training else 0.0, causal=True)
|
| 394 |
+
y = y.reshape(b, t, c)
|
| 395 |
+
else:
|
| 396 |
+
y = F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout_p if self.training else 0.0, is_causal=True)
|
| 397 |
+
y = y.transpose(1, 2).contiguous().view(b, t, c)
|
| 398 |
+
return self.proj(y)
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
class SwiGLU(nn.Module):
|
| 402 |
+
def __init__(self, cfg: GPTConfig):
|
| 403 |
+
super().__init__()
|
| 404 |
+
self.w1 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
|
| 405 |
+
self.w2 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
|
| 406 |
+
self.w3 = nn.Linear(cfg.d_ff, cfg.d_model, bias=False)
|
| 407 |
+
def forward(self, x):
|
| 408 |
+
return self.w3(F.silu(self.w1(x)) * self.w2(x))
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
class Block(nn.Module):
|
| 412 |
+
def __init__(self, cfg: GPTConfig):
|
| 413 |
+
super().__init__()
|
| 414 |
+
self.ln1 = RMSNorm(cfg.d_model)
|
| 415 |
+
self.attn = CausalSelfAttention(cfg)
|
| 416 |
+
self.ln2 = RMSNorm(cfg.d_model)
|
| 417 |
+
self.ff = SwiGLU(cfg)
|
| 418 |
+
def forward(self, x):
|
| 419 |
+
x = x + self.attn(self.ln1(x))
|
| 420 |
+
x = x + self.ff(self.ln2(x))
|
| 421 |
+
return x
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
class GPT(nn.Module):
|
| 425 |
+
def __init__(self, cfg: GPTConfig):
|
| 426 |
+
super().__init__()
|
| 427 |
+
self.cfg = cfg
|
| 428 |
+
self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
|
| 429 |
+
self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)])
|
| 430 |
+
self.ln_f = RMSNorm(cfg.d_model)
|
| 431 |
+
self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
|
| 432 |
+
self.lm_head.weight = self.tok_emb.weight
|
| 433 |
+
self.apply(self._init_weights)
|
| 434 |
+
|
| 435 |
+
@staticmethod
|
| 436 |
+
def _init_weights(m):
|
| 437 |
+
if isinstance(m, (nn.Linear, nn.Embedding)):
|
| 438 |
+
nn.init.normal_(m.weight, 0.0, 0.02)
|
| 439 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 440 |
+
nn.init.zeros_(m.bias)
|
| 441 |
+
|
| 442 |
+
def forward(self, input_ids, labels=None):
|
| 443 |
+
x = self.tok_emb(input_ids)
|
| 444 |
+
for block in self.blocks:
|
| 445 |
+
if self.cfg.use_checkpointing and self.training:
|
| 446 |
+
x = torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False)
|
| 447 |
+
else:
|
| 448 |
+
x = block(x)
|
| 449 |
+
logits = self.lm_head(self.ln_f(x))
|
| 450 |
+
loss = None
|
| 451 |
+
if labels is not None:
|
| 452 |
+
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), labels.reshape(-1), ignore_index=-100)
|
| 453 |
+
return logits, loss
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
class LoRALinear(nn.Module):
|
| 457 |
+
def __init__(self, base_layer: nn.Linear, r: int = LORA_R, alpha: int = LORA_ALPHA, dropout: float = LORA_DROPOUT):
|
| 458 |
+
super().__init__()
|
| 459 |
+
self.base = base_layer
|
| 460 |
+
self.r = r
|
| 461 |
+
self.scale = alpha / r
|
| 462 |
+
in_f, out_f = base_layer.in_features, base_layer.out_features
|
| 463 |
+
try:
|
| 464 |
+
dev = next(base_layer.parameters()).device
|
| 465 |
+
except StopIteration:
|
| 466 |
+
dev = torch.device("cpu")
|
| 467 |
+
self.lora_A = nn.Linear(in_f, r, bias=False, device=dev)
|
| 468 |
+
self.lora_B = nn.Linear(r, out_f, bias=False, device=dev)
|
| 469 |
+
self.drop = nn.Dropout(dropout)
|
| 470 |
+
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
|
| 471 |
+
nn.init.zeros_(self.lora_B.weight)
|
| 472 |
+
for p in self.base.parameters():
|
| 473 |
+
p.requires_grad = False
|
| 474 |
+
|
| 475 |
+
def forward(self, x):
|
| 476 |
+
return self.base(x) + self.lora_B(self.lora_A(self.drop(x))) * self.scale
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
def apply_qlora(model: GPT, device: torch.device) -> GPT:
|
| 480 |
+
if not USE_QLORA:
|
| 481 |
+
return model
|
| 482 |
+
replaced = 0
|
| 483 |
+
targets = []
|
| 484 |
+
for name, module in model.named_modules():
|
| 485 |
+
parts = name.split(".")
|
| 486 |
+
if parts[-1] in LORA_TARGET_MODULES and isinstance(module, nn.Linear):
|
| 487 |
+
targets.append((name, module))
|
| 488 |
+
for name, module in targets:
|
| 489 |
+
parts = name.split(".")
|
| 490 |
+
parent = model
|
| 491 |
+
for part in parts[:-1]:
|
| 492 |
+
parent = getattr(parent, part)
|
| 493 |
+
lora_layer = LoRALinear(module)
|
| 494 |
+
setattr(parent, parts[-1], lora_layer)
|
| 495 |
+
replaced += 1
|
| 496 |
+
if is_main():
|
| 497 |
+
print(f"QLoRA : {replaced} couches remplacées (device={device}, NF4={HAS_BNB})")
|
| 498 |
+
return model
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
def freeze_base_weights(model: GPT) -> None:
|
| 502 |
+
for name, p in model.named_parameters():
|
| 503 |
+
p.requires_grad = ("lora_A" in name or "lora_B" in name)
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
def build_optimizer(model: nn.Module) -> torch.optim.Optimizer:
|
| 507 |
+
decay, no_decay = [], []
|
| 508 |
+
for name, p in unwrap_model(model).named_parameters():
|
| 509 |
+
if not p.requires_grad: continue
|
| 510 |
+
(decay if p.ndim >= 2 and "weight" in name else no_decay).append(p)
|
| 511 |
+
groups = [
|
| 512 |
+
{"params": decay, "weight_decay": WEIGHT_DECAY},
|
| 513 |
+
{"params": no_decay, "weight_decay": 0.0},
|
| 514 |
+
]
|
| 515 |
+
if HAS_BNB:
|
| 516 |
+
return bnb.optim.PagedAdamW8bit(groups, lr=LEARNING_RATE, betas=(0.9, 0.95), eps=1e-8)
|
| 517 |
+
return torch.optim.AdamW(groups, lr=LEARNING_RATE, betas=(0.9, 0.95), eps=1e-8, fused=torch.cuda.is_available())
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
def cosine_lr(step: int, total_steps: int) -> float:
|
| 521 |
+
if step < WARMUP_STEPS:
|
| 522 |
+
return LEARNING_RATE * step / max(1, WARMUP_STEPS)
|
| 523 |
+
p = min(1.0, (step - WARMUP_STEPS) / max(1, total_steps - WARMUP_STEPS))
|
| 524 |
+
return MIN_LR + 0.5 * (LEARNING_RATE - MIN_LR) * (1 + math.cos(math.pi * p))
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
def save_checkpoint(model, optimizer, epoch, step, best_loss, path):
|
| 528 |
+
raw = unwrap_model(model)
|
| 529 |
+
torch.save({
|
| 530 |
+
"model": normalize_state_dict_keys(raw.state_dict()),
|
| 531 |
+
"optimizer": optimizer.state_dict(),
|
| 532 |
+
"epoch": epoch, "step": step, "best_loss": best_loss,
|
| 533 |
+
"config": asdict(raw.cfg),
|
| 534 |
+
}, path)
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
def maybe_load_base_checkpoint(model, device):
|
| 538 |
+
if BASE_CHECKPOINT is None or not Path(BASE_CHECKPOINT).exists():
|
| 539 |
+
return
|
| 540 |
+
ckpt = torch.load(BASE_CHECKPOINT, map_location=device)
|
| 541 |
+
unwrap_model(model).load_state_dict(normalize_state_dict_keys(ckpt["model"]), strict=False)
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
def load_resume_checkpoint(model, optimizer, path, device):
|
| 545 |
+
ckpt = torch.load(path, map_location=device)
|
| 546 |
+
unwrap_model(model).load_state_dict(normalize_state_dict_keys(ckpt["model"]), strict=True)
|
| 547 |
+
try:
|
| 548 |
+
optimizer.load_state_dict(ckpt["optimizer"])
|
| 549 |
+
except Exception as e:
|
| 550 |
+
print(f"[warn] Optimizer state non repris: {e}")
|
| 551 |
+
return int(ckpt.get("epoch", 0)), int(ckpt.get("step", 0)), float(ckpt.get("best_loss", 1e9))
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
@torch.no_grad()
|
| 555 |
+
def evaluate(model, loader, device, max_batches=200) -> float:
|
| 556 |
+
model.eval()
|
| 557 |
+
losses = []
|
| 558 |
+
for i, batch in enumerate(loader):
|
| 559 |
+
if i >= max_batches: break
|
| 560 |
+
inp = batch["input_ids"].to(device, non_blocking=True)
|
| 561 |
+
lbl = batch["labels"].to(device, non_blocking=True)
|
| 562 |
+
with autocast_context(device):
|
| 563 |
+
_, loss = model(inp, lbl)
|
| 564 |
+
losses.append(loss.item())
|
| 565 |
+
model.train()
|
| 566 |
+
return sum(losses) / max(1, len(losses))
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
def make_loader(dataset, batch_size, num_workers, is_cuda):
|
| 570 |
+
kwargs = dict(batch_size=batch_size, num_workers=num_workers, pin_memory=is_cuda)
|
| 571 |
+
if num_workers > 0:
|
| 572 |
+
kwargs["persistent_workers"] = True
|
| 573 |
+
kwargs["prefetch_factor"] = PREFETCH_FACTOR
|
| 574 |
+
return torch.utils.data.DataLoader(dataset, **kwargs)
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
class PackedTextList(torch.utils.data.IterableDataset):
|
| 578 |
+
def __init__(self, texts, tokenizer, block_size, epoch_seed=0):
|
| 579 |
+
super().__init__()
|
| 580 |
+
self.texts = texts
|
| 581 |
+
self.tokenizer = tokenizer
|
| 582 |
+
self.block_size = block_size
|
| 583 |
+
self.epoch_seed = epoch_seed
|
| 584 |
+
|
| 585 |
+
def __iter__(self):
|
| 586 |
+
worker = torch.utils.data.get_worker_info()
|
| 587 |
+
rank, ws = get_rank(), get_world_size()
|
| 588 |
+
if worker is None:
|
| 589 |
+
shard_mod, shard_id = ws, rank
|
| 590 |
+
else:
|
| 591 |
+
shard_mod = worker.num_workers * ws
|
| 592 |
+
shard_id = rank * worker.num_workers + worker.id
|
| 593 |
+
|
| 594 |
+
rng = random.Random(self.epoch_seed)
|
| 595 |
+
indices = list(range(len(self.texts)))
|
| 596 |
+
rng.shuffle(indices)
|
| 597 |
+
|
| 598 |
+
bos, eos = self.tokenizer.bos_token_id, self.tokenizer.eos_token_id
|
| 599 |
+
buf: list[int] = []
|
| 600 |
+
|
| 601 |
+
for li, ti in enumerate(indices):
|
| 602 |
+
if li % shard_mod != shard_id:
|
| 603 |
+
continue
|
| 604 |
+
ids = self.tokenizer.encode(self.texts[ti], add_special_tokens=False)
|
| 605 |
+
if not ids: continue
|
| 606 |
+
buf.extend([bos] + ids + [eos])
|
| 607 |
+
while len(buf) >= self.block_size + 1:
|
| 608 |
+
chunk = buf[:self.block_size + 1]
|
| 609 |
+
buf = buf[self.block_size + 1:]
|
| 610 |
+
yield {
|
| 611 |
+
"input_ids": torch.tensor(chunk[:-1], dtype=torch.long),
|
| 612 |
+
"labels": torch.tensor(chunk[1:], dtype=torch.long),
|
| 613 |
+
}
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
# ╔══════════════════════════════════════════════════════════════════════════════╗
|
| 617 |
+
# ║ MAIN ║
|
| 618 |
+
# ╚══════════════════════════════════════════════════════════════════════════════╝
|
| 619 |
+
|
| 620 |
+
def main() -> None:
|
| 621 |
+
ddp_device = init_distributed()
|
| 622 |
+
set_seed(SEED + get_rank())
|
| 623 |
+
device = get_device(ddp_device)
|
| 624 |
+
is_cuda = device.type == "cuda"
|
| 625 |
+
|
| 626 |
+
cuda_idx = None
|
| 627 |
+
if is_cuda:
|
| 628 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 629 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 630 |
+
torch.set_float32_matmul_precision("high")
|
| 631 |
+
cuda_idx = current_cuda_index(device)
|
| 632 |
+
_, total = torch.cuda.mem_get_info(cuda_idx)
|
| 633 |
+
vram_fraction = min(TARGET_VRAM_GIB * (1024**3) / total, 0.999)
|
| 634 |
+
torch.cuda.memory.set_per_process_memory_fraction(vram_fraction, device=cuda_idx)
|
| 635 |
+
|
| 636 |
+
if is_main():
|
| 637 |
+
print("=" * 72)
|
| 638 |
+
print(" GPT ~1B | H100 80 Go | QLoRA + BF16 + TF32 | MAX 100 GB (public)")
|
| 639 |
+
print("=" * 72)
|
| 640 |
+
print(f"Device : {device} | World: {get_world_size()} GPU(s)")
|
| 641 |
+
print(f"Flash-2 : {HAS_FLASH} | BNB 4-bit: {HAS_BNB} | QLoRA: {USE_QLORA}")
|
| 642 |
+
print(f"Grad ckpt: {USE_CHECKPOINTING} | Compile: {USE_COMPILE} ({COMPILE_MODE})")
|
| 643 |
+
if is_cuda:
|
| 644 |
+
free, total = torch.cuda.mem_get_info(cuda_idx)
|
| 645 |
+
print(f"GPU : {torch.cuda.get_device_name(cuda_idx)}")
|
| 646 |
+
print(f"VRAM : {total/1024**3:.1f} GiB | libre: {free/1024**3:.1f} GiB")
|
| 647 |
+
|
| 648 |
+
tokenizer = train_or_load_tokenizer()
|
| 649 |
+
cfg = GPTConfig(vocab_size=len(tokenizer))
|
| 650 |
+
|
| 651 |
+
if is_main():
|
| 652 |
+
CONFIG_FILE.write_text(json.dumps(asdict(cfg), indent=2, ensure_ascii=False), encoding="utf-8")
|
| 653 |
+
|
| 654 |
+
model = GPT(cfg).to(device)
|
| 655 |
+
|
| 656 |
+
if USE_QLORA:
|
| 657 |
+
model = apply_qlora(model, device)
|
| 658 |
+
freeze_base_weights(model)
|
| 659 |
+
|
| 660 |
+
maybe_load_base_checkpoint(model, device)
|
| 661 |
+
|
| 662 |
+
if USE_COMPILE and not USE_CHECKPOINTING and hasattr(torch, "compile"):
|
| 663 |
+
try:
|
| 664 |
+
model = torch.compile(model, mode=COMPILE_MODE)
|
| 665 |
+
if is_main():
|
| 666 |
+
print(f"torch.compile activé ({COMPILE_MODE})")
|
| 667 |
+
except Exception as e:
|
| 668 |
+
if is_main():
|
| 669 |
+
print(f"[warn] torch.compile échoué ({e}) — poursuite sans compile")
|
| 670 |
+
|
| 671 |
+
if is_distributed():
|
| 672 |
+
model = DDP(model, device_ids=[device.index])
|
| 673 |
+
|
| 674 |
+
optimizer = build_optimizer(model)
|
| 675 |
+
|
| 676 |
+
eval_texts = build_eval_texts()
|
| 677 |
+
eval_ds = PackedTextList(eval_texts, tokenizer, cfg.block_size, SEED + 999)
|
| 678 |
+
eval_loader = make_loader(eval_ds, BATCH_SIZE, EVAL_NUM_WORKERS, is_cuda)
|
| 679 |
+
|
| 680 |
+
init_texts = build_epoch_train_texts(0)
|
| 681 |
+
steps_per_epoch = max(1, len(init_texts) // BATCH_SIZE)
|
| 682 |
+
total_steps_est = steps_per_epoch * NUM_EPOCHS
|
| 683 |
+
|
| 684 |
+
start_epoch, start_step, best_eval = 0, 0, 1e9
|
| 685 |
+
if STATE_FILE.exists():
|
| 686 |
+
try:
|
| 687 |
+
if is_main(): print(f"Reprise depuis {STATE_FILE}")
|
| 688 |
+
start_epoch, start_step, best_eval = load_resume_checkpoint(model, optimizer, STATE_FILE, device)
|
| 689 |
+
except Exception as e:
|
| 690 |
+
if is_main():
|
| 691 |
+
bad = STATE_FILE.with_suffix(".corrupt.pt")
|
| 692 |
+
print(f"[warn] Checkpoint illisible: {e}")
|
| 693 |
+
try: STATE_FILE.rename(bad)
|
| 694 |
+
except: pass
|
| 695 |
+
start_epoch, start_step, best_eval = 0, 0, 1e9
|
| 696 |
+
|
| 697 |
+
if is_main():
|
| 698 |
+
raw = unwrap_model(model)
|
| 699 |
+
n_total = count_parameters(raw, False)
|
| 700 |
+
n_train = count_parameters(raw, True)
|
| 701 |
+
print(f"Paramètres totaux : {n_total/1e9:.3f}B")
|
| 702 |
+
print(f"Paramètres entraînés : {n_train/1e6:.1f}M ({100*n_train/max(1,n_total):.2f}%)")
|
| 703 |
+
print(f"Batch size : {BATCH_SIZE} | Grad accum: {GRAD_ACCUM_STEPS} | Effective: {BATCH_SIZE*GRAD_ACCUM_STEPS}")
|
| 704 |
+
print(f"Steps estimés: {total_steps_est} | Eval texts: {len(eval_texts)}")
|
| 705 |
+
print("\n── Conseil VRAM ────────────────────────────────────────────────")
|
| 706 |
+
print(" Surveille max_reserved à step 50.")
|
| 707 |
+
print(" Si OOM → baisse BATCH_SIZE ou active USE_CHECKPOINTING=True")
|
| 708 |
+
print("────────────────────────────────────────────────────────────────")
|
| 709 |
+
|
| 710 |
+
model.train()
|
| 711 |
+
optimizer.zero_grad(set_to_none=True)
|
| 712 |
+
|
| 713 |
+
global_step = start_step
|
| 714 |
+
t0 = time.time()
|
| 715 |
+
log_loss_sum = 0.0
|
| 716 |
+
log_loss_count = 0
|
| 717 |
+
tokens_since_log = 0
|
| 718 |
+
last_log = time.time()
|
| 719 |
+
|
| 720 |
+
if is_cuda:
|
| 721 |
+
torch.cuda.reset_peak_memory_stats(cuda_idx)
|
| 722 |
+
|
| 723 |
+
for epoch in range(start_epoch, NUM_EPOCHS):
|
| 724 |
+
if is_main():
|
| 725 |
+
print(f"\n{'='*20} Epoch {epoch+1}/{NUM_EPOCHS} {'='*20}")
|
| 726 |
+
|
| 727 |
+
train_texts = build_epoch_train_texts(epoch)
|
| 728 |
+
train_ds = PackedTextList(train_texts, tokenizer, cfg.block_size, SEED + epoch)
|
| 729 |
+
train_loader = make_loader(train_ds, BATCH_SIZE, TRAIN_NUM_WORKERS, is_cuda)
|
| 730 |
+
|
| 731 |
+
for micro_step, batch in enumerate(train_loader):
|
| 732 |
+
inp = batch["input_ids"].to(device, non_blocking=True)
|
| 733 |
+
lbl = batch["labels"].to(device, non_blocking=True)
|
| 734 |
+
|
| 735 |
+
with autocast_context(device):
|
| 736 |
+
_, loss = model(inp, lbl)
|
| 737 |
+
|
| 738 |
+
(loss / GRAD_ACCUM_STEPS).backward()
|
| 739 |
+
|
| 740 |
+
log_loss_sum += loss.item()
|
| 741 |
+
log_loss_count += 1
|
| 742 |
+
tokens_since_log += inp.numel()
|
| 743 |
+
|
| 744 |
+
if (micro_step + 1) % GRAD_ACCUM_STEPS != 0:
|
| 745 |
+
continue
|
| 746 |
+
|
| 747 |
+
lr = cosine_lr(global_step, total_steps_est)
|
| 748 |
+
for group in optimizer.param_groups:
|
| 749 |
+
group["lr"] = lr
|
| 750 |
+
|
| 751 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
|
| 752 |
+
optimizer.step()
|
| 753 |
+
optimizer.zero_grad(set_to_none=True)
|
| 754 |
+
global_step += 1
|
| 755 |
+
|
| 756 |
+
if global_step % 50 == 0 and is_main():
|
| 757 |
+
now = time.time()
|
| 758 |
+
elapsed = max(1e-6, now - last_log)
|
| 759 |
+
tok_s = tokens_since_log / elapsed
|
| 760 |
+
avg_loss = log_loss_sum / max(1, log_loss_count)
|
| 761 |
+
print(f"ep {epoch+1}/{NUM_EPOCHS} | step={global_step:5d} | loss={avg_loss:.4f} | lr={lr:.2e} | {tok_s:,.0f} tok/s")
|
| 762 |
+
if is_cuda:
|
| 763 |
+
alloc = torch.cuda.memory_allocated(cuda_idx) / 1024**3
|
| 764 |
+
reserved = torch.cuda.memory_reserved(cuda_idx) / 1024**3
|
| 765 |
+
max_alloc = torch.cuda.max_memory_allocated(cuda_idx) / 1024**3
|
| 766 |
+
max_res = torch.cuda.max_memory_reserved(cuda_idx) / 1024**3
|
| 767 |
+
print(f"GPU mem | alloc={alloc:.2f} | reserved={reserved:.2f} | max_reserved={max_res:.2f} GiB")
|
| 768 |
+
last_log = now
|
| 769 |
+
tokens_since_log = 0
|
| 770 |
+
log_loss_sum = 0.0
|
| 771 |
+
log_loss_count = 0
|
| 772 |
+
|
| 773 |
+
if global_step % EVAL_EVERY == 0 and is_main():
|
| 774 |
+
val_loss = evaluate(model, eval_loader, device)
|
| 775 |
+
print(f"[eval] step {global_step:5d} | val_loss={val_loss:.4f}")
|
| 776 |
+
if val_loss < best_eval:
|
| 777 |
+
best_eval = val_loss
|
| 778 |
+
save_checkpoint(model, optimizer, epoch, global_step, best_eval, BEST_MODEL_FILE)
|
| 779 |
+
print(f"✓ Meilleur modèle → {BEST_MODEL_FILE}")
|
| 780 |
+
|
| 781 |
+
if global_step % SAVE_EVERY == 0 and is_main():
|
| 782 |
+
save_checkpoint(model, optimizer, epoch, global_step, best_eval, STATE_FILE)
|
| 783 |
+
save_checkpoint(model, optimizer, epoch, global_step, best_eval, MODEL_FILE)
|
| 784 |
+
print(f"✓ Checkpoint → {MODEL_FILE}")
|
| 785 |
+
|
| 786 |
+
if is_main():
|
| 787 |
+
save_checkpoint(model, optimizer, epoch + 1, global_step, best_eval, STATE_FILE)
|
| 788 |
+
ckpt = OUT_DIR / f"model_epoch_{epoch+1:02d}.pt"
|
| 789 |
+
save_checkpoint(model, optimizer, epoch + 1, global_step, best_eval, ckpt)
|
| 790 |
+
print(f"✓ Fin epoch {epoch+1}/{NUM_EPOCHS} → {ckpt}")
|
| 791 |
+
|
| 792 |
+
if is_main():
|
| 793 |
+
save_checkpoint(model, optimizer, NUM_EPOCHS, global_step, best_eval, MODEL_FILE)
|
| 794 |
+
save_checkpoint(model, optimizer, NUM_EPOCHS, global_step, best_eval, STATE_FILE)
|
| 795 |
+
total_min = (time.time() - t0) / 60
|
| 796 |
+
print(f"\nModèle final → {MODEL_FILE}")
|
| 797 |
+
print(f"Meilleur modèle → {BEST_MODEL_FILE}")
|
| 798 |
+
print(f"Temps total : {total_min:.1f} min | Steps: {global_step}")
|
| 799 |
+
|
| 800 |
+
if is_distributed():
|
| 801 |
+
dist.destroy_process_group()
|
| 802 |
+
|
| 803 |
+
|
| 804 |
+
if __name__ == "__main__":
|
| 805 |
+
main()
|
upload.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
import time
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
from huggingface_hub import HfApi
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
IGNORE_NAMES = {
|
| 11 |
+
".git",
|
| 12 |
+
"__pycache__",
|
| 13 |
+
".ipynb_checkpoints",
|
| 14 |
+
".DS_Store",
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
IGNORE_SUFFIXES = {
|
| 18 |
+
".tmp",
|
| 19 |
+
".lock",
|
| 20 |
+
".swp",
|
| 21 |
+
".swx",
|
| 22 |
+
".part",
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def sanitize_repo_name(name: str) -> str:
|
| 27 |
+
name = name.strip().replace(" ", "-")
|
| 28 |
+
name = re.sub(r"[^A-Za-z0-9._-]+", "-", name)
|
| 29 |
+
name = re.sub(r"-{2,}", "-", name).strip("-")
|
| 30 |
+
return name[:96] or "model"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def should_ignore(path: Path) -> bool:
|
| 34 |
+
if any(part in IGNORE_NAMES for part in path.parts):
|
| 35 |
+
return True
|
| 36 |
+
if path.suffix.lower() in IGNORE_SUFFIXES:
|
| 37 |
+
return True
|
| 38 |
+
return False
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def folder_stats(folder: Path):
|
| 42 |
+
file_count = 0
|
| 43 |
+
total_size = 0
|
| 44 |
+
|
| 45 |
+
for p in folder.rglob("*"):
|
| 46 |
+
if not p.is_file():
|
| 47 |
+
continue
|
| 48 |
+
if should_ignore(p):
|
| 49 |
+
continue
|
| 50 |
+
|
| 51 |
+
try:
|
| 52 |
+
total_size += p.stat().st_size
|
| 53 |
+
file_count += 1
|
| 54 |
+
except FileNotFoundError:
|
| 55 |
+
continue
|
| 56 |
+
|
| 57 |
+
return file_count, total_size
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def format_bytes(num_bytes: int) -> str:
|
| 61 |
+
value = float(num_bytes)
|
| 62 |
+
units = ["B", "KB", "MB", "GB", "TB"]
|
| 63 |
+
for unit in units:
|
| 64 |
+
if value < 1024 or unit == units[-1]:
|
| 65 |
+
return f"{value:.2f} {unit}"
|
| 66 |
+
value /= 1024
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def format_speed(bytes_per_sec: float) -> str:
|
| 70 |
+
return f"{format_bytes(bytes_per_sec)}/s"
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def main():
|
| 74 |
+
parser = argparse.ArgumentParser(
|
| 75 |
+
description="Upload one-shot d'un dossier modèle vers Hugging Face Hub."
|
| 76 |
+
)
|
| 77 |
+
parser.add_argument(
|
| 78 |
+
"--model_dir",
|
| 79 |
+
type=str,
|
| 80 |
+
required=True,
|
| 81 |
+
help="Chemin du dossier modèle à uploader",
|
| 82 |
+
)
|
| 83 |
+
parser.add_argument(
|
| 84 |
+
"--namespace",
|
| 85 |
+
type=str,
|
| 86 |
+
default="Medyassino",
|
| 87 |
+
help="Namespace ou username Hugging Face",
|
| 88 |
+
)
|
| 89 |
+
parser.add_argument(
|
| 90 |
+
"--repo_name",
|
| 91 |
+
type=str,
|
| 92 |
+
default=None,
|
| 93 |
+
help="Nom du repo distant. Par défaut: nom du dossier modèle",
|
| 94 |
+
)
|
| 95 |
+
parser.add_argument(
|
| 96 |
+
"--private",
|
| 97 |
+
action="store_true",
|
| 98 |
+
help="Créer le repo en privé. Sans ce flag, il sera public",
|
| 99 |
+
)
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--large",
|
| 102 |
+
action="store_true",
|
| 103 |
+
help="Utiliser upload_large_folder pour les gros dossiers",
|
| 104 |
+
)
|
| 105 |
+
parser.add_argument(
|
| 106 |
+
"--token",
|
| 107 |
+
type=str,
|
| 108 |
+
default=None,
|
| 109 |
+
help="Token HF. Sinon utilise HF_TOKEN ou le login local",
|
| 110 |
+
)
|
| 111 |
+
parser.add_argument(
|
| 112 |
+
"--commit_message",
|
| 113 |
+
type=str,
|
| 114 |
+
default="Upload model from local folder",
|
| 115 |
+
help="Message de commit pour upload_folder",
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
args = parser.parse_args()
|
| 119 |
+
|
| 120 |
+
model_dir = Path(args.model_dir).expanduser().resolve()
|
| 121 |
+
if not model_dir.exists() or not model_dir.is_dir():
|
| 122 |
+
raise RuntimeError(f"Dossier modèle introuvable: {model_dir}")
|
| 123 |
+
|
| 124 |
+
repo_name = sanitize_repo_name(args.repo_name or model_dir.name)
|
| 125 |
+
repo_id = f"{args.namespace}/{repo_name}"
|
| 126 |
+
|
| 127 |
+
token = args.token or os.environ.get("HF_TOKEN")
|
| 128 |
+
api = HfApi(token=token)
|
| 129 |
+
|
| 130 |
+
try:
|
| 131 |
+
who = api.whoami()
|
| 132 |
+
except Exception as e:
|
| 133 |
+
raise RuntimeError(
|
| 134 |
+
"Authentification Hugging Face impossible. "
|
| 135 |
+
"Fais `hf auth login`, ou passe `--token`, ou définis HF_TOKEN."
|
| 136 |
+
) from e
|
| 137 |
+
|
| 138 |
+
file_count, total_size = folder_stats(model_dir)
|
| 139 |
+
|
| 140 |
+
print(f"Authentifié comme: {who.get('name') or who.get('fullname') or who}")
|
| 141 |
+
print(f"Upload de: {model_dir}")
|
| 142 |
+
print(f"Repo cible: {repo_id}")
|
| 143 |
+
print(f"Visibilité : {'privé' if args.private else 'public'}")
|
| 144 |
+
print(f"Fichiers détectés: {file_count}")
|
| 145 |
+
print(f"Taille totale : {format_bytes(total_size)}")
|
| 146 |
+
|
| 147 |
+
api.create_repo(
|
| 148 |
+
repo_id=repo_id,
|
| 149 |
+
repo_type="model",
|
| 150 |
+
private=args.private,
|
| 151 |
+
exist_ok=True,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
start_time = time.perf_counter()
|
| 155 |
+
|
| 156 |
+
if args.large:
|
| 157 |
+
api.upload_large_folder(
|
| 158 |
+
repo_id=repo_id,
|
| 159 |
+
repo_type="model",
|
| 160 |
+
folder_path=str(model_dir),
|
| 161 |
+
)
|
| 162 |
+
else:
|
| 163 |
+
api.upload_folder(
|
| 164 |
+
repo_id=repo_id,
|
| 165 |
+
repo_type="model",
|
| 166 |
+
folder_path=str(model_dir),
|
| 167 |
+
commit_message=args.commit_message,
|
| 168 |
+
ignore_patterns=[
|
| 169 |
+
"**/.git/**",
|
| 170 |
+
"**/__pycache__/**",
|
| 171 |
+
"**/.ipynb_checkpoints/**",
|
| 172 |
+
"**/*.tmp",
|
| 173 |
+
"**/*.lock",
|
| 174 |
+
"**/*.swp",
|
| 175 |
+
"**/*.part",
|
| 176 |
+
],
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
elapsed = time.perf_counter() - start_time
|
| 180 |
+
avg_speed = (total_size / elapsed) if elapsed > 0 else 0.0
|
| 181 |
+
|
| 182 |
+
print()
|
| 183 |
+
print(f"Upload OK -> https://huggingface.co/{repo_id}")
|
| 184 |
+
print(f"Durée totale : {elapsed:.2f} s")
|
| 185 |
+
print(f"Vitesse moyenne : {format_speed(avg_speed)}")
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
if __name__ == "__main__":
|
| 189 |
+
main()
|
wikipedia_ar_h100/config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"vocab_size": 32000,
|
| 3 |
+
"block_size": 1024,
|
| 4 |
+
"d_model": 1024,
|
| 5 |
+
"n_heads": 16,
|
| 6 |
+
"n_layers": 24,
|
| 7 |
+
"d_ff": 4096,
|
| 8 |
+
"dropout": 0.0,
|
| 9 |
+
"use_checkpointing": false
|
| 10 |
+
}
|
wikipedia_ar_h100/tokenizer_32k/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
wikipedia_ar_h100/tokenizer_32k/tokenizer_config.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"backend": "tokenizers",
|
| 3 |
+
"bos_token": "<bos>",
|
| 4 |
+
"eos_token": "<eos>",
|
| 5 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 6 |
+
"pad_token": "<pad>",
|
| 7 |
+
"tokenizer_class": "TokenizersBackend",
|
| 8 |
+
"unk_token": "<unk>"
|
| 9 |
+
}
|
wikipedia_ar_h100/train_state.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:09c3c854a6aaa593570f779f0a4c7281db9cc027a3897437dfe3a4cb1075b92e
|
| 3 |
+
size 2322908797
|
wikipedia_ar_h100_agri_30gb/config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"vocab_size": 32000,
|
| 3 |
+
"block_size": 1024,
|
| 4 |
+
"d_model": 1024,
|
| 5 |
+
"n_heads": 16,
|
| 6 |
+
"n_layers": 24,
|
| 7 |
+
"d_ff": 4096,
|
| 8 |
+
"dropout": 0.0,
|
| 9 |
+
"use_checkpointing": false
|
| 10 |
+
}
|
wikipedia_ar_h100_codealpaca/config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"vocab_size": 32000,
|
| 3 |
+
"block_size": 1024,
|
| 4 |
+
"d_model": 1024,
|
| 5 |
+
"n_heads": 16,
|
| 6 |
+
"n_layers": 24,
|
| 7 |
+
"d_ff": 4096,
|
| 8 |
+
"dropout": 0.0,
|
| 9 |
+
"use_checkpointing": false
|
| 10 |
+
}
|
wikipedia_ar_h100_env_fr_ar_77gb/config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"vocab_size": 32000,
|
| 3 |
+
"block_size": 1024,
|
| 4 |
+
"d_model": 1024,
|
| 5 |
+
"n_heads": 16,
|
| 6 |
+
"n_layers": 24,
|
| 7 |
+
"d_ff": 4096,
|
| 8 |
+
"dropout": 0.0,
|
| 9 |
+
"use_checkpointing": false
|
| 10 |
+
}
|
wikipedia_ar_h100_env_fr_ar_77gb/model_epoch_03.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2c04ad704abccb3f01a61fbec3d09442fd9edc3b8d656addc98ba6b25e01ac28
|
| 3 |
+
size 5225864649
|
wikipedia_ar_h100_multicode/config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"vocab_size": 32000,
|
| 3 |
+
"block_size": 1024,
|
| 4 |
+
"d_model": 1024,
|
| 5 |
+
"n_heads": 16,
|
| 6 |
+
"n_layers": 24,
|
| 7 |
+
"d_ff": 4096,
|
| 8 |
+
"dropout": 0.0,
|
| 9 |
+
"use_checkpointing": false
|
| 10 |
+
}
|
wikipedia_ar_h100_multicode/train_state.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8234246d0aa1dbe4ced74d07266bebd381919bd0aada72b8aafa11621f0846bc
|
| 3 |
+
size 5225862591
|
wikipedia_ar_h100_multicode_10x2000/config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"vocab_size": 32000,
|
| 3 |
+
"block_size": 1024,
|
| 4 |
+
"d_model": 1024,
|
| 5 |
+
"n_heads": 16,
|
| 6 |
+
"n_layers": 24,
|
| 7 |
+
"d_ff": 4096,
|
| 8 |
+
"dropout": 0.0,
|
| 9 |
+
"use_checkpointing": false
|
| 10 |
+
}
|
wikipedia_ar_h100_multicode_10x2000/model_round_06.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c7c64f58ee916183c6f04c84c1e9fbd8b92991375ade718fe438385cea36c090
|
| 3 |
+
size 5225864649
|