Medyassino commited on
Commit
b9049d2
·
verified ·
1 Parent(s): 8a68ba4

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. __pycache__/train_aramix_h100_full.cpython-313.pyc +0 -0
  2. aramix_h100/config.json +10 -0
  3. aramix_h100/model.pt +3 -0
  4. aramix_h100/model_best.pt +3 -0
  5. aramix_h100/qa_test_report_simple.json +134 -0
  6. aramix_h100/qa_test_report_simple.txt +80 -0
  7. aramix_h100/tokenizer_32k/tokenizer.json +0 -0
  8. aramix_h100/tokenizer_32k/tokenizer_config.json +9 -0
  9. aramix_h100/train_log.jsonl +44 -0
  10. aramix_h100/train_state.pt +3 -0
  11. donner +0 -0
  12. nlp_1b_h100_maxvram/config.json +10 -0
  13. nlp_1b_h100_maxvram/tokenizer_32k/tokenizer.json +0 -0
  14. nlp_1b_h100_maxvram/tokenizer_32k/tokenizer_config.json +9 -0
  15. nlp_1b_h100_opt/config.json +10 -0
  16. nlp_1b_h100_opt/model.pt +3 -0
  17. nlp_1b_h100_opt/model_best.pt +3 -0
  18. nlp_1b_h100_opt/tokenizer_32k/tokenizer.json +0 -0
  19. nlp_1b_h100_opt/tokenizer_32k/tokenizer_config.json +9 -0
  20. nlp_1b_h100_opt/train_state.pt +3 -0
  21. nlp_1b_wiki_en_fr_ar/config.json +10 -0
  22. nlp_1b_wiki_en_fr_ar/model_best.pt +3 -0
  23. nlp_1b_wiki_en_fr_ar/model_epoch_02.pt +3 -0
  24. nlp_1b_wiki_en_fr_ar/tokenizer_32k/tokenizer.json +0 -0
  25. nlp_1b_wiki_en_fr_ar/tokenizer_32k/tokenizer_config.json +9 -0
  26. simple_qa_test_aramix.py +504 -0
  27. simple_qa_test_aramix_v2.py +472 -0
  28. simple_qa_test_aramix_v3.py +583 -0
  29. simple_qa_test_finished_model (1).py +309 -0
  30. simple_qa_test_finished_model.py +309 -0
  31. test.py +428 -0
  32. top_p +0 -0
  33. train.py +859 -0
  34. train2.py +852 -0
  35. train_aramix_h100_full.py +1055 -0
  36. train_nlp_h100_maxvram_v6.py +1046 -0
  37. train_nlp_h100_maxvram_v7.py +805 -0
  38. upload.py +189 -0
  39. wikipedia_ar_h100/config.json +10 -0
  40. wikipedia_ar_h100/tokenizer_32k/tokenizer.json +0 -0
  41. wikipedia_ar_h100/tokenizer_32k/tokenizer_config.json +9 -0
  42. wikipedia_ar_h100/train_state.pt +3 -0
  43. wikipedia_ar_h100_agri_30gb/config.json +10 -0
  44. wikipedia_ar_h100_codealpaca/config.json +10 -0
  45. wikipedia_ar_h100_env_fr_ar_77gb/config.json +10 -0
  46. wikipedia_ar_h100_env_fr_ar_77gb/model_epoch_03.pt +3 -0
  47. wikipedia_ar_h100_multicode/config.json +10 -0
  48. wikipedia_ar_h100_multicode/train_state.pt +3 -0
  49. wikipedia_ar_h100_multicode_10x2000/config.json +10 -0
  50. wikipedia_ar_h100_multicode_10x2000/model_round_06.pt +3 -0
__pycache__/train_aramix_h100_full.cpython-313.pyc ADDED
Binary file (48.4 kB). View file
 
aramix_h100/config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 32000,
3
+ "block_size": 1024,
4
+ "d_model": 1024,
5
+ "n_heads": 16,
6
+ "n_layers": 24,
7
+ "d_ff": 4096,
8
+ "dropout": 0.0,
9
+ "use_checkpointing": false
10
+ }
aramix_h100/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78d6b75f8a8a079e5ba8b85def0c733dd808a075993fcb8fbe21a07350e0d8dc
3
+ size 5225851307
aramix_h100/model_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b1300c7623453b9d5a03fa6b2142ebe0f840676233ec122678429b8f2b7cd2ce
3
+ size 5225876625
aramix_h100/qa_test_report_simple.json ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "repo_dir": "/workspace/FirstChat/aramix_h100",
3
+ "train_script": "/workspace/FirstChat/train_aramix_h100_full.py",
4
+ "checkpoint": "/workspace/FirstChat/aramix_h100/model_best.pt",
5
+ "config_path": "/workspace/FirstChat/aramix_h100/config.json",
6
+ "tokenizer_dir": "/workspace/FirstChat/aramix_h100/tokenizer_32k",
7
+ "device": "cuda",
8
+ "total_questions": 10,
9
+ "avg_overlap_score": 0.1778,
10
+ "exact_match_rate": 0.0,
11
+ "avg_latency_s": 1.195,
12
+ "avg_words_generated": 28.6,
13
+ "scores_by_category": {
14
+ "Géographie": 0.25,
15
+ "Science": 0.05,
16
+ "Histoire": 0.0,
17
+ "Mathématiques": 0.0,
18
+ "Langage": 0.0,
19
+ "Raisonnement": 1.0,
20
+ "Dialogue": 0.0
21
+ },
22
+ "results": [
23
+ {
24
+ "id": 1,
25
+ "category": "Géographie",
26
+ "question": "Quelle est la capitale de la France ?",
27
+ "answer": "Au total, le personnage d'une société financière (dont le but à une institution) est une commission de la musique. Il a également considéré plusieurs sociétés, notamment les p",
28
+ "reference": "Paris",
29
+ "latency_s": 1.6,
30
+ "tokens_generated_approx": 28,
31
+ "overlap_score": 0.0,
32
+ "exact_match": false
33
+ },
34
+ {
35
+ "id": 2,
36
+ "category": "Géographie",
37
+ "question": "Quel est le plus long fleuve d'Afrique ?",
38
+ "answer": "Les petits chemins de fer français sont un type de classe moyenne de . Le groupe est le plus long, de poids lourds. Historique Le , le groupe est également connu comme secondaire d'une man",
39
+ "reference": "Le Nil",
40
+ "latency_s": 1.25,
41
+ "tokens_generated_approx": 35,
42
+ "overlap_score": 0.5,
43
+ "exact_match": false
44
+ },
45
+ {
46
+ "id": 3,
47
+ "category": "Science",
48
+ "question": "Qu'est-ce que la photosynthèse ?",
49
+ "answer": "Die verlore, die wat gevra het nie beslis nie! Chefsel en sy kragtige tempo van die opwindende akademiese medium. Hierdie photographies is gebaseer, asook hulle nie eens uitges",
50
+ "reference": "Processus par lequel les plantes convertissent la lumière en énergie",
51
+ "latency_s": 1.242,
52
+ "tokens_generated_approx": 28,
53
+ "overlap_score": 0.1,
54
+ "exact_match": false
55
+ },
56
+ {
57
+ "id": 4,
58
+ "category": "Science",
59
+ "question": "Combien d'os compte le corps humain adulte ?",
60
+ "answer": "Die volgende plek is available, sal ons beslis nie opgelewing te koop nie.<eos>",
61
+ "reference": "206",
62
+ "latency_s": 0.518,
63
+ "tokens_generated_approx": 13,
64
+ "overlap_score": 0.0,
65
+ "exact_match": false
66
+ },
67
+ {
68
+ "id": 5,
69
+ "category": "Histoire",
70
+ "question": "En quelle année a eu lieu la Révolution française ?",
71
+ "answer": "Les premiers élèves sont des trois fils de Jacques Salamine et Julien Fernandez. Leur nombre d'exemplaires est installée, confirmant la principale section du groupe de musique",
72
+ "reference": "1789",
73
+ "latency_s": 1.229,
74
+ "tokens_generated_approx": 26,
75
+ "overlap_score": 0.0,
76
+ "exact_match": false
77
+ },
78
+ {
79
+ "id": 6,
80
+ "category": "Histoire",
81
+ "question": "Qui a écrit Les Misérables ?",
82
+ "answer": "Historia et soutenue, L'Advention d'une partie de la mise en place des règnes du terme de cette matière. Le phénomène ou la présence d'un système de variable est écrit : «",
83
+ "reference": "Victor Hugo",
84
+ "latency_s": 1.22,
85
+ "tokens_generated_approx": 31,
86
+ "overlap_score": 0.0,
87
+ "exact_match": false
88
+ },
89
+ {
90
+ "id": 7,
91
+ "category": "Mathématiques",
92
+ "question": "Quelle est la formule de l'aire d'un cercle ?",
93
+ "answer": "Remarque : évoquant le premier au sens interdisant les règles différentes, on utilise les mots de type latin. Le cercle est un facteur plus ou moins électrique et une sorte de poids (",
94
+ "reference": "pi r carre",
95
+ "latency_s": 1.224,
96
+ "tokens_generated_approx": 33,
97
+ "overlap_score": 0.0,
98
+ "exact_match": false
99
+ },
100
+ {
101
+ "id": 8,
102
+ "category": "Langage",
103
+ "question": "Donne un synonyme du mot heureux.",
104
+ "answer": "le premier à créer une fois la relation de l'homme, qui serait naturellement réduite. Son grand-père est d'une famille enveloppant le chevalier (ou encore mythique). Les barrag",
105
+ "reference": "joyeux",
106
+ "latency_s": 1.228,
107
+ "tokens_generated_approx": 27,
108
+ "overlap_score": 0.0,
109
+ "exact_match": false
110
+ },
111
+ {
112
+ "id": 9,
113
+ "category": "Raisonnement",
114
+ "question": "Si j'ai 5 pommes et j'en donne 2, combien m'en reste-t-il ?",
115
+ "answer": "Loi de la section 1 à 3 par exemple : Dans les différentes factions qui vont à cette occasion, une section de terminaisons n'est pas connue. En janvier 1940, le site est introduit dans la m",
116
+ "reference": "3",
117
+ "latency_s": 1.226,
118
+ "tokens_generated_approx": 36,
119
+ "overlap_score": 1.0,
120
+ "exact_match": false
121
+ },
122
+ {
123
+ "id": 10,
124
+ "category": "Dialogue",
125
+ "question": "Comment vas-tu aujourd'hui ?",
126
+ "answer": "Die verskaffinge en maandelikse digitale oor hierdie raadslede is, is nie gebruik van die Vrystaatse kombuisies nie. Dit was die swart-regering in die omgewing aan die beste metode te",
127
+ "reference": null,
128
+ "latency_s": 1.212,
129
+ "tokens_generated_approx": 29,
130
+ "overlap_score": null,
131
+ "exact_match": null
132
+ }
133
+ ]
134
+ }
aramix_h100/qa_test_report_simple.txt ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TEST QA SIMPLE — MODÈLE DÉJÀ ENTRAÎNÉ
2
+ ============================================================
3
+
4
+ [01] Géographie
5
+ User : Quelle est la capitale de la France ?
6
+ Assistant : Au total, le personnage d'une société financière (dont le but à une institution) est une commission de la musique. Il a également considéré plusieurs sociétés, notamment les p
7
+ Référence : Paris
8
+ Overlap : 0%
9
+ ExactMatch: False
10
+ Latence : 1.6s
11
+
12
+ [02] Géographie
13
+ User : Quel est le plus long fleuve d'Afrique ?
14
+ Assistant : Les petits chemins de fer français sont un type de classe moyenne de . Le groupe est le plus long, de poids lourds. Historique Le , le groupe est également connu comme secondaire d'une man
15
+ Référence : Le Nil
16
+ Overlap : 50%
17
+ ExactMatch: False
18
+ Latence : 1.25s
19
+
20
+ [03] Science
21
+ User : Qu'est-ce que la photosynthèse ?
22
+ Assistant : Die verlore, die wat gevra het nie beslis nie! Chefsel en sy kragtige tempo van die opwindende akademiese medium. Hierdie photographies is gebaseer, asook hulle nie eens uitges
23
+ Référence : Processus par lequel les plantes convertissent la lumière en énergie
24
+ Overlap : 10%
25
+ ExactMatch: False
26
+ Latence : 1.242s
27
+
28
+ [04] Science
29
+ User : Combien d'os compte le corps humain adulte ?
30
+ Assistant : Die volgende plek is available, sal ons beslis nie opgelewing te koop nie.<eos>
31
+ Référence : 206
32
+ Overlap : 0%
33
+ ExactMatch: False
34
+ Latence : 0.518s
35
+
36
+ [05] Histoire
37
+ User : En quelle année a eu lieu la Révolution française ?
38
+ Assistant : Les premiers élèves sont des trois fils de Jacques Salamine et Julien Fernandez. Leur nombre d'exemplaires est installée, confirmant la principale section du groupe de musique
39
+ Référence : 1789
40
+ Overlap : 0%
41
+ ExactMatch: False
42
+ Latence : 1.229s
43
+
44
+ [06] Histoire
45
+ User : Qui a écrit Les Misérables ?
46
+ Assistant : Historia et soutenue, L'Advention d'une partie de la mise en place des règnes du terme de cette matière. Le phénomène ou la présence d'un système de variable est écrit : «
47
+ Référence : Victor Hugo
48
+ Overlap : 0%
49
+ ExactMatch: False
50
+ Latence : 1.22s
51
+
52
+ [07] Mathématiques
53
+ User : Quelle est la formule de l'aire d'un cercle ?
54
+ Assistant : Remarque : évoquant le premier au sens interdisant les règles différentes, on utilise les mots de type latin. Le cercle est un facteur plus ou moins électrique et une sorte de poids (
55
+ Référence : pi r carre
56
+ Overlap : 0%
57
+ ExactMatch: False
58
+ Latence : 1.224s
59
+
60
+ [08] Langage
61
+ User : Donne un synonyme du mot heureux.
62
+ Assistant : le premier à créer une fois la relation de l'homme, qui serait naturellement réduite. Son grand-père est d'une famille enveloppant le chevalier (ou encore mythique). Les barrag
63
+ Référence : joyeux
64
+ Overlap : 0%
65
+ ExactMatch: False
66
+ Latence : 1.228s
67
+
68
+ [09] Raisonnement
69
+ User : Si j'ai 5 pommes et j'en donne 2, combien m'en reste-t-il ?
70
+ Assistant : Loi de la section 1 à 3 par exemple : Dans les différentes factions qui vont à cette occasion, une section de terminaisons n'est pas connue. En janvier 1940, le site est introduit dans la m
71
+ Référence : 3
72
+ Overlap : 100%
73
+ ExactMatch: False
74
+ Latence : 1.226s
75
+
76
+ [10] Dialogue
77
+ User : Comment vas-tu aujourd'hui ?
78
+ Assistant : Die verskaffinge en maandelikse digitale oor hierdie raadslede is, is nie gebruik van die Vrystaatse kombuisies nie. Dit was die swart-regering in die omgewing aan die beste metode te
79
+ Latence : 1.212s
80
+
aramix_h100/tokenizer_32k/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
aramix_h100/tokenizer_32k/tokenizer_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "backend": "tokenizers",
3
+ "bos_token": "<bos>",
4
+ "eos_token": "<eos>",
5
+ "model_max_length": 1000000000000000019884624838656,
6
+ "pad_token": "<pad>",
7
+ "tokenizer_class": "TokenizersBackend",
8
+ "unk_token": "<unk>"
9
+ }
aramix_h100/train_log.jsonl ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"step": 3050, "loss": 3.571237907409668, "ppl": 35.5605866297653, "lr": 0.00012589612961614406, "tok_s": 11144.981072882043, "time": 147.00788259506226}
2
+ {"step": 3100, "loss": 2.8578320932388306, "ppl": 17.423712977652347, "lr": 0.00012168934965573969, "tok_s": 57222.53834418063, "time": 175.63995742797852}
3
+ {"step": 3150, "loss": 2.83519070148468, "ppl": 17.03364833271062, "lr": 0.00011752894782453746, "tok_s": 57001.54441684658, "time": 204.3830382823944}
4
+ {"step": 3200, "loss": 2.9635899543762205, "ppl": 19.36737509653468, "lr": 0.00011341937918502494, "tok_s": 56964.73242461735, "time": 233.14469361305237}
5
+ {"step": 3250, "loss": 2.8826186561584475, "ppl": 17.860983768575373, "lr": 0.0001093650443662347, "tok_s": 57001.06640251769, "time": 261.88801550865173}
6
+ {"step": 3300, "loss": 2.703218548297882, "ppl": 14.927700012805772, "lr": 0.00010537028485144083, "tok_s": 60419.98038551336, "time": 289.00487303733826}
7
+ {"step": 3350, "loss": 2.745184621810913, "ppl": 15.56748776467281, "lr": 0.00010143937832918955, "tok_s": 57034.40571613923, "time": 317.7313930988312}
8
+ {"step": 3400, "loss": 2.7973192358016967, "ppl": 16.400621587712635, "lr": 9.757653411264333e-05, "tok_s": 61073.53656045087, "time": 344.55806946754456}
9
+ {"step": 3450, "loss": 2.7036573982238767, "ppl": 14.934252470519295, "lr": 9.378588863214297e-05, "tok_s": 56988.909283348396, "time": 373.30752301216125}
10
+ {"step": 3500, "loss": 2.5318048357963563, "ppl": 12.576183612317518, "lr": 9.007150100581427e-05, "tok_s": 57049.3650586449, "time": 402.02651047706604}
11
+ {"step": 3500, "val_loss": 2.7343945503234863, "val_ppl": 15.400416435550476, "per_domain": {"arabic_aramix": 2.896586099727042, "french_wiki": 2.638401048920083, "arabic_wiki": 2.9592110232303015, "math_stackexchange": 2.6955147483132103, "multilingual_cc": 2.75966284356334, "stories": 2.5633833775153527, "medical_pubmed": 2.8006927967071533, "medical_flashcards": 2.939271628856659}}
12
+ {"step": 3550, "loss": 2.622337305545807, "ppl": 13.767865716347135, "lr": 8.643734869296278e-05, "tok_s": 11633.872778271438, "time": 542.8566563129425}
13
+ {"step": 3600, "loss": 2.5885059988498687, "ppl": 13.30987177797193, "lr": 8.288732323491074e-05, "tok_s": 56913.92629532449, "time": 571.6439867019653}
14
+ {"step": 3650, "loss": 2.6797654819488526, "ppl": 14.581673229274589, "lr": 7.942522608783706e-05, "tok_s": 57007.48882356021, "time": 600.3840703964233}
15
+ {"step": 3700, "loss": 2.530478653907776, "ppl": 12.559516359730464, "lr": 7.605476455208276e-05, "tok_s": 56725.416636375594, "time": 629.2670667171478}
16
+ {"step": 3750, "loss": 2.569008586406708, "ppl": 13.052877224060135, "lr": 7.277954780228142e-05, "tok_s": 56930.38215870419, "time": 658.0460760593414}
17
+ {"step": 3800, "loss": 2.564538803100586, "ppl": 12.994663888764459, "lr": 6.960308302256383e-05, "tok_s": 56738.58728395681, "time": 686.922367811203}
18
+ {"step": 3850, "loss": 2.656132850646973, "ppl": 14.241109975260919, "lr": 6.652877165097785e-05, "tok_s": 56904.04773033927, "time": 715.7146956920624}
19
+ {"step": 3900, "loss": 2.603713195323944, "ppl": 13.51382445690133, "lr": 6.355990573714333e-05, "tok_s": 56984.61454772562, "time": 744.466315984726}
20
+ {"step": 3950, "loss": 2.4559333181381224, "ppl": 11.657308450637194, "lr": 6.069966441704281e-05, "tok_s": 56748.50220511147, "time": 773.3375625610352}
21
+ {"step": 4000, "loss": 2.4442277646064756, "ppl": 11.521648737554232, "lr": 5.795111050872301e-05, "tok_s": 59304.739097463025, "time": 800.9643597602844}
22
+ {"step": 4000, "val_loss": 2.6130046784877776, "val_ppl": 13.639973075654948, "per_domain": {"arabic_aramix": 2.794638446513438, "french_wiki": 2.522582699096084, "arabic_wiki": 2.849381112424951, "math_stackexchange": 2.5690312363884664, "multilingual_cc": 2.6081621836532247, "stories": 2.4209399079228495, "medical_pubmed": 2.6950340270996094, "medical_flashcards": 2.819118231534958}}
23
+ {"step": 4050, "loss": 2.616236004829407, "ppl": 13.684119567401414, "lr": 5.531718723255281e-05, "tok_s": 11053.809939264273, "time": 949.1847479343414}
24
+ {"step": 4100, "loss": 2.470917589664459, "ppl": 11.833299985286175, "lr": 5.280071505954885e-05, "tok_s": 56891.6055540174, "time": 977.9833726882935}
25
+ {"step": 4150, "loss": 2.51880806684494, "ppl": 12.413791432319185, "lr": 5.0404388691144755e-05, "tok_s": 56854.72699648875, "time": 1006.8006775379181}
26
+ {"step": 4200, "loss": 2.4818945503234864, "ppl": 11.96390918827006, "lr": 4.813077417363728e-05, "tok_s": 56728.39765714226, "time": 1035.682156085968}
27
+ {"step": 4250, "loss": 2.491965615749359, "ppl": 12.085007270264622, "lr": 4.5982306150399575e-05, "tok_s": 56775.80030404367, "time": 1064.5395212173462}
28
+ {"step": 4300, "loss": 2.5507710337638856, "ppl": 12.81698229994284, "lr": 4.3961285254804134e-05, "tok_s": 56859.33243706391, "time": 1093.3544919490814}
29
+ {"step": 4350, "loss": 2.5808962631225585, "ppl": 13.208971570046705, "lr": 4.206987564664711e-05, "tok_s": 58036.5193491192, "time": 1121.584992647171}
30
+ {"step": 4400, "loss": 2.7079242062568665, "ppl": 14.998110196375103, "lr": 4.031010269471151e-05, "tok_s": 57002.2214993296, "time": 1150.3277320861816}
31
+ {"step": 4450, "loss": 2.5953755283355715, "ppl": 13.401619104283894, "lr": 3.868385080795177e-05, "tok_s": 56785.49267694181, "time": 1179.1801717281342}
32
+ {"step": 4500, "loss": 2.5976397037506103, "ppl": 13.43199709835749, "lr": 3.71928614176214e-05, "tok_s": 56872.48957171817, "time": 1207.9884762763977}
33
+ {"step": 4500, "val_loss": 2.547382290661335, "val_ppl": 12.773622348939783, "per_domain": {"arabic_aramix": 2.7358595652868285, "french_wiki": 2.4578250470351537, "arabic_wiki": 2.7882657515375238, "math_stackexchange": 2.504125434702093, "multilingual_cc": 2.5334029245105656, "stories": 2.348106792994908, "medical_pubmed": 2.6415319442749023, "medical_flashcards": 2.757194072008133}}
34
+ {"step": 4550, "loss": 2.638249659538269, "ppl": 13.988697184009707, "lr": 3.583873111250479e-05, "tok_s": 11385.546618796197, "time": 1351.8902189731598}
35
+ {"step": 4600, "loss": 2.582155728340149, "ppl": 13.225618291081892, "lr": 3.462290992924992e-05, "tok_s": 56835.088159816194, "time": 1380.7174813747406}
36
+ {"step": 4650, "loss": 2.5022823572158814, "ppl": 12.210330518102197, "lr": 3.354669979963281e-05, "tok_s": 56918.493716463534, "time": 1409.5025017261505}
37
+ {"step": 4700, "loss": 2.4462553191185, "ppl": 11.545033207070272, "lr": 3.261125315641639e-05, "tok_s": 56923.2943292039, "time": 1438.285094499588}
38
+ {"step": 4750, "loss": 2.3780724477767943, "ppl": 10.784095909082684, "lr": 3.1817571699296604e-05, "tok_s": 57047.299242243054, "time": 1467.0051219463348}
39
+ {"step": 4800, "loss": 2.583693208694458, "ppl": 13.245968059053615, "lr": 3.116650532225727e-05, "tok_s": 56903.56946504383, "time": 1495.797691822052}
40
+ {"step": 4850, "loss": 2.5454819345474244, "ppl": 12.749370968040534, "lr": 3.065875120348237e-05, "tok_s": 56909.14470005539, "time": 1524.5874409675598}
41
+ {"step": 4900, "loss": 2.5480848932266236, "ppl": 12.782600282360375, "lr": 3.029485305880013e-05, "tok_s": 57001.02054004117, "time": 1553.3307859897614}
42
+ {"step": 4950, "loss": 2.5183346843719483, "ppl": 12.407916351721552, "lr": 3.007520055945856e-05, "tok_s": 56954.47416190712, "time": 1582.097621679306}
43
+ {"step": 5000, "loss": 2.488962812423706, "ppl": 12.048772799963553, "lr": 3.0000028914855615e-05, "tok_s": 56955.4739533805, "time": 1610.8639523983002}
44
+ {"step": 5000, "val_loss": 2.5137671560049055, "val_ppl": 12.351372073909298, "per_domain": {"arabic_aramix": 2.7070358647596118, "french_wiki": 2.427249958942895, "arabic_wiki": 2.755421937766828, "math_stackexchange": 2.469518039443276, "multilingual_cc": 2.489755442873998, "stories": 2.308292391536, "medical_pubmed": 2.6124343872070312, "medical_flashcards": 2.7211887538433075}}
aramix_h100/train_state.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc13a38fb25354acd0eba8d965216a73cd9fc292a9fa767ad6f27924eb855ac9
3
+ size 5225877311
donner ADDED
File without changes
nlp_1b_h100_maxvram/config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 32000,
3
+ "block_size": 2048,
4
+ "d_model": 1536,
5
+ "n_heads": 24,
6
+ "n_layers": 24,
7
+ "d_ff": 6144,
8
+ "dropout": 0.0,
9
+ "use_checkpointing": false
10
+ }
nlp_1b_h100_maxvram/tokenizer_32k/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
nlp_1b_h100_maxvram/tokenizer_32k/tokenizer_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "backend": "tokenizers",
3
+ "bos_token": "<bos>",
4
+ "eos_token": "<eos>",
5
+ "model_max_length": 1000000000000000019884624838656,
6
+ "pad_token": "<pad>",
7
+ "tokenizer_class": "TokenizersBackend",
8
+ "unk_token": "<unk>"
9
+ }
nlp_1b_h100_opt/config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 32000,
3
+ "block_size": 1024,
4
+ "d_model": 1536,
5
+ "n_heads": 24,
6
+ "n_layers": 24,
7
+ "d_ff": 6144,
8
+ "dropout": 0.0,
9
+ "use_checkpointing": false
10
+ }
nlp_1b_h100_opt/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:404b4fea027e28fbcabe19077a7ffebb8830de27ca94db6908992d55dcd85e6d
3
+ size 4415622541
nlp_1b_h100_opt/model_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:77b5bc51912e146de6f8855909bc84564dc6c30daa361613c88a11cc41ceb049
3
+ size 4415675901
nlp_1b_h100_opt/tokenizer_32k/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
nlp_1b_h100_opt/tokenizer_32k/tokenizer_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "backend": "tokenizers",
3
+ "bos_token": "<bos>",
4
+ "eos_token": "<eos>",
5
+ "model_max_length": 1000000000000000019884624838656,
6
+ "pad_token": "<pad>",
7
+ "tokenizer_class": "TokenizersBackend",
8
+ "unk_token": "<unk>"
9
+ }
nlp_1b_h100_opt/train_state.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a749a6a3258c1a8d1bd9d27d4f1bf0e45454c70bf598e88741b471b8f2afa088
3
+ size 4415677037
nlp_1b_wiki_en_fr_ar/config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 32000,
3
+ "block_size": 1024,
4
+ "d_model": 1536,
5
+ "n_heads": 24,
6
+ "n_layers": 24,
7
+ "d_ff": 6144,
8
+ "dropout": 0.0,
9
+ "use_checkpointing": false
10
+ }
nlp_1b_wiki_en_fr_ar/model_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc91b83ddaf6edf5e20142465b16b89b4505fc04eb5ef26680dae6839c030118
3
+ size 11462571709
nlp_1b_wiki_en_fr_ar/model_epoch_02.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:243fcb18b86fdcc825307bc552f4392aee2a996217a48e56e650c3fd00257fd3
3
+ size 11462574453
nlp_1b_wiki_en_fr_ar/tokenizer_32k/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
nlp_1b_wiki_en_fr_ar/tokenizer_32k/tokenizer_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "backend": "tokenizers",
3
+ "bos_token": "<bos>",
4
+ "eos_token": "<eos>",
5
+ "model_max_length": 1000000000000000019884624838656,
6
+ "pad_token": "<pad>",
7
+ "tokenizer_class": "TokenizersBackend",
8
+ "unk_token": "<unk>"
9
+ }
simple_qa_test_aramix.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ simple_qa_test_aramix.py
6
+
7
+ Test QA simple pour un modèle déjà entraîné dans une repo de type :
8
+ - train_aramix_h100_full.py
9
+ - aramix_h100/
10
+ - config.json
11
+ - model_best.pt
12
+ - model.pt
13
+ - tokenizer_32k/
14
+
15
+ Hypothèses alignées avec ton repo :
16
+ - le module d'entraînement expose : GPT, GPTConfig, train_or_load_tokenizer,
17
+ load_checkpoint, DOMAINS
18
+ - le tokenizer est géré par train_or_load_tokenizer(DOMAINS)
19
+ - le checkpoint se recharge avec load_checkpoint(model, opt, ckpt_path, device)
20
+
21
+ Usage
22
+ -----
23
+ python simple_qa_test_aramix.py
24
+ python simple_qa_test_aramix.py --repo_dir ./aramix_h100
25
+ python simple_qa_test_aramix.py --ckpt ./aramix_h100/model.pt
26
+ python simple_qa_test_aramix.py --questions qa_questions.json
27
+ python simple_qa_test_aramix.py --max_new_tokens 96 --temperature 0.4 --top_k 40
28
+ python simple_qa_test_aramix.py --save_report
29
+ """
30
+
31
+ from __future__ import annotations
32
+
33
+ import argparse
34
+ import importlib.util
35
+ import json
36
+ import os
37
+ import re
38
+ import sys
39
+ import time
40
+ import unicodedata
41
+ from pathlib import Path
42
+ from typing import Any, Dict, List, Optional
43
+
44
+ import torch
45
+ import torch.nn.functional as F
46
+
47
+
48
+ DEFAULT_QUESTIONS = [
49
+ {
50
+ "category": "Géographie",
51
+ "question": "Quelle est la capitale de la France ?",
52
+ "reference": "Paris",
53
+ },
54
+ {
55
+ "category": "Géographie",
56
+ "question": "Quel est le plus long fleuve d'Afrique ?",
57
+ "reference": "Le Nil",
58
+ },
59
+ {
60
+ "category": "Science",
61
+ "question": "Qu'est-ce que la photosynthèse ?",
62
+ "reference": "Processus par lequel les plantes convertissent la lumière en énergie",
63
+ },
64
+ {
65
+ "category": "Science",
66
+ "question": "Combien d'os compte le corps humain adulte ?",
67
+ "reference": "206",
68
+ },
69
+ {
70
+ "category": "Histoire",
71
+ "question": "En quelle année a eu lieu la Révolution française ?",
72
+ "reference": "1789",
73
+ },
74
+ {
75
+ "category": "Histoire",
76
+ "question": "Qui a écrit Les Misérables ?",
77
+ "reference": "Victor Hugo",
78
+ },
79
+ {
80
+ "category": "Mathématiques",
81
+ "question": "Quelle est la formule de l'aire d'un cercle ?",
82
+ "reference": "pi r carre",
83
+ },
84
+ {
85
+ "category": "Langage",
86
+ "question": "Donne un synonyme du mot heureux.",
87
+ "reference": "joyeux",
88
+ },
89
+ {
90
+ "category": "Raisonnement",
91
+ "question": "Si j'ai 5 pommes et j'en donne 2, combien m'en reste-t-il ?",
92
+ "reference": "3",
93
+ },
94
+ {
95
+ "category": "Dialogue",
96
+ "question": "Comment vas-tu aujourd'hui ?",
97
+ "reference": None,
98
+ },
99
+ ]
100
+
101
+
102
+ def load_module_from_file(py_path: Path):
103
+ spec = importlib.util.spec_from_file_location(py_path.stem, py_path)
104
+ if spec is None or spec.loader is None:
105
+ raise RuntimeError(f"Impossible de charger le module: {py_path}")
106
+ module = importlib.util.module_from_spec(spec)
107
+ sys.modules[py_path.stem] = module
108
+ spec.loader.exec_module(module)
109
+ return module
110
+
111
+
112
+ def normalize_text(text: str) -> str:
113
+ text = (text or "").strip().lower()
114
+ text = unicodedata.normalize("NFKD", text)
115
+ text = "".join(ch for ch in text if not unicodedata.combining(ch))
116
+ text = text.replace("π", "pi")
117
+ text = re.sub(r"[\W_]+", " ", text, flags=re.UNICODE)
118
+ text = re.sub(r"\s+", " ", text).strip()
119
+ return text
120
+
121
+
122
+ def lexical_overlap(reference: Optional[str], answer: str) -> Optional[float]:
123
+ if not reference:
124
+ return None
125
+ ref = set(normalize_text(reference).split())
126
+ ans = set(normalize_text(answer).split())
127
+ if not ref:
128
+ return None
129
+ return len(ref & ans) / len(ref)
130
+
131
+
132
+ def exact_match(reference: Optional[str], answer: str) -> Optional[bool]:
133
+ if not reference:
134
+ return None
135
+ return normalize_text(reference) == normalize_text(answer)
136
+
137
+
138
+ def infer_repo_defaults(repo_dir: Path):
139
+ train_script = repo_dir.parent / "train_aramix_h100_full.py"
140
+ if not train_script.exists():
141
+ train_script = repo_dir / "train_aramix_h100_full.py"
142
+
143
+ ckpt = repo_dir / "model_best.pt"
144
+ if not ckpt.exists():
145
+ ckpt = repo_dir / "model.pt"
146
+
147
+ config = repo_dir / "config.json"
148
+ tokenizer_dir = repo_dir / "tokenizer_32k"
149
+ return train_script, ckpt, config, tokenizer_dir
150
+
151
+
152
+ def safe_get(cfg: Dict[str, Any], *names: str, default=None):
153
+ for name in names:
154
+ if name in cfg:
155
+ return cfg[name]
156
+ return default
157
+
158
+
159
+ def build_model_config_dict(cfg_json: Dict[str, Any], vocab_size: int) -> Dict[str, Any]:
160
+ block_size = safe_get(cfg_json, "block_size", "max_seq_len", "seq_len", default=512)
161
+ d_model = safe_get(cfg_json, "d_model", "n_embd", "dim", default=768)
162
+ n_heads = safe_get(cfg_json, "n_heads", "n_head", "num_heads", default=12)
163
+ n_layers = safe_get(cfg_json, "n_layers", "n_layer", "num_layers", default=12)
164
+ d_ff = safe_get(cfg_json, "d_ff", "ffn_dim", "intermediate_size", default=d_model * 4)
165
+
166
+ return {
167
+ "vocab_size": vocab_size,
168
+ "block_size": int(block_size),
169
+ "d_model": int(d_model),
170
+ "n_heads": int(n_heads),
171
+ "n_layers": int(n_layers),
172
+ "d_ff": int(d_ff),
173
+ }
174
+
175
+
176
+ class AramixChatTester:
177
+ def __init__(
178
+ self,
179
+ repo_dir: Path,
180
+ train_script: Path,
181
+ ckpt_path: Path,
182
+ config_path: Path,
183
+ device: Optional[str] = None,
184
+ ):
185
+ self.repo_dir = repo_dir
186
+ self.train_script = train_script
187
+ self.ckpt_path = ckpt_path
188
+ self.config_path = config_path
189
+
190
+ self.device = torch.device(
191
+ device if device is not None else ("cuda" if torch.cuda.is_available() else "cpu")
192
+ )
193
+
194
+ self.M = load_module_from_file(self.train_script)
195
+
196
+ required = ["GPT", "GPTConfig", "train_or_load_tokenizer", "load_checkpoint", "DOMAINS"]
197
+ missing = [x for x in required if not hasattr(self.M, x)]
198
+ if missing:
199
+ raise RuntimeError(
200
+ f"Le fichier {self.train_script.name} ne contient pas les symboles attendus: {missing}"
201
+ )
202
+
203
+ self.cfg_json: Dict[str, Any] = {}
204
+ if self.config_path.exists():
205
+ with open(self.config_path, "r", encoding="utf-8") as f:
206
+ self.cfg_json = json.load(f)
207
+
208
+ self.tokenizer = self._load_tokenizer()
209
+ self.model = self._load_model()
210
+
211
+ def _load_tokenizer(self):
212
+ old_cwd = Path.cwd()
213
+ try:
214
+ os.chdir(self.repo_dir.parent)
215
+ tok = self.M.train_or_load_tokenizer(self.M.DOMAINS)
216
+ finally:
217
+ os.chdir(old_cwd)
218
+ return tok
219
+
220
+ def _make_gpt_config(self):
221
+ kwargs = build_model_config_dict(self.cfg_json, vocab_size=len(self.tokenizer))
222
+ try:
223
+ return self.M.GPTConfig(**kwargs)
224
+ except TypeError:
225
+ return self.M.GPTConfig(vocab_size=len(self.tokenizer))
226
+
227
+ def _load_model(self):
228
+ cfg = self._make_gpt_config()
229
+ model = self.M.GPT(cfg).to(self.device)
230
+
231
+ try:
232
+ self.M.load_checkpoint(model, None, self.ckpt_path, self.device)
233
+ except TypeError:
234
+ try:
235
+ self.M.load_checkpoint(model, self.ckpt_path, self.device)
236
+ except TypeError:
237
+ ckpt = torch.load(self.ckpt_path, map_location=self.device)
238
+ state = ckpt["model"] if isinstance(ckpt, dict) and "model" in ckpt else ckpt
239
+ if any(k.startswith("_orig_mod.") for k in state):
240
+ state = {k.replace("_orig_mod.", ""): v for k, v in state.items()}
241
+ model.load_state_dict(state, strict=False)
242
+
243
+ model.eval()
244
+ return model
245
+
246
+ def encode_prompt(self, question: str) -> List[int]:
247
+ bos = getattr(self.tokenizer, "bos_token_id", None)
248
+ eos = getattr(self.tokenizer, "eos_token_id", None)
249
+
250
+ prompt = f"Question: {question}\nRéponse:"
251
+ ids = self.tokenizer.encode(prompt, add_special_tokens=False)
252
+
253
+ if bos is not None:
254
+ ids = [bos] + ids
255
+ if eos is not None and len(ids) > 0 and ids[-1] == eos:
256
+ ids = ids[:-1]
257
+ return ids
258
+
259
+ @torch.no_grad()
260
+ def generate(
261
+ self,
262
+ question: str,
263
+ max_new_tokens: int = 96,
264
+ temperature: float = 0.4,
265
+ top_k: int = 40,
266
+ repetition_penalty: float = 1.12,
267
+ ) -> str:
268
+ ids = self.encode_prompt(question)
269
+ x = torch.tensor([ids], dtype=torch.long, device=self.device)
270
+
271
+ eos_id = getattr(self.tokenizer, "eos_token_id", None)
272
+ block_size = getattr(getattr(self.model, "cfg", None), "block_size", None)
273
+ if block_size is None:
274
+ block_size = safe_get(self.cfg_json, "block_size", "max_seq_len", default=512)
275
+
276
+ for step in range(max_new_tokens):
277
+ x_ctx = x[:, -int(block_size):]
278
+
279
+ try:
280
+ logits, _ = self.model(x_ctx)
281
+ except TypeError:
282
+ out = self.model(x_ctx)
283
+ logits = out[0] if isinstance(out, tuple) else out
284
+
285
+ logits = logits[:, -1, :]
286
+
287
+ recent = x[0, -64:].tolist()
288
+ for tok in set(recent):
289
+ logits[0, tok] /= repetition_penalty
290
+
291
+ if temperature <= 0:
292
+ next_tok = torch.argmax(logits, dim=-1, keepdim=True)
293
+ else:
294
+ logits = logits / max(temperature, 1e-5)
295
+
296
+ if top_k is not None and top_k > 0:
297
+ values, _ = torch.topk(logits, k=min(top_k, logits.size(-1)))
298
+ kth = values[:, -1].unsqueeze(-1)
299
+ logits = torch.where(logits < kth, torch.full_like(logits, float("-inf")), logits)
300
+
301
+ probs = F.softmax(logits, dim=-1)
302
+ next_tok = torch.multinomial(probs, num_samples=1)
303
+
304
+ x = torch.cat([x, next_tok], dim=1)
305
+
306
+ if eos_id is not None and next_tok.item() == eos_id and step >= 2:
307
+ break
308
+
309
+ new_ids = x[0, len(ids):].tolist()
310
+ text = self.tokenizer.decode(new_ids).strip()
311
+ text = re.sub(r"\s+", " ", text).strip()
312
+ text = text.replace("Réponse :", "").replace("Réponse:", "").strip()
313
+ return text
314
+
315
+
316
+ def load_questions(path: Optional[str]) -> List[Dict[str, Any]]:
317
+ if not path:
318
+ return DEFAULT_QUESTIONS
319
+ with open(path, "r", encoding="utf-8") as f:
320
+ data = json.load(f)
321
+ if not isinstance(data, list):
322
+ raise ValueError("Le fichier questions doit contenir une liste JSON.")
323
+ return data
324
+
325
+
326
+ def format_bar(score: float, width: int = 20) -> str:
327
+ n = max(0, min(width, int(round(score * width))))
328
+ return "█" * n + "░" * (width - n)
329
+
330
+
331
+ def save_reports(output_dir: Path, summary: Dict[str, Any]) -> None:
332
+ output_dir.mkdir(parents=True, exist_ok=True)
333
+
334
+ json_path = output_dir / "qa_test_report_simple.json"
335
+ txt_path = output_dir / "qa_test_report_simple.txt"
336
+
337
+ with open(json_path, "w", encoding="utf-8") as f:
338
+ json.dump(summary, f, ensure_ascii=False, indent=2)
339
+
340
+ with open(txt_path, "w", encoding="utf-8") as f:
341
+ f.write("TEST QA SIMPLE — MODÈLE DÉJÀ ENTRAÎNÉ\n")
342
+ f.write("=" * 60 + "\n\n")
343
+ for r in summary["results"]:
344
+ f.write(f"[{r['id']:02d}] {r['category']}\n")
345
+ f.write(f" User : {r['question']}\n")
346
+ f.write(f" Assistant : {r['answer']}\n")
347
+ if r["reference"]:
348
+ f.write(f" Référence : {r['reference']}\n")
349
+ if r["overlap_score"] is not None:
350
+ f.write(f" Overlap : {r['overlap_score']:.0%}\n")
351
+ if r["exact_match"] is not None:
352
+ f.write(f" ExactMatch: {r['exact_match']}\n")
353
+ f.write(f" Latence : {r['latency_s']}s\n\n")
354
+
355
+
356
+ def main():
357
+ parser = argparse.ArgumentParser("Test QA simple pour modèle Aramix déjà entraîné")
358
+ parser.add_argument("--repo_dir", type=str, default="./aramix_h100")
359
+ parser.add_argument("--train_script", type=str, default=None)
360
+ parser.add_argument("--ckpt", type=str, default=None)
361
+ parser.add_argument("--config", type=str, default=None)
362
+ parser.add_argument("--questions", type=str, default=None)
363
+ parser.add_argument("--max_new_tokens", type=int, default=96)
364
+ parser.add_argument("--temperature", type=float, default=0.4)
365
+ parser.add_argument("--top_k", type=int, default=40)
366
+ parser.add_argument("--repetition_penalty", type=float, default=1.12)
367
+ parser.add_argument("--device", type=str, default=None)
368
+ parser.add_argument("--save_report", action="store_true")
369
+ args = parser.parse_args()
370
+
371
+ repo_dir = Path(args.repo_dir).resolve()
372
+ train_script, ckpt_path, config_path, tokenizer_dir = infer_repo_defaults(repo_dir)
373
+
374
+ if args.train_script:
375
+ train_script = Path(args.train_script).resolve()
376
+ if args.ckpt:
377
+ ckpt_path = Path(args.ckpt).resolve()
378
+ if args.config:
379
+ config_path = Path(args.config).resolve()
380
+
381
+ if not train_script.exists():
382
+ raise FileNotFoundError(f"Script train introuvable: {train_script}")
383
+ if not ckpt_path.exists():
384
+ raise FileNotFoundError(f"Checkpoint introuvable: {ckpt_path}")
385
+ if not config_path.exists():
386
+ print(f"[WARN] config.json introuvable: {config_path} — fallback sur GPTConfig(vocab_size=...).")
387
+
388
+ questions = load_questions(args.questions)
389
+ tester = AramixChatTester(
390
+ repo_dir=repo_dir,
391
+ train_script=train_script,
392
+ ckpt_path=ckpt_path,
393
+ config_path=config_path,
394
+ device=args.device,
395
+ )
396
+
397
+ results: List[Dict[str, Any]] = []
398
+ categories: Dict[str, List[Dict[str, Any]]] = {}
399
+
400
+ print("\n" + "═" * 70)
401
+ print(" TEST QA SIMPLE — MODÈLE DÉJÀ ENTRAÎNÉ")
402
+ print("═" * 70)
403
+ print(f"Repo : {repo_dir}")
404
+ print(f"Train script: {train_script}")
405
+ print(f"Checkpoint : {ckpt_path}")
406
+ print(f"Config : {config_path}")
407
+ print(f"Tokenizer : {tokenizer_dir}")
408
+ print(f"Device : {tester.device}")
409
+ print(f"Questions : {len(questions)}")
410
+ print("═" * 70 + "\n")
411
+
412
+ for i, item in enumerate(questions, 1):
413
+ q = item["question"]
414
+ ref = item.get("reference")
415
+ cat = item.get("category", "Général")
416
+
417
+ t0 = time.time()
418
+ ans = tester.generate(
419
+ q,
420
+ max_new_tokens=args.max_new_tokens,
421
+ temperature=args.temperature,
422
+ top_k=args.top_k,
423
+ repetition_penalty=args.repetition_penalty,
424
+ )
425
+ latency = time.time() - t0
426
+
427
+ overlap = lexical_overlap(ref, ans)
428
+ em = exact_match(ref, ans)
429
+
430
+ entry = {
431
+ "id": i,
432
+ "category": cat,
433
+ "question": q,
434
+ "answer": ans,
435
+ "reference": ref,
436
+ "latency_s": round(latency, 3),
437
+ "tokens_generated_approx": len(ans.split()),
438
+ "overlap_score": None if overlap is None else round(overlap, 4),
439
+ "exact_match": em,
440
+ }
441
+ results.append(entry)
442
+ categories.setdefault(cat, []).append(entry)
443
+
444
+ overlap_str = f"{overlap:.0%}" if overlap is not None else "n/a"
445
+ em_str = "✓" if em else ("✗" if em is not None else "n/a")
446
+
447
+ print("─" * 70)
448
+ print(f"[{i:02d}] [{cat}] overlap={overlap_str} | EM={em_str}")
449
+ print(f" User : {q}")
450
+ print(f" Assistant : {ans}")
451
+ if ref:
452
+ print(f" Référence : {ref}")
453
+ print(f" ⏱ {latency:.2f}s | ~{entry['tokens_generated_approx']} mots\n")
454
+
455
+ scored_overlap = [r["overlap_score"] for r in results if r["overlap_score"] is not None]
456
+ scored_em = [r["exact_match"] for r in results if r["exact_match"] is not None]
457
+
458
+ avg_overlap = sum(scored_overlap) / len(scored_overlap) if scored_overlap else 0.0
459
+ em_rate = sum(1 for x in scored_em if x) / len(scored_em) if scored_em else 0.0
460
+ avg_latency = sum(r["latency_s"] for r in results) / len(results) if results else 0.0
461
+ avg_words = sum(r["tokens_generated_approx"] for r in results) / len(results) if results else 0.0
462
+
463
+ cat_scores: Dict[str, float] = {}
464
+ for cat, items in categories.items():
465
+ vals = [r["overlap_score"] for r in items if r["overlap_score"] is not None]
466
+ cat_scores[cat] = (sum(vals) / len(vals)) if vals else 0.0
467
+
468
+ summary = {
469
+ "repo_dir": str(repo_dir),
470
+ "train_script": str(train_script),
471
+ "checkpoint": str(ckpt_path),
472
+ "config_path": str(config_path),
473
+ "tokenizer_dir": str(tokenizer_dir),
474
+ "device": str(tester.device),
475
+ "total_questions": len(results),
476
+ "avg_overlap_score": round(avg_overlap, 4),
477
+ "exact_match_rate": round(em_rate, 4),
478
+ "avg_latency_s": round(avg_latency, 3),
479
+ "avg_words_generated": round(avg_words, 1),
480
+ "scores_by_category": {k: round(v, 4) for k, v in cat_scores.items()},
481
+ "results": results,
482
+ }
483
+
484
+ print("═" * 70)
485
+ print(" RÉSUMÉ")
486
+ print("═" * 70)
487
+ print(f"Questions testées : {len(results)}")
488
+ print(f"Overlap moyen : {avg_overlap:.1%}")
489
+ print(f"Exact match : {em_rate:.1%}")
490
+ print(f"Latence moyenne : {avg_latency:.2f}s")
491
+ print(f"Mots moyens : {avg_words:.1f}")
492
+ print("Scores / catégorie:")
493
+ for cat, score in sorted(cat_scores.items()):
494
+ print(f" {cat:<15} {format_bar(score)} {score:.0%}")
495
+ print("═" * 70)
496
+
497
+ if args.save_report:
498
+ save_reports(repo_dir, summary)
499
+ print(f"Rapports sauvegardés dans : {repo_dir / 'qa_test_report_simple.json'}")
500
+ print(f" {repo_dir / 'qa_test_report_simple.txt'}")
501
+
502
+
503
+ if __name__ == "__main__":
504
+ main()
simple_qa_test_aramix_v2.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ simple_qa_test_aramix_v2.py
6
+
7
+ Version corrigée du test QA simple pour repo Aramix.
8
+ Correction principale :
9
+ - gère les checkpoints sauvés depuis torch.compile() avec préfixe "_orig_mod."
10
+ - contourne load_checkpoint() du script train si celui-ci échoue sur ce cas
11
+
12
+ Usage
13
+ -----
14
+ python simple_qa_test_aramix_v2.py --repo_dir ./aramix_h100 --save_report
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import argparse
20
+ import importlib.util
21
+ import json
22
+ import os
23
+ import re
24
+ import sys
25
+ import time
26
+ import unicodedata
27
+ from pathlib import Path
28
+ from typing import Any, Dict, List, Optional
29
+
30
+ import torch
31
+ import torch.nn.functional as F
32
+
33
+
34
+ DEFAULT_QUESTIONS = [
35
+ {"category": "Géographie", "question": "Quelle est la capitale de la France ?", "reference": "Paris"},
36
+ {"category": "Géographie", "question": "Quel est le plus long fleuve d'Afrique ?", "reference": "Le Nil"},
37
+ {"category": "Science", "question": "Qu'est-ce que la photosynthèse ?", "reference": "Processus par lequel les plantes convertissent la lumière en énergie"},
38
+ {"category": "Science", "question": "Combien d'os compte le corps humain adulte ?", "reference": "206"},
39
+ {"category": "Histoire", "question": "En quelle année a eu lieu la Révolution française ?", "reference": "1789"},
40
+ {"category": "Histoire", "question": "Qui a écrit Les Misérables ?", "reference": "Victor Hugo"},
41
+ {"category": "Mathématiques", "question": "Quelle est la formule de l'aire d'un cercle ?", "reference": "pi r carre"},
42
+ {"category": "Langage", "question": "Donne un synonyme du mot heureux.", "reference": "joyeux"},
43
+ {"category": "Raisonnement", "question": "Si j'ai 5 pommes et j'en donne 2, combien m'en reste-t-il ?", "reference": "3"},
44
+ {"category": "Dialogue", "question": "Comment vas-tu aujourd'hui ?", "reference": None},
45
+ ]
46
+
47
+
48
+ def load_module_from_file(py_path: Path):
49
+ spec = importlib.util.spec_from_file_location(py_path.stem, py_path)
50
+ if spec is None or spec.loader is None:
51
+ raise RuntimeError(f"Impossible de charger le module: {py_path}")
52
+ module = importlib.util.module_from_spec(spec)
53
+ sys.modules[py_path.stem] = module
54
+ spec.loader.exec_module(module)
55
+ return module
56
+
57
+
58
+ def normalize_text(text: str) -> str:
59
+ text = (text or "").strip().lower()
60
+ text = unicodedata.normalize("NFKD", text)
61
+ text = "".join(ch for ch in text if not unicodedata.combining(ch))
62
+ text = text.replace("π", "pi")
63
+ text = re.sub(r"[\W_]+", " ", text, flags=re.UNICODE)
64
+ text = re.sub(r"\s+", " ", text).strip()
65
+ return text
66
+
67
+
68
+ def lexical_overlap(reference: Optional[str], answer: str) -> Optional[float]:
69
+ if not reference:
70
+ return None
71
+ ref = set(normalize_text(reference).split())
72
+ ans = set(normalize_text(answer).split())
73
+ if not ref:
74
+ return None
75
+ return len(ref & ans) / len(ref)
76
+
77
+
78
+ def exact_match(reference: Optional[str], answer: str) -> Optional[bool]:
79
+ if not reference:
80
+ return None
81
+ return normalize_text(reference) == normalize_text(answer)
82
+
83
+
84
+ def infer_repo_defaults(repo_dir: Path):
85
+ train_script = repo_dir.parent / "train_aramix_h100_full.py"
86
+ if not train_script.exists():
87
+ train_script = repo_dir / "train_aramix_h100_full.py"
88
+
89
+ ckpt = repo_dir / "model_best.pt"
90
+ if not ckpt.exists():
91
+ ckpt = repo_dir / "model.pt"
92
+
93
+ config = repo_dir / "config.json"
94
+ tokenizer_dir = repo_dir / "tokenizer_32k"
95
+ return train_script, ckpt, config, tokenizer_dir
96
+
97
+
98
+ def safe_get(cfg: Dict[str, Any], *names: str, default=None):
99
+ for name in names:
100
+ if name in cfg:
101
+ return cfg[name]
102
+ return default
103
+
104
+
105
+ def build_model_config_dict(cfg_json: Dict[str, Any], vocab_size: int) -> Dict[str, Any]:
106
+ block_size = safe_get(cfg_json, "block_size", "max_seq_len", "seq_len", default=512)
107
+ d_model = safe_get(cfg_json, "d_model", "n_embd", "dim", default=768)
108
+ n_heads = safe_get(cfg_json, "n_heads", "n_head", "num_heads", default=12)
109
+ n_layers = safe_get(cfg_json, "n_layers", "n_layer", "num_layers", default=12)
110
+ d_ff = safe_get(cfg_json, "d_ff", "ffn_dim", "intermediate_size", default=d_model * 4)
111
+ return {
112
+ "vocab_size": vocab_size,
113
+ "block_size": int(block_size),
114
+ "d_model": int(d_model),
115
+ "n_heads": int(n_heads),
116
+ "n_layers": int(n_layers),
117
+ "d_ff": int(d_ff),
118
+ }
119
+
120
+
121
+ def strip_orig_mod_prefix(state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
122
+ if any(k.startswith("_orig_mod.") for k in state.keys()):
123
+ return {k.replace("_orig_mod.", "", 1): v for k, v in state.items()}
124
+ return state
125
+
126
+
127
+ class AramixChatTester:
128
+ def __init__(
129
+ self,
130
+ repo_dir: Path,
131
+ train_script: Path,
132
+ ckpt_path: Path,
133
+ config_path: Path,
134
+ device: Optional[str] = None,
135
+ ):
136
+ self.repo_dir = repo_dir
137
+ self.train_script = train_script
138
+ self.ckpt_path = ckpt_path
139
+ self.config_path = config_path
140
+ self.device = torch.device(device if device is not None else ("cuda" if torch.cuda.is_available() else "cpu"))
141
+
142
+ self.M = load_module_from_file(self.train_script)
143
+
144
+ required = ["GPT", "GPTConfig", "train_or_load_tokenizer", "DOMAINS"]
145
+ missing = [x for x in required if not hasattr(self.M, x)]
146
+ if missing:
147
+ raise RuntimeError(f"Le fichier {self.train_script.name} ne contient pas les symboles attendus: {missing}")
148
+
149
+ self.cfg_json: Dict[str, Any] = {}
150
+ if self.config_path.exists():
151
+ with open(self.config_path, "r", encoding="utf-8") as f:
152
+ self.cfg_json = json.load(f)
153
+
154
+ self.tokenizer = self._load_tokenizer()
155
+ self.model = self._load_model()
156
+
157
+ def _load_tokenizer(self):
158
+ old_cwd = Path.cwd()
159
+ try:
160
+ os.chdir(self.repo_dir.parent)
161
+ tok = self.M.train_or_load_tokenizer(self.M.DOMAINS)
162
+ finally:
163
+ os.chdir(old_cwd)
164
+ return tok
165
+
166
+ def _make_gpt_config(self):
167
+ kwargs = build_model_config_dict(self.cfg_json, vocab_size=len(self.tokenizer))
168
+ try:
169
+ return self.M.GPTConfig(**kwargs)
170
+ except TypeError:
171
+ return self.M.GPTConfig(vocab_size=len(self.tokenizer))
172
+
173
+ def _manual_load_state(self, model: torch.nn.Module):
174
+ ckpt = torch.load(self.ckpt_path, map_location=self.device)
175
+ if isinstance(ckpt, dict) and "model" in ckpt:
176
+ state = ckpt["model"]
177
+ else:
178
+ state = ckpt
179
+
180
+ if not isinstance(state, dict):
181
+ raise RuntimeError("Checkpoint non reconnu: pas de state_dict exploitable.")
182
+
183
+ state = strip_orig_mod_prefix(state)
184
+ missing, unexpected = model.load_state_dict(state, strict=False)
185
+
186
+ # tolérance uniquement sur lm_head/tied weights éventuels, sinon on échoue
187
+ hard_missing = [k for k in missing if not k.endswith("lm_head.weight")]
188
+ hard_unexpected = [k for k in unexpected if not k.startswith("_orig_mod.")]
189
+ if hard_missing or hard_unexpected:
190
+ raise RuntimeError(
191
+ "Chargement manuel incomplet.\n"
192
+ f"Missing: {hard_missing[:20]}\n"
193
+ f"Unexpected: {hard_unexpected[:20]}"
194
+ )
195
+
196
+ def _load_model(self):
197
+ cfg = self._make_gpt_config()
198
+ model = self.M.GPT(cfg).to(self.device)
199
+
200
+ # 1) tentative via load_checkpoint du script train
201
+ if hasattr(self.M, "load_checkpoint"):
202
+ try:
203
+ try:
204
+ self.M.load_checkpoint(model, None, self.ckpt_path, self.device)
205
+ model.eval()
206
+ return model
207
+ except TypeError:
208
+ self.M.load_checkpoint(model, self.ckpt_path, self.device)
209
+ model.eval()
210
+ return model
211
+ except RuntimeError as e:
212
+ msg = str(e)
213
+ if "_orig_mod." not in msg:
214
+ raise
215
+
216
+ # 2) fallback robuste
217
+ self._manual_load_state(model)
218
+ model.eval()
219
+ return model
220
+
221
+ def encode_prompt(self, question: str) -> List[int]:
222
+ bos = getattr(self.tokenizer, "bos_token_id", None)
223
+ eos = getattr(self.tokenizer, "eos_token_id", None)
224
+
225
+ prompt = f"Question: {question}\nRéponse:"
226
+ ids = self.tokenizer.encode(prompt, add_special_tokens=False)
227
+
228
+ if bos is not None:
229
+ ids = [bos] + ids
230
+ if eos is not None and ids and ids[-1] == eos:
231
+ ids = ids[:-1]
232
+ return ids
233
+
234
+ @torch.no_grad()
235
+ def generate(
236
+ self,
237
+ question: str,
238
+ max_new_tokens: int = 96,
239
+ temperature: float = 0.4,
240
+ top_k: int = 40,
241
+ repetition_penalty: float = 1.12,
242
+ ) -> str:
243
+ ids = self.encode_prompt(question)
244
+ x = torch.tensor([ids], dtype=torch.long, device=self.device)
245
+
246
+ eos_id = getattr(self.tokenizer, "eos_token_id", None)
247
+ block_size = getattr(getattr(self.model, "cfg", None), "block_size", None)
248
+ if block_size is None:
249
+ block_size = safe_get(self.cfg_json, "block_size", "max_seq_len", default=512)
250
+
251
+ for step in range(max_new_tokens):
252
+ x_ctx = x[:, -int(block_size):]
253
+
254
+ out = self.model(x_ctx)
255
+ logits = out[0] if isinstance(out, tuple) else out
256
+ logits = logits[:, -1, :]
257
+
258
+ recent = x[0, -64:].tolist()
259
+ for tok in set(recent):
260
+ logits[0, tok] /= repetition_penalty
261
+
262
+ if temperature <= 0:
263
+ next_tok = torch.argmax(logits, dim=-1, keepdim=True)
264
+ else:
265
+ logits = logits / max(temperature, 1e-5)
266
+ if top_k is not None and top_k > 0:
267
+ values, _ = torch.topk(logits, k=min(top_k, logits.size(-1)))
268
+ kth = values[:, -1].unsqueeze(-1)
269
+ logits = torch.where(logits < kth, torch.full_like(logits, float("-inf")), logits)
270
+ probs = F.softmax(logits, dim=-1)
271
+ next_tok = torch.multinomial(probs, num_samples=1)
272
+
273
+ x = torch.cat([x, next_tok], dim=1)
274
+ if eos_id is not None and next_tok.item() == eos_id and step >= 2:
275
+ break
276
+
277
+ new_ids = x[0, len(ids):].tolist()
278
+ text = self.tokenizer.decode(new_ids).strip()
279
+ text = re.sub(r"\s+", " ", text).strip()
280
+ text = text.replace("Réponse :", "").replace("Réponse:", "").strip()
281
+ return text
282
+
283
+
284
+ def load_questions(path: Optional[str]) -> List[Dict[str, Any]]:
285
+ if not path:
286
+ return DEFAULT_QUESTIONS
287
+ with open(path, "r", encoding="utf-8") as f:
288
+ data = json.load(f)
289
+ if not isinstance(data, list):
290
+ raise ValueError("Le fichier questions doit contenir une liste JSON.")
291
+ return data
292
+
293
+
294
+ def format_bar(score: float, width: int = 20) -> str:
295
+ n = max(0, min(width, int(round(score * width))))
296
+ return "█" * n + "░" * (width - n)
297
+
298
+
299
+ def save_reports(output_dir: Path, summary: Dict[str, Any]) -> None:
300
+ output_dir.mkdir(parents=True, exist_ok=True)
301
+
302
+ json_path = output_dir / "qa_test_report_simple.json"
303
+ txt_path = output_dir / "qa_test_report_simple.txt"
304
+
305
+ with open(json_path, "w", encoding="utf-8") as f:
306
+ json.dump(summary, f, ensure_ascii=False, indent=2)
307
+
308
+ with open(txt_path, "w", encoding="utf-8") as f:
309
+ f.write("TEST QA SIMPLE — MODÈLE DÉJÀ ENTRAÎNÉ\n")
310
+ f.write("=" * 60 + "\n\n")
311
+ for r in summary["results"]:
312
+ f.write(f"[{r['id']:02d}] {r['category']}\n")
313
+ f.write(f" User : {r['question']}\n")
314
+ f.write(f" Assistant : {r['answer']}\n")
315
+ if r["reference"]:
316
+ f.write(f" Référence : {r['reference']}\n")
317
+ if r["overlap_score"] is not None:
318
+ f.write(f" Overlap : {r['overlap_score']:.0%}\n")
319
+ if r["exact_match"] is not None:
320
+ f.write(f" ExactMatch: {r['exact_match']}\n")
321
+ f.write(f" Latence : {r['latency_s']}s\n\n")
322
+
323
+
324
+ def main():
325
+ parser = argparse.ArgumentParser("Test QA simple pour modèle Aramix déjà entraîné")
326
+ parser.add_argument("--repo_dir", type=str, default="./aramix_h100")
327
+ parser.add_argument("--train_script", type=str, default=None)
328
+ parser.add_argument("--ckpt", type=str, default=None)
329
+ parser.add_argument("--config", type=str, default=None)
330
+ parser.add_argument("--questions", type=str, default=None)
331
+ parser.add_argument("--max_new_tokens", type=int, default=96)
332
+ parser.add_argument("--temperature", type=float, default=0.4)
333
+ parser.add_argument("--top_k", type=int, default=40)
334
+ parser.add_argument("--repetition_penalty", type=float, default=1.12)
335
+ parser.add_argument("--device", type=str, default=None)
336
+ parser.add_argument("--save_report", action="store_true")
337
+ args = parser.parse_args()
338
+
339
+ repo_dir = Path(args.repo_dir).resolve()
340
+ train_script, ckpt_path, config_path, tokenizer_dir = infer_repo_defaults(repo_dir)
341
+
342
+ if args.train_script:
343
+ train_script = Path(args.train_script).resolve()
344
+ if args.ckpt:
345
+ ckpt_path = Path(args.ckpt).resolve()
346
+ if args.config:
347
+ config_path = Path(args.config).resolve()
348
+
349
+ if not train_script.exists():
350
+ raise FileNotFoundError(f"Script train introuvable: {train_script}")
351
+ if not ckpt_path.exists():
352
+ raise FileNotFoundError(f"Checkpoint introuvable: {ckpt_path}")
353
+ if not config_path.exists():
354
+ print(f"[WARN] config.json introuvable: {config_path} — fallback sur GPTConfig(vocab_size=...).")
355
+
356
+ questions = load_questions(args.questions)
357
+ tester = AramixChatTester(
358
+ repo_dir=repo_dir,
359
+ train_script=train_script,
360
+ ckpt_path=ckpt_path,
361
+ config_path=config_path,
362
+ device=args.device,
363
+ )
364
+
365
+ results: List[Dict[str, Any]] = []
366
+ categories: Dict[str, List[Dict[str, Any]]] = {}
367
+
368
+ print("\n" + "═" * 70)
369
+ print(" TEST QA SIMPLE — MODÈLE DÉJÀ ENTRAÎNÉ")
370
+ print("═" * 70)
371
+ print(f"Repo : {repo_dir}")
372
+ print(f"Train script: {train_script}")
373
+ print(f"Checkpoint : {ckpt_path}")
374
+ print(f"Config : {config_path}")
375
+ print(f"Tokenizer : {tokenizer_dir}")
376
+ print(f"Device : {tester.device}")
377
+ print(f"Questions : {len(questions)}")
378
+ print("═" * 70 + "\n")
379
+
380
+ for i, item in enumerate(questions, 1):
381
+ q = item["question"]
382
+ ref = item.get("reference")
383
+ cat = item.get("category", "Général")
384
+
385
+ t0 = time.time()
386
+ ans = tester.generate(
387
+ q,
388
+ max_new_tokens=args.max_new_tokens,
389
+ temperature=args.temperature,
390
+ top_k=args.top_k,
391
+ repetition_penalty=args.repetition_penalty,
392
+ )
393
+ latency = time.time() - t0
394
+
395
+ overlap = lexical_overlap(ref, ans)
396
+ em = exact_match(ref, ans)
397
+
398
+ entry = {
399
+ "id": i,
400
+ "category": cat,
401
+ "question": q,
402
+ "answer": ans,
403
+ "reference": ref,
404
+ "latency_s": round(latency, 3),
405
+ "tokens_generated_approx": len(ans.split()),
406
+ "overlap_score": None if overlap is None else round(overlap, 4),
407
+ "exact_match": em,
408
+ }
409
+ results.append(entry)
410
+ categories.setdefault(cat, []).append(entry)
411
+
412
+ overlap_str = f"{overlap:.0%}" if overlap is not None else "n/a"
413
+ em_str = "✓" if em else ("✗" if em is not None else "n/a")
414
+
415
+ print("─" * 70)
416
+ print(f"[{i:02d}] [{cat}] overlap={overlap_str} | EM={em_str}")
417
+ print(f" User : {q}")
418
+ print(f" Assistant : {ans}")
419
+ if ref:
420
+ print(f" Référence : {ref}")
421
+ print(f" ⏱ {latency:.2f}s | ~{entry['tokens_generated_approx']} mots\n")
422
+
423
+ scored_overlap = [r["overlap_score"] for r in results if r["overlap_score"] is not None]
424
+ scored_em = [r["exact_match"] for r in results if r["exact_match"] is not None]
425
+
426
+ avg_overlap = sum(scored_overlap) / len(scored_overlap) if scored_overlap else 0.0
427
+ em_rate = sum(1 for x in scored_em if x) / len(scored_em) if scored_em else 0.0
428
+ avg_latency = sum(r["latency_s"] for r in results) / len(results) if results else 0.0
429
+ avg_words = sum(r["tokens_generated_approx"] for r in results) / len(results) if results else 0.0
430
+
431
+ cat_scores: Dict[str, float] = {}
432
+ for cat, items in categories.items():
433
+ vals = [r["overlap_score"] for r in items if r["overlap_score"] is not None]
434
+ cat_scores[cat] = (sum(vals) / len(vals)) if vals else 0.0
435
+
436
+ summary = {
437
+ "repo_dir": str(repo_dir),
438
+ "train_script": str(train_script),
439
+ "checkpoint": str(ckpt_path),
440
+ "config_path": str(config_path),
441
+ "tokenizer_dir": str(tokenizer_dir),
442
+ "device": str(tester.device),
443
+ "total_questions": len(results),
444
+ "avg_overlap_score": round(avg_overlap, 4),
445
+ "exact_match_rate": round(em_rate, 4),
446
+ "avg_latency_s": round(avg_latency, 3),
447
+ "avg_words_generated": round(avg_words, 1),
448
+ "scores_by_category": {k: round(v, 4) for k, v in cat_scores.items()},
449
+ "results": results,
450
+ }
451
+
452
+ print("═" * 70)
453
+ print(" RÉSUMÉ")
454
+ print("═" * 70)
455
+ print(f"Questions testées : {len(results)}")
456
+ print(f"Overlap moyen : {avg_overlap:.1%}")
457
+ print(f"Exact match : {em_rate:.1%}")
458
+ print(f"Latence moyenne : {avg_latency:.2f}s")
459
+ print(f"Mots moyens : {avg_words:.1f}")
460
+ print("Scores / catégorie:")
461
+ for cat, score in sorted(cat_scores.items()):
462
+ print(f" {cat:<15} {format_bar(score)} {score:.0%}")
463
+ print("═" * 70)
464
+
465
+ if args.save_report:
466
+ save_reports(repo_dir, summary)
467
+ print(f"Rapports sauvegardés dans : {repo_dir / 'qa_test_report_simple.json'}")
468
+ print(f" {repo_dir / 'qa_test_report_simple.txt'}")
469
+
470
+
471
+ if __name__ == "__main__":
472
+ main()
simple_qa_test_aramix_v3.py ADDED
@@ -0,0 +1,583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ simple_qa_test_aramix_v3.py
6
+
7
+ Test QA simple et strict pour un modèle déjà entraîné dans une repo de type :
8
+ - train_aramix_h100_full.py
9
+ - aramix_h100/
10
+ - config.json
11
+ - model_best.pt
12
+ - model.pt
13
+ - tokenizer_32k/
14
+
15
+ Cette version améliore le test QA en :
16
+ - gérant les checkpoints torch.compile() avec préfixe "_orig_mod."
17
+ - utilisant un prompt plus directif pour des réponses courtes
18
+ - privilégiant une génération greedy / quasi-greedy
19
+ - tronquant proprement les réponses trop longues
20
+ - ajoutant une métrique "contains_reference"
21
+
22
+ Usage
23
+ -----
24
+ python simple_qa_test_aramix_v3.py --repo_dir ./aramix_h100 --save_report
25
+ python simple_qa_test_aramix_v3.py --repo_dir ./aramix_h100 --temperature 0 --max_new_tokens 16
26
+ python simple_qa_test_aramix_v3.py --repo_dir ./aramix_h100 --ckpt ./aramix_h100/model_best.pt
27
+ python simple_qa_test_aramix_v3.py --repo_dir ./aramix_h100 --questions qa_questions.json
28
+ """
29
+
30
+ from __future__ import annotations
31
+
32
+ import argparse
33
+ import importlib.util
34
+ import json
35
+ import os
36
+ import re
37
+ import sys
38
+ import time
39
+ import unicodedata
40
+ from pathlib import Path
41
+ from typing import Any, Dict, List, Optional, Tuple
42
+
43
+ import torch
44
+ import torch.nn.functional as F
45
+
46
+
47
+ DEFAULT_QUESTIONS = [
48
+ {"category": "Géographie", "question": "Quelle est la capitale de la France ?", "reference": "Paris"},
49
+ {"category": "Géographie", "question": "Quel est le plus long fleuve d'Afrique ?", "reference": "Le Nil"},
50
+ {"category": "Science", "question": "Qu'est-ce que la photosynthèse ?", "reference": "Processus par lequel les plantes convertissent la lumière en énergie"},
51
+ {"category": "Science", "question": "Combien d'os compte le corps humain adulte ?", "reference": "206"},
52
+ {"category": "Histoire", "question": "En quelle année a eu lieu la Révolution française ?", "reference": "1789"},
53
+ {"category": "Histoire", "question": "Qui a écrit Les Misérables ?", "reference": "Victor Hugo"},
54
+ {"category": "Mathématiques", "question": "Quelle est la formule de l'aire d'un cercle ?", "reference": "pi r carre"},
55
+ {"category": "Langage", "question": "Donne un synonyme du mot heureux.", "reference": "joyeux"},
56
+ {"category": "Raisonnement", "question": "Si j'ai 5 pommes et j'en donne 2, combien m'en reste-t-il ?", "reference": "3"},
57
+ {"category": "Dialogue", "question": "Comment vas-tu aujourd'hui ?", "reference": None},
58
+ ]
59
+
60
+
61
+ def load_module_from_file(py_path: Path):
62
+ spec = importlib.util.spec_from_file_location(py_path.stem, py_path)
63
+ if spec is None or spec.loader is None:
64
+ raise RuntimeError(f"Impossible de charger le module: {py_path}")
65
+ module = importlib.util.module_from_spec(spec)
66
+ sys.modules[py_path.stem] = module
67
+ spec.loader.exec_module(module)
68
+ return module
69
+
70
+
71
+ def normalize_text(text: str) -> str:
72
+ text = (text or "").strip().lower()
73
+ text = unicodedata.normalize("NFKD", text)
74
+ text = "".join(ch for ch in text if not unicodedata.combining(ch))
75
+ text = text.replace("π", "pi")
76
+ text = re.sub(r"[\W_]+", " ", text, flags=re.UNICODE)
77
+ text = re.sub(r"\s+", " ", text).strip()
78
+ return text
79
+
80
+
81
+ def lexical_overlap(reference: Optional[str], answer: str) -> Optional[float]:
82
+ if not reference:
83
+ return None
84
+ ref = set(normalize_text(reference).split())
85
+ ans = set(normalize_text(answer).split())
86
+ if not ref:
87
+ return None
88
+ return len(ref & ans) / len(ref)
89
+
90
+
91
+ def exact_match(reference: Optional[str], answer: str) -> Optional[bool]:
92
+ if not reference:
93
+ return None
94
+ return normalize_text(reference) == normalize_text(answer)
95
+
96
+
97
+ def contains_reference(reference: Optional[str], answer: str) -> Optional[bool]:
98
+ if not reference:
99
+ return None
100
+ ref = normalize_text(reference)
101
+ ans = normalize_text(answer)
102
+ if not ref:
103
+ return None
104
+ return ref in ans
105
+
106
+
107
+ def infer_repo_defaults(repo_dir: Path):
108
+ train_script = repo_dir.parent / "train_aramix_h100_full.py"
109
+ if not train_script.exists():
110
+ train_script = repo_dir / "train_aramix_h100_full.py"
111
+
112
+ ckpt = repo_dir / "model_best.pt"
113
+ if not ckpt.exists():
114
+ ckpt = repo_dir / "model.pt"
115
+
116
+ config = repo_dir / "config.json"
117
+ tokenizer_dir = repo_dir / "tokenizer_32k"
118
+ return train_script, ckpt, config, tokenizer_dir
119
+
120
+
121
+ def safe_get(cfg: Dict[str, Any], *names: str, default=None):
122
+ for name in names:
123
+ if name in cfg:
124
+ return cfg[name]
125
+ return default
126
+
127
+
128
+ def build_model_config_dict(cfg_json: Dict[str, Any], vocab_size: int) -> Dict[str, Any]:
129
+ block_size = safe_get(cfg_json, "block_size", "max_seq_len", "seq_len", default=512)
130
+ d_model = safe_get(cfg_json, "d_model", "n_embd", "dim", default=768)
131
+ n_heads = safe_get(cfg_json, "n_heads", "n_head", "num_heads", default=12)
132
+ n_layers = safe_get(cfg_json, "n_layers", "n_layer", "num_layers", default=12)
133
+ d_ff = safe_get(cfg_json, "d_ff", "ffn_dim", "intermediate_size", default=d_model * 4)
134
+ return {
135
+ "vocab_size": vocab_size,
136
+ "block_size": int(block_size),
137
+ "d_model": int(d_model),
138
+ "n_heads": int(n_heads),
139
+ "n_layers": int(n_layers),
140
+ "d_ff": int(d_ff),
141
+ }
142
+
143
+
144
+ def strip_orig_mod_prefix(state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
145
+ if any(k.startswith("_orig_mod.") for k in state.keys()):
146
+ return {k.replace("_orig_mod.", "", 1): v for k, v in state.items()}
147
+ return state
148
+
149
+
150
+ def clean_answer_text(text: str, max_words: int = 16) -> str:
151
+ text = (text or "").strip()
152
+
153
+ # Retire quelques marqueurs fréquents
154
+ text = text.replace("<eos>", " ")
155
+ text = text.replace("</s>", " ")
156
+ text = text.replace("<pad>", " ")
157
+ text = text.replace("Réponse :", " ")
158
+ text = text.replace("Réponse:", " ")
159
+ text = text.replace("Answer:", " ")
160
+
161
+ # Garde la première ligne
162
+ text = text.split("\n")[0].strip()
163
+
164
+ # Coupe à la première vraie fin de phrase courte
165
+ m = re.search(r"([.!?])", text)
166
+ if m and m.start() < 120:
167
+ text = text[: m.start() + 1]
168
+
169
+ # Compacte les espaces
170
+ text = re.sub(r"\s+", " ", text).strip()
171
+
172
+ # Tronque au nombre de mots voulu
173
+ words = text.split()
174
+ if len(words) > max_words:
175
+ text = " ".join(words[:max_words]).strip()
176
+
177
+ # Supprime ponctuation finale excessive
178
+ text = re.sub(r"[,\s;:]+$", "", text).strip()
179
+
180
+ return text
181
+
182
+
183
+ def safe_top_k_logits(logits: torch.Tensor, top_k: int) -> torch.Tensor:
184
+ if top_k is None or top_k <= 0:
185
+ return logits
186
+ k = min(int(top_k), logits.size(-1))
187
+ values, _ = torch.topk(logits, k=k)
188
+ kth = values[:, -1].unsqueeze(-1)
189
+ return torch.where(logits < kth, torch.full_like(logits, float("-inf")), logits)
190
+
191
+
192
+ class AramixChatTester:
193
+ def __init__(
194
+ self,
195
+ repo_dir: Path,
196
+ train_script: Path,
197
+ ckpt_path: Path,
198
+ config_path: Path,
199
+ device: Optional[str] = None,
200
+ prompt_style: str = "strict_qa",
201
+ ):
202
+ self.repo_dir = repo_dir
203
+ self.train_script = train_script
204
+ self.ckpt_path = ckpt_path
205
+ self.config_path = config_path
206
+ self.prompt_style = prompt_style
207
+ self.device = torch.device(device if device is not None else ("cuda" if torch.cuda.is_available() else "cpu"))
208
+
209
+ self.M = load_module_from_file(self.train_script)
210
+
211
+ required = ["GPT", "GPTConfig", "train_or_load_tokenizer", "DOMAINS"]
212
+ missing = [x for x in required if not hasattr(self.M, x)]
213
+ if missing:
214
+ raise RuntimeError(f"Le fichier {self.train_script.name} ne contient pas les symboles attendus: {missing}")
215
+
216
+ self.cfg_json: Dict[str, Any] = {}
217
+ if self.config_path.exists():
218
+ with open(self.config_path, "r", encoding="utf-8") as f:
219
+ self.cfg_json = json.load(f)
220
+
221
+ self.tokenizer = self._load_tokenizer()
222
+ self.model = self._load_model()
223
+
224
+ def _load_tokenizer(self):
225
+ old_cwd = Path.cwd()
226
+ try:
227
+ os.chdir(self.repo_dir.parent)
228
+ tok = self.M.train_or_load_tokenizer(self.M.DOMAINS)
229
+ finally:
230
+ os.chdir(old_cwd)
231
+ return tok
232
+
233
+ def _make_gpt_config(self):
234
+ kwargs = build_model_config_dict(self.cfg_json, vocab_size=len(self.tokenizer))
235
+ try:
236
+ return self.M.GPTConfig(**kwargs)
237
+ except TypeError:
238
+ return self.M.GPTConfig(vocab_size=len(self.tokenizer))
239
+
240
+ def _manual_load_state(self, model: torch.nn.Module):
241
+ ckpt = torch.load(self.ckpt_path, map_location=self.device)
242
+ if isinstance(ckpt, dict) and "model" in ckpt:
243
+ state = ckpt["model"]
244
+ else:
245
+ state = ckpt
246
+
247
+ if not isinstance(state, dict):
248
+ raise RuntimeError("Checkpoint non reconnu: pas de state_dict exploitable.")
249
+
250
+ state = strip_orig_mod_prefix(state)
251
+ missing, unexpected = model.load_state_dict(state, strict=False)
252
+
253
+ # tolérance minimale uniquement sur le tying éventuel
254
+ hard_missing = [k for k in missing if not k.endswith("lm_head.weight")]
255
+ hard_unexpected = [k for k in unexpected if not k.startswith("_orig_mod.")]
256
+ if hard_missing or hard_unexpected:
257
+ raise RuntimeError(
258
+ "Chargement manuel incomplet.\n"
259
+ f"Missing: {hard_missing[:20]}\n"
260
+ f"Unexpected: {hard_unexpected[:20]}"
261
+ )
262
+
263
+ def _load_model(self):
264
+ cfg = self._make_gpt_config()
265
+ model = self.M.GPT(cfg).to(self.device)
266
+
267
+ if hasattr(self.M, "load_checkpoint"):
268
+ try:
269
+ try:
270
+ self.M.load_checkpoint(model, None, self.ckpt_path, self.device)
271
+ model.eval()
272
+ return model
273
+ except TypeError:
274
+ self.M.load_checkpoint(model, self.ckpt_path, self.device)
275
+ model.eval()
276
+ return model
277
+ except RuntimeError as e:
278
+ if "_orig_mod." not in str(e):
279
+ raise
280
+
281
+ self._manual_load_state(model)
282
+ model.eval()
283
+ return model
284
+
285
+ def build_prompt(self, question: str) -> str:
286
+ if self.prompt_style == "strict_qa":
287
+ return (
288
+ "Réponds très brièvement et uniquement en français.\n"
289
+ "Donne seulement la réponse finale, sans explication.\n\n"
290
+ f"Question : {question}\n"
291
+ "Réponse :"
292
+ )
293
+ if self.prompt_style == "qa":
294
+ return f"Question: {question}\nRéponse:"
295
+ return question
296
+
297
+ def encode_prompt(self, question: str) -> List[int]:
298
+ bos = getattr(self.tokenizer, "bos_token_id", None)
299
+ eos = getattr(self.tokenizer, "eos_token_id", None)
300
+
301
+ prompt = self.build_prompt(question)
302
+ ids = self.tokenizer.encode(prompt, add_special_tokens=False)
303
+
304
+ if bos is not None:
305
+ ids = [bos] + ids
306
+ if eos is not None and ids and ids[-1] == eos:
307
+ ids = ids[:-1]
308
+ return ids
309
+
310
+ @torch.no_grad()
311
+ def generate(
312
+ self,
313
+ question: str,
314
+ max_new_tokens: int = 16,
315
+ temperature: float = 0.0,
316
+ top_k: int = 1,
317
+ repetition_penalty: float = 1.10,
318
+ max_answer_words: int = 16,
319
+ ) -> str:
320
+ ids = self.encode_prompt(question)
321
+ x = torch.tensor([ids], dtype=torch.long, device=self.device)
322
+
323
+ eos_id = getattr(self.tokenizer, "eos_token_id", None)
324
+ block_size = getattr(getattr(self.model, "cfg", None), "block_size", None)
325
+ if block_size is None:
326
+ block_size = safe_get(self.cfg_json, "block_size", "max_seq_len", default=512)
327
+
328
+ generated_word_budget_hit = False
329
+
330
+ for step in range(max_new_tokens):
331
+ x_ctx = x[:, -int(block_size):]
332
+ out = self.model(x_ctx)
333
+ logits = out[0] if isinstance(out, tuple) else out
334
+ logits = logits[:, -1, :]
335
+
336
+ # Pénalité légère de répétition
337
+ recent = x[0, -48:].tolist()
338
+ for tok in set(recent):
339
+ logits[0, tok] /= repetition_penalty
340
+
341
+ # Greedy par défaut pour QA courte
342
+ if temperature <= 0:
343
+ if top_k is not None and top_k > 1:
344
+ masked = safe_top_k_logits(logits, top_k)
345
+ probs = F.softmax(masked, dim=-1)
346
+ next_tok = torch.multinomial(probs, num_samples=1)
347
+ else:
348
+ next_tok = torch.argmax(logits, dim=-1, keepdim=True)
349
+ else:
350
+ logits = logits / max(temperature, 1e-5)
351
+ logits = safe_top_k_logits(logits, top_k)
352
+ probs = F.softmax(logits, dim=-1)
353
+ next_tok = torch.multinomial(probs, num_samples=1)
354
+
355
+ x = torch.cat([x, next_tok], dim=1)
356
+
357
+ if eos_id is not None and next_tok.item() == eos_id and step >= 1:
358
+ break
359
+
360
+ # arrêt anticipé si la réponse devient déjà trop longue
361
+ partial = self.tokenizer.decode(x[0, len(ids):].tolist()).strip()
362
+ partial = clean_answer_text(partial, max_words=max_answer_words)
363
+ if len(partial.split()) >= max_answer_words:
364
+ generated_word_budget_hit = True
365
+ break
366
+
367
+ new_ids = x[0, len(ids):].tolist()
368
+ text = self.tokenizer.decode(new_ids).strip()
369
+ text = clean_answer_text(text, max_words=max_answer_words)
370
+
371
+ if generated_word_budget_hit:
372
+ text = clean_answer_text(text, max_words=max_answer_words)
373
+
374
+ return text
375
+
376
+
377
+ def load_questions(path: Optional[str]) -> List[Dict[str, Any]]:
378
+ if not path:
379
+ return DEFAULT_QUESTIONS
380
+ with open(path, "r", encoding="utf-8") as f:
381
+ data = json.load(f)
382
+ if not isinstance(data, list):
383
+ raise ValueError("Le fichier questions doit contenir une liste JSON.")
384
+ return data
385
+
386
+
387
+ def format_bar(score: float, width: int = 20) -> str:
388
+ n = max(0, min(width, int(round(score * width))))
389
+ return "█" * n + "░" * (width - n)
390
+
391
+
392
+ def save_reports(output_dir: Path, summary: Dict[str, Any]) -> Tuple[Path, Path]:
393
+ output_dir.mkdir(parents=True, exist_ok=True)
394
+
395
+ json_path = output_dir / "qa_test_report_simple_v3.json"
396
+ txt_path = output_dir / "qa_test_report_simple_v3.txt"
397
+
398
+ with open(json_path, "w", encoding="utf-8") as f:
399
+ json.dump(summary, f, ensure_ascii=False, indent=2)
400
+
401
+ with open(txt_path, "w", encoding="utf-8") as f:
402
+ f.write("TEST QA SIMPLE — MODÈLE DÉJÀ ENTRAÎNÉ (V3)\n")
403
+ f.write("=" * 60 + "\n\n")
404
+ for r in summary["results"]:
405
+ f.write(f"[{r['id']:02d}] {r['category']}\n")
406
+ f.write(f" User : {r['question']}\n")
407
+ f.write(f" Assistant : {r['answer']}\n")
408
+ if r["reference"]:
409
+ f.write(f" Référence : {r['reference']}\n")
410
+ if r["overlap_score"] is not None:
411
+ f.write(f" Overlap : {r['overlap_score']:.0%}\n")
412
+ if r["exact_match"] is not None:
413
+ f.write(f" ExactMatch: {r['exact_match']}\n")
414
+ if r["contains_reference"] is not None:
415
+ f.write(f" Contains : {r['contains_reference']}\n")
416
+ f.write(f" Latence : {r['latency_s']}s\n\n")
417
+ return json_path, txt_path
418
+
419
+
420
+ def main():
421
+ parser = argparse.ArgumentParser("Test QA simple et strict pour modèle Aramix déjà entraîné")
422
+ parser.add_argument("--repo_dir", type=str, default="./aramix_h100")
423
+ parser.add_argument("--train_script", type=str, default=None)
424
+ parser.add_argument("--ckpt", type=str, default=None)
425
+ parser.add_argument("--config", type=str, default=None)
426
+ parser.add_argument("--questions", type=str, default=None)
427
+ parser.add_argument("--max_new_tokens", type=int, default=16)
428
+ parser.add_argument("--temperature", type=float, default=0.0)
429
+ parser.add_argument("--top_k", type=int, default=1)
430
+ parser.add_argument("--repetition_penalty", type=float, default=1.10)
431
+ parser.add_argument("--max_answer_words", type=int, default=16)
432
+ parser.add_argument("--prompt_style", type=str, default="strict_qa", choices=["strict_qa", "qa", "raw"])
433
+ parser.add_argument("--device", type=str, default=None)
434
+ parser.add_argument("--save_report", action="store_true")
435
+ args = parser.parse_args()
436
+
437
+ repo_dir = Path(args.repo_dir).resolve()
438
+ train_script, ckpt_path, config_path, tokenizer_dir = infer_repo_defaults(repo_dir)
439
+
440
+ if args.train_script:
441
+ train_script = Path(args.train_script).resolve()
442
+ if args.ckpt:
443
+ ckpt_path = Path(args.ckpt).resolve()
444
+ if args.config:
445
+ config_path = Path(args.config).resolve()
446
+
447
+ if not train_script.exists():
448
+ raise FileNotFoundError(f"Script train introuvable: {train_script}")
449
+ if not ckpt_path.exists():
450
+ raise FileNotFoundError(f"Checkpoint introuvable: {ckpt_path}")
451
+ if not config_path.exists():
452
+ print(f"[WARN] config.json introuvable: {config_path} — fallback sur GPTConfig(vocab_size=...).")
453
+
454
+ questions = load_questions(args.questions)
455
+ tester = AramixChatTester(
456
+ repo_dir=repo_dir,
457
+ train_script=train_script,
458
+ ckpt_path=ckpt_path,
459
+ config_path=config_path,
460
+ device=args.device,
461
+ prompt_style=args.prompt_style,
462
+ )
463
+
464
+ results: List[Dict[str, Any]] = []
465
+ categories: Dict[str, List[Dict[str, Any]]] = {}
466
+
467
+ print("\n" + "=" * 60)
468
+ print("TEST QA SIMPLE — MODÈLE DÉJÀ ENTRAÎNÉ")
469
+ print("=" * 60)
470
+ print(f"Repo : {repo_dir}")
471
+ print(f"Train script: {train_script}")
472
+ print(f"Checkpoint : {ckpt_path}")
473
+ print(f"Config : {config_path}")
474
+ print(f"Tokenizer : {tokenizer_dir}")
475
+ print(f"Device : {tester.device}")
476
+ print(f"Prompt : {args.prompt_style}")
477
+ print(f"Questions : {len(questions)}")
478
+ print("=" * 60 + "\n")
479
+
480
+ for i, item in enumerate(questions, 1):
481
+ q = item["question"]
482
+ ref = item.get("reference")
483
+ cat = item.get("category", "Général")
484
+
485
+ t0 = time.time()
486
+ ans = tester.generate(
487
+ q,
488
+ max_new_tokens=args.max_new_tokens,
489
+ temperature=args.temperature,
490
+ top_k=args.top_k,
491
+ repetition_penalty=args.repetition_penalty,
492
+ max_answer_words=args.max_answer_words,
493
+ )
494
+ latency = time.time() - t0
495
+
496
+ overlap = lexical_overlap(ref, ans)
497
+ em = exact_match(ref, ans)
498
+ contains_ref = contains_reference(ref, ans)
499
+
500
+ entry = {
501
+ "id": i,
502
+ "category": cat,
503
+ "question": q,
504
+ "answer": ans,
505
+ "reference": ref,
506
+ "latency_s": round(latency, 3),
507
+ "tokens_generated_approx": len(ans.split()),
508
+ "overlap_score": None if overlap is None else round(overlap, 4),
509
+ "exact_match": em,
510
+ "contains_reference": contains_ref,
511
+ }
512
+ results.append(entry)
513
+ categories.setdefault(cat, []).append(entry)
514
+
515
+ overlap_str = f"{overlap:.0%}" if overlap is not None else "n/a"
516
+ em_str = "✓" if em else ("✗" if em is not None else "n/a")
517
+ contains_str = "✓" if contains_ref else ("✗" if contains_ref is not None else "n/a")
518
+
519
+ print(f"[{i:02d}] {cat}")
520
+ print(f" User : {q}")
521
+ print(f" Assistant : {ans}")
522
+ if ref:
523
+ print(f" Référence : {ref}")
524
+ print(f" Overlap : {overlap_str}")
525
+ print(f" ExactMatch: {em_str}")
526
+ print(f" Contains : {contains_str}")
527
+ print(f" Latence : {latency:.3f}s\n")
528
+
529
+ scored_overlap = [r["overlap_score"] for r in results if r["overlap_score"] is not None]
530
+ scored_em = [r["exact_match"] for r in results if r["exact_match"] is not None]
531
+ scored_contains = [r["contains_reference"] for r in results if r["contains_reference"] is not None]
532
+
533
+ avg_overlap = sum(scored_overlap) / len(scored_overlap) if scored_overlap else 0.0
534
+ em_rate = sum(1 for x in scored_em if x) / len(scored_em) if scored_em else 0.0
535
+ contains_rate = sum(1 for x in scored_contains if x) / len(scored_contains) if scored_contains else 0.0
536
+ avg_latency = sum(r["latency_s"] for r in results) / len(results) if results else 0.0
537
+ avg_words = sum(r["tokens_generated_approx"] for r in results) / len(results) if results else 0.0
538
+
539
+ cat_scores: Dict[str, float] = {}
540
+ for cat, items in categories.items():
541
+ vals = [r["overlap_score"] for r in items if r["overlap_score"] is not None]
542
+ cat_scores[cat] = (sum(vals) / len(vals)) if vals else 0.0
543
+
544
+ summary = {
545
+ "repo_dir": str(repo_dir),
546
+ "train_script": str(train_script),
547
+ "checkpoint": str(ckpt_path),
548
+ "config_path": str(config_path),
549
+ "tokenizer_dir": str(tokenizer_dir),
550
+ "device": str(tester.device),
551
+ "prompt_style": args.prompt_style,
552
+ "total_questions": len(results),
553
+ "avg_overlap_score": round(avg_overlap, 4),
554
+ "exact_match_rate": round(em_rate, 4),
555
+ "contains_reference_rate": round(contains_rate, 4),
556
+ "avg_latency_s": round(avg_latency, 3),
557
+ "avg_words_generated": round(avg_words, 1),
558
+ "scores_by_category": {k: round(v, 4) for k, v in cat_scores.items()},
559
+ "results": results,
560
+ }
561
+
562
+ print("=" * 60)
563
+ print("RÉSUMÉ")
564
+ print("=" * 60)
565
+ print(f"Questions testées : {len(results)}")
566
+ print(f"Overlap moyen : {avg_overlap:.1%}")
567
+ print(f"Exact match : {em_rate:.1%}")
568
+ print(f"Contains ref : {contains_rate:.1%}")
569
+ print(f"Latence moyenne : {avg_latency:.2f}s")
570
+ print(f"Mots moyens : {avg_words:.1f}")
571
+ print("Scores / catégorie:")
572
+ for cat, score in sorted(cat_scores.items()):
573
+ print(f" {cat:<15} {format_bar(score)} {score:.0%}")
574
+ print("=" * 60)
575
+
576
+ if args.save_report:
577
+ json_path, txt_path = save_reports(repo_dir, summary)
578
+ print(f"Rapports sauvegardés dans : {json_path}")
579
+ print(f" {txt_path}")
580
+
581
+
582
+ if __name__ == "__main__":
583
+ main()
simple_qa_test_finished_model (1).py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ Test QA simple pour un modèle déjà entraîné.
6
+
7
+ Fonctionne avec le script : train_chatbot_100m_large.py
8
+ - charge la config depuis output_dir/train_config.json si disponible
9
+ - charge un checkpoint fini (par défaut: output_dir/sft_best.pt)
10
+ - pose une petite liste de questions QA
11
+ - calcule un score simple d'overlap lexical
12
+ - sauvegarde un rapport JSON + TXT
13
+
14
+ Exemples
15
+ --------
16
+ python simple_qa_test_finished_model.py --output_dir ./fr_100m
17
+ python simple_qa_test_finished_model.py --output_dir ./fr_100m --ckpt ./fr_100m/sft_final.pt
18
+ python simple_qa_test_finished_model.py --output_dir ./fr_100m --questions qa_questions.json
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import argparse
24
+ import importlib.util
25
+ import json
26
+ import os
27
+ import sys
28
+ import time
29
+ import unicodedata
30
+ import re
31
+ from pathlib import Path
32
+ from typing import Dict, List, Optional
33
+
34
+
35
+ DEFAULT_QUESTIONS = [
36
+ {
37
+ "category": "Géographie",
38
+ "question": "Quelle est la capitale de la France ?",
39
+ "reference": "Paris",
40
+ },
41
+ {
42
+ "category": "Géographie",
43
+ "question": "Quel est le plus long fleuve d'Afrique ?",
44
+ "reference": "Le Nil",
45
+ },
46
+ {
47
+ "category": "Science",
48
+ "question": "Qu'est-ce que la photosynthèse ?",
49
+ "reference": "Processus par lequel les plantes convertissent la lumière en énergie",
50
+ },
51
+ {
52
+ "category": "Science",
53
+ "question": "Combien d'os compte le corps humain adulte ?",
54
+ "reference": "206",
55
+ },
56
+ {
57
+ "category": "Histoire",
58
+ "question": "En quelle année a eu lieu la Révolution française ?",
59
+ "reference": "1789",
60
+ },
61
+ {
62
+ "category": "Histoire",
63
+ "question": "Qui a écrit Les Misérables ?",
64
+ "reference": "Victor Hugo",
65
+ },
66
+ {
67
+ "category": "Mathématiques",
68
+ "question": "Quelle est la formule de l'aire d'un cercle ?",
69
+ "reference": "π × r²",
70
+ },
71
+ {
72
+ "category": "Langage",
73
+ "question": "Donne un synonyme du mot heureux.",
74
+ "reference": "joyeux",
75
+ },
76
+ {
77
+ "category": "Raisonnement",
78
+ "question": "Si j'ai 5 pommes et j'en donne 2, combien m'en reste-t-il ?",
79
+ "reference": "3",
80
+ },
81
+ {
82
+ "category": "Dialogue",
83
+ "question": "Comment vas-tu aujourd'hui ?",
84
+ "reference": None,
85
+ },
86
+ ]
87
+
88
+
89
+ def normalize_text(s: str) -> str:
90
+ s = (s or "").strip().lower()
91
+ s = unicodedata.normalize("NFKD", s)
92
+ s = "".join(ch for ch in s if not unicodedata.combining(ch))
93
+ s = re.sub(r"[^\w\s]", " ", s, flags=re.UNICODE)
94
+ s = re.sub(r"\s+", " ", s).strip()
95
+ return s
96
+
97
+
98
+ def lexical_overlap(reference: Optional[str], answer: str) -> Optional[float]:
99
+ if not reference:
100
+ return None
101
+ ref_tokens = set(normalize_text(reference).split())
102
+ ans_tokens = set(normalize_text(answer).split())
103
+ if not ref_tokens:
104
+ return 0.0
105
+ return len(ref_tokens & ans_tokens) / len(ref_tokens)
106
+
107
+
108
+ def exact_match(reference: Optional[str], answer: str) -> Optional[bool]:
109
+ if not reference:
110
+ return None
111
+ return normalize_text(reference) == normalize_text(answer)
112
+
113
+
114
+ def import_train_module(train_script_path: str):
115
+ path = Path(train_script_path)
116
+ if not path.exists():
117
+ raise FileNotFoundError(f"Script d'entraînement introuvable: {path}")
118
+
119
+ spec = importlib.util.spec_from_file_location("train_module", str(path))
120
+ if spec is None or spec.loader is None:
121
+ raise ImportError(f"Impossible de charger le module: {path}")
122
+
123
+ module = importlib.util.module_from_spec(spec)
124
+ sys.modules["train_module"] = module
125
+ spec.loader.exec_module(module)
126
+ return module
127
+
128
+
129
+ def build_cfg(train_module, output_dir: str):
130
+ cfg_path = Path(output_dir) / "train_config.json"
131
+ if cfg_path.exists():
132
+ with open(cfg_path, "r", encoding="utf-8") as f:
133
+ saved = json.load(f)
134
+ cfg = train_module.TrainConfig(**saved)
135
+ else:
136
+ cfg = train_module.TrainConfig(output_dir=output_dir, tokenizer_prefix=f"{output_dir}/tokenizer")
137
+
138
+ cfg.output_dir = output_dir
139
+ cfg.tokenizer_prefix = f"{output_dir}/tokenizer"
140
+ return cfg
141
+
142
+
143
+ def run_test(
144
+ train_script: str,
145
+ output_dir: str,
146
+ ckpt_path: str,
147
+ questions: List[Dict],
148
+ save_report: bool,
149
+ ):
150
+ train_module = import_train_module(train_script)
151
+ cfg = build_cfg(train_module, output_dir)
152
+ bot = train_module.Chatbot(cfg, ckpt_path)
153
+
154
+ results = []
155
+ categories: Dict[str, List[Dict]] = {}
156
+
157
+ sep = "─" * 64
158
+ print(f"\n{'═'*64}")
159
+ print(" TEST QA SIMPLE — MODÈLE ENTRAÎNÉ")
160
+ print(f" Checkpoint : {ckpt_path}")
161
+ print(f" Questions : {len(questions)}")
162
+ print(f"{'═'*64}\n")
163
+
164
+ for i, item in enumerate(questions, 1):
165
+ q = item["question"]
166
+ ref = item.get("reference")
167
+ cat = item.get("category", "Général")
168
+ ctx = item.get("context", "")
169
+
170
+ t0 = time.time()
171
+ ans = bot.chat(q, context=ctx)
172
+ latency = time.time() - t0
173
+
174
+ overlap = lexical_overlap(ref, ans)
175
+ em = exact_match(ref, ans)
176
+
177
+ row = {
178
+ "id": i,
179
+ "category": cat,
180
+ "question": q,
181
+ "context": ctx,
182
+ "reference": ref,
183
+ "answer": ans,
184
+ "overlap_score": overlap,
185
+ "exact_match": em,
186
+ "latency_s": round(latency, 3),
187
+ "tokens_generated_approx": len(ans.split()),
188
+ }
189
+ results.append(row)
190
+ categories.setdefault(cat, []).append(row)
191
+
192
+ score_text = []
193
+ if overlap is not None:
194
+ score_text.append(f"overlap={overlap:.0%}")
195
+ if em is not None:
196
+ score_text.append(f"EM={'oui' if em else 'non'}")
197
+ score_str = f" [{' | '.join(score_text)}]" if score_text else ""
198
+
199
+ print(sep)
200
+ print(f"[{i:02d}] [{cat}]{score_str}")
201
+ if ctx:
202
+ print(f" Contexte : {ctx[:120]}{'...' if len(ctx) > 120 else ''}")
203
+ print(f" User : {q}")
204
+ print(f" Assistant : {ans}")
205
+ if ref:
206
+ print(f" Référence : {ref}")
207
+ print(f" ⏱ {latency:.2f}s | ~{row['tokens_generated_approx']} mots\n")
208
+
209
+ scored = [r for r in results if r["overlap_score"] is not None]
210
+ avg_overlap = sum(r["overlap_score"] for r in scored) / len(scored) if scored else 0.0
211
+ em_rows = [r for r in results if r["exact_match"] is not None]
212
+ em_rate = sum(1 for r in em_rows if r["exact_match"]) / len(em_rows) if em_rows else 0.0
213
+ avg_latency = sum(r["latency_s"] for r in results) / max(1, len(results))
214
+ avg_tokens = sum(r["tokens_generated_approx"] for r in results) / max(1, len(results))
215
+
216
+ scores_by_category = {}
217
+ for cat, items in categories.items():
218
+ cat_scored = [x for x in items if x["overlap_score"] is not None]
219
+ cat_em = [x for x in items if x["exact_match"] is not None]
220
+ scores_by_category[cat] = {
221
+ "avg_overlap": round(sum(x["overlap_score"] for x in cat_scored) / len(cat_scored), 4) if cat_scored else None,
222
+ "exact_match_rate": round(sum(1 for x in cat_em if x["exact_match"]) / len(cat_em), 4) if cat_em else None,
223
+ }
224
+
225
+ summary = {
226
+ "checkpoint": ckpt_path,
227
+ "total_questions": len(results),
228
+ "avg_overlap_score": round(avg_overlap, 4),
229
+ "exact_match_rate": round(em_rate, 4),
230
+ "avg_latency_s": round(avg_latency, 3),
231
+ "avg_tokens_generated_approx": round(avg_tokens, 1),
232
+ "scores_by_category": scores_by_category,
233
+ "results": results,
234
+ }
235
+
236
+ print(f"{'═'*64}")
237
+ print(" RÉSUMÉ")
238
+ print(f"{'═'*64}")
239
+ print(f" Questions testées : {len(results)}")
240
+ print(f" Overlap moyen : {avg_overlap:.1%}")
241
+ print(f" Exact match : {em_rate:.1%}")
242
+ print(f" Latence moyenne : {avg_latency:.2f}s")
243
+ print(f" Mots moyens : {avg_tokens:.0f}")
244
+ print(" Scores / catégorie :")
245
+ for cat, sc in scores_by_category.items():
246
+ ov = sc["avg_overlap"]
247
+ emc = sc["exact_match_rate"]
248
+ print(f" - {cat:<15} overlap={ov if ov is not None else 'n/a'} | EM={emc if emc is not None else 'n/a'}")
249
+ print(f"{'═'*64}\n")
250
+
251
+ if save_report:
252
+ report_json = Path(output_dir) / "qa_test_simple_report.json"
253
+ report_txt = Path(output_dir) / "qa_test_simple_report.txt"
254
+
255
+ with open(report_json, "w", encoding="utf-8") as f:
256
+ json.dump(summary, f, ensure_ascii=False, indent=2)
257
+
258
+ with open(report_txt, "w", encoding="utf-8") as f:
259
+ f.write("TEST QA SIMPLE — MODÈLE ENTRAÎNÉ\n")
260
+ f.write(f"Checkpoint : {ckpt_path}\n\n")
261
+ for r in results:
262
+ f.write(f"[{r['id']:02d}] {r['category']}\n")
263
+ if r["context"]:
264
+ f.write(f" Contexte : {r['context']}\n")
265
+ f.write(f" User : {r['question']}\n")
266
+ f.write(f" Assistant : {r['answer']}\n")
267
+ if r["reference"]:
268
+ f.write(f" Référence : {r['reference']}\n")
269
+ if r["overlap_score"] is not None:
270
+ f.write(f" Overlap : {r['overlap_score']:.0%}\n")
271
+ if r["exact_match"] is not None:
272
+ f.write(f" EM : {'oui' if r['exact_match'] else 'non'}\n")
273
+ f.write(f" Latence : {r['latency_s']}s\n\n")
274
+
275
+ print(f"Rapport JSON -> {report_json}")
276
+ print(f"Rapport TXT -> {report_txt}")
277
+
278
+ return summary
279
+
280
+
281
+ if __name__ == "__main__":
282
+ parser = argparse.ArgumentParser("Test QA simple pour modèle déjà entraîné")
283
+ parser.add_argument("--train_script", type=str, default="./train_chatbot_100m_large.py")
284
+ parser.add_argument("--output_dir", type=str, default="./fr_100m")
285
+ parser.add_argument("--ckpt", type=str, default=None)
286
+ parser.add_argument("--questions", type=str, default=None, help="JSON optionnel [{question, reference, category, context?}]")
287
+ parser.add_argument("--no_save", action="store_true")
288
+ args = parser.parse_args()
289
+
290
+ ckpt_path = args.ckpt or os.path.join(args.output_dir, "sft_best.pt")
291
+ if not Path(ckpt_path).exists():
292
+ raise FileNotFoundError(
293
+ f"Checkpoint introuvable: {ckpt_path}\n"
294
+ f"Exemple: --ckpt {args.output_dir}/sft_final.pt"
295
+ )
296
+
297
+ if args.questions:
298
+ with open(args.questions, "r", encoding="utf-8") as f:
299
+ questions = json.load(f)
300
+ else:
301
+ questions = DEFAULT_QUESTIONS
302
+
303
+ run_test(
304
+ train_script=args.train_script,
305
+ output_dir=args.output_dir,
306
+ ckpt_path=ckpt_path,
307
+ questions=questions,
308
+ save_report=not args.no_save,
309
+ )
simple_qa_test_finished_model.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ Test QA simple pour un modèle déjà entraîné.
6
+
7
+ Fonctionne avec le script : train_chatbot_100m_large.py
8
+ - charge la config depuis output_dir/train_config.json si disponible
9
+ - charge un checkpoint fini (par défaut: output_dir/sft_best.pt)
10
+ - pose une petite liste de questions QA
11
+ - calcule un score simple d'overlap lexical
12
+ - sauvegarde un rapport JSON + TXT
13
+
14
+ Exemples
15
+ --------
16
+ python simple_qa_test_finished_model.py --output_dir ./fr_100m
17
+ python simple_qa_test_finished_model.py --output_dir ./fr_100m --ckpt ./fr_100m/sft_final.pt
18
+ python simple_qa_test_finished_model.py --output_dir ./fr_100m --questions qa_questions.json
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import argparse
24
+ import importlib.util
25
+ import json
26
+ import os
27
+ import sys
28
+ import time
29
+ import unicodedata
30
+ import re
31
+ from pathlib import Path
32
+ from typing import Dict, List, Optional
33
+
34
+
35
+ DEFAULT_QUESTIONS = [
36
+ {
37
+ "category": "Géographie",
38
+ "question": "Quelle est la capitale de la France ?",
39
+ "reference": "Paris",
40
+ },
41
+ {
42
+ "category": "Géographie",
43
+ "question": "Quel est le plus long fleuve d'Afrique ?",
44
+ "reference": "Le Nil",
45
+ },
46
+ {
47
+ "category": "Science",
48
+ "question": "Qu'est-ce que la photosynthèse ?",
49
+ "reference": "Processus par lequel les plantes convertissent la lumière en énergie",
50
+ },
51
+ {
52
+ "category": "Science",
53
+ "question": "Combien d'os compte le corps humain adulte ?",
54
+ "reference": "206",
55
+ },
56
+ {
57
+ "category": "Histoire",
58
+ "question": "En quelle année a eu lieu la Révolution française ?",
59
+ "reference": "1789",
60
+ },
61
+ {
62
+ "category": "Histoire",
63
+ "question": "Qui a écrit Les Misérables ?",
64
+ "reference": "Victor Hugo",
65
+ },
66
+ {
67
+ "category": "Mathématiques",
68
+ "question": "Quelle est la formule de l'aire d'un cercle ?",
69
+ "reference": "π × r²",
70
+ },
71
+ {
72
+ "category": "Langage",
73
+ "question": "Donne un synonyme du mot heureux.",
74
+ "reference": "joyeux",
75
+ },
76
+ {
77
+ "category": "Raisonnement",
78
+ "question": "Si j'ai 5 pommes et j'en donne 2, combien m'en reste-t-il ?",
79
+ "reference": "3",
80
+ },
81
+ {
82
+ "category": "Dialogue",
83
+ "question": "Comment vas-tu aujourd'hui ?",
84
+ "reference": None,
85
+ },
86
+ ]
87
+
88
+
89
+ def normalize_text(s: str) -> str:
90
+ s = (s or "").strip().lower()
91
+ s = unicodedata.normalize("NFKD", s)
92
+ s = "".join(ch for ch in s if not unicodedata.combining(ch))
93
+ s = re.sub(r"[^\w\s]", " ", s, flags=re.UNICODE)
94
+ s = re.sub(r"\s+", " ", s).strip()
95
+ return s
96
+
97
+
98
+ def lexical_overlap(reference: Optional[str], answer: str) -> Optional[float]:
99
+ if not reference:
100
+ return None
101
+ ref_tokens = set(normalize_text(reference).split())
102
+ ans_tokens = set(normalize_text(answer).split())
103
+ if not ref_tokens:
104
+ return 0.0
105
+ return len(ref_tokens & ans_tokens) / len(ref_tokens)
106
+
107
+
108
+ def exact_match(reference: Optional[str], answer: str) -> Optional[bool]:
109
+ if not reference:
110
+ return None
111
+ return normalize_text(reference) == normalize_text(answer)
112
+
113
+
114
+ def import_train_module(train_script_path: str):
115
+ path = Path(train_script_path)
116
+ if not path.exists():
117
+ raise FileNotFoundError(f"Script d'entraînement introuvable: {path}")
118
+
119
+ spec = importlib.util.spec_from_file_location("train_module", str(path))
120
+ if spec is None or spec.loader is None:
121
+ raise ImportError(f"Impossible de charger le module: {path}")
122
+
123
+ module = importlib.util.module_from_spec(spec)
124
+ sys.modules["train_module"] = module
125
+ spec.loader.exec_module(module)
126
+ return module
127
+
128
+
129
+ def build_cfg(train_module, output_dir: str):
130
+ cfg_path = Path(output_dir) / "train_config.json"
131
+ if cfg_path.exists():
132
+ with open(cfg_path, "r", encoding="utf-8") as f:
133
+ saved = json.load(f)
134
+ cfg = train_module.TrainConfig(**saved)
135
+ else:
136
+ cfg = train_module.TrainConfig(output_dir=output_dir, tokenizer_prefix=f"{output_dir}/tokenizer")
137
+
138
+ cfg.output_dir = output_dir
139
+ cfg.tokenizer_prefix = f"{output_dir}/tokenizer"
140
+ return cfg
141
+
142
+
143
+ def run_test(
144
+ train_script: str,
145
+ output_dir: str,
146
+ ckpt_path: str,
147
+ questions: List[Dict],
148
+ save_report: bool,
149
+ ):
150
+ train_module = import_train_module(train_script)
151
+ cfg = build_cfg(train_module, output_dir)
152
+ bot = train_module.Chatbot(cfg, ckpt_path)
153
+
154
+ results = []
155
+ categories: Dict[str, List[Dict]] = {}
156
+
157
+ sep = "─" * 64
158
+ print(f"\n{'═'*64}")
159
+ print(" TEST QA SIMPLE — MODÈLE ENTRAÎNÉ")
160
+ print(f" Checkpoint : {ckpt_path}")
161
+ print(f" Questions : {len(questions)}")
162
+ print(f"{'═'*64}\n")
163
+
164
+ for i, item in enumerate(questions, 1):
165
+ q = item["question"]
166
+ ref = item.get("reference")
167
+ cat = item.get("category", "Général")
168
+ ctx = item.get("context", "")
169
+
170
+ t0 = time.time()
171
+ ans = bot.chat(q, context=ctx)
172
+ latency = time.time() - t0
173
+
174
+ overlap = lexical_overlap(ref, ans)
175
+ em = exact_match(ref, ans)
176
+
177
+ row = {
178
+ "id": i,
179
+ "category": cat,
180
+ "question": q,
181
+ "context": ctx,
182
+ "reference": ref,
183
+ "answer": ans,
184
+ "overlap_score": overlap,
185
+ "exact_match": em,
186
+ "latency_s": round(latency, 3),
187
+ "tokens_generated_approx": len(ans.split()),
188
+ }
189
+ results.append(row)
190
+ categories.setdefault(cat, []).append(row)
191
+
192
+ score_text = []
193
+ if overlap is not None:
194
+ score_text.append(f"overlap={overlap:.0%}")
195
+ if em is not None:
196
+ score_text.append(f"EM={'oui' if em else 'non'}")
197
+ score_str = f" [{' | '.join(score_text)}]" if score_text else ""
198
+
199
+ print(sep)
200
+ print(f"[{i:02d}] [{cat}]{score_str}")
201
+ if ctx:
202
+ print(f" Contexte : {ctx[:120]}{'...' if len(ctx) > 120 else ''}")
203
+ print(f" User : {q}")
204
+ print(f" Assistant : {ans}")
205
+ if ref:
206
+ print(f" Référence : {ref}")
207
+ print(f" ⏱ {latency:.2f}s | ~{row['tokens_generated_approx']} mots\n")
208
+
209
+ scored = [r for r in results if r["overlap_score"] is not None]
210
+ avg_overlap = sum(r["overlap_score"] for r in scored) / len(scored) if scored else 0.0
211
+ em_rows = [r for r in results if r["exact_match"] is not None]
212
+ em_rate = sum(1 for r in em_rows if r["exact_match"]) / len(em_rows) if em_rows else 0.0
213
+ avg_latency = sum(r["latency_s"] for r in results) / max(1, len(results))
214
+ avg_tokens = sum(r["tokens_generated_approx"] for r in results) / max(1, len(results))
215
+
216
+ scores_by_category = {}
217
+ for cat, items in categories.items():
218
+ cat_scored = [x for x in items if x["overlap_score"] is not None]
219
+ cat_em = [x for x in items if x["exact_match"] is not None]
220
+ scores_by_category[cat] = {
221
+ "avg_overlap": round(sum(x["overlap_score"] for x in cat_scored) / len(cat_scored), 4) if cat_scored else None,
222
+ "exact_match_rate": round(sum(1 for x in cat_em if x["exact_match"]) / len(cat_em), 4) if cat_em else None,
223
+ }
224
+
225
+ summary = {
226
+ "checkpoint": ckpt_path,
227
+ "total_questions": len(results),
228
+ "avg_overlap_score": round(avg_overlap, 4),
229
+ "exact_match_rate": round(em_rate, 4),
230
+ "avg_latency_s": round(avg_latency, 3),
231
+ "avg_tokens_generated_approx": round(avg_tokens, 1),
232
+ "scores_by_category": scores_by_category,
233
+ "results": results,
234
+ }
235
+
236
+ print(f"{'═'*64}")
237
+ print(" RÉSUMÉ")
238
+ print(f"{'═'*64}")
239
+ print(f" Questions testées : {len(results)}")
240
+ print(f" Overlap moyen : {avg_overlap:.1%}")
241
+ print(f" Exact match : {em_rate:.1%}")
242
+ print(f" Latence moyenne : {avg_latency:.2f}s")
243
+ print(f" Mots moyens : {avg_tokens:.0f}")
244
+ print(" Scores / catégorie :")
245
+ for cat, sc in scores_by_category.items():
246
+ ov = sc["avg_overlap"]
247
+ emc = sc["exact_match_rate"]
248
+ print(f" - {cat:<15} overlap={ov if ov is not None else 'n/a'} | EM={emc if emc is not None else 'n/a'}")
249
+ print(f"{'═'*64}\n")
250
+
251
+ if save_report:
252
+ report_json = Path(output_dir) / "qa_test_simple_report.json"
253
+ report_txt = Path(output_dir) / "qa_test_simple_report.txt"
254
+
255
+ with open(report_json, "w", encoding="utf-8") as f:
256
+ json.dump(summary, f, ensure_ascii=False, indent=2)
257
+
258
+ with open(report_txt, "w", encoding="utf-8") as f:
259
+ f.write("TEST QA SIMPLE — MODÈLE ENTRAÎNÉ\n")
260
+ f.write(f"Checkpoint : {ckpt_path}\n\n")
261
+ for r in results:
262
+ f.write(f"[{r['id']:02d}] {r['category']}\n")
263
+ if r["context"]:
264
+ f.write(f" Contexte : {r['context']}\n")
265
+ f.write(f" User : {r['question']}\n")
266
+ f.write(f" Assistant : {r['answer']}\n")
267
+ if r["reference"]:
268
+ f.write(f" Référence : {r['reference']}\n")
269
+ if r["overlap_score"] is not None:
270
+ f.write(f" Overlap : {r['overlap_score']:.0%}\n")
271
+ if r["exact_match"] is not None:
272
+ f.write(f" EM : {'oui' if r['exact_match'] else 'non'}\n")
273
+ f.write(f" Latence : {r['latency_s']}s\n\n")
274
+
275
+ print(f"Rapport JSON -> {report_json}")
276
+ print(f"Rapport TXT -> {report_txt}")
277
+
278
+ return summary
279
+
280
+
281
+ if __name__ == "__main__":
282
+ parser = argparse.ArgumentParser("Test QA simple pour modèle déjà entraîné")
283
+ parser.add_argument("--train_script", type=str, default="./train_chatbot_100m_large.py")
284
+ parser.add_argument("--output_dir", type=str, default="./fr_100m")
285
+ parser.add_argument("--ckpt", type=str, default=None)
286
+ parser.add_argument("--questions", type=str, default=None, help="JSON optionnel [{question, reference, category, context?}]")
287
+ parser.add_argument("--no_save", action="store_true")
288
+ args = parser.parse_args()
289
+
290
+ ckpt_path = args.ckpt or os.path.join(args.output_dir, "sft_best.pt")
291
+ if not Path(ckpt_path).exists():
292
+ raise FileNotFoundError(
293
+ f"Checkpoint introuvable: {ckpt_path}\n"
294
+ f"Exemple: --ckpt {args.output_dir}/sft_final.pt"
295
+ )
296
+
297
+ if args.questions:
298
+ with open(args.questions, "r", encoding="utf-8") as f:
299
+ questions = json.load(f)
300
+ else:
301
+ questions = DEFAULT_QUESTIONS
302
+
303
+ run_test(
304
+ train_script=args.train_script,
305
+ output_dir=args.output_dir,
306
+ ckpt_path=ckpt_path,
307
+ questions=questions,
308
+ save_report=not args.no_save,
309
+ )
test.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import json
8
+ from collections import OrderedDict
9
+ from contextlib import nullcontext
10
+ from dataclasses import dataclass
11
+ from pathlib import Path
12
+ from typing import Optional
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from transformers import PreTrainedTokenizerFast
18
+
19
+
20
+ # ============================================================
21
+ # Paths par défaut
22
+ # ============================================================
23
+
24
+ MODEL_DIR = Path("./nlp_1b_wiki_en_fr_ar")
25
+
26
+ DEFAULT_CHECKPOINT = MODEL_DIR / "model_best.pt"
27
+ DEFAULT_CONFIG = MODEL_DIR / "config.json"
28
+ DEFAULT_TOKENIZER_DIR = MODEL_DIR / "tokenizer_32k"
29
+
30
+
31
+ # ============================================================
32
+ # Utils
33
+ # ============================================================
34
+
35
+ def get_device() -> torch.device:
36
+ if torch.cuda.is_available():
37
+ return torch.device(f"cuda:{torch.cuda.current_device()}")
38
+ return torch.device("cpu")
39
+
40
+
41
+ def autocast_context(device: torch.device):
42
+ if device.type == "cuda":
43
+ return torch.autocast("cuda", dtype=torch.bfloat16)
44
+ return nullcontext()
45
+
46
+
47
+ def normalize_state_dict_keys(state_dict: dict) -> OrderedDict:
48
+ normalized = OrderedDict()
49
+ for k, v in state_dict.items():
50
+ nk = k
51
+ if nk.startswith("module._orig_mod."):
52
+ nk = nk[len("module._orig_mod."):]
53
+ elif nk.startswith("_orig_mod."):
54
+ nk = nk[len("_orig_mod."):]
55
+ elif nk.startswith("module."):
56
+ nk = nk[len("module."):]
57
+ normalized[nk] = v
58
+ return normalized
59
+
60
+
61
+ def postprocess_text(text: str) -> str:
62
+ return text.strip()
63
+
64
+
65
+ # ============================================================
66
+ # Architecture
67
+ # ============================================================
68
+
69
+ @dataclass
70
+ class GPTConfig:
71
+ vocab_size: int
72
+ block_size: int
73
+ d_model: int
74
+ n_heads: int
75
+ n_layers: int
76
+ d_ff: int
77
+ dropout: float = 0.0
78
+ use_checkpointing: bool = False
79
+
80
+
81
+ class RMSNorm(nn.Module):
82
+ def __init__(self, dim: int, eps: float = 1e-6):
83
+ super().__init__()
84
+ self.weight = nn.Parameter(torch.ones(dim))
85
+ self.eps = eps
86
+
87
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
88
+ return self.weight * x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
89
+
90
+
91
+ class RotaryEmbedding(nn.Module):
92
+ def __init__(self, dim: int, base: int = 10000, max_seq: int = 4096):
93
+ super().__init__()
94
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
95
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
96
+
97
+ t = torch.arange(max_seq).float()
98
+ freqs = torch.outer(t, inv_freq)
99
+
100
+ self.register_buffer(
101
+ "cos_cache",
102
+ torch.repeat_interleave(freqs.cos(), 2, dim=-1),
103
+ persistent=False,
104
+ )
105
+ self.register_buffer(
106
+ "sin_cache",
107
+ torch.repeat_interleave(freqs.sin(), 2, dim=-1),
108
+ persistent=False,
109
+ )
110
+
111
+ def forward(self, seq_len: int, dtype: torch.dtype):
112
+ return self.cos_cache[:seq_len].to(dtype), self.sin_cache[:seq_len].to(dtype)
113
+
114
+
115
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
116
+ x1, x2 = x[..., ::2], x[..., 1::2]
117
+ return torch.stack((-x2, x1), dim=-1).flatten(-2)
118
+
119
+
120
+ def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
121
+ cos = cos.unsqueeze(0).unsqueeze(0)
122
+ sin = sin.unsqueeze(0).unsqueeze(0)
123
+ return x * cos + rotate_half(x) * sin
124
+
125
+
126
+ class CausalSelfAttention(nn.Module):
127
+ def __init__(self, cfg: GPTConfig):
128
+ super().__init__()
129
+ assert cfg.d_model % cfg.n_heads == 0
130
+ self.n_heads = cfg.n_heads
131
+ self.head_dim = cfg.d_model // cfg.n_heads
132
+ self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
133
+ self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
134
+ self.rope = RotaryEmbedding(self.head_dim)
135
+
136
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
137
+ b, t, c = x.shape
138
+ q, k, v = self.qkv(x).split(c, dim=-1)
139
+
140
+ q = q.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
141
+ k = k.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
142
+ v = v.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
143
+
144
+ cos, sin = self.rope(t, x.dtype)
145
+ q = apply_rope(q, cos, sin)
146
+ k = apply_rope(k, cos, sin)
147
+
148
+ y = F.scaled_dot_product_attention(
149
+ q, k, v,
150
+ dropout_p=0.0,
151
+ is_causal=True,
152
+ )
153
+ y = y.transpose(1, 2).contiguous().view(b, t, c)
154
+ return self.proj(y)
155
+
156
+
157
+ class SwiGLU(nn.Module):
158
+ def __init__(self, cfg: GPTConfig):
159
+ super().__init__()
160
+ self.w1 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
161
+ self.w2 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
162
+ self.w3 = nn.Linear(cfg.d_ff, cfg.d_model, bias=False)
163
+
164
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
165
+ return self.w3(F.silu(self.w1(x)) * self.w2(x))
166
+
167
+
168
+ class Block(nn.Module):
169
+ def __init__(self, cfg: GPTConfig):
170
+ super().__init__()
171
+ self.ln1 = RMSNorm(cfg.d_model)
172
+ self.attn = CausalSelfAttention(cfg)
173
+ self.ln2 = RMSNorm(cfg.d_model)
174
+ self.ff = SwiGLU(cfg)
175
+
176
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
177
+ x = x + self.attn(self.ln1(x))
178
+ x = x + self.ff(self.ln2(x))
179
+ return x
180
+
181
+
182
+ class GPT(nn.Module):
183
+ def __init__(self, cfg: GPTConfig):
184
+ super().__init__()
185
+ self.cfg = cfg
186
+ self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
187
+ self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)])
188
+ self.ln_f = RMSNorm(cfg.d_model)
189
+ self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
190
+ self.lm_head.weight = self.tok_emb.weight
191
+
192
+ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
193
+ x = self.tok_emb(input_ids)
194
+ for block in self.blocks:
195
+ x = block(x)
196
+ return self.lm_head(self.ln_f(x))
197
+
198
+ @torch.inference_mode()
199
+ def generate(
200
+ self,
201
+ input_ids: torch.Tensor,
202
+ max_new_tokens: int = 160,
203
+ temperature: float = 0.8,
204
+ top_k: int = 50,
205
+ top_p: float = 0.95,
206
+ repetition_penalty: float = 1.05,
207
+ eos_token_id: Optional[int] = None,
208
+ ) -> torch.Tensor:
209
+ self.eval()
210
+
211
+ for _ in range(max_new_tokens):
212
+ idx_cond = input_ids[:, -self.cfg.block_size :]
213
+ logits = self(idx_cond)
214
+ logits = logits[:, -1, :]
215
+
216
+ if repetition_penalty != 1.0:
217
+ for b in range(input_ids.size(0)):
218
+ seen = torch.unique(input_ids[b])
219
+ logits[b, seen] /= repetition_penalty
220
+
221
+ if temperature <= 0:
222
+ next_token = torch.argmax(logits, dim=-1, keepdim=True)
223
+ else:
224
+ logits = logits / temperature
225
+
226
+ if top_k > 0:
227
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
228
+ logits[logits < v[:, [-1]]] = -float("inf")
229
+
230
+ if 0 < top_p < 1.0:
231
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
232
+ probs = F.softmax(sorted_logits, dim=-1)
233
+ cumulative_probs = torch.cumsum(probs, dim=-1)
234
+
235
+ sorted_mask = cumulative_probs > top_p
236
+ sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
237
+ sorted_mask[..., 0] = False
238
+
239
+ mask = torch.zeros_like(logits, dtype=torch.bool)
240
+ mask.scatter_(1, sorted_indices, sorted_mask)
241
+ logits = logits.masked_fill(mask, -float("inf"))
242
+
243
+ probs = F.softmax(logits, dim=-1)
244
+ next_token = torch.multinomial(probs, num_samples=1)
245
+
246
+ input_ids = torch.cat([input_ids, next_token], dim=1)
247
+
248
+ if eos_token_id is not None and (next_token == eos_token_id).all():
249
+ break
250
+
251
+ return input_ids
252
+
253
+
254
+ # ============================================================
255
+ # Load / generate
256
+ # ============================================================
257
+
258
+ def load_model_and_tokenizer(
259
+ checkpoint_path: Path,
260
+ config_path: Path,
261
+ tokenizer_dir: Path,
262
+ device: torch.device,
263
+ use_compile: bool = False,
264
+ ):
265
+ if not checkpoint_path.exists():
266
+ raise FileNotFoundError(f"Checkpoint introuvable: {checkpoint_path}")
267
+ if not config_path.exists():
268
+ raise FileNotFoundError(f"Config introuvable: {config_path}")
269
+ if not tokenizer_dir.exists():
270
+ raise FileNotFoundError(f"Tokenizer introuvable: {tokenizer_dir}")
271
+
272
+ cfg_dict = json.loads(config_path.read_text(encoding="utf-8"))
273
+ cfg = GPTConfig(**cfg_dict)
274
+
275
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(str(tokenizer_dir))
276
+ model = GPT(cfg).to(device)
277
+
278
+ ckpt = torch.load(checkpoint_path, map_location=device)
279
+ state_dict = normalize_state_dict_keys(ckpt["model"])
280
+ model.load_state_dict(state_dict, strict=True)
281
+ model.eval()
282
+
283
+ if use_compile and hasattr(torch, "compile"):
284
+ model = torch.compile(model, mode="default")
285
+
286
+ return model, tokenizer, ckpt
287
+
288
+
289
+ def generate_text(
290
+ model: GPT,
291
+ tokenizer: PreTrainedTokenizerFast,
292
+ prompt: str,
293
+ device: torch.device,
294
+ max_new_tokens: int,
295
+ temperature: float,
296
+ top_k: int,
297
+ top_p: float,
298
+ repetition_penalty: float,
299
+ ) -> str:
300
+ encoded = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
301
+ input_ids = encoded["input_ids"].to(device)
302
+
303
+ if tokenizer.bos_token_id is not None:
304
+ bos = torch.tensor([[tokenizer.bos_token_id]], device=device, dtype=input_ids.dtype)
305
+ input_ids = torch.cat([bos, input_ids], dim=1)
306
+
307
+ prompt_len = input_ids.shape[1]
308
+
309
+ with autocast_context(device):
310
+ output_ids = model.generate(
311
+ input_ids=input_ids,
312
+ max_new_tokens=max_new_tokens,
313
+ temperature=temperature,
314
+ top_k=top_k,
315
+ top_p=top_p,
316
+ repetition_penalty=repetition_penalty,
317
+ eos_token_id=tokenizer.eos_token_id,
318
+ )
319
+
320
+ generated_ids = output_ids[0][prompt_len:]
321
+ text = tokenizer.decode(generated_ids, skip_special_tokens=True)
322
+ return postprocess_text(text)
323
+
324
+
325
+ def main():
326
+ parser = argparse.ArgumentParser()
327
+ parser.add_argument("--checkpoint", type=str, default=str(DEFAULT_CHECKPOINT))
328
+ parser.add_argument("--config", type=str, default=str(DEFAULT_CONFIG))
329
+ parser.add_argument("--tokenizer_dir", type=str, default=str(DEFAULT_TOKENIZER_DIR))
330
+ parser.add_argument("--prompt", type=str, default="Wikipedia is a free online encyclopedia")
331
+ parser.add_argument("--max_new_tokens", type=int, default=160)
332
+ parser.add_argument("--temperature", type=float, default=0.8)
333
+ parser.add_argument("--top_k", type=int, default=50)
334
+ parser.add_argument("--top_p", type=float, default=0.95)
335
+ parser.add_argument("--repetition_penalty", type=float, default=1.05)
336
+ parser.add_argument("--interactive", action="store_true")
337
+ parser.add_argument("--show_examples", action="store_true")
338
+ parser.add_argument("--compile", action="store_true")
339
+ args = parser.parse_args()
340
+
341
+ device = get_device()
342
+
343
+ if device.type == "cuda":
344
+ torch.backends.cuda.matmul.allow_tf32 = True
345
+ torch.backends.cudnn.allow_tf32 = True
346
+ torch.set_float32_matmul_precision("high")
347
+
348
+ model, tokenizer, ckpt = load_model_and_tokenizer(
349
+ checkpoint_path=Path(args.checkpoint),
350
+ config_path=Path(args.config),
351
+ tokenizer_dir=Path(args.tokenizer_dir),
352
+ device=device,
353
+ use_compile=args.compile,
354
+ )
355
+
356
+ print(f"Device: {device}")
357
+ print(f"Checkpoint: {args.checkpoint}")
358
+ print(f"epoch={ckpt.get('epoch', 'N/A')} | step={ckpt.get('step', 'N/A')} | best_loss={ckpt.get('best_loss', 'N/A')}")
359
+
360
+ if args.show_examples:
361
+ examples = [
362
+ "Wikipedia is a free online encyclopedia",
363
+ "La France est un pays d'Europe",
364
+ "الزراعة من أهم القطاعات الاقتصادية",
365
+ "Machine learning is a field of artificial intelligence",
366
+ ]
367
+ for ex in examples:
368
+ print("\n--- Prompt ---")
369
+ print(ex)
370
+ print("\n--- Output ---")
371
+ print(
372
+ generate_text(
373
+ model=model,
374
+ tokenizer=tokenizer,
375
+ prompt=ex,
376
+ device=device,
377
+ max_new_tokens=args.max_new_tokens,
378
+ temperature=args.temperature,
379
+ top_k=args.top_k,
380
+ top_p=args.top_p,
381
+ repetition_penalty=args.repetition_penalty,
382
+ )
383
+ )
384
+ return
385
+
386
+ if args.interactive:
387
+ print("Mode interactif. Tape 'exit' pour quitter.\n")
388
+ while True:
389
+ prompt = input("Prompt> ").strip()
390
+ if prompt.lower() in {"exit", "quit"}:
391
+ break
392
+ if not prompt:
393
+ continue
394
+
395
+ print("\n=== Output ===")
396
+ print(
397
+ generate_text(
398
+ model=model,
399
+ tokenizer=tokenizer,
400
+ prompt=prompt,
401
+ device=device,
402
+ max_new_tokens=args.max_new_tokens,
403
+ temperature=args.temperature,
404
+ top_k=args.top_k,
405
+ top_p=args.top_p,
406
+ repetition_penalty=args.repetition_penalty,
407
+ )
408
+ )
409
+ print()
410
+ return
411
+
412
+ print(
413
+ generate_text(
414
+ model=model,
415
+ tokenizer=tokenizer,
416
+ prompt=args.prompt,
417
+ device=device,
418
+ max_new_tokens=args.max_new_tokens,
419
+ temperature=args.temperature,
420
+ top_k=args.top_k,
421
+ top_p=args.top_p,
422
+ repetition_penalty=args.repetition_penalty,
423
+ )
424
+ )
425
+
426
+
427
+ if __name__ == "__main__":
428
+ main()
top_p ADDED
File without changes
train.py ADDED
@@ -0,0 +1,859 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from __future__ import annotations
5
+
6
+ import json
7
+ import math
8
+ import os
9
+ import random
10
+ import time
11
+ from collections import OrderedDict
12
+ from contextlib import nullcontext
13
+ from dataclasses import asdict, dataclass
14
+ from pathlib import Path
15
+ from typing import Iterator, Optional
16
+
17
+ import torch
18
+ import torch.distributed as dist
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from torch.nn.parallel import DistributedDataParallel as DDP
22
+ from datasets import load_dataset
23
+ from transformers import PreTrainedTokenizerFast
24
+
25
+
26
+ # ============================================================
27
+ # Base model / tokenizer / config
28
+ # ============================================================
29
+
30
+ BASE_CHECKPOINT = Path("./wikipedia_ar_h100_codealpaca/model_best.pt")
31
+ BASE_TOKENIZER_DIR = Path("./wikipedia_ar_h100/tokenizer_32k")
32
+ BASE_CONFIG_FILE = Path("./wikipedia_ar_h100/config.json")
33
+
34
+ OUT_DIR = Path("./wikipedia_ar_h100_multicode_10x2000")
35
+ OUT_DIR.mkdir(parents=True, exist_ok=True)
36
+
37
+ MODEL_FILE = OUT_DIR / "model.pt"
38
+ BEST_MODEL_FILE = OUT_DIR / "model_best.pt"
39
+ STATE_FILE = OUT_DIR / "train_state.pt"
40
+ CONFIG_FILE = OUT_DIR / "config.json"
41
+
42
+
43
+ # ============================================================
44
+ # Datasets
45
+ # ============================================================
46
+
47
+ TRAIN_SOURCES = [
48
+ {
49
+ "name": "HuggingFaceH4/CodeAlpaca_20K",
50
+ "subset": None,
51
+ "split": "train",
52
+ "kind": "codealpaca",
53
+ "weight": 0.45,
54
+ "streaming": False,
55
+ },
56
+ {
57
+ "name": "open-r1/codeforces",
58
+ "subset": "verifiable-prompts",
59
+ "split": "train",
60
+ "kind": "codeforces_python",
61
+ "weight": 0.35,
62
+ "streaming": False,
63
+ },
64
+ {
65
+ "name": "wikimedia/wikipedia",
66
+ "subset": "20231101.ar",
67
+ "split": "train",
68
+ "kind": "wikipedia_ar",
69
+ "weight": 0.20,
70
+ "streaming": True,
71
+ },
72
+ ]
73
+
74
+ EVAL_SOURCE = {
75
+ "name": "HuggingFaceH4/CodeAlpaca_20K",
76
+ "subset": None,
77
+ "split": "test",
78
+ "kind": "codealpaca",
79
+ "streaming": False,
80
+ }
81
+
82
+ CODEFORCES_LANGUAGE = "python"
83
+
84
+
85
+ # ============================================================
86
+ # Hyperparamètres
87
+ # ============================================================
88
+
89
+ SEED = 42
90
+ TARGET_VRAM_GIB = 75.0
91
+
92
+ LEARNING_RATE = 5e-5
93
+ MIN_LR = 5e-6
94
+ WEIGHT_DECAY = 0.1
95
+ WARMUP_STEPS = 200
96
+
97
+ NUM_ROUNDS = 10
98
+ STEPS_PER_ROUND = 2000
99
+ MAX_STEPS = NUM_ROUNDS * STEPS_PER_ROUND # 20000
100
+
101
+ BATCH_SIZE = 24
102
+ GRAD_ACCUM_STEPS = 1
103
+ MAX_GRAD_NORM = 1.0
104
+
105
+ EVAL_EVERY = 250
106
+ SAVE_EVERY = 500
107
+ MAX_EVAL_EXAMPLES = 2000
108
+ TEXT_CHAR_LIMIT = 6000
109
+
110
+ DTYPE = torch.bfloat16
111
+ USE_COMPILE = True
112
+ COMPILE_MODE = "default"
113
+ USE_CHECKPOINTING = False
114
+
115
+ TRAIN_NUM_WORKERS = 0
116
+ EVAL_NUM_WORKERS = 0
117
+
118
+
119
+ # ============================================================
120
+ # Helpers
121
+ # ============================================================
122
+
123
+ def is_distributed() -> bool:
124
+ return dist.is_available() and dist.is_initialized()
125
+
126
+ def get_rank() -> int:
127
+ return dist.get_rank() if is_distributed() else 0
128
+
129
+ def get_world_size() -> int:
130
+ return dist.get_world_size() if is_distributed() else 1
131
+
132
+ def is_main() -> bool:
133
+ return get_rank() == 0
134
+
135
+ def init_distributed() -> Optional[torch.device]:
136
+ local_rank = int(os.environ.get("LOCAL_RANK", -1))
137
+ if local_rank == -1:
138
+ return None
139
+ dist.init_process_group("nccl")
140
+ torch.cuda.set_device(local_rank)
141
+ return torch.device(f"cuda:{local_rank}")
142
+
143
+ def set_seed(seed: int) -> None:
144
+ random.seed(seed)
145
+ torch.manual_seed(seed)
146
+ if torch.cuda.is_available():
147
+ torch.cuda.manual_seed_all(seed)
148
+
149
+ def get_device(ddp_device: Optional[torch.device] = None) -> torch.device:
150
+ if ddp_device is not None:
151
+ return ddp_device
152
+ if torch.cuda.is_available():
153
+ return torch.device(f"cuda:{torch.cuda.current_device()}")
154
+ return torch.device("cpu")
155
+
156
+ def current_cuda_index(device: torch.device) -> int:
157
+ if device.type != "cuda":
158
+ raise ValueError("Device non CUDA")
159
+ return device.index if device.index is not None else torch.cuda.current_device()
160
+
161
+ def autocast_context(device: torch.device):
162
+ if device.type == "cuda":
163
+ return torch.autocast("cuda", dtype=DTYPE)
164
+ return nullcontext()
165
+
166
+ def unwrap_model(model: nn.Module) -> nn.Module:
167
+ m = model.module if isinstance(model, DDP) else model
168
+ if hasattr(m, "_orig_mod"):
169
+ return m._orig_mod
170
+ return m
171
+
172
+ def count_parameters(model: nn.Module) -> int:
173
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
174
+
175
+ def normalize_state_dict_keys(state_dict: dict) -> OrderedDict:
176
+ normalized = OrderedDict()
177
+ for k, v in state_dict.items():
178
+ nk = k
179
+ if nk.startswith("module._orig_mod."):
180
+ nk = nk[len("module._orig_mod."):]
181
+ elif nk.startswith("_orig_mod."):
182
+ nk = nk[len("_orig_mod."):]
183
+ elif nk.startswith("module."):
184
+ nk = nk[len("module."):]
185
+ normalized[nk] = v
186
+ return normalized
187
+
188
+ def normalize_text(text: str) -> str:
189
+ return " ".join(text.strip().split())
190
+
191
+
192
+ # ============================================================
193
+ # Dataset loading / formatting
194
+ # ============================================================
195
+
196
+ def load_one_dataset(spec: dict):
197
+ kwargs = {
198
+ "path": spec["name"],
199
+ "split": spec["split"],
200
+ "streaming": spec["streaming"],
201
+ }
202
+ if spec["subset"] is not None:
203
+ kwargs["name"] = spec["subset"]
204
+ return load_dataset(**kwargs)
205
+
206
+ def format_record(row: dict, kind: str) -> str:
207
+ if kind == "codealpaca":
208
+ prompt = row.get("prompt", "")
209
+ completion = row.get("completion", "")
210
+ if not isinstance(prompt, str):
211
+ prompt = str(prompt)
212
+ if not isinstance(completion, str):
213
+ completion = str(completion)
214
+ text = (
215
+ "### Instruction\n"
216
+ f"{prompt.strip()}\n\n"
217
+ "### Response\n"
218
+ f"{completion.strip()}"
219
+ )
220
+ return normalize_text(text)
221
+
222
+ if kind == "codeforces_python":
223
+ language = row.get("language", "")
224
+ if language != CODEFORCES_LANGUAGE:
225
+ return ""
226
+
227
+ prompt = row.get("prompt", "")
228
+ title = row.get("title", "")
229
+ if not isinstance(prompt, str):
230
+ prompt = str(prompt)
231
+ if not isinstance(title, str):
232
+ title = str(title)
233
+
234
+ text = (
235
+ f"### Competitive Programming Problem ({language})\n"
236
+ f"{title.strip()}\n\n"
237
+ f"{prompt.strip()}"
238
+ )
239
+ return normalize_text(text)
240
+
241
+ if kind == "wikipedia_ar":
242
+ text = row.get("text", "")
243
+ if not isinstance(text, str):
244
+ text = str(text)
245
+ return normalize_text(text)
246
+
247
+ return ""
248
+
249
+ def example_text_iter(spec: dict, max_examples: Optional[int] = None) -> Iterator[str]:
250
+ ds = load_one_dataset(spec)
251
+ n = 0
252
+ for row in ds:
253
+ text = format_record(row, spec["kind"])
254
+ if not text or len(text) < 20:
255
+ continue
256
+ if TEXT_CHAR_LIMIT is not None:
257
+ text = text[:TEXT_CHAR_LIMIT]
258
+ yield text
259
+ n += 1
260
+ if max_examples is not None and n >= max_examples:
261
+ break
262
+
263
+
264
+ class MixedTextSource:
265
+ def __init__(self, specs: list[dict]):
266
+ self.specs = specs
267
+ self.weights = [s["weight"] for s in specs]
268
+ self.streams = [example_text_iter(s) for s in specs]
269
+
270
+ def next_text(self) -> str:
271
+ while True:
272
+ idx = random.choices(range(len(self.specs)), weights=self.weights, k=1)[0]
273
+ try:
274
+ return next(self.streams[idx])
275
+ except StopIteration:
276
+ self.streams[idx] = example_text_iter(self.specs[idx])
277
+
278
+
279
+ def packed_block_stream_mixed(
280
+ tokenizer: PreTrainedTokenizerFast,
281
+ specs: list[dict],
282
+ block_size: int,
283
+ ) -> Iterator[list[int]]:
284
+ bos, eos = tokenizer.bos_token_id, tokenizer.eos_token_id
285
+ buffer: list[int] = []
286
+ source = MixedTextSource(specs)
287
+
288
+ while True:
289
+ text = source.next_text()
290
+ ids = tokenizer.encode(text, add_special_tokens=False)
291
+ if not ids:
292
+ continue
293
+
294
+ buffer.extend([bos] + ids + [eos])
295
+
296
+ while len(buffer) >= block_size + 1:
297
+ yield buffer[: block_size + 1]
298
+ buffer = buffer[block_size + 1:]
299
+
300
+
301
+ class PackedMixedBlocks(torch.utils.data.IterableDataset):
302
+ def __init__(self, tokenizer, specs, block_size):
303
+ super().__init__()
304
+ self.tokenizer = tokenizer
305
+ self.specs = specs
306
+ self.block_size = block_size
307
+
308
+ def __iter__(self):
309
+ worker = torch.utils.data.get_worker_info()
310
+ rank = get_rank()
311
+ world_size = get_world_size()
312
+
313
+ if worker is None:
314
+ shard_mod = world_size
315
+ shard_id = rank
316
+ else:
317
+ shard_mod = worker.num_workers * world_size
318
+ shard_id = rank * worker.num_workers + worker.id
319
+
320
+ for idx, chunk in enumerate(
321
+ packed_block_stream_mixed(
322
+ tokenizer=self.tokenizer,
323
+ specs=self.specs,
324
+ block_size=self.block_size,
325
+ )
326
+ ):
327
+ if idx % shard_mod != shard_id:
328
+ continue
329
+
330
+ yield {
331
+ "input_ids": torch.tensor(chunk[:-1], dtype=torch.long),
332
+ "labels": torch.tensor(chunk[1:], dtype=torch.long),
333
+ }
334
+
335
+
336
+ class PackedEvalBlocks(torch.utils.data.IterableDataset):
337
+ def __init__(self, tokenizer, spec, block_size, max_examples):
338
+ super().__init__()
339
+ self.tokenizer = tokenizer
340
+ self.spec = spec
341
+ self.block_size = block_size
342
+ self.max_examples = max_examples
343
+
344
+ def __iter__(self):
345
+ worker = torch.utils.data.get_worker_info()
346
+ rank = get_rank()
347
+ world_size = get_world_size()
348
+
349
+ if worker is None:
350
+ shard_mod = world_size
351
+ shard_id = rank
352
+ else:
353
+ shard_mod = worker.num_workers * world_size
354
+ shard_id = rank * worker.num_workers + worker.id
355
+
356
+ bos, eos = self.tokenizer.bos_token_id, self.tokenizer.eos_token_id
357
+ buffer: list[int] = []
358
+
359
+ for ex_idx, text in enumerate(example_text_iter(self.spec, max_examples=self.max_examples)):
360
+ if ex_idx % shard_mod != shard_id:
361
+ continue
362
+
363
+ ids = self.tokenizer.encode(text, add_special_tokens=False)
364
+ if not ids:
365
+ continue
366
+
367
+ buffer.extend([bos] + ids + [eos])
368
+
369
+ while len(buffer) >= self.block_size + 1:
370
+ chunk = buffer[: self.block_size + 1]
371
+ buffer = buffer[self.block_size + 1:]
372
+ yield {
373
+ "input_ids": torch.tensor(chunk[:-1], dtype=torch.long),
374
+ "labels": torch.tensor(chunk[1:], dtype=torch.long),
375
+ }
376
+
377
+
378
+ # ============================================================
379
+ # Architecture
380
+ # ============================================================
381
+
382
+ @dataclass
383
+ class GPTConfig:
384
+ vocab_size: int
385
+ block_size: int
386
+ d_model: int
387
+ n_heads: int
388
+ n_layers: int
389
+ d_ff: int
390
+ dropout: float = 0.0
391
+ use_checkpointing: bool = False
392
+
393
+
394
+ class RMSNorm(nn.Module):
395
+ def __init__(self, dim: int, eps: float = 1e-6):
396
+ super().__init__()
397
+ self.weight = nn.Parameter(torch.ones(dim))
398
+ self.eps = eps
399
+
400
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
401
+ return self.weight * x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
402
+
403
+
404
+ class RotaryEmbedding(nn.Module):
405
+ def __init__(self, dim: int, base: int = 10000, max_seq: int = 4096):
406
+ super().__init__()
407
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
408
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
409
+
410
+ t = torch.arange(max_seq).float()
411
+ freqs = torch.outer(t, inv_freq)
412
+
413
+ self.register_buffer("cos_cache", torch.repeat_interleave(freqs.cos(), 2, dim=-1), persistent=False)
414
+ self.register_buffer("sin_cache", torch.repeat_interleave(freqs.sin(), 2, dim=-1), persistent=False)
415
+
416
+ def forward(self, seq_len: int, dtype: torch.dtype):
417
+ return self.cos_cache[:seq_len].to(dtype), self.sin_cache[:seq_len].to(dtype)
418
+
419
+
420
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
421
+ x1, x2 = x[..., ::2], x[..., 1::2]
422
+ return torch.stack((-x2, x1), dim=-1).flatten(-2)
423
+
424
+ def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
425
+ cos = cos.unsqueeze(0).unsqueeze(0)
426
+ sin = sin.unsqueeze(0).unsqueeze(0)
427
+ return x * cos + rotate_half(x) * sin
428
+
429
+
430
+ class CausalSelfAttention(nn.Module):
431
+ def __init__(self, cfg: GPTConfig):
432
+ super().__init__()
433
+ assert cfg.d_model % cfg.n_heads == 0
434
+ self.n_heads = cfg.n_heads
435
+ self.head_dim = cfg.d_model // cfg.n_heads
436
+ self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
437
+ self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
438
+ self.dropout_p = cfg.dropout
439
+ self.rope = RotaryEmbedding(self.head_dim)
440
+
441
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
442
+ b, t, c = x.shape
443
+ q, k, v = self.qkv(x).split(c, dim=-1)
444
+
445
+ q = q.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
446
+ k = k.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
447
+ v = v.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
448
+
449
+ cos, sin = self.rope(t, x.dtype)
450
+ q = apply_rope(q, cos, sin)
451
+ k = apply_rope(k, cos, sin)
452
+
453
+ y = F.scaled_dot_product_attention(
454
+ q, k, v,
455
+ dropout_p=self.dropout_p if self.training else 0.0,
456
+ is_causal=True,
457
+ )
458
+ y = y.transpose(1, 2).contiguous().view(b, t, c)
459
+ return self.proj(y)
460
+
461
+
462
+ class SwiGLU(nn.Module):
463
+ def __init__(self, cfg: GPTConfig):
464
+ super().__init__()
465
+ self.w1 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
466
+ self.w2 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
467
+ self.w3 = nn.Linear(cfg.d_ff, cfg.d_model, bias=False)
468
+
469
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
470
+ return self.w3(F.silu(self.w1(x)) * self.w2(x))
471
+
472
+
473
+ class Block(nn.Module):
474
+ def __init__(self, cfg: GPTConfig):
475
+ super().__init__()
476
+ self.ln1 = RMSNorm(cfg.d_model)
477
+ self.attn = CausalSelfAttention(cfg)
478
+ self.ln2 = RMSNorm(cfg.d_model)
479
+ self.ff = SwiGLU(cfg)
480
+
481
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
482
+ x = x + self.attn(self.ln1(x))
483
+ x = x + self.ff(self.ln2(x))
484
+ return x
485
+
486
+
487
+ class GPT(nn.Module):
488
+ def __init__(self, cfg: GPTConfig):
489
+ super().__init__()
490
+ self.cfg = cfg
491
+ self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
492
+ self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)])
493
+ self.ln_f = RMSNorm(cfg.d_model)
494
+ self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
495
+ self.lm_head.weight = self.tok_emb.weight
496
+ self.apply(self._init_weights)
497
+
498
+ @staticmethod
499
+ def _init_weights(m: nn.Module) -> None:
500
+ if isinstance(m, (nn.Linear, nn.Embedding)):
501
+ nn.init.normal_(m.weight, mean=0.0, std=0.02)
502
+ if isinstance(m, nn.Linear) and m.bias is not None:
503
+ nn.init.zeros_(m.bias)
504
+
505
+ def forward(self, input_ids: torch.Tensor, labels: Optional[torch.Tensor] = None):
506
+ x = self.tok_emb(input_ids)
507
+
508
+ for block in self.blocks:
509
+ if self.cfg.use_checkpointing and self.training:
510
+ x = torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False)
511
+ else:
512
+ x = block(x)
513
+
514
+ logits = self.lm_head(self.ln_f(x))
515
+ loss = None
516
+
517
+ if labels is not None:
518
+ loss = F.cross_entropy(
519
+ logits.reshape(-1, logits.size(-1)),
520
+ labels.reshape(-1),
521
+ ignore_index=-100,
522
+ )
523
+
524
+ return logits, loss
525
+
526
+
527
+ # ============================================================
528
+ # Optimizer / LR
529
+ # ============================================================
530
+
531
+ def build_optimizer(model: nn.Module) -> torch.optim.Optimizer:
532
+ decay, no_decay = [], []
533
+ for name, p in unwrap_model(model).named_parameters():
534
+ if not p.requires_grad:
535
+ continue
536
+ (decay if p.ndim >= 2 and "weight" in name else no_decay).append(p)
537
+
538
+ return torch.optim.AdamW(
539
+ [
540
+ {"params": decay, "weight_decay": WEIGHT_DECAY},
541
+ {"params": no_decay, "weight_decay": 0.0},
542
+ ],
543
+ lr=LEARNING_RATE,
544
+ betas=(0.9, 0.95),
545
+ eps=1e-8,
546
+ fused=torch.cuda.is_available(),
547
+ )
548
+
549
+ def cosine_lr(step: int) -> float:
550
+ if step < WARMUP_STEPS:
551
+ return LEARNING_RATE * step / max(1, WARMUP_STEPS)
552
+ p = min(1.0, (step - WARMUP_STEPS) / max(1, MAX_STEPS - WARMUP_STEPS))
553
+ return MIN_LR + 0.5 * (LEARNING_RATE - MIN_LR) * (1 + math.cos(math.pi * p))
554
+
555
+
556
+ # ============================================================
557
+ # Checkpoints
558
+ # ============================================================
559
+
560
+ def load_base_config() -> GPTConfig:
561
+ cfg_dict = json.loads(BASE_CONFIG_FILE.read_text(encoding="utf-8"))
562
+ cfg_dict["use_checkpointing"] = USE_CHECKPOINTING
563
+ return GPTConfig(**cfg_dict)
564
+
565
+ def initialize_model_from_base(model: nn.Module, device: torch.device) -> None:
566
+ if not BASE_CHECKPOINT.exists():
567
+ raise FileNotFoundError(f"Checkpoint de base introuvable: {BASE_CHECKPOINT}")
568
+ ckpt = torch.load(BASE_CHECKPOINT, map_location=device)
569
+ state_dict = normalize_state_dict_keys(ckpt["model"])
570
+ unwrap_model(model).load_state_dict(state_dict, strict=True)
571
+
572
+ def save_checkpoint(model, optimizer, step, best_loss, path):
573
+ raw = unwrap_model(model)
574
+ model_state = normalize_state_dict_keys(raw.state_dict())
575
+ torch.save(
576
+ {
577
+ "model": model_state,
578
+ "optimizer": optimizer.state_dict(),
579
+ "step": step,
580
+ "best_loss": best_loss,
581
+ "config": asdict(raw.cfg),
582
+ },
583
+ path,
584
+ )
585
+
586
+ def load_resume_checkpoint(model, optimizer, path, device) -> tuple[int, float]:
587
+ ckpt = torch.load(path, map_location=device)
588
+ raw = unwrap_model(model)
589
+ model_state = normalize_state_dict_keys(ckpt["model"])
590
+ raw.load_state_dict(model_state, strict=True)
591
+
592
+ try:
593
+ optimizer.load_state_dict(ckpt["optimizer"])
594
+ except Exception as e:
595
+ print(f"[warn] Optimizer state non repris: {e}")
596
+
597
+ return int(ckpt.get("step", 0)), float(ckpt.get("best_loss", 1e9))
598
+
599
+
600
+ # ============================================================
601
+ # Evaluation
602
+ # ============================================================
603
+
604
+ @torch.no_grad()
605
+ def evaluate(model, loader, device, max_batches: int = 100) -> float:
606
+ model.eval()
607
+ losses = []
608
+
609
+ for i, batch in enumerate(loader):
610
+ if i >= max_batches:
611
+ break
612
+
613
+ inp = batch["input_ids"].to(device, non_blocking=True)
614
+ lbl = batch["labels"].to(device, non_blocking=True)
615
+
616
+ with autocast_context(device):
617
+ _, loss = model(inp, lbl)
618
+
619
+ losses.append(loss.item())
620
+
621
+ model.train()
622
+ return sum(losses) / max(1, len(losses))
623
+
624
+
625
+ # ============================================================
626
+ # Main
627
+ # ============================================================
628
+
629
+ def main() -> None:
630
+ ddp_device = init_distributed()
631
+ set_seed(SEED + get_rank())
632
+ device = get_device(ddp_device)
633
+
634
+ cuda_device_index = None
635
+ vram_fraction = None
636
+
637
+ if device.type == "cuda":
638
+ torch.backends.cuda.matmul.allow_tf32 = True
639
+ torch.backends.cudnn.allow_tf32 = True
640
+ torch.set_float32_matmul_precision("high")
641
+
642
+ cuda_device_index = current_cuda_index(device)
643
+ _, total_mem_bytes = torch.cuda.mem_get_info(cuda_device_index)
644
+ target_bytes = int(TARGET_VRAM_GIB * (1024 ** 3))
645
+ vram_fraction = min(target_bytes / total_mem_bytes, 0.999)
646
+
647
+ torch.cuda.memory.set_per_process_memory_fraction(
648
+ vram_fraction,
649
+ device=cuda_device_index,
650
+ )
651
+
652
+ if is_main():
653
+ print("=" * 60)
654
+ print(" Re-train même modèle | 10 x 2000 steps")
655
+ print("=" * 60)
656
+ print(f"Device: {device} | World: {get_world_size()} GPU(s)")
657
+ if device.type == "cuda":
658
+ free_mem, total_mem = torch.cuda.mem_get_info(cuda_device_index)
659
+ print(f"GPU: {torch.cuda.get_device_name(cuda_device_index)}")
660
+ print(f"VRAM cible: {TARGET_VRAM_GIB:.1f} GiB")
661
+ print(f"Fraction PyTorch: {vram_fraction:.4f}")
662
+ print(f"GPU total: {total_mem / 1024**3:.2f} GiB | libre: {free_mem / 1024**3:.2f} GiB")
663
+ print(f"Rounds: {NUM_ROUNDS} | Steps/round: {STEPS_PER_ROUND} | MAX_STEPS: {MAX_STEPS}")
664
+
665
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(str(BASE_TOKENIZER_DIR))
666
+ cfg = load_base_config()
667
+ cfg.vocab_size = len(tokenizer)
668
+
669
+ if is_main():
670
+ CONFIG_FILE.write_text(
671
+ json.dumps(asdict(cfg), indent=2, ensure_ascii=False),
672
+ encoding="utf-8",
673
+ )
674
+ print(f"Base checkpoint: {BASE_CHECKPOINT}")
675
+ print(f"Tokenizer: {BASE_TOKENIZER_DIR}")
676
+
677
+ model = GPT(cfg).to(device)
678
+ initialize_model_from_base(model, device)
679
+
680
+ if USE_COMPILE and hasattr(torch, "compile"):
681
+ model = torch.compile(model, mode=COMPILE_MODE)
682
+ if is_main():
683
+ print(f"torch.compile activé ({COMPILE_MODE})")
684
+
685
+ if is_distributed():
686
+ model = DDP(model, device_ids=[device.index])
687
+
688
+ optimizer = build_optimizer(model)
689
+
690
+ start_step, best_eval = 0, 1e9
691
+ if STATE_FILE.exists():
692
+ try:
693
+ if is_main():
694
+ print(f"Reprise depuis {STATE_FILE}")
695
+ start_step, best_eval = load_resume_checkpoint(model, optimizer, STATE_FILE, device)
696
+ except Exception as e:
697
+ if is_main():
698
+ bad_path = STATE_FILE.with_suffix(".corrupt.pt")
699
+ print(f"[warn] Checkpoint illisible: {e}")
700
+ try:
701
+ STATE_FILE.rename(bad_path)
702
+ print(f"[warn] Checkpoint corrompu renommé vers {bad_path}")
703
+ except Exception:
704
+ pass
705
+ print("[warn] Reprise ignorée, démarrage depuis le checkpoint de base.")
706
+ start_step, best_eval = 0, 1e9
707
+
708
+ if start_step >= MAX_STEPS:
709
+ if is_main():
710
+ print(f"[warn] start_step={start_step} >= MAX_STEPS={MAX_STEPS}")
711
+ print("[warn] Rien à entraîner.")
712
+ return
713
+
714
+ train_ds = PackedMixedBlocks(
715
+ tokenizer=tokenizer,
716
+ specs=TRAIN_SOURCES,
717
+ block_size=cfg.block_size,
718
+ )
719
+
720
+ eval_ds = PackedEvalBlocks(
721
+ tokenizer=tokenizer,
722
+ spec=EVAL_SOURCE,
723
+ block_size=cfg.block_size,
724
+ max_examples=MAX_EVAL_EXAMPLES,
725
+ )
726
+
727
+ train_loader = torch.utils.data.DataLoader(
728
+ train_ds,
729
+ batch_size=BATCH_SIZE,
730
+ num_workers=TRAIN_NUM_WORKERS,
731
+ pin_memory=(device.type == "cuda"),
732
+ )
733
+
734
+ eval_loader = torch.utils.data.DataLoader(
735
+ eval_ds,
736
+ batch_size=BATCH_SIZE,
737
+ num_workers=EVAL_NUM_WORKERS,
738
+ pin_memory=(device.type == "cuda"),
739
+ )
740
+
741
+ if is_main():
742
+ raw_model = unwrap_model(model)
743
+ n_params = count_parameters(raw_model)
744
+ print(f"Paramètres: {n_params / 1e6:.1f}M")
745
+ print(f"Architecture: d={cfg.d_model} | heads={cfg.n_heads} | layers={cfg.n_layers} | block={cfg.block_size}")
746
+ print(f"Batch size: {BATCH_SIZE} | Grad accum: {GRAD_ACCUM_STEPS}")
747
+ print(f"Dtype: {DTYPE} | Compile: {USE_COMPILE} ({COMPILE_MODE if USE_COMPILE else 'off'})")
748
+
749
+ model.train()
750
+ optimizer.zero_grad(set_to_none=True)
751
+
752
+ train_iter = iter(train_loader)
753
+ step = start_step
754
+ t0 = time.time()
755
+ log_loss_sum = 0.0
756
+ log_loss_count = 0
757
+ tokens_since_log = 0
758
+ last_log = time.time()
759
+
760
+ if device.type == "cuda":
761
+ torch.cuda.reset_peak_memory_stats(cuda_device_index)
762
+
763
+ current_round = (step // STEPS_PER_ROUND) + 1
764
+
765
+ while step < MAX_STEPS:
766
+ for _ in range(GRAD_ACCUM_STEPS):
767
+ batch = next(train_iter)
768
+
769
+ inp = batch["input_ids"].to(device, non_blocking=True)
770
+ lbl = batch["labels"].to(device, non_blocking=True)
771
+
772
+ with autocast_context(device):
773
+ _, loss = model(inp, lbl)
774
+
775
+ (loss / GRAD_ACCUM_STEPS).backward()
776
+
777
+ log_loss_sum += loss.item()
778
+ log_loss_count += 1
779
+ tokens_since_log += inp.numel()
780
+
781
+ lr = cosine_lr(step)
782
+ for group in optimizer.param_groups:
783
+ group["lr"] = lr
784
+
785
+ torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
786
+ optimizer.step()
787
+ optimizer.zero_grad(set_to_none=True)
788
+ step += 1
789
+
790
+ new_round = ((step - 1) // STEPS_PER_ROUND) + 1
791
+ if new_round != current_round and is_main():
792
+ current_round = new_round
793
+ print(f"\n===== Round {current_round}/{NUM_ROUNDS} =====")
794
+
795
+ if step % 50 == 0 and is_main():
796
+ now = time.time()
797
+ elapsed = max(1e-6, now - last_log)
798
+ tok_s = tokens_since_log / elapsed
799
+ avg_loss = log_loss_sum / max(1, log_loss_count)
800
+ round_idx = ((step - 1) // STEPS_PER_ROUND) + 1
801
+ step_in_round = ((step - 1) % STEPS_PER_ROUND) + 1
802
+
803
+ print(
804
+ f"round {round_idx:2d}/{NUM_ROUNDS} | "
805
+ f"step {step_in_round:4d}/{STEPS_PER_ROUND} | "
806
+ f"global {step:5d}/{MAX_STEPS} | "
807
+ f"loss={avg_loss:.4f} | lr={lr:.2e} | {tok_s:,.0f} tok/s"
808
+ )
809
+
810
+ if device.type == "cuda":
811
+ allocated = torch.cuda.memory_allocated(cuda_device_index) / 1024**3
812
+ reserved = torch.cuda.memory_reserved(cuda_device_index) / 1024**3
813
+ max_alloc = torch.cuda.max_memory_allocated(cuda_device_index) / 1024**3
814
+ max_reserved = torch.cuda.max_memory_reserved(cuda_device_index) / 1024**3
815
+ print(
816
+ f"GPU mem | alloc={allocated:.2f} GiB | reserved={reserved:.2f} GiB | "
817
+ f"max_alloc={max_alloc:.2f} GiB | max_reserved={max_reserved:.2f} GiB"
818
+ )
819
+
820
+ last_log = now
821
+ tokens_since_log = 0
822
+ log_loss_sum = 0.0
823
+ log_loss_count = 0
824
+
825
+ if step % EVAL_EVERY == 0 and is_main():
826
+ val_loss = evaluate(model, eval_loader, device)
827
+ print(f"[eval] global step {step:5d} | val_loss={val_loss:.4f}")
828
+
829
+ if val_loss < best_eval:
830
+ best_eval = val_loss
831
+ save_checkpoint(model, optimizer, step, best_eval, BEST_MODEL_FILE)
832
+ print(f"✓ Meilleur modèle → {BEST_MODEL_FILE}")
833
+
834
+ if step % SAVE_EVERY == 0 and is_main():
835
+ save_checkpoint(model, optimizer, step, best_eval, STATE_FILE)
836
+ save_checkpoint(model, optimizer, step, best_eval, MODEL_FILE)
837
+ print(f"✓ Checkpoint → {MODEL_FILE}")
838
+
839
+ if step % STEPS_PER_ROUND == 0 and is_main():
840
+ round_no = step // STEPS_PER_ROUND
841
+ round_ckpt = OUT_DIR / f"model_round_{round_no:02d}.pt"
842
+ save_checkpoint(model, optimizer, step, best_eval, round_ckpt)
843
+ print(f"✓ Fin round {round_no}/{NUM_ROUNDS} → {round_ckpt}")
844
+
845
+ if is_main():
846
+ save_checkpoint(model, optimizer, step, best_eval, MODEL_FILE)
847
+ save_checkpoint(model, optimizer, step, best_eval, STATE_FILE)
848
+ total = (time.time() - t0) / 60
849
+ print(f"\nModèle final → {MODEL_FILE}")
850
+ print(f"Meilleur modèle → {BEST_MODEL_FILE}")
851
+ print(f"Temps total : {total:.1f} min")
852
+ print(f"Steps effectués : {step}")
853
+
854
+ if is_distributed():
855
+ dist.destroy_process_group()
856
+
857
+
858
+ if __name__ == "__main__":
859
+ main()
train2.py ADDED
@@ -0,0 +1,852 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ train_nlp_h100_optimized.py — v2 (bugfix device mismatch)
5
+ ===========================================================
6
+ Corrections vs v1 :
7
+ • apply_qlora() appelé APRÈS model.to(device) → lora_A/lora_B naissent sur CUDA
8
+ • LoRALinear.__init__ : move explicite des adaptateurs sur le device du base_layer
9
+ • torch.compile désactivé quand USE_CHECKPOINTING=True (conflict dynamo+checkpoint
10
+ avec sous-modules custom) — on utilise COMPILE_AFTER_CKPT pour les cas où on
11
+ veut quand même compiler (USE_CHECKPOINTING=False)
12
+ • Ajout d'un fallback propre : si compile crash, on continue sans compile
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import itertools
18
+ import json
19
+ import math
20
+ import os
21
+ import random
22
+ import time
23
+ from collections import OrderedDict
24
+ from contextlib import nullcontext
25
+ from dataclasses import asdict, dataclass
26
+ from pathlib import Path
27
+ from typing import Iterator, Optional
28
+
29
+ import torch
30
+ import torch.distributed as dist
31
+ import torch.nn as nn
32
+ import torch.nn.functional as F
33
+
34
+ try:
35
+ import bitsandbytes as bnb
36
+ from bitsandbytes.nn import Params4bit
37
+ HAS_BNB = True
38
+ except ImportError:
39
+ HAS_BNB = False
40
+ print("[warn] bitsandbytes non disponible – quantification 4-bit désactivée")
41
+
42
+ try:
43
+ from flash_attn import flash_attn_func
44
+ HAS_FLASH = True
45
+ except ImportError:
46
+ HAS_FLASH = False
47
+ print("[warn] flash-attn non disponible – fallback F.scaled_dot_product_attention")
48
+
49
+ from datasets import load_dataset
50
+ from torch.nn.parallel import DistributedDataParallel as DDP
51
+ from tokenizers import (
52
+ Tokenizer, decoders, models, normalizers,
53
+ pre_tokenizers, processors, trainers,
54
+ )
55
+ from transformers import PreTrainedTokenizerFast
56
+
57
+
58
+ # ╔══════════════════════════════════════════════════════════════════════════════╗
59
+ # ║ CHEMINS ║
60
+ # ╚══════════════════════════════════════════════════════════════════════════════╝
61
+
62
+ OUT_DIR = Path("./nlp_1b_h100_opt")
63
+ OUT_DIR.mkdir(parents=True, exist_ok=True)
64
+ TOKENIZER_DIR = OUT_DIR / "tokenizer_32k"
65
+ CONFIG_FILE = OUT_DIR / "config.json"
66
+ MODEL_FILE = OUT_DIR / "model.pt"
67
+ BEST_MODEL_FILE= OUT_DIR / "model_best.pt"
68
+ STATE_FILE = OUT_DIR / "train_state.pt"
69
+ BASE_CHECKPOINT: Optional[Path] = None
70
+
71
+
72
+ # ╔══════════════════════════════════════════════════════════════════════════════╗
73
+ # ║ HYPERPARAMÈTRES ║
74
+ # ╚══════════════════════════════════════════════════════════════════════════════╝
75
+
76
+ SEED = 42
77
+ TARGET_VRAM_GIB= 78.0
78
+
79
+ BLOCK_SIZE = 1024
80
+ VOCAB_SIZE = 32_000
81
+ D_MODEL = 1536
82
+ N_HEADS = 24
83
+ N_LAYERS = 24
84
+ D_FF = 6144
85
+ DROPOUT = 0.0
86
+
87
+ USE_QLORA = True
88
+ LORA_R = 64
89
+ LORA_ALPHA = 128
90
+ LORA_DROPOUT = 0.05
91
+ LORA_TARGET_MODULES = ["qkv", "proj", "w1", "w2", "w3"]
92
+
93
+ NUM_EPOCHS = 10
94
+ LEARNING_RATE = 3e-4
95
+ MIN_LR = 3e-5
96
+ WEIGHT_DECAY = 0.1
97
+ WARMUP_STEPS = 500
98
+
99
+ # ┌─────────────────────────────────────────────────────────────────────────────┐
100
+ # │ RÉGLAGE BATCH SIZE → 78 Go VRAM │
101
+ # │ Démarrer : BATCH_SIZE=8, GRAD_ACCUM_STEPS=2 │
102
+ # │ Augmenter BATCH_SIZE par +2 jusqu'à max_reserved ≈ 77 Go dans les logs │
103
+ # │ Si OOM : BATCH_SIZE -= 1 ou USE_CHECKPOINTING=True │
104
+ # └─────────────────────────────────────────────────────────────────────────────┘
105
+ BATCH_SIZE = 16
106
+ GRAD_ACCUM_STEPS = 1
107
+ MAX_GRAD_NORM = 1.0
108
+ EVAL_EVERY = 500
109
+ SAVE_EVERY = 1_000
110
+
111
+ DTYPE = torch.bfloat16
112
+
113
+ # ── Compile : désactivé quand USE_CHECKPOINTING=True pour éviter le conflict
114
+ # dynamo ↔ checkpoint ↔ sous-modules custom (LoRALinear).
115
+ # Mettre USE_CHECKPOINTING=False ET USE_COMPILE=True pour vitesse max.
116
+ USE_CHECKPOINTING = False # économise ~8× activations VRAM
117
+ USE_COMPILE = True # ← mettre True seulement si USE_CHECKPOINTING=False
118
+ COMPILE_MODE = "reduce-overhead"
119
+
120
+ TRAIN_NUM_WORKERS = 4
121
+ EVAL_NUM_WORKERS = 2
122
+ PREFETCH_FACTOR = 2
123
+
124
+ TOKENIZER_SAMPLE_DOCS_PER_SOURCE = 15_000
125
+ TOKENIZER_CHAR_LIMIT = 2_000
126
+ TEXT_CHAR_LIMIT = 4_000
127
+
128
+ SPECIAL_TOKENS = ["<pad>", "<bos>", "<eos>", "<unk>"]
129
+ PAD_TOKEN, BOS_TOKEN, EOS_TOKEN, UNK_TOKEN = SPECIAL_TOKENS
130
+
131
+ WIKI_CONFIGS = ["20231101.en", "20231101.fr", "20231101.ar"]
132
+ FINEWEB_CONFIG = "sample-10BT"
133
+ DEV_DOCS_PER_WIKI_CONFIG = 1_500
134
+ DEV_DOCS_FINEWEB = 3_000
135
+ TRAIN_DOCS_PER_WIKI_CONFIG = 30_000
136
+ TRAIN_DOCS_FINEWEB = 60_000
137
+
138
+
139
+ # ╔══════════════════════════════════════════════════════════════════════════════╗
140
+ # ║ DISTRIBUTED ║
141
+ # ╚══════════════════════════════════════════════════════════════════════════════╝
142
+
143
+ def is_distributed() -> bool:
144
+ return dist.is_available() and dist.is_initialized()
145
+
146
+ def get_rank() -> int:
147
+ return dist.get_rank() if is_distributed() else 0
148
+
149
+ def get_world_size() -> int:
150
+ return dist.get_world_size() if is_distributed() else 1
151
+
152
+ def is_main() -> bool:
153
+ return get_rank() == 0
154
+
155
+ def init_distributed() -> Optional[torch.device]:
156
+ local_rank = int(os.environ.get("LOCAL_RANK", -1))
157
+ if local_rank == -1:
158
+ return None
159
+ dist.init_process_group("nccl")
160
+ torch.cuda.set_device(local_rank)
161
+ return torch.device(f"cuda:{local_rank}")
162
+
163
+ def set_seed(seed: int) -> None:
164
+ random.seed(seed)
165
+ torch.manual_seed(seed)
166
+ if torch.cuda.is_available():
167
+ torch.cuda.manual_seed_all(seed)
168
+
169
+ def get_device(ddp_device: Optional[torch.device] = None) -> torch.device:
170
+ if ddp_device is not None:
171
+ return ddp_device
172
+ if torch.cuda.is_available():
173
+ return torch.device(f"cuda:{torch.cuda.current_device()}")
174
+ return torch.device("cpu")
175
+
176
+ def current_cuda_index(device: torch.device) -> int:
177
+ return device.index if device.index is not None else torch.cuda.current_device()
178
+
179
+ def autocast_context(device: torch.device):
180
+ if device.type == "cuda":
181
+ return torch.autocast("cuda", dtype=DTYPE)
182
+ return nullcontext()
183
+
184
+ def unwrap_model(model: nn.Module) -> nn.Module:
185
+ m = model.module if isinstance(model, DDP) else model
186
+ return m._orig_mod if hasattr(m, "_orig_mod") else m
187
+
188
+ def count_parameters(model: nn.Module, trainable_only: bool = True) -> int:
189
+ return sum(p.numel() for p in model.parameters() if not trainable_only or p.requires_grad)
190
+
191
+ def normalize_state_dict_keys(sd: dict) -> OrderedDict:
192
+ out = OrderedDict()
193
+ for k, v in sd.items():
194
+ for prefix in ("module._orig_mod.", "_orig_mod.", "module."):
195
+ if k.startswith(prefix):
196
+ k = k[len(prefix):]
197
+ break
198
+ out[k] = v
199
+ return out
200
+
201
+ def normalize_text(t: str) -> str:
202
+ return " ".join(t.strip().split())
203
+
204
+ def safe_str(x) -> str:
205
+ return x if isinstance(x, str) else ("" if x is None else str(x))
206
+
207
+
208
+ # ╔══════════════════════════════════════════════════════════════════════════════╗
209
+ # ║ DATASETS ║
210
+ # ╚══════════════════════════════════════════════════════════════════════════════╝
211
+
212
+ def load_wiki_stream(cfg_name: str):
213
+ return load_dataset("wikimedia/wikipedia", cfg_name, split="train", streaming=True)
214
+
215
+ def load_fineweb_stream():
216
+ return load_dataset("HuggingFaceFW/fineweb-edu", FINEWEB_CONFIG, split="train", streaming=True)
217
+
218
+ def stream_texts(ds, start: int, count: int, char_limit: int) -> Iterator[str]:
219
+ for row in itertools.islice(ds, start, start + count):
220
+ text = normalize_text(safe_str(row.get("text", "")))
221
+ if len(text) >= 20:
222
+ yield text[:char_limit]
223
+
224
+ def tokenizer_training_iterator() -> Iterator[str]:
225
+ for c in WIKI_CONFIGS:
226
+ yield from stream_texts(load_wiki_stream(c), 0, TOKENIZER_SAMPLE_DOCS_PER_SOURCE, TOKENIZER_CHAR_LIMIT)
227
+ yield from stream_texts(load_fineweb_stream(), 0, TOKENIZER_SAMPLE_DOCS_PER_SOURCE, TOKENIZER_CHAR_LIMIT)
228
+
229
+ def build_epoch_train_texts(epoch: int) -> list[str]:
230
+ texts: list[str] = []
231
+ for c in WIKI_CONFIGS:
232
+ start = DEV_DOCS_PER_WIKI_CONFIG + epoch * TRAIN_DOCS_PER_WIKI_CONFIG
233
+ texts.extend(stream_texts(load_wiki_stream(c), start, TRAIN_DOCS_PER_WIKI_CONFIG, TEXT_CHAR_LIMIT))
234
+ start = DEV_DOCS_FINEWEB + epoch * TRAIN_DOCS_FINEWEB
235
+ texts.extend(stream_texts(load_fineweb_stream(), start, TRAIN_DOCS_FINEWEB, TEXT_CHAR_LIMIT))
236
+ random.Random(SEED + epoch).shuffle(texts)
237
+ return texts
238
+
239
+ def build_eval_texts() -> list[str]:
240
+ texts: list[str] = []
241
+ for c in WIKI_CONFIGS:
242
+ texts.extend(stream_texts(load_wiki_stream(c), 0, DEV_DOCS_PER_WIKI_CONFIG, TEXT_CHAR_LIMIT))
243
+ texts.extend(stream_texts(load_fineweb_stream(), 0, DEV_DOCS_FINEWEB, TEXT_CHAR_LIMIT))
244
+ return texts
245
+
246
+
247
+ # ╔══════════════════════════════════════════════════════════════════════════════╗
248
+ # ║ PACKED DATASET ║
249
+ # ╚══════════════════════════════════════════════════════════════════════════════╝
250
+
251
+ class PackedTextList(torch.utils.data.IterableDataset):
252
+ def __init__(self, texts, tokenizer, block_size, epoch_seed=0):
253
+ super().__init__()
254
+ self.texts = texts
255
+ self.tokenizer = tokenizer
256
+ self.block_size = block_size
257
+ self.epoch_seed = epoch_seed
258
+
259
+ def __iter__(self):
260
+ worker = torch.utils.data.get_worker_info()
261
+ rank, ws = get_rank(), get_world_size()
262
+ if worker is None:
263
+ shard_mod, shard_id = ws, rank
264
+ else:
265
+ shard_mod = worker.num_workers * ws
266
+ shard_id = rank * worker.num_workers + worker.id
267
+
268
+ rng = random.Random(self.epoch_seed)
269
+ indices = list(range(len(self.texts)))
270
+ rng.shuffle(indices)
271
+
272
+ bos, eos = self.tokenizer.bos_token_id, self.tokenizer.eos_token_id
273
+ buf: list[int] = []
274
+
275
+ for li, ti in enumerate(indices):
276
+ if li % shard_mod != shard_id:
277
+ continue
278
+ ids = self.tokenizer.encode(self.texts[ti], add_special_tokens=False)
279
+ if not ids:
280
+ continue
281
+ buf.extend([bos] + ids + [eos])
282
+ while len(buf) >= self.block_size + 1:
283
+ chunk = buf[: self.block_size + 1]
284
+ buf = buf[self.block_size + 1 :]
285
+ yield {
286
+ "input_ids": torch.tensor(chunk[:-1], dtype=torch.long),
287
+ "labels": torch.tensor(chunk[1:], dtype=torch.long),
288
+ }
289
+
290
+
291
+ # ╔══════════════════════════════════════════════════════════════════════════════╗
292
+ # ║ TOKENIZER ║
293
+ # ╚══════════════════════════════════════════════════════════════════════════════╝
294
+
295
+ def tokenizer_ready() -> bool:
296
+ return (TOKENIZER_DIR / "tokenizer.json").exists() and (TOKENIZER_DIR / "tokenizer_config.json").exists()
297
+
298
+ def train_tokenizer_once() -> None:
299
+ TOKENIZER_DIR.mkdir(parents=True, exist_ok=True)
300
+ tok = Tokenizer(models.BPE(unk_token=UNK_TOKEN))
301
+ tok.normalizer = normalizers.Sequence([normalizers.NFKC()])
302
+ tok.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
303
+ tok.decoder = decoders.ByteLevel()
304
+ trainer = trainers.BpeTrainer(
305
+ vocab_size=VOCAB_SIZE, min_frequency=2, show_progress=is_main(),
306
+ special_tokens=SPECIAL_TOKENS, initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
307
+ )
308
+ tok.train_from_iterator(tokenizer_training_iterator(), trainer=trainer)
309
+ bos_id, eos_id = tok.token_to_id(BOS_TOKEN), tok.token_to_id(EOS_TOKEN)
310
+ tok.post_processor = processors.TemplateProcessing(
311
+ single=f"{BOS_TOKEN} $A {EOS_TOKEN}",
312
+ pair=f"{BOS_TOKEN} $A {EOS_TOKEN} $B:1 {EOS_TOKEN}:1",
313
+ special_tokens=[(BOS_TOKEN, bos_id), (EOS_TOKEN, eos_id)],
314
+ )
315
+ tok.save(str(TOKENIZER_DIR / "tokenizer.json"))
316
+ fast = PreTrainedTokenizerFast(
317
+ tokenizer_file=str(TOKENIZER_DIR / "tokenizer.json"),
318
+ bos_token=BOS_TOKEN, eos_token=EOS_TOKEN, unk_token=UNK_TOKEN, pad_token=PAD_TOKEN,
319
+ )
320
+ fast.save_pretrained(str(TOKENIZER_DIR))
321
+
322
+ def train_or_load_tokenizer() -> PreTrainedTokenizerFast:
323
+ TOKENIZER_DIR.mkdir(parents=True, exist_ok=True)
324
+ if not tokenizer_ready():
325
+ if is_distributed():
326
+ if is_main():
327
+ print("Entraînement tokenizer 32k…"); train_tokenizer_once()
328
+ dist.barrier()
329
+ else:
330
+ print("Entraînement tokenizer 32k…"); train_tokenizer_once()
331
+ return PreTrainedTokenizerFast.from_pretrained(str(TOKENIZER_DIR))
332
+
333
+
334
+ # ╔══════════════════════════════════════════════════════════════════════════════╗
335
+ # ║ MODÈLE ║
336
+ # ╚══════════════════════════════════════════════════════════════════════════════╝
337
+
338
+ @dataclass
339
+ class GPTConfig:
340
+ vocab_size: int = VOCAB_SIZE
341
+ block_size: int = BLOCK_SIZE
342
+ d_model: int = D_MODEL
343
+ n_heads: int = N_HEADS
344
+ n_layers: int = N_LAYERS
345
+ d_ff: int = D_FF
346
+ dropout: float = DROPOUT
347
+ use_checkpointing: bool = USE_CHECKPOINTING
348
+
349
+
350
+ class RMSNorm(nn.Module):
351
+ def __init__(self, dim: int, eps: float = 1e-6):
352
+ super().__init__()
353
+ self.weight = nn.Parameter(torch.ones(dim))
354
+ self.eps = eps
355
+ def forward(self, x):
356
+ return self.weight * x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
357
+
358
+
359
+ class RotaryEmbedding(nn.Module):
360
+ def __init__(self, dim: int, base: int = 10_000, max_seq: int = 4_096):
361
+ super().__init__()
362
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
363
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
364
+ t = torch.arange(max_seq).float()
365
+ freqs = torch.outer(t, inv_freq)
366
+ self.register_buffer("cos_cache", torch.repeat_interleave(freqs.cos(), 2, dim=-1), persistent=False)
367
+ self.register_buffer("sin_cache", torch.repeat_interleave(freqs.sin(), 2, dim=-1), persistent=False)
368
+ def forward(self, seq_len: int, dtype: torch.dtype):
369
+ return self.cos_cache[:seq_len].to(dtype), self.sin_cache[:seq_len].to(dtype)
370
+
371
+
372
+ def rotate_half(x):
373
+ x1, x2 = x[..., ::2], x[..., 1::2]
374
+ return torch.stack((-x2, x1), dim=-1).flatten(-2)
375
+
376
+ def apply_rope(x, cos, sin):
377
+ return x * cos.unsqueeze(0).unsqueeze(0) + rotate_half(x) * sin.unsqueeze(0).unsqueeze(0)
378
+
379
+
380
+ class CausalSelfAttention(nn.Module):
381
+ def __init__(self, cfg: GPTConfig):
382
+ super().__init__()
383
+ assert cfg.d_model % cfg.n_heads == 0
384
+ self.n_heads = cfg.n_heads
385
+ self.head_dim = cfg.d_model // cfg.n_heads
386
+ self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
387
+ self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
388
+ self.dropout_p = cfg.dropout
389
+ self.rope = RotaryEmbedding(self.head_dim)
390
+
391
+ def forward(self, x):
392
+ b, t, c = x.shape
393
+ q, k, v = self.qkv(x).split(c, dim=-1)
394
+ q = q.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
395
+ k = k.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
396
+ v = v.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
397
+ cos, sin = self.rope(t, x.dtype)
398
+ q, k = apply_rope(q, cos, sin), apply_rope(k, cos, sin)
399
+
400
+ if HAS_FLASH:
401
+ # Flash Attention 2 attend (b, t, nh, hd)
402
+ q = q.transpose(1, 2)
403
+ k = k.transpose(1, 2)
404
+ v = v.transpose(1, 2)
405
+ y = flash_attn_func(q, k, v, dropout_p=self.dropout_p if self.training else 0.0, causal=True)
406
+ y = y.reshape(b, t, c)
407
+ else:
408
+ y = F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout_p if self.training else 0.0, is_causal=True)
409
+ y = y.transpose(1, 2).contiguous().view(b, t, c)
410
+
411
+ return self.proj(y)
412
+
413
+
414
+ class SwiGLU(nn.Module):
415
+ def __init__(self, cfg: GPTConfig):
416
+ super().__init__()
417
+ self.w1 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
418
+ self.w2 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
419
+ self.w3 = nn.Linear(cfg.d_ff, cfg.d_model, bias=False)
420
+ def forward(self, x):
421
+ return self.w3(F.silu(self.w1(x)) * self.w2(x))
422
+
423
+
424
+ class Block(nn.Module):
425
+ def __init__(self, cfg: GPTConfig):
426
+ super().__init__()
427
+ self.ln1 = RMSNorm(cfg.d_model)
428
+ self.attn = CausalSelfAttention(cfg)
429
+ self.ln2 = RMSNorm(cfg.d_model)
430
+ self.ff = SwiGLU(cfg)
431
+ def forward(self, x):
432
+ x = x + self.attn(self.ln1(x))
433
+ x = x + self.ff(self.ln2(x))
434
+ return x
435
+
436
+
437
+ class GPT(nn.Module):
438
+ def __init__(self, cfg: GPTConfig):
439
+ super().__init__()
440
+ self.cfg = cfg
441
+ self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
442
+ self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)])
443
+ self.ln_f = RMSNorm(cfg.d_model)
444
+ self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
445
+ self.lm_head.weight = self.tok_emb.weight # weight tying
446
+ self.apply(self._init_weights)
447
+
448
+ @staticmethod
449
+ def _init_weights(m):
450
+ if isinstance(m, (nn.Linear, nn.Embedding)):
451
+ nn.init.normal_(m.weight, 0.0, 0.02)
452
+ if isinstance(m, nn.Linear) and m.bias is not None:
453
+ nn.init.zeros_(m.bias)
454
+
455
+ def forward(self, input_ids, labels=None):
456
+ x = self.tok_emb(input_ids)
457
+ for block in self.blocks:
458
+ if self.cfg.use_checkpointing and self.training:
459
+ x = torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False)
460
+ else:
461
+ x = block(x)
462
+ logits = self.lm_head(self.ln_f(x))
463
+ loss = None
464
+ if labels is not None:
465
+ loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), labels.reshape(-1), ignore_index=-100)
466
+ return logits, loss
467
+
468
+
469
+ # ╔══════════════════════════════════════════════════════════════════════════════╗
470
+ # ║ QLORA ║
471
+ # ║ ║
472
+ # ║ CORRECTIF CLÉ : apply_qlora() DOIT être appelé APRÈS model.to(device). ║
473
+ # ║ LoRALinear détecte automatiquement le device du base_layer et y crée ║
474
+ # ║ lora_A / lora_B directement, sans besoin de .to() séparé. ║
475
+ # ╚══════════════════════════════════════════════════════════════════════════════╝
476
+
477
+ class LoRALinear(nn.Module):
478
+ """
479
+ Adaptateur LoRA autour d'un nn.Linear existant.
480
+
481
+ IMPORTANT : les sous-modules lora_A et lora_B sont créés sur le MÊME
482
+ device que base_layer.weight via le move explicite ci-dessous.
483
+ C'est la correction du bug 'cuda:0 vs cpu' de la v1.
484
+ """
485
+ def __init__(self, base_layer: nn.Linear, r: int = LORA_R, alpha: int = LORA_ALPHA, dropout: float = LORA_DROPOUT):
486
+ super().__init__()
487
+ self.base = base_layer
488
+ self.r = r
489
+ self.scale = alpha / r
490
+ in_f, out_f = base_layer.in_features, base_layer.out_features
491
+
492
+ # ── Détecter le device du base_layer ──────────────────────────────────
493
+ # base_layer.weight peut être un Params4bit (pas de .device direct)
494
+ try:
495
+ dev = next(base_layer.parameters()).device
496
+ except StopIteration:
497
+ dev = torch.device("cpu")
498
+
499
+ # Créer les adaptateurs DIRECTEMENT sur le bon device
500
+ self.lora_A = nn.Linear(in_f, r, bias=False, device=dev)
501
+ self.lora_B = nn.Linear(r, out_f, bias=False, device=dev)
502
+ self.drop = nn.Dropout(dropout)
503
+
504
+ # Initialisation standard LoRA
505
+ nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
506
+ nn.init.zeros_(self.lora_B.weight)
507
+
508
+ # Geler les poids de base
509
+ for p in self.base.parameters():
510
+ p.requires_grad = False
511
+
512
+ def forward(self, x):
513
+ return self.base(x) + self.lora_B(self.lora_A(self.drop(x))) * self.scale
514
+
515
+
516
+ def apply_qlora(model: GPT, device: torch.device) -> GPT:
517
+ """
518
+ Remplace les couches cibles par LoRALinear.
519
+ À appeler IMPÉRATIVEMENT après model.to(device).
520
+ """
521
+ if not USE_QLORA:
522
+ return model
523
+
524
+ replaced = 0
525
+ # Collecter d'abord pour éviter de modifier le dict pendant l'itération
526
+ targets = []
527
+ for name, module in model.named_modules():
528
+ parts = name.split(".")
529
+ if parts[-1] in LORA_TARGET_MODULES and isinstance(module, nn.Linear):
530
+ targets.append((name, module))
531
+
532
+ for name, module in targets:
533
+ parts = name.split(".")
534
+ parent = model
535
+ for part in parts[:-1]:
536
+ parent = getattr(parent, part)
537
+
538
+ lora_layer = LoRALinear(module)
539
+ setattr(parent, parts[-1], lora_layer)
540
+ replaced += 1
541
+
542
+ if is_main():
543
+ print(f"QLoRA : {replaced} couches remplacées (device={device}, NF4={HAS_BNB})")
544
+ return model
545
+
546
+
547
+ def freeze_base_weights(model: GPT) -> None:
548
+ """Seuls lora_A et lora_B restent entraînables."""
549
+ for name, p in model.named_parameters():
550
+ p.requires_grad = ("lora_A" in name or "lora_B" in name)
551
+
552
+
553
+ # ╔══════════════════════════════════════════════════════════════════════════════╗
554
+ # ║ OPTIMIZER ║
555
+ # ╚══════════════════════════════════════════════════════════════════════════════╝
556
+
557
+ def build_optimizer(model: nn.Module) -> torch.optim.Optimizer:
558
+ decay, no_decay = [], []
559
+ for name, p in unwrap_model(model).named_parameters():
560
+ if not p.requires_grad:
561
+ continue
562
+ (decay if p.ndim >= 2 and "weight" in name else no_decay).append(p)
563
+
564
+ groups = [
565
+ {"params": decay, "weight_decay": WEIGHT_DECAY},
566
+ {"params": no_decay, "weight_decay": 0.0},
567
+ ]
568
+
569
+ if HAS_BNB:
570
+ return bnb.optim.PagedAdamW8bit(groups, lr=LEARNING_RATE, betas=(0.9, 0.95), eps=1e-8)
571
+
572
+ return torch.optim.AdamW(groups, lr=LEARNING_RATE, betas=(0.9, 0.95), eps=1e-8, fused=torch.cuda.is_available())
573
+
574
+
575
+ def cosine_lr(step: int, total_steps: int) -> float:
576
+ if step < WARMUP_STEPS:
577
+ return LEARNING_RATE * step / max(1, WARMUP_STEPS)
578
+ p = min(1.0, (step - WARMUP_STEPS) / max(1, total_steps - WARMUP_STEPS))
579
+ return MIN_LR + 0.5 * (LEARNING_RATE - MIN_LR) * (1 + math.cos(math.pi * p))
580
+
581
+
582
+ # ╔══════════════════════════════════════════════════════════════════════════════╗
583
+ # ║ CHECKPOINT ║
584
+ # ╚══════════════════════════════════════════════════════════════════════════════╝
585
+
586
+ def save_checkpoint(model, optimizer, epoch, step, best_loss, path):
587
+ raw = unwrap_model(model)
588
+ torch.save({
589
+ "model": normalize_state_dict_keys(raw.state_dict()),
590
+ "optimizer": optimizer.state_dict(),
591
+ "epoch": epoch, "step": step, "best_loss": best_loss,
592
+ "config": asdict(raw.cfg),
593
+ }, path)
594
+
595
+ def maybe_load_base_checkpoint(model, device):
596
+ if BASE_CHECKPOINT is None or not Path(BASE_CHECKPOINT).exists():
597
+ return
598
+ ckpt = torch.load(BASE_CHECKPOINT, map_location=device)
599
+ unwrap_model(model).load_state_dict(normalize_state_dict_keys(ckpt["model"]), strict=False)
600
+
601
+ def load_resume_checkpoint(model, optimizer, path, device):
602
+ ckpt = torch.load(path, map_location=device)
603
+ unwrap_model(model).load_state_dict(normalize_state_dict_keys(ckpt["model"]), strict=True)
604
+ try:
605
+ optimizer.load_state_dict(ckpt["optimizer"])
606
+ except Exception as e:
607
+ print(f"[warn] Optimizer state non repris: {e}")
608
+ return int(ckpt.get("epoch", 0)), int(ckpt.get("step", 0)), float(ckpt.get("best_loss", 1e9))
609
+
610
+
611
+ # ╔══════════════════════════════════════════════════════════════════════════════╗
612
+ # ║ ÉVALUATION ║
613
+ # ╚══════════════════════════════════════════════════════════════════════════════╝
614
+
615
+ @torch.no_grad()
616
+ def evaluate(model, loader, device, max_batches=200) -> float:
617
+ model.eval()
618
+ losses = []
619
+ for i, batch in enumerate(loader):
620
+ if i >= max_batches:
621
+ break
622
+ inp = batch["input_ids"].to(device, non_blocking=True)
623
+ lbl = batch["labels"].to(device, non_blocking=True)
624
+ with autocast_context(device):
625
+ _, loss = model(inp, lbl)
626
+ losses.append(loss.item())
627
+ model.train()
628
+ return sum(losses) / max(1, len(losses))
629
+
630
+
631
+ # ╔══════════════════════════════════════════════════════════════════════════════╗
632
+ # ║ DATALOADER ║
633
+ # ╚══════════════════════════════════════════════════════════════════════════════╝
634
+
635
+ def make_loader(dataset, batch_size, num_workers, is_cuda):
636
+ kwargs = dict(batch_size=batch_size, num_workers=num_workers, pin_memory=is_cuda)
637
+ if num_workers > 0:
638
+ kwargs["persistent_workers"] = True
639
+ kwargs["prefetch_factor"] = PREFETCH_FACTOR
640
+ return torch.utils.data.DataLoader(dataset, **kwargs)
641
+
642
+
643
+ # ╔══════════════════════════════════════════════════════════════════════════════╗
644
+ # ║ MAIN ║
645
+ # ╚══════════════════════════════════════════════════════════════════════════════╝
646
+
647
+ def main() -> None:
648
+ ddp_device = init_distributed()
649
+ set_seed(SEED + get_rank())
650
+ device = get_device(ddp_device)
651
+ is_cuda = device.type == "cuda"
652
+
653
+ cuda_idx = None
654
+ vram_fraction = None
655
+
656
+ if is_cuda:
657
+ torch.backends.cuda.matmul.allow_tf32 = True
658
+ torch.backends.cudnn.allow_tf32 = True
659
+ torch.set_float32_matmul_precision("high")
660
+ cuda_idx = current_cuda_index(device)
661
+ _, total = torch.cuda.mem_get_info(cuda_idx)
662
+ vram_fraction = min(TARGET_VRAM_GIB * (1024**3) / total, 0.999)
663
+ torch.cuda.memory.set_per_process_memory_fraction(vram_fraction, device=cuda_idx)
664
+
665
+ if is_main():
666
+ print("=" * 72)
667
+ print(" GPT ~1B | H100 80 Go | QLoRA + BF16 + TF32 | v2 (device fix)")
668
+ print("=" * 72)
669
+ print(f"Device : {device} | World: {get_world_size()} GPU(s)")
670
+ print(f"Flash-2 : {HAS_FLASH} | BNB 4-bit: {HAS_BNB} | QLoRA: {USE_QLORA}")
671
+ print(f"Grad ckpt: {USE_CHECKPOINTING} | Compile: {USE_COMPILE} ({COMPILE_MODE})")
672
+ if is_cuda:
673
+ free, total = torch.cuda.mem_get_info(cuda_idx)
674
+ print(f"GPU : {torch.cuda.get_device_name(cuda_idx)}")
675
+ print(f"VRAM : {total/1024**3:.1f} GiB | libre: {free/1024**3:.1f} GiB")
676
+
677
+ tokenizer = train_or_load_tokenizer()
678
+ cfg = GPTConfig(vocab_size=len(tokenizer))
679
+
680
+ if is_main():
681
+ CONFIG_FILE.write_text(json.dumps(asdict(cfg), indent=2, ensure_ascii=False), encoding="utf-8")
682
+
683
+ # ── 1. Créer le modèle ────────────────────────────────────────────────────
684
+ model = GPT(cfg).to(device)
685
+
686
+ # ── 2. Appliquer QLoRA APRÈS .to(device) ─────────────────────────────────
687
+ # C'est la correction principale : lora_A/lora_B sont créés sur CUDA
688
+ if USE_QLORA:
689
+ model = apply_qlora(model, device)
690
+ freeze_base_weights(model)
691
+
692
+ maybe_load_base_checkpoint(model, device)
693
+
694
+ # ── 3. torch.compile (seulement si USE_CHECKPOINTING=False) ──────────────
695
+ # La combinaison compile + checkpoint + LoRALinear custom est instable
696
+ # avec torch.dynamo sur PyTorch 2.x. Choisir l'un ou l'autre.
697
+ if USE_COMPILE and not USE_CHECKPOINTING and hasattr(torch, "compile"):
698
+ try:
699
+ model = torch.compile(model, mode=COMPILE_MODE)
700
+ if is_main():
701
+ print(f"torch.compile activé ({COMPILE_MODE})")
702
+ except Exception as e:
703
+ if is_main():
704
+ print(f"[warn] torch.compile échoué ({e}) — poursuite sans compile")
705
+
706
+ # ── 4. DDP ────────────────────────────────────────────────────────────────
707
+ if is_distributed():
708
+ model = DDP(model, device_ids=[device.index])
709
+
710
+ optimizer = build_optimizer(model)
711
+
712
+ # ── Datasets ──────────────────────────────────────────────────────────────
713
+ eval_texts = build_eval_texts()
714
+ eval_ds = PackedTextList(eval_texts, tokenizer, cfg.block_size, SEED + 999)
715
+ eval_loader = make_loader(eval_ds, BATCH_SIZE, EVAL_NUM_WORKERS, is_cuda)
716
+
717
+ init_texts = build_epoch_train_texts(0)
718
+ steps_per_epoch = max(1, len(init_texts) // BATCH_SIZE)
719
+ total_steps_est = steps_per_epoch * NUM_EPOCHS
720
+
721
+ # ── Reprise ───────────────────────────────────────────────────────────────
722
+ start_epoch, start_step, best_eval = 0, 0, 1e9
723
+ if STATE_FILE.exists():
724
+ try:
725
+ if is_main(): print(f"Reprise depuis {STATE_FILE}")
726
+ start_epoch, start_step, best_eval = load_resume_checkpoint(model, optimizer, STATE_FILE, device)
727
+ except Exception as e:
728
+ if is_main():
729
+ bad = STATE_FILE.with_suffix(".corrupt.pt")
730
+ print(f"[warn] Checkpoint illisible: {e}")
731
+ try: STATE_FILE.rename(bad)
732
+ except Exception: pass
733
+ start_epoch, start_step, best_eval = 0, 0, 1e9
734
+
735
+ if is_main():
736
+ raw = unwrap_model(model)
737
+ n_total = count_parameters(raw, False)
738
+ n_train = count_parameters(raw, True)
739
+ print(f"Paramètres totaux : {n_total/1e9:.3f}B")
740
+ print(f"Paramètres entraînés : {n_train/1e6:.1f}M ({100*n_train/max(1,n_total):.2f}%)")
741
+ print(f"Batch size : {BATCH_SIZE} | Grad accum: {GRAD_ACCUM_STEPS} | Effective: {BATCH_SIZE*GRAD_ACCUM_STEPS}")
742
+ print(f"Steps estimés: {total_steps_est} | Eval texts: {len(eval_texts)}")
743
+ print()
744
+ print("── Conseil VRAM ───────────────────────────���────────────────────")
745
+ print(" Surveille 'max_reserved=XX GiB' à step 50.")
746
+ print(" Augmente BATCH_SIZE par +2 jusqu'à ~77 Go réservés.")
747
+ print(" Si OOM : BATCH_SIZE -= 1 ou USE_CHECKPOINTING=True.")
748
+ print("────────────────────────────────────────────────────────────────")
749
+
750
+ # ── Boucle d'entraînement ─────────────────────────────────────────────────
751
+ model.train()
752
+ optimizer.zero_grad(set_to_none=True)
753
+
754
+ global_step = start_step
755
+ t0 = time.time()
756
+ log_loss_sum = 0.0
757
+ log_loss_count = 0
758
+ tokens_since_log = 0
759
+ last_log = time.time()
760
+
761
+ if is_cuda:
762
+ torch.cuda.reset_peak_memory_stats(cuda_idx)
763
+
764
+ for epoch in range(start_epoch, NUM_EPOCHS):
765
+ if is_main():
766
+ print(f"\n{'='*20} Epoch {epoch+1}/{NUM_EPOCHS} {'='*20}")
767
+
768
+ train_texts = build_epoch_train_texts(epoch)
769
+ train_ds = PackedTextList(train_texts, tokenizer, cfg.block_size, SEED + epoch)
770
+ train_loader = make_loader(train_ds, BATCH_SIZE, TRAIN_NUM_WORKERS, is_cuda)
771
+
772
+ for micro_step, batch in enumerate(train_loader):
773
+ inp = batch["input_ids"].to(device, non_blocking=True)
774
+ lbl = batch["labels"].to(device, non_blocking=True)
775
+
776
+ with autocast_context(device):
777
+ _, loss = model(inp, lbl)
778
+
779
+ (loss / GRAD_ACCUM_STEPS).backward()
780
+
781
+ log_loss_sum += loss.item()
782
+ log_loss_count += 1
783
+ tokens_since_log += inp.numel()
784
+
785
+ if (micro_step + 1) % GRAD_ACCUM_STEPS != 0:
786
+ continue
787
+
788
+ lr = cosine_lr(global_step, total_steps_est)
789
+ for group in optimizer.param_groups:
790
+ group["lr"] = lr
791
+
792
+ torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
793
+ optimizer.step()
794
+ optimizer.zero_grad(set_to_none=True)
795
+ global_step += 1
796
+
797
+ if global_step % 50 == 0 and is_main():
798
+ now = time.time()
799
+ elapsed = max(1e-6, now - last_log)
800
+ tok_s = tokens_since_log / elapsed
801
+ avg_loss = log_loss_sum / max(1, log_loss_count)
802
+ print(
803
+ f"ep {epoch+1}/{NUM_EPOCHS} | step={global_step:5d} | "
804
+ f"loss={avg_loss:.4f} | lr={lr:.2e} | {tok_s:,.0f} tok/s"
805
+ )
806
+ if is_cuda:
807
+ alloc = torch.cuda.memory_allocated(cuda_idx) / 1024**3
808
+ reserved = torch.cuda.memory_reserved(cuda_idx) / 1024**3
809
+ max_alloc = torch.cuda.max_memory_allocated(cuda_idx) / 1024**3
810
+ max_res = torch.cuda.max_memory_reserved(cuda_idx) / 1024**3
811
+ print(
812
+ f"GPU mem | alloc={alloc:.2f} | reserved={reserved:.2f} | "
813
+ f"max_alloc={max_alloc:.2f} | max_reserved={max_res:.2f} (GiB)"
814
+ )
815
+ last_log = now
816
+ tokens_since_log = 0
817
+ log_loss_sum = 0.0
818
+ log_loss_count = 0
819
+
820
+ if global_step % EVAL_EVERY == 0 and is_main():
821
+ val_loss = evaluate(model, eval_loader, device)
822
+ print(f"[eval] step {global_step:5d} | val_loss={val_loss:.4f}")
823
+ if val_loss < best_eval:
824
+ best_eval = val_loss
825
+ save_checkpoint(model, optimizer, epoch, global_step, best_eval, BEST_MODEL_FILE)
826
+ print(f"✓ Meilleur modèle → {BEST_MODEL_FILE}")
827
+
828
+ if global_step % SAVE_EVERY == 0 and is_main():
829
+ save_checkpoint(model, optimizer, epoch, global_step, best_eval, STATE_FILE)
830
+ save_checkpoint(model, optimizer, epoch, global_step, best_eval, MODEL_FILE)
831
+ print(f"✓ Checkpoint → {MODEL_FILE}")
832
+
833
+ if is_main():
834
+ save_checkpoint(model, optimizer, epoch + 1, global_step, best_eval, STATE_FILE)
835
+ ckpt = OUT_DIR / f"model_epoch_{epoch+1:02d}.pt"
836
+ save_checkpoint(model, optimizer, epoch + 1, global_step, best_eval, ckpt)
837
+ print(f"✓ Fin epoch {epoch+1}/{NUM_EPOCHS} → {ckpt}")
838
+
839
+ if is_main():
840
+ save_checkpoint(model, optimizer, NUM_EPOCHS, global_step, best_eval, MODEL_FILE)
841
+ save_checkpoint(model, optimizer, NUM_EPOCHS, global_step, best_eval, STATE_FILE)
842
+ total_min = (time.time() - t0) / 60
843
+ print(f"\nModèle final → {MODEL_FILE}")
844
+ print(f"Meilleur modèle → {BEST_MODEL_FILE}")
845
+ print(f"Temps total : {total_min:.1f} min | Steps: {global_step}")
846
+
847
+ if is_distributed():
848
+ dist.destroy_process_group()
849
+
850
+
851
+ if __name__ == "__main__":
852
+ main()
train_aramix_h100_full.py ADDED
@@ -0,0 +1,1055 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ Entraînement LLM Multi-Domaine — H100 80 Go [CORRIGÉ]
6
+ ════════════════════════════════════════════════════════
7
+ Architecture : GPT causal ~435M
8
+ RMSNorm · RoPE · SwiGLU · Flash Attention (SDPA)
9
+ Précision : BF16 natif + TF32 + fused AdamW
10
+ Compilation : torch.compile(mode="reduce-overhead")
11
+ Dataset : 10 domaines en streaming HF, échantillonnage pondéré
12
+
13
+ Correctifs v2 :
14
+ ✓ trust_remote_code=True supprimé (déprécié datasets>=3.x)
15
+ ✓ wikipedia → datasets.load_dataset sans script legacy
16
+ ✓ Datasets remplacés par leurs équivalents Parquet/modernes
17
+ ✓ MAX_STEPS = 5 000 (2 epochs estimées sur corpus réduit)
18
+ ✓ pubmed_abstracts → pubmed_qa (Parquet natif)
19
+ ✓ RedPajama CC → allenai/c4 (en/fr/ar)
20
+ ✓ pile-of-law → joelniklaus/pile_of_law (Parquet)
21
+ ✓ Gestion propre de StopIteration dans le DataLoader
22
+
23
+ Usage mono-GPU :
24
+ python train_aramix_h100_full.py
25
+ """
26
+
27
+ from __future__ import annotations
28
+
29
+ import copy
30
+ import gc
31
+ import json
32
+ import math
33
+ import os
34
+ import random
35
+ import time
36
+ from dataclasses import asdict, dataclass, field
37
+ from pathlib import Path
38
+ from typing import Dict, Iterator, List, Optional, Tuple
39
+
40
+ import torch
41
+ import torch.distributed as dist
42
+ import torch.nn as nn
43
+ import torch.nn.functional as F
44
+ from torch.nn.parallel import DistributedDataParallel as DDP
45
+ from torch.utils.data import DataLoader, IterableDataset
46
+ from datasets import load_dataset
47
+ from tokenizers import (
48
+ Tokenizer,
49
+ decoders,
50
+ models,
51
+ normalizers,
52
+ pre_tokenizers,
53
+ processors,
54
+ trainers,
55
+ )
56
+ from transformers import PreTrainedTokenizerFast
57
+
58
+
59
+ # ══════════════════════════════════════════════════════════════════
60
+ # §1 CONFIGURATION GLOBALE
61
+ # ══════════════════════════════════════════════════════════════════
62
+
63
+ OUT_DIR = Path("./aramix_h100")
64
+ OUT_DIR.mkdir(parents=True, exist_ok=True)
65
+
66
+ SEED = 42
67
+
68
+ # ── Tokenizer ─────────────────────────────────────────────────────
69
+ TOKENIZER_DIR = OUT_DIR / "tokenizer_32k"
70
+ TOKENIZER_VOCAB = 32_000
71
+ TOKENIZER_SAMPLE_DOCS = 80_000 # réduit pour aller plus vite
72
+ TOKENIZER_CHAR_LIMIT = 2_000
73
+
74
+ SPECIAL_TOKENS = ["<pad>", "<bos>", "<eos>", "<unk>"]
75
+ PAD_TOKEN, BOS_TOKEN, EOS_TOKEN, UNK_TOKEN = SPECIAL_TOKENS
76
+
77
+ # ── Architecture (~435M) ──────────────────────────────────────────
78
+ VOCAB_SIZE = 32_000 # mis à jour après chargement du tokenizer
79
+ BLOCK_SIZE = 1024
80
+ D_MODEL = 1024
81
+ N_HEADS = 18
82
+ N_LAYERS = 24
83
+ D_FF = 4096
84
+ DROPOUT = 0.1
85
+
86
+ # ── Entraînement ──────────────────────────────────────────────────
87
+ LEARNING_RATE = 3e-4
88
+ MIN_LR = 3e-5
89
+ WEIGHT_DECAY = 0.1
90
+ WARMUP_STEPS = 200 # réduit proportionnellement (5k steps)
91
+ MAX_STEPS = 5_000 # ← 5 000 steps / ~2 epochs
92
+ MAX_GRAD_NORM = 1.0
93
+
94
+ BATCH_SIZE = 32 # H100 80 Go : 32×1024 en BF16 ≈ 26 Go
95
+ GRAD_ACCUM_STEPS = 1
96
+ EVAL_EVERY = 500
97
+ SAVE_EVERY = 1_000
98
+ MAX_EVAL_DOCS_DOM = 300
99
+ TRAIN_CHAR_LIMIT = 4_000
100
+
101
+ # ── Précision & compilation ───────────────────────────────────────
102
+ DTYPE = torch.bfloat16
103
+ USE_COMPILE = True
104
+ USE_CHECKPOINTING = False # inutile avec 80 Go VRAM
105
+
106
+ # ── Fichiers ─────────────────────────────────────────────────────
107
+ MODEL_FILE = OUT_DIR / "model.pt"
108
+ BEST_MODEL_FILE = OUT_DIR / "model_best.pt"
109
+ STATE_FILE = OUT_DIR / "train_state.pt"
110
+ CONFIG_FILE = OUT_DIR / "config.json"
111
+ LOG_FILE = OUT_DIR / "train_log.jsonl"
112
+
113
+
114
+ # ══════════════════════════════════════════════════════════════════
115
+ # §2 REGISTRE MULTI-DOMAINES (tous compatibles Parquet / sans script)
116
+ # ══════════════════════════════════════════════════════════════════
117
+
118
+ @dataclass
119
+ class DomainConfig:
120
+ name: str
121
+ hf_path: str
122
+ hf_subset: Optional[str]
123
+ hf_split: str
124
+ text_field: str
125
+ weight: float
126
+ char_limit: int
127
+ min_chars: int = 80
128
+ lang_filter: Optional[str] = None
129
+ # ── NOUVEAU : certains datasets nécessitent un champ imbriqué ──
130
+ extra_kwargs: dict = field(default_factory=dict)
131
+
132
+
133
+ DOMAINS: List[DomainConfig] = [
134
+
135
+ # ── Code / Dev ────────────────────────────────────────────────
136
+ # bigcode/starcoderdata utilise un script MAIS accepte désormais
137
+ # le paramètre trust_remote_code=False via HF_DATASETS_TRUST_REMOTE_CODE=0
138
+ # → on passe par sa version Parquet directe sur HF hub
139
+ DomainConfig(
140
+ name="code_python",
141
+ hf_path="bigcode/the-stack-dedup",
142
+ hf_subset="data/python",
143
+ hf_split="train", text_field="content",
144
+ weight=0.12, char_limit=6_000, min_chars=100,
145
+ ),
146
+ DomainConfig(
147
+ name="code_js",
148
+ hf_path="bigcode/the-stack-dedup",
149
+ hf_subset="data/javascript",
150
+ hf_split="train", text_field="content",
151
+ weight=0.06, char_limit=5_000, min_chars=100,
152
+ ),
153
+ DomainConfig(
154
+ name="code_csharp",
155
+ hf_path="bigcode/the-stack-dedup",
156
+ hf_subset="data/c-sharp",
157
+ hf_split="train", text_field="content",
158
+ weight=0.04, char_limit=5_000, min_chars=100,
159
+ ),
160
+
161
+ # ── Médical ───────────────────────────────────────────────────
162
+ DomainConfig(
163
+ name="medical_flashcards",
164
+ hf_path="medalpaca/medical_meadow_medical_flashcards",
165
+ hf_subset=None, hf_split="train", text_field="output",
166
+ weight=0.06, char_limit=3_000, min_chars=60,
167
+ ),
168
+ # pubmed_abstracts → pubmed_qa (natif Parquet)
169
+ DomainConfig(
170
+ name="medical_pubmed",
171
+ hf_path="qiaojin/PubMedQA",
172
+ hf_subset="pqa_labeled",
173
+ hf_split="train", text_field="long_answer",
174
+ weight=0.06, char_limit=4_000, min_chars=100,
175
+ ),
176
+
177
+ # ── Français ──────────────────────────────────────────────────
178
+ # wikipedia sans script : utiliser la version datasets>=2.14 qui
179
+ # charge directement les Parquet sans script .py
180
+ DomainConfig(
181
+ name="french_wiki",
182
+ hf_path="wikimedia/wikipedia",
183
+ hf_subset="20231101.fr",
184
+ hf_split="train", text_field="text",
185
+ weight=0.08, char_limit=5_000, min_chars=100,
186
+ ),
187
+ DomainConfig(
188
+ name="french_culture",
189
+ hf_path="lyon-nlp/corpus-france-culture-inter-2023",
190
+ hf_subset=None, hf_split="train", text_field="text",
191
+ weight=0.04, char_limit=4_000, min_chars=80,
192
+ ),
193
+ DomainConfig(
194
+ name="french_news",
195
+ hf_path="mlsum", hf_subset="fr",
196
+ hf_split="train", text_field="text",
197
+ weight=0.03, char_limit=3_000, min_chars=80,
198
+ ),
199
+
200
+ # ── Arabe ─────────────────────────────────────────────────────
201
+ DomainConfig(
202
+ name="arabic_aramix",
203
+ hf_path="AdaMLLab/AraMix", hf_subset="matched",
204
+ hf_split="train", text_field="text",
205
+ weight=0.10, char_limit=4_000, min_chars=80,
206
+ ),
207
+ DomainConfig(
208
+ name="arabic_wiki",
209
+ hf_path="wikimedia/wikipedia",
210
+ hf_subset="20231101.ar",
211
+ hf_split="train", text_field="text",
212
+ weight=0.05, char_limit=4_000, min_chars=80,
213
+ ),
214
+ # OSCAR-2301 → oscar-corpus/OSCAR-2301 reste supporté sans script
215
+ DomainConfig(
216
+ name="arabic_oscar",
217
+ hf_path="oscar-corpus/OSCAR-2301", hf_subset="ar",
218
+ hf_split="train", text_field="content",
219
+ weight=0.04, char_limit=3_000, min_chars=80,
220
+ ),
221
+
222
+ # ── Créatif ───────────────────────────────────────────────────
223
+ DomainConfig(
224
+ name="creative_writing",
225
+ hf_path="ajibawa-2023/creative-writing-40k",
226
+ hf_subset=None, hf_split="train", text_field="output",
227
+ weight=0.05, char_limit=5_000, min_chars=100,
228
+ ),
229
+ DomainConfig(
230
+ name="stories",
231
+ hf_path="roneneldan/TinyStories",
232
+ hf_subset=None, hf_split="train", text_field="text",
233
+ weight=0.04, char_limit=3_000, min_chars=80,
234
+ ),
235
+ DomainConfig(
236
+ name="reddit_posts",
237
+ hf_path="webis/tldr-17",
238
+ hf_subset=None, hf_split="train", text_field="content",
239
+ weight=0.03, char_limit=3_000, min_chars=80,
240
+ ),
241
+
242
+ # ── Mathématiques ─────────────────────────────────────────────
243
+ DomainConfig(
244
+ name="math_stackexchange",
245
+ hf_path="math-ai/StackMathQA",
246
+ hf_subset=None, hf_split="train", text_field="A",
247
+ weight=0.04, char_limit=4_000, min_chars=80,
248
+ ),
249
+ DomainConfig(
250
+ name="math_problems",
251
+ hf_path="lighteval/MATH",
252
+ hf_subset=None, hf_split="train", text_field="solution",
253
+ weight=0.03, char_limit=3_000, min_chars=60,
254
+ ),
255
+
256
+ # ── Juridique ─────────────────────────────────────────────────
257
+ # joelniklaus/pile_of_law = version Parquet de pile-of-law
258
+ DomainConfig(
259
+ name="legal_en",
260
+ hf_path="joelniklaus/pile_of_law",
261
+ hf_subset="courtlistener_opinions",
262
+ hf_split="train", text_field="text",
263
+ weight=0.04, char_limit=5_000, min_chars=100,
264
+ ),
265
+ DomainConfig(
266
+ name="legal_fr",
267
+ hf_path="antoinelouis/french-legal-corpus",
268
+ hf_subset=None, hf_split="train", text_field="text",
269
+ weight=0.02, char_limit=4_000, min_chars=80,
270
+ ),
271
+
272
+ # ── Science ───────────────────────────────────────────────────
273
+ # RedPajama arxiv → allenai/peS2o (semantic scholar, Parquet)
274
+ DomainConfig(
275
+ name="science_arxiv",
276
+ hf_path="allenai/peS2o",
277
+ hf_subset=None,
278
+ hf_split="train", text_field="text",
279
+ weight=0.05, char_limit=6_000, min_chars=100,
280
+ ),
281
+
282
+ # ── Multilingue général ───────────────────────────────────────
283
+ # RedPajama CC → allenai/c4 multilingual (Parquet)
284
+ DomainConfig(
285
+ name="multilingual_cc",
286
+ hf_path="allenai/c4",
287
+ hf_subset="multilingual",
288
+ hf_split="train", text_field="text",
289
+ weight=0.02, char_limit=3_000, min_chars=80,
290
+ ),
291
+ ]
292
+
293
+ # Validation : somme des poids ≈ 1.0
294
+ _wsum = sum(d.weight for d in DOMAINS)
295
+ assert abs(_wsum - 1.0) < 0.01, f"Somme des poids = {_wsum:.4f} ≠ 1.0"
296
+
297
+
298
+ def select_domains(*names: str) -> List[DomainConfig]:
299
+ """Sous-sélection + renormalisation automatique des poids."""
300
+ selected = [d for d in DOMAINS if d.name in names]
301
+ if not selected:
302
+ raise ValueError(f"Aucun domaine parmi : {names}")
303
+ total = sum(d.weight for d in selected)
304
+ out = []
305
+ for d in selected:
306
+ dc = copy.copy(d)
307
+ dc.weight = round(d.weight / total, 6)
308
+ out.append(dc)
309
+ return out
310
+
311
+
312
+ def print_domain_summary(domains: Optional[List[DomainConfig]] = None) -> None:
313
+ if domains is None:
314
+ domains = DOMAINS
315
+ print(f"\n{'Domaine':<25} {'Dataset HF':<45} {'Poids':>6}")
316
+ print("─" * 80)
317
+ for d in sorted(domains, key=lambda x: -x.weight):
318
+ sub = f"/{d.hf_subset}" if d.hf_subset else ""
319
+ print(f"{d.name:<25} {d.hf_path + sub:<45} {d.weight:>6.1%}")
320
+ print(f"{'TOTAL':<25} {'':<45} {sum(d.weight for d in domains):>6.1%}\n")
321
+
322
+
323
+ # ══════════════════════════════════════════════════════════════════
324
+ # §3 STREAMING DATASET MULTI-DOMAINES
325
+ # ══════════════════════════════════════════════════════════════════
326
+
327
+ def domain_text_stream(
328
+ domain: DomainConfig,
329
+ max_docs: Optional[int] = None,
330
+ ) -> Iterator[str]:
331
+ """
332
+ Charge, filtre et nettoie le texte brut d'un domaine HF en streaming.
333
+
334
+ CORRECTIF : trust_remote_code supprimé, gestion d'erreur par domaine
335
+ pour éviter de planter tout l'entraînement si un dataset échoue.
336
+ """
337
+ try:
338
+ ds = load_dataset(
339
+ domain.hf_path,
340
+ domain.hf_subset,
341
+ split=domain.hf_split,
342
+ streaming=True,
343
+ # trust_remote_code=True ← SUPPRIMÉ (déprécié datasets>=3.x)
344
+ )
345
+ except Exception as e:
346
+ print(f"[WARN] Domaine '{domain.name}' impossible à charger : {e}")
347
+ return # domaine ignoré proprement
348
+
349
+ n = 0
350
+ for row in ds:
351
+ text = row.get(domain.text_field, "")
352
+ if not text or not isinstance(text, str):
353
+ continue
354
+ if domain.lang_filter:
355
+ lang = row.get("lang", row.get("language", ""))
356
+ if lang and lang != domain.lang_filter:
357
+ continue
358
+ text = " ".join(text.strip().split())
359
+ if len(text) < domain.min_chars:
360
+ continue
361
+ yield text[: domain.char_limit]
362
+ n += 1
363
+ if max_docs and n >= max_docs:
364
+ break
365
+
366
+
367
+ def interleaved_text_stream(
368
+ domains: List[DomainConfig],
369
+ max_docs_per_domain: Optional[int] = None,
370
+ seed: int = 42,
371
+ ) -> Iterator[Tuple[str, str]]:
372
+ """
373
+ Mélange stochastique pondéré des domaines.
374
+ Yield : (domain_name, text)
375
+
376
+ CORRECTIF : quand tous les domaines sont épuisés → StopIteration propre.
377
+ """
378
+ rng = random.Random(seed)
379
+ iters = {d.name: domain_text_stream(d, max_docs=max_docs_per_domain) for d in domains}
380
+ exhausted: set = set()
381
+
382
+ while len(exhausted) < len(domains):
383
+ active = [d for d in domains if d.name not in exhausted]
384
+ if not active:
385
+ break
386
+ chosen = rng.choices(active, weights=[d.weight for d in active], k=1)[0]
387
+ try:
388
+ yield chosen.name, next(iters[chosen.name])
389
+ except StopIteration:
390
+ exhausted.add(chosen.name)
391
+
392
+
393
+ def packed_block_stream(
394
+ tokenizer: PreTrainedTokenizerFast,
395
+ domains: List[DomainConfig],
396
+ block_size: int,
397
+ max_docs_per_domain: Optional[int] = None,
398
+ seed: int = 42,
399
+ ) -> Iterator[Dict]:
400
+ """
401
+ Tokenise et pack les textes en blocs denses de block_size tokens.
402
+ Yield : {"input_ids": list[int], "labels": list[int], "domain": str}
403
+ """
404
+ bos, eos = tokenizer.bos_token_id, tokenizer.eos_token_id
405
+ buffer: List[int] = []
406
+ buffer_domain: List[str] = []
407
+
408
+ for domain_name, text in interleaved_text_stream(domains, max_docs_per_domain, seed):
409
+ ids = tokenizer.encode(text, add_special_tokens=False)
410
+ if not ids:
411
+ continue
412
+ seq = [bos] + ids + [eos]
413
+ buffer.extend(seq)
414
+ buffer_domain.extend([domain_name] * len(seq))
415
+
416
+ while len(buffer) >= block_size + 1:
417
+ chunk = buffer[:block_size + 1]
418
+ chunk_domain = buffer_domain[:block_size + 1]
419
+ buffer = buffer[block_size + 1:]
420
+ buffer_domain = buffer_domain[block_size + 1:]
421
+ majority = max(set(chunk_domain), key=chunk_domain.count)
422
+ yield {"input_ids": chunk[:-1], "labels": chunk[1:], "domain": majority}
423
+
424
+
425
+ class MultiDomainPackedDataset(IterableDataset):
426
+ """
427
+ IterableDataset multi-domaines avec sharding inter-workers.
428
+ Compatible DataLoader(num_workers=N).
429
+ """
430
+ def __init__(
431
+ self,
432
+ tokenizer: PreTrainedTokenizerFast,
433
+ domains: List[DomainConfig],
434
+ block_size: int,
435
+ max_docs_per_domain: Optional[int] = None,
436
+ seed: int = 42,
437
+ ):
438
+ super().__init__()
439
+ self.tokenizer = tokenizer
440
+ self.domains = domains
441
+ self.block_size = block_size
442
+ self.max_docs_per_domain = max_docs_per_domain
443
+ self.seed = seed
444
+
445
+ def __iter__(self):
446
+ worker = torch.utils.data.get_worker_info()
447
+ wid = worker.id if worker else 0
448
+ n_workers = worker.num_workers if worker else 1
449
+
450
+ for idx, block in enumerate(packed_block_stream(
451
+ self.tokenizer, self.domains, self.block_size,
452
+ self.max_docs_per_domain, seed=self.seed + wid,
453
+ )):
454
+ if idx % n_workers != wid:
455
+ continue
456
+ yield {
457
+ "input_ids": torch.tensor(block["input_ids"], dtype=torch.long),
458
+ "labels": torch.tensor(block["labels"], dtype=torch.long),
459
+ "domain": block["domain"],
460
+ }
461
+
462
+
463
+ def build_dataloaders(
464
+ tokenizer: PreTrainedTokenizerFast,
465
+ domains: List[DomainConfig],
466
+ block_size: int,
467
+ train_batch_size: int,
468
+ eval_batch_size: int = 16,
469
+ max_eval_docs_per_dom: int = 300,
470
+ num_workers: int = 4,
471
+ seed: int = 42,
472
+ ) -> Tuple[DataLoader, DataLoader]:
473
+ train_ds = MultiDomainPackedDataset(
474
+ tokenizer, domains, block_size,
475
+ max_docs_per_domain=None,
476
+ seed=seed,
477
+ )
478
+ eval_ds = MultiDomainPackedDataset(
479
+ tokenizer, domains, block_size,
480
+ max_docs_per_domain=max_eval_docs_per_dom,
481
+ seed=seed + 9999,
482
+ )
483
+
484
+ def collate_fn(batch):
485
+ return {
486
+ "input_ids": torch.stack([b["input_ids"] for b in batch]),
487
+ "labels": torch.stack([b["labels"] for b in batch]),
488
+ "domain": [b["domain"] for b in batch],
489
+ }
490
+
491
+ train_loader = DataLoader(
492
+ train_ds, batch_size=train_batch_size,
493
+ num_workers=num_workers, pin_memory=True,
494
+ prefetch_factor=2, collate_fn=collate_fn,
495
+ )
496
+ eval_loader = DataLoader(
497
+ eval_ds, batch_size=eval_batch_size,
498
+ num_workers=max(1, num_workers // 2), pin_memory=True,
499
+ prefetch_factor=2, collate_fn=collate_fn,
500
+ )
501
+ return train_loader, eval_loader
502
+
503
+
504
+ # ══════════════════════════════════════════════════════════════════
505
+ # §4 TOKENIZER BPE 32k
506
+ # ══════════════════════════════════════════════════════════════════
507
+
508
+ def train_or_load_tokenizer(
509
+ domains: List[DomainConfig],
510
+ ) -> PreTrainedTokenizerFast:
511
+ TOKENIZER_DIR.mkdir(parents=True, exist_ok=True)
512
+ tok_json = TOKENIZER_DIR / "tokenizer.json"
513
+ tok_cfg = TOKENIZER_DIR / "tokenizer_config.json"
514
+
515
+ if tok_json.exists() and tok_cfg.exists():
516
+ print("Tokenizer existant chargé depuis le cache.")
517
+ return PreTrainedTokenizerFast.from_pretrained(str(TOKENIZER_DIR))
518
+
519
+ if is_main():
520
+ print("Entraînement tokenizer BPE 32k…")
521
+
522
+ def _iter_sample() -> Iterator[str]:
523
+ n_per_domain = max(1, TOKENIZER_SAMPLE_DOCS // len(domains))
524
+ for domain in domains:
525
+ for text in domain_text_stream(domain, max_docs=n_per_domain):
526
+ yield text
527
+
528
+ tok = Tokenizer(models.BPE(unk_token=UNK_TOKEN))
529
+ tok.normalizer = normalizers.Sequence([normalizers.NFKC()])
530
+ tok.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
531
+ tok.decoder = decoders.ByteLevel()
532
+
533
+ trainer = trainers.BpeTrainer(
534
+ vocab_size=TOKENIZER_VOCAB,
535
+ min_frequency=2,
536
+ show_progress=is_main(),
537
+ special_tokens=SPECIAL_TOKENS,
538
+ initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
539
+ )
540
+ tok.train_from_iterator(_iter_sample(), trainer=trainer)
541
+
542
+ bos_id = tok.token_to_id(BOS_TOKEN)
543
+ eos_id = tok.token_to_id(EOS_TOKEN)
544
+ tok.post_processor = processors.TemplateProcessing(
545
+ single=f"{BOS_TOKEN} $A {EOS_TOKEN}",
546
+ pair=f"{BOS_TOKEN} $A {EOS_TOKEN} $B:1 {EOS_TOKEN}:1",
547
+ special_tokens=[(BOS_TOKEN, bos_id), (EOS_TOKEN, eos_id)],
548
+ )
549
+ tok.save(str(tok_json))
550
+
551
+ fast = PreTrainedTokenizerFast(
552
+ tokenizer_file=str(tok_json),
553
+ bos_token=BOS_TOKEN, eos_token=EOS_TOKEN,
554
+ unk_token=UNK_TOKEN, pad_token=PAD_TOKEN,
555
+ )
556
+ fast.save_pretrained(str(TOKENIZER_DIR))
557
+
558
+ smap = TOKENIZER_DIR / "special_tokens_map.json"
559
+ if not smap.exists():
560
+ smap.write_text(json.dumps({
561
+ "bos_token": BOS_TOKEN, "eos_token": EOS_TOKEN,
562
+ "unk_token": UNK_TOKEN, "pad_token": PAD_TOKEN,
563
+ }, indent=2), encoding="utf-8")
564
+
565
+ return PreTrainedTokenizerFast.from_pretrained(str(TOKENIZER_DIR))
566
+
567
+
568
+ # ══════════════════════════════════════════════════════════════════
569
+ # §5 ARCHITECTURE GPT
570
+ # ══════════════════════════════════════════════════════════════════
571
+
572
+ @dataclass
573
+ class GPTConfig:
574
+ vocab_size: int = VOCAB_SIZE
575
+ block_size: int = BLOCK_SIZE
576
+ d_model: int = D_MODEL
577
+ n_heads: int = N_HEADS
578
+ n_layers: int = N_LAYERS
579
+ d_ff: int = D_FF
580
+ dropout: float = DROPOUT
581
+ use_checkpointing: bool = USE_CHECKPOINTING
582
+
583
+
584
+ class RMSNorm(nn.Module):
585
+ def __init__(self, dim: int, eps: float = 1e-6):
586
+ super().__init__()
587
+ self.weight = nn.Parameter(torch.ones(dim))
588
+ self.eps = eps
589
+
590
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
591
+ return self.weight * x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
592
+
593
+
594
+ class RotaryEmbedding(nn.Module):
595
+ def __init__(self, dim: int, base: int = 10_000, max_seq: int = 4096):
596
+ super().__init__()
597
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
598
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
599
+ t = torch.arange(max_seq).float()
600
+ freqs = torch.outer(t, inv_freq)
601
+ self.register_buffer("cos_cache", torch.repeat_interleave(freqs.cos(), 2, dim=-1), persistent=False)
602
+ self.register_buffer("sin_cache", torch.repeat_interleave(freqs.sin(), 2, dim=-1), persistent=False)
603
+
604
+ def forward(self, seq_len: int, dtype: torch.dtype) -> Tuple[torch.Tensor, torch.Tensor]:
605
+ return self.cos_cache[:seq_len].to(dtype), self.sin_cache[:seq_len].to(dtype)
606
+
607
+
608
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
609
+ return torch.stack((-x[..., 1::2], x[..., ::2]), dim=-1).flatten(-2)
610
+
611
+ def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
612
+ return x * cos.unsqueeze(0).unsqueeze(0) + rotate_half(x) * sin.unsqueeze(0).unsqueeze(0)
613
+
614
+
615
+ class CausalSelfAttention(nn.Module):
616
+ def __init__(self, cfg: GPTConfig):
617
+ super().__init__()
618
+ assert cfg.d_model % cfg.n_heads == 0
619
+ self.n_heads = cfg.n_heads
620
+ self.head_dim = cfg.d_model // cfg.n_heads
621
+ self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
622
+ self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
623
+ self.dropout_p = cfg.dropout
624
+ self.rope = RotaryEmbedding(self.head_dim)
625
+
626
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
627
+ b, t, c = x.shape
628
+ q, k, v = self.qkv(x).split(c, dim=-1)
629
+ q = q.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
630
+ k = k.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
631
+ v = v.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
632
+
633
+ cos, sin = self.rope(t, x.dtype)
634
+ q, k = apply_rope(q, cos, sin), apply_rope(k, cos, sin)
635
+
636
+ # Flash Attention via SDPA (PyTorch ≥2.0, natif H100)
637
+ y = F.scaled_dot_product_attention(
638
+ q, k, v,
639
+ dropout_p=self.dropout_p if self.training else 0.0,
640
+ is_causal=True,
641
+ )
642
+ return self.proj(y.transpose(1, 2).contiguous().view(b, t, c))
643
+
644
+
645
+ class SwiGLU(nn.Module):
646
+ def __init__(self, cfg: GPTConfig):
647
+ super().__init__()
648
+ self.w1 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
649
+ self.w2 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
650
+ self.w3 = nn.Linear(cfg.d_ff, cfg.d_model, bias=False)
651
+
652
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
653
+ return self.w3(F.silu(self.w1(x)) * self.w2(x))
654
+
655
+
656
+ class Block(nn.Module):
657
+ def __init__(self, cfg: GPTConfig):
658
+ super().__init__()
659
+ self.ln1 = RMSNorm(cfg.d_model)
660
+ self.attn = CausalSelfAttention(cfg)
661
+ self.ln2 = RMSNorm(cfg.d_model)
662
+ self.ff = SwiGLU(cfg)
663
+
664
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
665
+ x = x + self.attn(self.ln1(x))
666
+ x = x + self.ff(self.ln2(x))
667
+ return x
668
+
669
+
670
+ class GPT(nn.Module):
671
+ def __init__(self, cfg: GPTConfig):
672
+ super().__init__()
673
+ self.cfg = cfg
674
+ self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
675
+ self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)])
676
+ self.ln_f = RMSNorm(cfg.d_model)
677
+ self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
678
+ self.lm_head.weight = self.tok_emb.weight # weight tying
679
+
680
+ self.apply(self._init_weights)
681
+
682
+ @staticmethod
683
+ def _init_weights(m: nn.Module) -> None:
684
+ if isinstance(m, (nn.Linear, nn.Embedding)):
685
+ nn.init.normal_(m.weight, mean=0.0, std=0.02)
686
+ if isinstance(m, nn.Linear) and m.bias is not None:
687
+ nn.init.zeros_(m.bias)
688
+
689
+ def forward(
690
+ self,
691
+ input_ids: torch.Tensor,
692
+ labels: Optional[torch.Tensor] = None,
693
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
694
+ x = self.tok_emb(input_ids)
695
+ for block in self.blocks:
696
+ if self.cfg.use_checkpointing and self.training:
697
+ x = torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False)
698
+ else:
699
+ x = block(x)
700
+ logits = self.lm_head(self.ln_f(x))
701
+ loss = None
702
+ if labels is not None:
703
+ loss = F.cross_entropy(
704
+ logits.reshape(-1, logits.size(-1)),
705
+ labels.reshape(-1),
706
+ ignore_index=-100,
707
+ )
708
+ return logits, loss
709
+
710
+
711
+ # ══════════════════════════════════════════════════════════════════
712
+ # §6 OPTIMIZER & LR SCHEDULE
713
+ # ══════════════════════════════════════════════════════════════════
714
+
715
+ def build_optimizer(model: nn.Module) -> torch.optim.Optimizer:
716
+ decay, no_decay = [], []
717
+ for name, p in model.named_parameters():
718
+ if not p.requires_grad:
719
+ continue
720
+ (decay if p.ndim >= 2 and "weight" in name else no_decay).append(p)
721
+ return torch.optim.AdamW(
722
+ [
723
+ {"params": decay, "weight_decay": WEIGHT_DECAY},
724
+ {"params": no_decay, "weight_decay": 0.0},
725
+ ],
726
+ lr=LEARNING_RATE, betas=(0.9, 0.95), eps=1e-8,
727
+ fused=True, # kernel unique GPU → +10-15% sur H100
728
+ )
729
+
730
+
731
+ def cosine_lr(step: int) -> float:
732
+ if step < WARMUP_STEPS:
733
+ return LEARNING_RATE * step / max(1, WARMUP_STEPS)
734
+ p = min(1.0, (step - WARMUP_STEPS) / max(1, MAX_STEPS - WARMUP_STEPS))
735
+ return MIN_LR + 0.5 * (LEARNING_RATE - MIN_LR) * (1.0 + math.cos(math.pi * p))
736
+
737
+
738
+ # ══════════════════════════════════════════════════════════════════
739
+ # §7 CHECKPOINT
740
+ # ══════════════════════════════════════════════════════════════════
741
+
742
+ def save_checkpoint(
743
+ model: nn.Module,
744
+ optimizer: torch.optim.Optimizer,
745
+ step: int,
746
+ best_loss: float,
747
+ path: Path,
748
+ ) -> None:
749
+ raw = model.module if isinstance(model, DDP) else model
750
+ torch.save({
751
+ "model": raw.state_dict(),
752
+ "optimizer": optimizer.state_dict(),
753
+ "step": step,
754
+ "best_loss": best_loss,
755
+ "config": asdict(raw.cfg),
756
+ }, path)
757
+
758
+
759
+ def load_checkpoint(
760
+ model: nn.Module,
761
+ optimizer: torch.optim.Optimizer,
762
+ path: Path,
763
+ device: torch.device,
764
+ ) -> Tuple[int, float]:
765
+ ckpt = torch.load(path, map_location=device)
766
+ raw = model.module if isinstance(model, DDP) else model
767
+ raw.load_state_dict(ckpt["model"])
768
+ optimizer.load_state_dict(ckpt["optimizer"])
769
+ return int(ckpt.get("step", 0)), float(ckpt.get("best_loss", 1e9))
770
+
771
+
772
+ # ══════════════════════════════════════════════════════════════════
773
+ # §8 DDP HELPERS
774
+ # ══════════════════════════════════════════════════════════════════
775
+
776
+ def init_distributed() -> Optional[torch.device]:
777
+ local_rank = int(os.environ.get("LOCAL_RANK", -1))
778
+ if local_rank == -1:
779
+ return None
780
+ dist.init_process_group("nccl")
781
+ torch.cuda.set_device(local_rank)
782
+ return torch.device(f"cuda:{local_rank}")
783
+
784
+ def is_distributed() -> bool:
785
+ return dist.is_available() and dist.is_initialized()
786
+
787
+ def get_rank() -> int:
788
+ return dist.get_rank() if is_distributed() else 0
789
+
790
+ def get_world_size() -> int:
791
+ return dist.get_world_size() if is_distributed() else 1
792
+
793
+ def is_main() -> bool:
794
+ return get_rank() == 0
795
+
796
+
797
+ # ══════════════════════════════════════════════════════════════════
798
+ # §9 ÉVALUATION
799
+ # ══════════════════════════════════════════════════════════════════
800
+
801
+ @torch.no_grad()
802
+ def evaluate(
803
+ model: nn.Module,
804
+ loader: DataLoader,
805
+ device: torch.device,
806
+ max_batches: int = 80,
807
+ ) -> Tuple[float, Dict[str, float]]:
808
+ model.eval()
809
+ total_loss, total_n = 0.0, 0
810
+ domain_losses: Dict[str, list] = {}
811
+
812
+ for i, batch in enumerate(loader):
813
+ if i >= max_batches:
814
+ break
815
+ inp = batch["input_ids"].to(device)
816
+ lbl = batch["labels"].to(device)
817
+ domains_batch = batch["domain"]
818
+
819
+ with torch.autocast("cuda", dtype=DTYPE):
820
+ _, loss = model(inp, lbl)
821
+
822
+ lv = loss.item()
823
+ total_loss += lv
824
+ total_n += 1
825
+
826
+ for dom in domains_batch:
827
+ domain_losses.setdefault(dom, []).append(lv)
828
+
829
+ model.train()
830
+ global_loss = total_loss / max(1, total_n)
831
+ per_domain = {k: sum(v) / len(v) for k, v in domain_losses.items()}
832
+ return global_loss, per_domain
833
+
834
+
835
+ # ══════════════════════════════════════════════════════════════════
836
+ # §10 LOGGING
837
+ # ══════════════════════════════════════════════════════════════════
838
+
839
+ def log_jsonl(path: Path, record: dict) -> None:
840
+ with open(path, "a", encoding="utf-8") as f:
841
+ f.write(json.dumps(record, ensure_ascii=False) + "\n")
842
+
843
+
844
+ # ══════════════════════════════════════════════════════════════════
845
+ # §11 MAIN
846
+ # ══════════════════════════════════════════════════════════════════
847
+
848
+ def main() -> None:
849
+
850
+ # ── DDP init ──────────────────────────────────────────────────
851
+ ddp_device = init_distributed()
852
+ set_seed_fn(SEED + get_rank())
853
+ device = ddp_device if ddp_device else torch.device(
854
+ "cuda" if torch.cuda.is_available() else "cpu"
855
+ )
856
+
857
+ # ── Optimisations H100 ────────────────────────────────────────
858
+ torch.backends.cuda.matmul.allow_tf32 = True
859
+ torch.backends.cudnn.allow_tf32 = True
860
+
861
+ if is_main():
862
+ print("=" * 60)
863
+ print(" LLM Multi-Domaine — H100 Training [v2 CORRIGÉ]")
864
+ print("=" * 60)
865
+ gpu = torch.cuda.get_device_name(device) if device.type == "cuda" else "CPU"
866
+ print(f"Device : {device} ({gpu})")
867
+ print(f"GPUs : {get_world_size()}")
868
+ print(f"Steps : {MAX_STEPS} (~2 epochs sur corpus réduit)")
869
+ print_domain_summary()
870
+
871
+ # ── Tokenizer ─────────────────────────────────────────────────
872
+ tokenizer = train_or_load_tokenizer(DOMAINS)
873
+ vocab_size = len(tokenizer)
874
+ if is_main():
875
+ print(f"Tokenizer : {TOKENIZER_DIR} | vocab={vocab_size}")
876
+
877
+ # ── Modèle ────────────────────────────────────────────────────
878
+ cfg = GPTConfig(vocab_size=vocab_size)
879
+ if is_main():
880
+ CONFIG_FILE.write_text(
881
+ json.dumps(asdict(cfg), indent=2, ensure_ascii=False),
882
+ encoding="utf-8",
883
+ )
884
+
885
+ model = GPT(cfg).to(device)
886
+
887
+ if USE_COMPILE and device.type == "cuda":
888
+ model = torch.compile(model, mode="reduce-overhead")
889
+ if is_main():
890
+ print("torch.compile : reduce-overhead ✓")
891
+
892
+ if is_distributed():
893
+ model = DDP(model, device_ids=[device.index])
894
+
895
+ optimizer = build_optimizer(model)
896
+
897
+ # ── Reprise depuis checkpoint ──────────────────────────────────
898
+ start_step, best_eval = 0, 1e9
899
+ if STATE_FILE.exists():
900
+ if is_main():
901
+ print(f"Reprise depuis {STATE_FILE}")
902
+ start_step, best_eval = load_checkpoint(model, optimizer, STATE_FILE, device)
903
+ if is_main():
904
+ print(f" → reprise à step {start_step}, best_loss={best_eval:.4f}")
905
+
906
+ # ── DataLoaders ───────────────────────────────────────────────
907
+ train_loader, eval_loader = build_dataloaders(
908
+ tokenizer = tokenizer,
909
+ domains = DOMAINS,
910
+ block_size = BLOCK_SIZE,
911
+ train_batch_size = BATCH_SIZE,
912
+ eval_batch_size = max(1, BATCH_SIZE // 2),
913
+ max_eval_docs_per_dom = MAX_EVAL_DOCS_DOM,
914
+ num_workers = 4,
915
+ seed = SEED,
916
+ )
917
+
918
+ # ── Résumé ────────────────────────────────────────────────────
919
+ if is_main():
920
+ raw = model.module if isinstance(model, DDP) else model
921
+ # comptage sans modules compilés
922
+ try:
923
+ n_params = sum(p.numel() for p in raw.parameters() if p.requires_grad)
924
+ except Exception:
925
+ n_params = -1
926
+ eff_batch = BATCH_SIZE * GRAD_ACCUM_STEPS * get_world_size()
927
+ print(f"Paramètres : {n_params/1e6:.1f}M" if n_params > 0 else "Paramètres : N/A (compilé)")
928
+ print(f"Architecture : d={D_MODEL} | heads={N_HEADS} | layers={N_LAYERS} | ctx={BLOCK_SIZE}")
929
+ print(f"Batch effectif: {eff_batch} séq × {BLOCK_SIZE} tok = {eff_batch*BLOCK_SIZE:,} tok/step")
930
+ print(f"Dtype : {DTYPE} | Steps : {MAX_STEPS} | Warmup : {WARMUP_STEPS}")
931
+ print("=" * 60)
932
+
933
+ # ── Boucle d'entraînement ─────────────────────────────────────
934
+ model.train()
935
+ optimizer.zero_grad(set_to_none=True)
936
+
937
+ train_iter = iter(train_loader)
938
+ step = start_step
939
+ t0 = time.time()
940
+ log_loss_sum = 0.0
941
+ log_loss_n = 0
942
+ tokens_log = 0
943
+ last_log = time.time()
944
+
945
+ while step < MAX_STEPS:
946
+
947
+ # ── gradient accumulation ──────────────────────────────────
948
+ for micro in range(GRAD_ACCUM_STEPS):
949
+ try:
950
+ batch = next(train_iter)
951
+ except StopIteration:
952
+ # ← CORRECTIF : relance l'itérateur proprement
953
+ train_iter = iter(train_loader)
954
+ try:
955
+ batch = next(train_iter)
956
+ except StopIteration:
957
+ print("[WARN] Dataset entièrement épuisé avant MAX_STEPS.")
958
+ break
959
+
960
+ inp = batch["input_ids"].to(device, non_blocking=True)
961
+ lbl = batch["labels"].to(device, non_blocking=True)
962
+
963
+ with torch.autocast("cuda", dtype=DTYPE):
964
+ _, loss = model(inp, lbl)
965
+
966
+ (loss / GRAD_ACCUM_STEPS).backward()
967
+ log_loss_sum += loss.item()
968
+ log_loss_n += 1
969
+ tokens_log += inp.numel()
970
+
971
+ # ── optimizer step ────────────────────────────────────────
972
+ lr = cosine_lr(step)
973
+ for g in optimizer.param_groups:
974
+ g["lr"] = lr
975
+
976
+ torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
977
+ optimizer.step()
978
+ optimizer.zero_grad(set_to_none=True)
979
+ step += 1
980
+
981
+ # ── logging toutes les 50 steps ───────────────────────────
982
+ if step % 50 == 0 and is_main():
983
+ now = time.time()
984
+ elapsed = max(1e-6, now - last_log)
985
+ tok_s = tokens_log / elapsed
986
+ avg_l = log_loss_sum / max(1, log_loss_n)
987
+ ppl = math.exp(min(avg_l, 20))
988
+ print(
989
+ f"step {step:5d}/{MAX_STEPS} | "
990
+ f"loss={avg_l:.4f} | ppl={ppl:.1f} | "
991
+ f"lr={lr:.2e} | {tok_s:,.0f} tok/s"
992
+ )
993
+ log_jsonl(LOG_FILE, {
994
+ "step": step, "loss": avg_l, "ppl": ppl,
995
+ "lr": lr, "tok_s": tok_s, "time": now - t0,
996
+ })
997
+ last_log = now
998
+ tokens_log = 0
999
+ log_loss_sum = 0.0
1000
+ log_loss_n = 0
1001
+
1002
+ # ── évaluation ────────────────────────────────────────────
1003
+ if step % EVAL_EVERY == 0 and is_main():
1004
+ val_loss, per_dom = evaluate(model, eval_loader, device)
1005
+ ppl_val = math.exp(min(val_loss, 20))
1006
+ print(f"\n[eval] step {step} | val_loss={val_loss:.4f} | ppl={ppl_val:.1f}")
1007
+ print(" Perplexité par domaine :")
1008
+ for dom, dl in sorted(per_dom.items(), key=lambda x: -x[1]):
1009
+ print(f" {dom:<25} loss={dl:.4f} ppl={math.exp(min(dl,20)):.1f}")
1010
+ print()
1011
+
1012
+ log_jsonl(LOG_FILE, {
1013
+ "step": step, "val_loss": val_loss, "val_ppl": ppl_val,
1014
+ "per_domain": per_dom,
1015
+ })
1016
+
1017
+ if val_loss < best_eval:
1018
+ best_eval = val_loss
1019
+ save_checkpoint(model, optimizer, step, best_eval, BEST_MODEL_FILE)
1020
+ print(f" ✓ Meilleur modèle → {BEST_MODEL_FILE}\n")
1021
+
1022
+ # ── checkpoint périodique ─────────────────────────────────
1023
+ if step % SAVE_EVERY == 0 and is_main():
1024
+ save_checkpoint(model, optimizer, step, best_eval, STATE_FILE)
1025
+ save_checkpoint(model, optimizer, step, best_eval, MODEL_FILE)
1026
+ print(f" ✓ Checkpoint step {step} → {MODEL_FILE}")
1027
+
1028
+ # ── Fin ───────────────────────────────────────────────────────
1029
+ if is_main():
1030
+ save_checkpoint(model, optimizer, step, best_eval, MODEL_FILE)
1031
+ save_checkpoint(model, optimizer, step, best_eval, STATE_FILE)
1032
+ total_min = (time.time() - t0) / 60
1033
+ print(f"\n{'='*60}")
1034
+ print(f"Modèle final → {MODEL_FILE}")
1035
+ print(f"Meilleur modèle→ {BEST_MODEL_FILE}")
1036
+ print(f"Steps réalisés : {step}")
1037
+ print(f"Temps total : {total_min:.1f} min")
1038
+ print(f"{'='*60}")
1039
+
1040
+ if is_distributed():
1041
+ dist.destroy_process_group()
1042
+
1043
+
1044
+ # ══════════════════════════════════════════════════════════════════
1045
+ # §12 UTILS
1046
+ # ══════════════════════════════════════════════════════════════════
1047
+
1048
+ def set_seed_fn(seed: int) -> None:
1049
+ random.seed(seed)
1050
+ torch.manual_seed(seed)
1051
+ torch.cuda.manual_seed_all(seed)
1052
+
1053
+
1054
+ if __name__ == "__main__":
1055
+ main()
train_nlp_h100_maxvram_v6.py ADDED
@@ -0,0 +1,1046 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ train_nlp_h100_maxvram.py — v6 (démarrage sûr sur H100 80G)
5
+ ============================================================
6
+
7
+ Corrections principales vs v5
8
+ -----------------------------
9
+ 1. Le log v5 montre que, même sans cap VRAM logiciel,
10
+ BATCH_SIZE=32 provoque un vrai OOM matériel dès le premier
11
+ forward compilé sur H100 80G.
12
+
13
+ 2. On abaisse donc le réglage de départ à :
14
+ BATCH_SIZE = 28
15
+ pour démarrer dans une zone sûre, puis remonter ensuite si le
16
+ log montre encore de la marge.
17
+
18
+ 3. Le cap logiciel reste désactivé par défaut :
19
+ TARGET_VRAM_GIB = None
20
+ afin d'éviter tout faux OOM dû à PyTorch.
21
+
22
+ 4. Le mode torch.compile reste sur :
23
+ COMPILE_MODE = "max-autotune-no-cudagraphs"
24
+ qui garde un bon compromis perf / mémoire sans surcoût CUDA graphs.
25
+
26
+ 5. BASE_CHECKPOINT est chargé AVANT l'injection LoRA,
27
+ et l'estimation des steps reste corrigée pour le scheduler LR.
28
+
29
+ Conseils de réglage
30
+ -------------------
31
+ - Démarre avec :
32
+ BATCH_SIZE = 28
33
+ TARGET_VRAM_GIB = None
34
+ COMPILE_MODE = "max-autotune-no-cudagraphs"
35
+ - Si stable et max_reserved < 72 GiB après quelques logs :
36
+ BATCH_SIZE += 2
37
+ - Si vrai OOM matériel :
38
+ BATCH_SIZE -= 2 puis relance
39
+ - Si tu veux encore plus de marge au premier essai :
40
+ BATCH_SIZE = 24
41
+ """
42
+
43
+ from __future__ import annotations
44
+
45
+ import itertools
46
+ import json
47
+ import math
48
+ import os
49
+ import random
50
+ import time
51
+ from collections import OrderedDict
52
+ from contextlib import nullcontext
53
+ from dataclasses import asdict, dataclass
54
+ from pathlib import Path
55
+ from typing import Iterator, Optional
56
+
57
+ # A définir AVANT import torch
58
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
59
+
60
+ import torch
61
+ import torch.distributed as dist
62
+ import torch.nn as nn
63
+ import torch.nn.functional as F
64
+
65
+ try:
66
+ import bitsandbytes as bnb
67
+ HAS_BNB = True
68
+ except ImportError:
69
+ HAS_BNB = False
70
+ print("[warn] bitsandbytes non disponible")
71
+
72
+ try:
73
+ from flash_attn import flash_attn_func
74
+ HAS_FLASH = True
75
+ except ImportError:
76
+ HAS_FLASH = False
77
+ print("[warn] flash-attn non disponible – fallback SDPA (toujours fusionné sur H100)")
78
+
79
+ from datasets import load_dataset
80
+ from torch.nn.parallel import DistributedDataParallel as DDP
81
+ from tokenizers import (
82
+ Tokenizer, decoders, models, normalizers,
83
+ pre_tokenizers, processors, trainers,
84
+ )
85
+ from transformers import PreTrainedTokenizerFast
86
+
87
+
88
+ # ╔══════════════════════════════════════════════════════════════════════════════╗
89
+ # ║ CHEMINS ║
90
+ # ╚══════════════════════════════════════════════════════════════════════════════╝
91
+
92
+ OUT_DIR = Path("./nlp_1b_h100_maxvram")
93
+ OUT_DIR.mkdir(parents=True, exist_ok=True)
94
+ TOKENIZER_DIR = OUT_DIR / "tokenizer_32k"
95
+ CONFIG_FILE = OUT_DIR / "config.json"
96
+ MODEL_FILE = OUT_DIR / "model.pt"
97
+ BEST_MODEL_FILE = OUT_DIR / "model_best.pt"
98
+ STATE_FILE = OUT_DIR / "train_state.pt"
99
+ BASE_CHECKPOINT: Optional[Path] = None
100
+
101
+
102
+ # ╔══════════════════════════════════════════════════════════════════════════════╗
103
+ # ║ HYPERPARAMÈTRES — H100 ║
104
+ # ╚══════════════════════════════════════════════════════════════════════════════╝
105
+
106
+ SEED = 42
107
+ TARGET_VRAM_GIB = None # None = pas de cap logiciel ; évite les faux OOM dus au cap PyTorch
108
+
109
+ # ── Architecture ~1B ──────────────────────────────────────────────────────────
110
+ BLOCK_SIZE = 2048
111
+ VOCAB_SIZE = 32_000
112
+ D_MODEL = 1536
113
+ N_HEADS = 24
114
+ N_LAYERS = 24
115
+ D_FF = 6144
116
+ DROPOUT = 0.0
117
+
118
+ # ── LoRA / "QLoRA" si BNB dispo côté optimiseur ──────────────────────────────
119
+ USE_QLORA = True
120
+ LORA_R = 64
121
+ LORA_ALPHA = 128
122
+ LORA_DROPOUT = 0.05
123
+ LORA_TARGET_MODULES = ["qkv", "proj", "w1", "w2", "w3"]
124
+
125
+ # ── Entraînement ──────────────────────────────────────────────────────────────
126
+ NUM_EPOCHS = 10
127
+ LEARNING_RATE = 3e-4
128
+ MIN_LR = 3e-5
129
+ WEIGHT_DECAY = 0.1
130
+ WARMUP_STEPS = 500
131
+
132
+ # Réglage de départ conseillé
133
+ BATCH_SIZE = 28
134
+ GRAD_ACCUM_STEPS = 1
135
+
136
+ MAX_GRAD_NORM = 1.0
137
+ EVAL_EVERY = 500
138
+ SAVE_EVERY = 1_000
139
+
140
+ DTYPE = torch.bfloat16
141
+
142
+ # Compile : version plus robuste au démarrage
143
+ USE_CHECKPOINTING = False
144
+ USE_COMPILE = True
145
+ COMPILE_MODE = "max-autotune-no-cudagraphs"
146
+
147
+ # ── DataLoader ────────────────────────────────────────────────────────────────
148
+ TRAIN_NUM_WORKERS = 8
149
+ EVAL_NUM_WORKERS = 4
150
+ PREFETCH_FACTOR = 4
151
+
152
+ # ── Textes ────────────────────────────────────────────────────────────────────
153
+ TOKENIZER_SAMPLE_DOCS_PER_SOURCE = 15_000
154
+ TOKENIZER_CHAR_LIMIT = 2_000
155
+ TEXT_CHAR_LIMIT = 8_000
156
+
157
+ SPECIAL_TOKENS = ["<pad>", "<bos>", "<eos>", "<unk>"]
158
+ PAD_TOKEN, BOS_TOKEN, EOS_TOKEN, UNK_TOKEN = SPECIAL_TOKENS
159
+
160
+ WIKI_CONFIGS = ["20231101.en", "20231101.fr", "20231101.ar"]
161
+ FINEWEB_CONFIG = "sample-10BT"
162
+ DEV_DOCS_PER_WIKI_CONFIG = 1_500
163
+ DEV_DOCS_FINEWEB = 3_000
164
+ TRAIN_DOCS_PER_WIKI_CONFIG = 30_000
165
+ TRAIN_DOCS_FINEWEB = 60_000
166
+
167
+
168
+ # ╔══════════════════════════════════════════════════════════════════════════════╗
169
+ # ║ DISTRIBUTED ║
170
+ # ╚══════════════════════════════════════════════════════════════════════════════╝
171
+
172
+ def is_distributed() -> bool:
173
+ return dist.is_available() and dist.is_initialized()
174
+
175
+ def get_rank() -> int:
176
+ return dist.get_rank() if is_distributed() else 0
177
+
178
+ def get_world_size() -> int:
179
+ return dist.get_world_size() if is_distributed() else 1
180
+
181
+ def is_main() -> bool:
182
+ return get_rank() == 0
183
+
184
+ def init_distributed() -> Optional[torch.device]:
185
+ lr = int(os.environ.get("LOCAL_RANK", -1))
186
+ if lr == -1:
187
+ return None
188
+ dist.init_process_group("nccl")
189
+ torch.cuda.set_device(lr)
190
+ return torch.device(f"cuda:{lr}")
191
+
192
+ def set_seed(seed: int) -> None:
193
+ random.seed(seed)
194
+ torch.manual_seed(seed)
195
+ if torch.cuda.is_available():
196
+ torch.cuda.manual_seed_all(seed)
197
+
198
+ def get_device(ddp=None) -> torch.device:
199
+ if ddp is not None:
200
+ return ddp
201
+ if torch.cuda.is_available():
202
+ return torch.device(f"cuda:{torch.cuda.current_device()}")
203
+ return torch.device("cpu")
204
+
205
+ def current_cuda_index(device: torch.device) -> int:
206
+ return device.index if device.index is not None else torch.cuda.current_device()
207
+
208
+ def autocast_context(device: torch.device):
209
+ return torch.autocast("cuda", dtype=DTYPE) if device.type == "cuda" else nullcontext()
210
+
211
+ def unwrap_model(model: nn.Module) -> nn.Module:
212
+ m = model.module if isinstance(model, DDP) else model
213
+ return m._orig_mod if hasattr(m, "_orig_mod") else m
214
+
215
+ def count_parameters(model: nn.Module, trainable_only: bool = True) -> int:
216
+ return sum(p.numel() for p in model.parameters() if not trainable_only or p.requires_grad)
217
+
218
+ def normalize_state_dict_keys(sd: dict) -> OrderedDict:
219
+ out = OrderedDict()
220
+ for k, v in sd.items():
221
+ for prefix in ("module._orig_mod.", "_orig_mod.", "module."):
222
+ if k.startswith(prefix):
223
+ k = k[len(prefix):]
224
+ break
225
+ out[k] = v
226
+ return out
227
+
228
+ def normalize_text(t: str) -> str:
229
+ return " ".join(t.strip().split())
230
+
231
+ def safe_str(x) -> str:
232
+ return x if isinstance(x, str) else ("" if x is None else str(x))
233
+
234
+
235
+ # ╔══════════════════════════════════════════════════════════════════════════════╗
236
+ # ║ DATASETS ║
237
+ # ╚══════════════════════════════════════════════════════════════════════════════╝
238
+
239
+ def load_wiki_stream(cfg_name: str):
240
+ return load_dataset("wikimedia/wikipedia", cfg_name, split="train", streaming=True)
241
+
242
+ def load_fineweb_stream():
243
+ return load_dataset("HuggingFaceFW/fineweb-edu", FINEWEB_CONFIG, split="train", streaming=True)
244
+
245
+ def stream_texts(ds, start: int, count: int, char_limit: int) -> Iterator[str]:
246
+ for row in itertools.islice(ds, start, start + count):
247
+ text = normalize_text(safe_str(row.get("text", "")))
248
+ if len(text) >= 20:
249
+ yield text[:char_limit]
250
+
251
+ def tokenizer_training_iterator() -> Iterator[str]:
252
+ for c in WIKI_CONFIGS:
253
+ yield from stream_texts(
254
+ load_wiki_stream(c),
255
+ 0,
256
+ TOKENIZER_SAMPLE_DOCS_PER_SOURCE,
257
+ TOKENIZER_CHAR_LIMIT,
258
+ )
259
+ yield from stream_texts(
260
+ load_fineweb_stream(),
261
+ 0,
262
+ TOKENIZER_SAMPLE_DOCS_PER_SOURCE,
263
+ TOKENIZER_CHAR_LIMIT,
264
+ )
265
+
266
+ def build_epoch_train_texts(epoch: int) -> list[str]:
267
+ texts: list[str] = []
268
+ for c in WIKI_CONFIGS:
269
+ start = DEV_DOCS_PER_WIKI_CONFIG + epoch * TRAIN_DOCS_PER_WIKI_CONFIG
270
+ texts.extend(stream_texts(load_wiki_stream(c), start, TRAIN_DOCS_PER_WIKI_CONFIG, TEXT_CHAR_LIMIT))
271
+ start = DEV_DOCS_FINEWEB + epoch * TRAIN_DOCS_FINEWEB
272
+ texts.extend(stream_texts(load_fineweb_stream(), start, TRAIN_DOCS_FINEWEB, TEXT_CHAR_LIMIT))
273
+ random.Random(SEED + epoch).shuffle(texts)
274
+ return texts
275
+
276
+ def build_eval_texts() -> list[str]:
277
+ texts: list[str] = []
278
+ for c in WIKI_CONFIGS:
279
+ texts.extend(stream_texts(load_wiki_stream(c), 0, DEV_DOCS_PER_WIKI_CONFIG, TEXT_CHAR_LIMIT))
280
+ texts.extend(stream_texts(load_fineweb_stream(), 0, DEV_DOCS_FINEWEB, TEXT_CHAR_LIMIT))
281
+ return texts
282
+
283
+ def estimate_steps_per_epoch(
284
+ texts: list[str],
285
+ tokenizer: PreTrainedTokenizerFast,
286
+ block_size: int,
287
+ batch_size: int,
288
+ sample_size: int = 512,
289
+ ) -> int:
290
+ """
291
+ Estimation des steps réels pour le scheduler LR.
292
+ Chaque texte contribue environ len(ids)+2 tokens (BOS/EOS).
293
+ Un exemple packed consomme block_size+1 tokens.
294
+ """
295
+ if not texts:
296
+ return 1
297
+
298
+ rng = random.Random(SEED)
299
+ if len(texts) > sample_size:
300
+ sample = rng.sample(texts, sample_size)
301
+ else:
302
+ sample = texts
303
+
304
+ total_tokens = 0
305
+ valid = 0
306
+ for txt in sample:
307
+ ids = tokenizer.encode(txt, add_special_tokens=False)
308
+ if ids:
309
+ total_tokens += len(ids) + 2
310
+ valid += 1
311
+
312
+ avg_tokens_per_text = total_tokens / max(1, valid)
313
+ est_epoch_tokens = avg_tokens_per_text * len(texts)
314
+ tokens_per_step = (block_size + 1) * batch_size * get_world_size()
315
+ return max(1, int(est_epoch_tokens // max(1, tokens_per_step)))
316
+
317
+
318
+ # ╔══════════════════════════════════════════════════════════════════════════════╗
319
+ # ║ PACKED DATASET ║
320
+ # ╚══════════════════════════════════════════════════════════════════════════════╝
321
+
322
+ class PackedTextList(torch.utils.data.IterableDataset):
323
+ """
324
+ Packing dense sans padding.
325
+ drop_last=True dans le DataLoader → shapes constantes, utile pour compile.
326
+ """
327
+ def __init__(self, texts, tokenizer, block_size, epoch_seed=0):
328
+ super().__init__()
329
+ self.texts = texts
330
+ self.tokenizer = tokenizer
331
+ self.block_size = block_size
332
+ self.epoch_seed = epoch_seed
333
+
334
+ def __iter__(self):
335
+ worker = torch.utils.data.get_worker_info()
336
+ rank, ws = get_rank(), get_world_size()
337
+
338
+ if worker is None:
339
+ shard_mod, shard_id = ws, rank
340
+ else:
341
+ shard_mod = worker.num_workers * ws
342
+ shard_id = rank * worker.num_workers + worker.id
343
+
344
+ rng = random.Random(self.epoch_seed)
345
+ indices = list(range(len(self.texts)))
346
+ rng.shuffle(indices)
347
+
348
+ bos, eos = self.tokenizer.bos_token_id, self.tokenizer.eos_token_id
349
+ buf: list[int] = []
350
+
351
+ for li, ti in enumerate(indices):
352
+ if li % shard_mod != shard_id:
353
+ continue
354
+
355
+ ids = self.tokenizer.encode(self.texts[ti], add_special_tokens=False)
356
+ if not ids:
357
+ continue
358
+
359
+ buf.extend([bos] + ids + [eos])
360
+
361
+ while len(buf) >= self.block_size + 1:
362
+ chunk = buf[: self.block_size + 1]
363
+ buf = buf[self.block_size + 1 :]
364
+ yield {
365
+ "input_ids": torch.tensor(chunk[:-1], dtype=torch.long),
366
+ "labels": torch.tensor(chunk[1:], dtype=torch.long),
367
+ }
368
+
369
+
370
+ # ╔══════════════════════════════════════════════════════════════════════════════╗
371
+ # ║ TOKENIZER ║
372
+ # ╚═══════════════════════════════════════════════════════════════════════════��══╝
373
+
374
+ def tokenizer_ready() -> bool:
375
+ return (
376
+ (TOKENIZER_DIR / "tokenizer.json").exists()
377
+ and (TOKENIZER_DIR / "tokenizer_config.json").exists()
378
+ )
379
+
380
+ def train_tokenizer_once() -> None:
381
+ TOKENIZER_DIR.mkdir(parents=True, exist_ok=True)
382
+
383
+ tok = Tokenizer(models.BPE(unk_token=UNK_TOKEN))
384
+ tok.normalizer = normalizers.Sequence([normalizers.NFKC()])
385
+ tok.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
386
+ tok.decoder = decoders.ByteLevel()
387
+
388
+ trainer = trainers.BpeTrainer(
389
+ vocab_size=VOCAB_SIZE,
390
+ min_frequency=2,
391
+ show_progress=is_main(),
392
+ special_tokens=SPECIAL_TOKENS,
393
+ initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
394
+ )
395
+
396
+ tok.train_from_iterator(tokenizer_training_iterator(), trainer=trainer)
397
+
398
+ bos_id = tok.token_to_id(BOS_TOKEN)
399
+ eos_id = tok.token_to_id(EOS_TOKEN)
400
+
401
+ tok.post_processor = processors.TemplateProcessing(
402
+ single=f"{BOS_TOKEN} $A {EOS_TOKEN}",
403
+ pair=f"{BOS_TOKEN} $A {EOS_TOKEN} $B:1 {EOS_TOKEN}:1",
404
+ special_tokens=[(BOS_TOKEN, bos_id), (EOS_TOKEN, eos_id)],
405
+ )
406
+
407
+ tok.save(str(TOKENIZER_DIR / "tokenizer.json"))
408
+
409
+ fast = PreTrainedTokenizerFast(
410
+ tokenizer_file=str(TOKENIZER_DIR / "tokenizer.json"),
411
+ bos_token=BOS_TOKEN,
412
+ eos_token=EOS_TOKEN,
413
+ unk_token=UNK_TOKEN,
414
+ pad_token=PAD_TOKEN,
415
+ )
416
+ fast.save_pretrained(str(TOKENIZER_DIR))
417
+
418
+ def train_or_load_tokenizer() -> PreTrainedTokenizerFast:
419
+ TOKENIZER_DIR.mkdir(parents=True, exist_ok=True)
420
+
421
+ if not tokenizer_ready():
422
+ if is_distributed():
423
+ if is_main():
424
+ print("Entraînement tokenizer 32k…")
425
+ train_tokenizer_once()
426
+ dist.barrier()
427
+ else:
428
+ print("Entraînement tokenizer 32k…")
429
+ train_tokenizer_once()
430
+
431
+ return PreTrainedTokenizerFast.from_pretrained(str(TOKENIZER_DIR))
432
+
433
+
434
+ # ╔══════════════════════════════════════════════════════════════════════════════╗
435
+ # ║ MODÈLE GPT ║
436
+ # ╚══════════════════════════════════════════════════════════════════════════════╝
437
+
438
+ @dataclass
439
+ class GPTConfig:
440
+ vocab_size: int = VOCAB_SIZE
441
+ block_size: int = BLOCK_SIZE
442
+ d_model: int = D_MODEL
443
+ n_heads: int = N_HEADS
444
+ n_layers: int = N_LAYERS
445
+ d_ff: int = D_FF
446
+ dropout: float = DROPOUT
447
+ use_checkpointing: bool = USE_CHECKPOINTING
448
+
449
+
450
+ class RMSNorm(nn.Module):
451
+ def __init__(self, dim: int, eps: float = 1e-6):
452
+ super().__init__()
453
+ self.weight = nn.Parameter(torch.ones(dim))
454
+ self.eps = eps
455
+
456
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
457
+ return self.weight * x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
458
+
459
+
460
+ class RotaryEmbedding(nn.Module):
461
+ def __init__(self, dim: int, base: int = 10_000, max_seq: int = 8_192):
462
+ super().__init__()
463
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
464
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
465
+
466
+ t = torch.arange(max_seq).float()
467
+ freqs = torch.outer(t, inv_freq)
468
+ self.register_buffer("cos_cache", torch.repeat_interleave(freqs.cos(), 2, dim=-1), persistent=False)
469
+ self.register_buffer("sin_cache", torch.repeat_interleave(freqs.sin(), 2, dim=-1), persistent=False)
470
+
471
+ def forward(self, seq_len: int, dtype: torch.dtype):
472
+ return self.cos_cache[:seq_len].to(dtype), self.sin_cache[:seq_len].to(dtype)
473
+
474
+
475
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
476
+ x1, x2 = x[..., ::2], x[..., 1::2]
477
+ return torch.stack((-x2, x1), dim=-1).flatten(-2)
478
+
479
+ def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
480
+ return x * cos.unsqueeze(0).unsqueeze(0) + rotate_half(x) * sin.unsqueeze(0).unsqueeze(0)
481
+
482
+
483
+ class CausalSelfAttention(nn.Module):
484
+ """
485
+ Flash Attention 2 si disponible.
486
+ Sinon SDPA PyTorch, fusionné et performant sur H100.
487
+ """
488
+ def __init__(self, cfg: GPTConfig):
489
+ super().__init__()
490
+ assert cfg.d_model % cfg.n_heads == 0
491
+
492
+ self.n_heads = cfg.n_heads
493
+ self.head_dim = cfg.d_model // cfg.n_heads
494
+ self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
495
+ self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
496
+ self.dropout_p = cfg.dropout
497
+ self.rope = RotaryEmbedding(self.head_dim)
498
+
499
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
500
+ b, t, c = x.shape
501
+
502
+ q, k, v = self.qkv(x).split(c, dim=-1)
503
+ q = q.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
504
+ k = k.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
505
+ v = v.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
506
+
507
+ cos, sin = self.rope(t, x.dtype)
508
+ q, k = apply_rope(q, cos, sin), apply_rope(k, cos, sin)
509
+
510
+ if HAS_FLASH:
511
+ q = q.transpose(1, 2)
512
+ k = k.transpose(1, 2)
513
+ v = v.transpose(1, 2)
514
+ y = flash_attn_func(
515
+ q,
516
+ k,
517
+ v,
518
+ dropout_p=self.dropout_p if self.training else 0.0,
519
+ causal=True,
520
+ )
521
+ y = y.reshape(b, t, c)
522
+ else:
523
+ y = F.scaled_dot_product_attention(
524
+ q,
525
+ k,
526
+ v,
527
+ dropout_p=self.dropout_p if self.training else 0.0,
528
+ is_causal=True,
529
+ )
530
+ y = y.transpose(1, 2).contiguous().view(b, t, c)
531
+
532
+ return self.proj(y)
533
+
534
+
535
+ class SwiGLU(nn.Module):
536
+ def __init__(self, cfg: GPTConfig):
537
+ super().__init__()
538
+ self.w1 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
539
+ self.w2 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
540
+ self.w3 = nn.Linear(cfg.d_ff, cfg.d_model, bias=False)
541
+
542
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
543
+ return self.w3(F.silu(self.w1(x)) * self.w2(x))
544
+
545
+
546
+ class Block(nn.Module):
547
+ def __init__(self, cfg: GPTConfig):
548
+ super().__init__()
549
+ self.ln1 = RMSNorm(cfg.d_model)
550
+ self.attn = CausalSelfAttention(cfg)
551
+ self.ln2 = RMSNorm(cfg.d_model)
552
+ self.ff = SwiGLU(cfg)
553
+
554
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
555
+ x = x + self.attn(self.ln1(x))
556
+ x = x + self.ff(self.ln2(x))
557
+ return x
558
+
559
+
560
+ class GPT(nn.Module):
561
+ def __init__(self, cfg: GPTConfig):
562
+ super().__init__()
563
+ self.cfg = cfg
564
+ self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
565
+ self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)])
566
+ self.ln_f = RMSNorm(cfg.d_model)
567
+ self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
568
+ self.lm_head.weight = self.tok_emb.weight
569
+ self.apply(self._init_weights)
570
+
571
+ @staticmethod
572
+ def _init_weights(m: nn.Module) -> None:
573
+ if isinstance(m, (nn.Linear, nn.Embedding)):
574
+ nn.init.normal_(m.weight, 0.0, 0.02)
575
+ if isinstance(m, nn.Linear) and m.bias is not None:
576
+ nn.init.zeros_(m.bias)
577
+
578
+ def forward(self, input_ids: torch.Tensor, labels: Optional[torch.Tensor] = None):
579
+ x = self.tok_emb(input_ids)
580
+
581
+ for block in self.blocks:
582
+ if self.cfg.use_checkpointing and self.training:
583
+ x = torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False)
584
+ else:
585
+ x = block(x)
586
+
587
+ logits = self.lm_head(self.ln_f(x))
588
+ loss = None
589
+
590
+ if labels is not None:
591
+ loss = F.cross_entropy(
592
+ logits.reshape(-1, logits.size(-1)),
593
+ labels.reshape(-1),
594
+ ignore_index=-100,
595
+ )
596
+
597
+ return logits, loss
598
+
599
+
600
+ # ╔══════════════════════════════════════════════════════════════════════════════╗
601
+ # ║ LoRA ║
602
+ # ╚══════════════════════════════════════════════════════════════════════════════╝
603
+
604
+ class LoRALinear(nn.Module):
605
+ def __init__(self, base_layer: nn.Linear, r=LORA_R, alpha=LORA_ALPHA, dropout=LORA_DROPOUT):
606
+ super().__init__()
607
+ self.base = base_layer
608
+ self.r = r
609
+ self.scale = alpha / r
610
+
611
+ in_f, out_f = base_layer.in_features, base_layer.out_features
612
+ try:
613
+ dev = next(base_layer.parameters()).device
614
+ except StopIteration:
615
+ dev = torch.device("cpu")
616
+
617
+ self.lora_A = nn.Linear(in_f, r, bias=False, device=dev)
618
+ self.lora_B = nn.Linear(r, out_f, bias=False, device=dev)
619
+ self.drop = nn.Dropout(dropout)
620
+
621
+ nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
622
+ nn.init.zeros_(self.lora_B.weight)
623
+
624
+ for p in self.base.parameters():
625
+ p.requires_grad = False
626
+
627
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
628
+ return self.base(x) + self.lora_B(self.lora_A(self.drop(x))) * self.scale
629
+
630
+
631
+ def apply_qlora(model: GPT, device: torch.device) -> GPT:
632
+ if not USE_QLORA:
633
+ return model
634
+
635
+ targets = [
636
+ (name, module)
637
+ for name, module in model.named_modules()
638
+ if name.split(".")[-1] in LORA_TARGET_MODULES and isinstance(module, nn.Linear)
639
+ ]
640
+
641
+ for name, module in targets:
642
+ parts = name.split(".")
643
+ parent = model
644
+ for part in parts[:-1]:
645
+ parent = getattr(parent, part)
646
+ setattr(parent, parts[-1], LoRALinear(module))
647
+
648
+ if is_main():
649
+ print(f"LoRA : {len(targets)} couches remplacées (device={device})")
650
+
651
+ return model
652
+
653
+ def freeze_base_weights(model: GPT) -> None:
654
+ for name, p in model.named_parameters():
655
+ p.requires_grad = ("lora_A" in name or "lora_B" in name)
656
+
657
+
658
+ # ╔══════════════════════════════════════════════════════════════════════════════╗
659
+ # ║ OPTIMIZER ║
660
+ # ╚══════════════════════════════════════════════════════════════════════════════╝
661
+
662
+ def build_optimizer(model: nn.Module) -> torch.optim.Optimizer:
663
+ decay, no_decay = [], []
664
+
665
+ for name, p in unwrap_model(model).named_parameters():
666
+ if not p.requires_grad:
667
+ continue
668
+ if p.ndim >= 2 and "weight" in name:
669
+ decay.append(p)
670
+ else:
671
+ no_decay.append(p)
672
+
673
+ groups = [
674
+ {"params": decay, "weight_decay": WEIGHT_DECAY},
675
+ {"params": no_decay, "weight_decay": 0.0},
676
+ ]
677
+
678
+ if HAS_BNB:
679
+ return bnb.optim.PagedAdamW8bit(
680
+ groups,
681
+ lr=LEARNING_RATE,
682
+ betas=(0.9, 0.95),
683
+ eps=1e-8,
684
+ )
685
+
686
+ return torch.optim.AdamW(
687
+ groups,
688
+ lr=LEARNING_RATE,
689
+ betas=(0.9, 0.95),
690
+ eps=1e-8,
691
+ fused=torch.cuda.is_available(),
692
+ )
693
+
694
+ def cosine_lr(step: int, total_steps: int) -> float:
695
+ if step < WARMUP_STEPS:
696
+ return LEARNING_RATE * step / max(1, WARMUP_STEPS)
697
+
698
+ p = min(1.0, (step - WARMUP_STEPS) / max(1, total_steps - WARMUP_STEPS))
699
+ return MIN_LR + 0.5 * (LEARNING_RATE - MIN_LR) * (1 + math.cos(math.pi * p))
700
+
701
+
702
+ # ╔══════════════════════════════════════════════════════════════════════════════╗
703
+ # ║ CHECKPOINT ║
704
+ # ╚══════════════════════════════════════════════════════════════════════════════╝
705
+
706
+ def save_checkpoint(model, optimizer, epoch, step, best_loss, path):
707
+ raw = unwrap_model(model)
708
+ torch.save(
709
+ {
710
+ "model": normalize_state_dict_keys(raw.state_dict()),
711
+ "optimizer": optimizer.state_dict(),
712
+ "epoch": epoch,
713
+ "step": step,
714
+ "best_loss": best_loss,
715
+ "config": asdict(raw.cfg),
716
+ },
717
+ path,
718
+ )
719
+
720
+ def maybe_load_base_checkpoint(model, device):
721
+ if BASE_CHECKPOINT is None or not Path(BASE_CHECKPOINT).exists():
722
+ return
723
+
724
+ ckpt = torch.load(BASE_CHECKPOINT, map_location=device)
725
+ missing, unexpected = model.load_state_dict(normalize_state_dict_keys(ckpt["model"]), strict=False)
726
+
727
+ if is_main():
728
+ print(f"Base checkpoint chargé depuis {BASE_CHECKPOINT}")
729
+ if missing:
730
+ print(f"[warn] missing keys base ckpt: {len(missing)}")
731
+ if unexpected:
732
+ print(f"[warn] unexpected keys base ckpt: {len(unexpected)}")
733
+
734
+ def load_resume_checkpoint(model, optimizer, path, device):
735
+ ckpt = torch.load(path, map_location=device)
736
+ unwrap_model(model).load_state_dict(normalize_state_dict_keys(ckpt["model"]), strict=True)
737
+
738
+ try:
739
+ optimizer.load_state_dict(ckpt["optimizer"])
740
+ except Exception as e:
741
+ print(f"[warn] Optimizer state non repris: {e}")
742
+
743
+ return (
744
+ int(ckpt.get("epoch", 0)),
745
+ int(ckpt.get("step", 0)),
746
+ float(ckpt.get("best_loss", 1e9)),
747
+ )
748
+
749
+
750
+ # ╔══════════════════════════════════════════════════════════════════════════════╗
751
+ # ║ ÉVALUATION ║
752
+ # ╚══════════════════════════════════════════════════════════════════════════════╝
753
+
754
+ @torch.no_grad()
755
+ def evaluate(model, loader, device, max_batches=100) -> float:
756
+ model.eval()
757
+ losses = []
758
+
759
+ for i, batch in enumerate(loader):
760
+ if i >= max_batches:
761
+ break
762
+
763
+ inp = batch["input_ids"].to(device, non_blocking=True)
764
+ lbl = batch["labels"].to(device, non_blocking=True)
765
+
766
+ with autocast_context(device):
767
+ _, loss = model(inp, lbl)
768
+
769
+ losses.append(loss.item())
770
+
771
+ model.train()
772
+ return sum(losses) / max(1, len(losses))
773
+
774
+
775
+ # ╔══════════════════════════════════════════════════════════════════════════════╗
776
+ # ║ DATALOADER ║
777
+ # ╚══════════════════════════════════════════════════════════════════════════════╝
778
+
779
+ def make_loader(dataset, batch_size, num_workers, is_cuda, drop_last=True):
780
+ kwargs = dict(
781
+ batch_size=batch_size,
782
+ num_workers=num_workers,
783
+ pin_memory=is_cuda,
784
+ drop_last=drop_last,
785
+ )
786
+ if num_workers > 0:
787
+ kwargs["persistent_workers"] = True
788
+ kwargs["prefetch_factor"] = PREFETCH_FACTOR
789
+ return torch.utils.data.DataLoader(dataset, **kwargs)
790
+
791
+
792
+ # ╔══════════════════════════════════════════════════════════════════════════════╗
793
+ # ║ CUDA / LOGGING ║
794
+ # ╚══════════════════════════════════════════════════════════════════════════════╝
795
+
796
+ def maybe_limit_process_memory(device: torch.device) -> tuple[Optional[int], Optional[float]]:
797
+ if device.type != "cuda":
798
+ return None, None
799
+
800
+ cuda_idx = current_cuda_index(device)
801
+
802
+ if TARGET_VRAM_GIB is None:
803
+ return cuda_idx, None
804
+
805
+ _, total = torch.cuda.mem_get_info(cuda_idx)
806
+ vram_fraction = min(TARGET_VRAM_GIB * (1024**3) / total, 0.98)
807
+ torch.cuda.memory.set_per_process_memory_fraction(vram_fraction, device=cuda_idx)
808
+ return cuda_idx, vram_fraction
809
+
810
+ def sync_if_cuda(device: torch.device) -> None:
811
+ if device.type == "cuda":
812
+ torch.cuda.synchronize(current_cuda_index(device))
813
+
814
+
815
+ # ╔══════════════════════════════════════════════════════════════════════════════╗
816
+ # ║ MAIN ║
817
+ # ╚══════════════════════════════════════════════════════════════════════════════╝
818
+
819
+ def main() -> None:
820
+ ddp_device = init_distributed()
821
+ set_seed(SEED + get_rank())
822
+
823
+ device = get_device(ddp_device)
824
+ is_cuda = device.type == "cuda"
825
+
826
+ if is_cuda:
827
+ torch.backends.cuda.matmul.allow_tf32 = True
828
+ torch.backends.cudnn.allow_tf32 = True
829
+ torch.set_float32_matmul_precision("high")
830
+ torch.backends.cudnn.benchmark = True
831
+
832
+ cuda_idx, vram_fraction = maybe_limit_process_memory(device)
833
+
834
+ if is_main():
835
+ print("=" * 72)
836
+ print(" GPT ~1B | H100 MAX VRAM | LoRA + BF16 + TF32 + compile | v4")
837
+ print("=" * 72)
838
+ print(f"Device : {device} | World: {get_world_size()} GPU(s)")
839
+ print(f"Flash-2 : {HAS_FLASH} | BNB: {HAS_BNB} | LoRA: {USE_QLORA}")
840
+ print(f"Grad ckpt : {USE_CHECKPOINTING} | Compile: {USE_COMPILE} ({COMPILE_MODE})")
841
+ print(f"BLOCK_SIZE : {BLOCK_SIZE} | BATCH_SIZE: {BATCH_SIZE} | GRAD_ACCUM: {GRAD_ACCUM_STEPS}")
842
+ print(f"Tokens/step : {BLOCK_SIZE * BATCH_SIZE * GRAD_ACCUM_STEPS:,}")
843
+ if is_cuda:
844
+ free, total = torch.cuda.mem_get_info(cuda_idx)
845
+ print(f"GPU : {torch.cuda.get_device_name(cuda_idx)}")
846
+ print(f"VRAM : {total/1024**3:.1f} GiB | libre: {free/1024**3:.1f} GiB")
847
+ if vram_fraction is None:
848
+ print("Cap VRAM : désactivé")
849
+ else:
850
+ print(f"Cap VRAM : {TARGET_VRAM_GIB:.1f} GiB ({100*vram_fraction:.1f}% du device)")
851
+
852
+ tokenizer = train_or_load_tokenizer()
853
+ cfg = GPTConfig(vocab_size=len(tokenizer))
854
+
855
+ if is_main():
856
+ CONFIG_FILE.write_text(json.dumps(asdict(cfg), indent=2, ensure_ascii=False), encoding="utf-8")
857
+
858
+ # 1) Base model sur GPU
859
+ model = GPT(cfg).to(device)
860
+
861
+ # 2) Charger éventuel checkpoint base AVANT LoRA
862
+ maybe_load_base_checkpoint(model, device)
863
+
864
+ # 3) Injecter LoRA ensuite
865
+ if USE_QLORA:
866
+ model = apply_qlora(model, device)
867
+ freeze_base_weights(model)
868
+
869
+ # 4) Compiler le modèle ensuite
870
+ if USE_COMPILE and hasattr(torch, "compile"):
871
+ if is_main():
872
+ print(f"Compilation torch.compile({COMPILE_MODE})…")
873
+ try:
874
+ model = torch.compile(model, mode=COMPILE_MODE, fullgraph=False)
875
+ if is_main():
876
+ print("torch.compile : OK")
877
+ except Exception as e:
878
+ if is_main():
879
+ print(f"[warn] torch.compile échoué ({e}) — fallback eager")
880
+
881
+ # 5) DDP
882
+ if is_distributed():
883
+ model = DDP(model, device_ids=[device.index])
884
+
885
+ optimizer = build_optimizer(model)
886
+
887
+ # ── Datasets ──────────────────────────────────────────────────────────────
888
+ eval_texts = build_eval_texts()
889
+ eval_ds = PackedTextList(eval_texts, tokenizer, cfg.block_size, SEED + 999)
890
+ eval_loader = make_loader(eval_ds, BATCH_SIZE, EVAL_NUM_WORKERS, is_cuda, drop_last=False)
891
+
892
+ init_texts = build_epoch_train_texts(0)
893
+ steps_per_epoch = estimate_steps_per_epoch(init_texts, tokenizer, cfg.block_size, BATCH_SIZE * GRAD_ACCUM_STEPS)
894
+ total_steps_est = max(steps_per_epoch * NUM_EPOCHS, WARMUP_STEPS + 100)
895
+
896
+ # ── Reprise ───────────────────────────────────────────────────────────────
897
+ start_epoch, start_step, best_eval = 0, 0, 1e9
898
+ if STATE_FILE.exists():
899
+ try:
900
+ if is_main():
901
+ print(f"Reprise depuis {STATE_FILE}")
902
+ start_epoch, start_step, best_eval = load_resume_checkpoint(model, optimizer, STATE_FILE, device)
903
+ except Exception as e:
904
+ if is_main():
905
+ try:
906
+ STATE_FILE.rename(STATE_FILE.with_suffix(".corrupt.pt"))
907
+ except Exception:
908
+ pass
909
+ print(f"[warn] Checkpoint illisible ({e}) — reprise ignorée")
910
+ start_epoch, start_step, best_eval = 0, 0, 1e9
911
+
912
+ if is_main():
913
+ raw = unwrap_model(model)
914
+ n_total = count_parameters(raw, False)
915
+ n_train = count_parameters(raw, True)
916
+ effective_bs = BATCH_SIZE * GRAD_ACCUM_STEPS * get_world_size()
917
+
918
+ print(f"\nParamètres totaux : {n_total/1e9:.3f}B")
919
+ print(f"Paramètres entraînés : {n_train/1e6:.1f}M ({100*n_train/max(1, n_total):.2f}%)")
920
+ print(f"Batch effectif : {effective_bs} ({BATCH_SIZE}×{GRAD_ACCUM_STEPS}×{get_world_size()} GPU)")
921
+ print(f"Tokens/step : {BLOCK_SIZE * effective_bs:,}")
922
+ print(f"Steps estimés : {total_steps_est:,}")
923
+ print()
924
+ print("┌── Pilotage VRAM ──────────────────────────────────────────────┐")
925
+ print("│ Lis 'max_reserved' après quelques logs : │")
926
+ print("│ < 72 GiB → +2 BS | 72–77.5 GiB → zone cible │")
927
+ print("│ vrai OOM → -2 BS | puis relance │")
928
+ print("└───────────────────────────────────────────────────────────────┘")
929
+
930
+ # ── Boucle principale ─────────────────────────────────────────────────────
931
+ model.train()
932
+ optimizer.zero_grad(set_to_none=True)
933
+
934
+ global_step = start_step
935
+ t0 = time.time()
936
+ log_loss_sum = 0.0
937
+ log_loss_count = 0
938
+ tokens_since_log = 0
939
+ last_log = time.time()
940
+
941
+ if is_cuda:
942
+ torch.cuda.reset_peak_memory_stats(cuda_idx)
943
+
944
+ for epoch in range(start_epoch, NUM_EPOCHS):
945
+ if is_main():
946
+ print(f"\n{'='*20} Epoch {epoch+1}/{NUM_EPOCHS} {'='*20}")
947
+
948
+ train_texts = build_epoch_train_texts(epoch)
949
+ train_ds = PackedTextList(train_texts, tokenizer, cfg.block_size, SEED + epoch)
950
+ train_loader = make_loader(train_ds, BATCH_SIZE, TRAIN_NUM_WORKERS, is_cuda, drop_last=True)
951
+
952
+ for micro_step, batch in enumerate(train_loader):
953
+ inp = batch["input_ids"].to(device, non_blocking=True)
954
+ lbl = batch["labels"].to(device, non_blocking=True)
955
+
956
+ with autocast_context(device):
957
+ _, loss = model(inp, lbl)
958
+
959
+ (loss / GRAD_ACCUM_STEPS).backward()
960
+
961
+ log_loss_sum += loss.item()
962
+ log_loss_count += 1
963
+ tokens_since_log += inp.numel()
964
+
965
+ if (micro_step + 1) % GRAD_ACCUM_STEPS != 0:
966
+ continue
967
+
968
+ lr = cosine_lr(global_step, total_steps_est)
969
+ for group in optimizer.param_groups:
970
+ group["lr"] = lr
971
+
972
+ torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
973
+ optimizer.step()
974
+ optimizer.zero_grad(set_to_none=True)
975
+ global_step += 1
976
+
977
+ if global_step % 50 == 0 and is_main():
978
+ sync_if_cuda(device)
979
+ now = time.time()
980
+ elapsed = max(1e-6, now - last_log)
981
+ tok_s = tokens_since_log / elapsed
982
+ avg_loss = log_loss_sum / max(1, log_loss_count)
983
+
984
+ print(
985
+ f"ep {epoch+1}/{NUM_EPOCHS} | step={global_step:5d} | "
986
+ f"loss={avg_loss:.4f} | lr={lr:.2e} | {tok_s:,.0f} tok/s"
987
+ )
988
+
989
+ if is_cuda:
990
+ alloc = torch.cuda.memory_allocated(cuda_idx) / 1024**3
991
+ reserved = torch.cuda.memory_reserved(cuda_idx) / 1024**3
992
+ max_res = torch.cuda.max_memory_reserved(cuda_idx) / 1024**3
993
+
994
+ status = (
995
+ "▲ OK"
996
+ if max_res < 75.0 else
997
+ "⚠ proche limite"
998
+ if max_res < 77.5 else
999
+ "🔴 DANGER OOM"
1000
+ )
1001
+
1002
+ print(
1003
+ f" GPU mem | alloc={alloc:.1f} | reserved={reserved:.1f} | "
1004
+ f"max_reserved={max_res:.1f} GiB {status}"
1005
+ )
1006
+
1007
+ last_log = now
1008
+ tokens_since_log = 0
1009
+ log_loss_sum = 0.0
1010
+ log_loss_count = 0
1011
+
1012
+ if global_step % EVAL_EVERY == 0 and is_main():
1013
+ val_loss = evaluate(model, eval_loader, device)
1014
+ print(f"[eval] step {global_step:5d} | val_loss={val_loss:.4f}")
1015
+
1016
+ if val_loss < best_eval:
1017
+ best_eval = val_loss
1018
+ save_checkpoint(model, optimizer, epoch, global_step, best_eval, BEST_MODEL_FILE)
1019
+ print(f"✓ Meilleur modèle → {BEST_MODEL_FILE}")
1020
+
1021
+ if global_step % SAVE_EVERY == 0 and is_main():
1022
+ save_checkpoint(model, optimizer, epoch, global_step, best_eval, STATE_FILE)
1023
+ save_checkpoint(model, optimizer, epoch, global_step, best_eval, MODEL_FILE)
1024
+ print(f"✓ Checkpoint → {MODEL_FILE}")
1025
+
1026
+ if is_main():
1027
+ save_checkpoint(model, optimizer, epoch + 1, global_step, best_eval, STATE_FILE)
1028
+ ckpt = OUT_DIR / f"model_epoch_{epoch+1:02d}.pt"
1029
+ save_checkpoint(model, optimizer, epoch + 1, global_step, best_eval, ckpt)
1030
+ print(f"✓ Fin epoch {epoch+1}/{NUM_EPOCHS} → {ckpt}")
1031
+
1032
+ if is_main():
1033
+ save_checkpoint(model, optimizer, NUM_EPOCHS, global_step, best_eval, MODEL_FILE)
1034
+ save_checkpoint(model, optimizer, NUM_EPOCHS, global_step, best_eval, STATE_FILE)
1035
+ total_min = (time.time() - t0) / 60
1036
+
1037
+ print(f"\nModèle final → {MODEL_FILE}")
1038
+ print(f"Meilleur modèle → {BEST_MODEL_FILE}")
1039
+ print(f"Temps total : {total_min:.1f} min | Steps: {global_step}")
1040
+
1041
+ if is_distributed():
1042
+ dist.destroy_process_group()
1043
+
1044
+
1045
+ if __name__ == "__main__":
1046
+ main()
train_nlp_h100_maxvram_v7.py ADDED
@@ -0,0 +1,805 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ train_nlp_h100_maxvram_v7.py — v3 (fix gated OSCAR → public C4)
5
+ ===========================================================
6
+ • Datasets publics seulement (plus de gated error)
7
+ • Toujours ~85 GB de données traitées sur 10 epochs
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import itertools
13
+ import json
14
+ import math
15
+ import os
16
+ import random
17
+ import time
18
+ from collections import OrderedDict
19
+ from contextlib import nullcontext
20
+ from dataclasses import asdict, dataclass
21
+ from pathlib import Path
22
+ from typing import Iterator, Optional
23
+
24
+ import torch
25
+ import torch.distributed as dist
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+
29
+ try:
30
+ import bitsandbytes as bnb
31
+ HAS_BNB = True
32
+ except ImportError:
33
+ HAS_BNB = False
34
+ print("[warn] bitsandbytes non disponible – quantification 4-bit désactivée")
35
+
36
+ try:
37
+ from flash_attn import flash_attn_func
38
+ HAS_FLASH = True
39
+ except ImportError:
40
+ HAS_FLASH = False
41
+ print("[warn] flash-attn non disponible – fallback F.scaled_dot_product_attention")
42
+
43
+ from datasets import load_dataset
44
+ from torch.nn.parallel import DistributedDataParallel as DDP
45
+ from tokenizers import (
46
+ Tokenizer, decoders, models, normalizers,
47
+ pre_tokenizers, processors, trainers,
48
+ )
49
+ from transformers import PreTrainedTokenizerFast
50
+
51
+
52
+ # ╔══════════════════════════════════════════════════════════════════════════════╗
53
+ # ║ CHEMINS ║
54
+ # ╚══════════════════════════════════════════════════════════════════════════════╝
55
+
56
+ OUT_DIR = Path("./nlp_1b_h100_opt")
57
+ OUT_DIR.mkdir(parents=True, exist_ok=True)
58
+ TOKENIZER_DIR = OUT_DIR / "tokenizer_32k"
59
+ CONFIG_FILE = OUT_DIR / "config.json"
60
+ MODEL_FILE = OUT_DIR / "model.pt"
61
+ BEST_MODEL_FILE= OUT_DIR / "model_best.pt"
62
+ STATE_FILE = OUT_DIR / "train_state.pt"
63
+ BASE_CHECKPOINT: Optional[Path] = None
64
+
65
+
66
+ # ╔══════════════════════════════════════════════════════════════════════════════╗
67
+ # ║ HYPERPARAMÈTRES ║
68
+ # ╚══════════════════════════════════════════════════════════════════════════════╝
69
+
70
+ SEED = 42
71
+ TARGET_VRAM_GIB= 78.0
72
+
73
+ BLOCK_SIZE = 1024
74
+ VOCAB_SIZE = 32_000
75
+ D_MODEL = 1536
76
+ N_HEADS = 24
77
+ N_LAYERS = 24
78
+ D_FF = 6144
79
+ DROPOUT = 0.0
80
+
81
+ USE_QLORA = True
82
+ LORA_R = 64
83
+ LORA_ALPHA = 128
84
+ LORA_DROPOUT = 0.05
85
+ LORA_TARGET_MODULES = ["qkv", "proj", "w1", "w2", "w3"]
86
+
87
+ NUM_EPOCHS = 3
88
+ LEARNING_RATE = 3e-4
89
+ MIN_LR = 3e-5
90
+ WEIGHT_DECAY = 0.1
91
+ WARMUP_STEPS = 500
92
+
93
+ BATCH_SIZE = 28
94
+ GRAD_ACCUM_STEPS = 1
95
+ MAX_GRAD_NORM = 1.0
96
+ EVAL_EVERY = 500
97
+ SAVE_EVERY = 1_000
98
+
99
+ DTYPE = torch.bfloat16
100
+
101
+ USE_CHECKPOINTING = False
102
+ USE_COMPILE = True
103
+ COMPILE_MODE = "reduce-overhead"
104
+
105
+ TRAIN_NUM_WORKERS = 4
106
+ EVAL_NUM_WORKERS = 2
107
+ PREFETCH_FACTOR = 2
108
+
109
+ TOKENIZER_SAMPLE_DOCS_PER_SOURCE = 15_000
110
+ TOKENIZER_CHAR_LIMIT = 2_000
111
+ TEXT_CHAR_LIMIT = 4_000
112
+
113
+ SPECIAL_TOKENS = ["<pad>", "<bos>", "<eos>", "<unk>"]
114
+ PAD_TOKEN, BOS_TOKEN, EOS_TOKEN, UNK_TOKEN = SPECIAL_TOKENS
115
+
116
+
117
+ # ╔══════════════════════════════════════════════════════════════════════════════╗
118
+ # ║ DATASETS — PUBLIC + MAX 100 GB (fix gated OSCAR) ║
119
+ # ╚══════════════════════════════════════════════════════════════════════════════╝
120
+
121
+ DATA_SOURCES = [
122
+ # 1. FineWeb (anglais – très haute qualité)
123
+ {
124
+ "name": "HuggingFaceFW/fineweb",
125
+ "config": None,
126
+ "split": "train",
127
+ "text_column": "text",
128
+ "dev_docs": 10_000,
129
+ "train_docs_per_epoch": 1_200_000, # ~48 GB sur 10 epochs
130
+ "language_filter": None,
131
+ },
132
+ # 2. C4 multilingual → français
133
+ {
134
+ "name": "allenai/c4",
135
+ "config": "multilingual",
136
+ "split": "train",
137
+ "text_column": "text",
138
+ "dev_docs": 5_000,
139
+ "train_docs_per_epoch": 400_000, # ~16 GB sur 10 epochs
140
+ "language_filter": "fr",
141
+ },
142
+ # 3. C4 multilingual → arabe
143
+ {
144
+ "name": "allenai/c4",
145
+ "config": "multilingual",
146
+ "split": "train",
147
+ "text_column": "text",
148
+ "dev_docs": 5_000,
149
+ "train_docs_per_epoch": 300_000, # ~12 GB sur 10 epochs
150
+ "language_filter": "ar",
151
+ },
152
+ ]
153
+
154
+
155
+ # ╔══════════════════════════════════════════════════════════════════════════════╗
156
+ # ║ DISTRIBUTED + UTILS ║
157
+ # ╚══════════════════════════════════════════════════════════════════════════════╝
158
+
159
+ def is_distributed() -> bool:
160
+ return dist.is_available() and dist.is_initialized()
161
+
162
+ def get_rank() -> int:
163
+ return dist.get_rank() if is_distributed() else 0
164
+
165
+ def get_world_size() -> int:
166
+ return dist.get_world_size() if is_distributed() else 1
167
+
168
+ def is_main() -> bool:
169
+ return get_rank() == 0
170
+
171
+ def init_distributed() -> Optional[torch.device]:
172
+ local_rank = int(os.environ.get("LOCAL_RANK", -1))
173
+ if local_rank == -1:
174
+ return None
175
+ dist.init_process_group("nccl")
176
+ torch.cuda.set_device(local_rank)
177
+ return torch.device(f"cuda:{local_rank}")
178
+
179
+ def set_seed(seed: int) -> None:
180
+ random.seed(seed)
181
+ torch.manual_seed(seed)
182
+ if torch.cuda.is_available():
183
+ torch.cuda.manual_seed_all(seed)
184
+
185
+ def get_device(ddp_device: Optional[torch.device] = None) -> torch.device:
186
+ if ddp_device is not None:
187
+ return ddp_device
188
+ if torch.cuda.is_available():
189
+ return torch.device(f"cuda:{torch.cuda.current_device()}")
190
+ return torch.device("cpu")
191
+
192
+ def current_cuda_index(device: torch.device) -> int:
193
+ return device.index if device.index is not None else torch.cuda.current_device()
194
+
195
+ def autocast_context(device: torch.device):
196
+ if device.type == "cuda":
197
+ return torch.autocast("cuda", dtype=DTYPE)
198
+ return nullcontext()
199
+
200
+ def unwrap_model(model: nn.Module) -> nn.Module:
201
+ m = model.module if isinstance(model, DDP) else model
202
+ return m._orig_mod if hasattr(m, "_orig_mod") else m
203
+
204
+ def count_parameters(model: nn.Module, trainable_only: bool = True) -> int:
205
+ return sum(p.numel() for p in model.parameters() if not trainable_only or p.requires_grad)
206
+
207
+ def normalize_state_dict_keys(sd: dict) -> OrderedDict:
208
+ out = OrderedDict()
209
+ for k, v in sd.items():
210
+ for prefix in ("module._orig_mod.", "_orig_mod.", "module."):
211
+ if k.startswith(prefix):
212
+ k = k[len(prefix):]
213
+ break
214
+ out[k] = v
215
+ return out
216
+
217
+ def normalize_text(t: str) -> str:
218
+ return " ".join(t.strip().split())
219
+
220
+ def safe_str(x) -> str:
221
+ return x if isinstance(x, str) else ("" if x is None else str(x))
222
+
223
+
224
+ # ╔══════════════════════════════════════════════════════════════════════════════╗
225
+ # ║ DATA LOADING (streaming + language filter) ║
226
+ # ╚══════════════════════════════════════════════════════════════════════════════╝
227
+
228
+ def load_hf_stream(repo_id: str, config: str | None = None, split: str = "train"):
229
+ return load_dataset(repo_id, config, split=split, streaming=True)
230
+
231
+ def stream_texts_from_source(source: dict, start: int, count: int, char_limit: int) -> Iterator[str]:
232
+ ds = load_hf_stream(source["name"], source.get("config"), source.get("split", "train"))
233
+ col = source["text_column"]
234
+
235
+ for row in itertools.islice(ds, start, start + count):
236
+ text = normalize_text(safe_str(row.get(col, "")))
237
+ if len(text) < 20:
238
+ continue
239
+
240
+ # Filtre langue (pour C4 multilingual)
241
+ if source.get("language_filter"):
242
+ if row.get("language") != source["language_filter"]:
243
+ continue
244
+
245
+ yield text[:char_limit]
246
+
247
+
248
+ def build_epoch_train_texts(epoch: int) -> list[str]:
249
+ texts: list[str] = []
250
+ rng = random.Random(SEED + epoch)
251
+
252
+ for src in DATA_SOURCES:
253
+ start = src["dev_docs"] + epoch * src["train_docs_per_epoch"]
254
+ texts.extend(stream_texts_from_source(
255
+ src, start, src["train_docs_per_epoch"], TEXT_CHAR_LIMIT
256
+ ))
257
+
258
+ rng.shuffle(texts)
259
+ return texts
260
+
261
+
262
+ def build_eval_texts() -> list[str]:
263
+ texts: list[str] = []
264
+ for src in DATA_SOURCES:
265
+ texts.extend(stream_texts_from_source(
266
+ src, 0, src["dev_docs"], TEXT_CHAR_LIMIT
267
+ ))
268
+ return texts
269
+
270
+
271
+ # ╔═��════════════════════════════════════════════════════════════════════════════╗
272
+ # ║ TOKENIZER ║
273
+ # ╚══════════════════════════════════════════════════════════════════════════════╝
274
+
275
+ def tokenizer_ready() -> bool:
276
+ return (TOKENIZER_DIR / "tokenizer.json").exists() and (TOKENIZER_DIR / "tokenizer_config.json").exists()
277
+
278
+ def train_tokenizer_once() -> None:
279
+ TOKENIZER_DIR.mkdir(parents=True, exist_ok=True)
280
+ tok = Tokenizer(models.BPE(unk_token=UNK_TOKEN))
281
+ tok.normalizer = normalizers.Sequence([normalizers.NFKC()])
282
+ tok.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
283
+ tok.decoder = decoders.ByteLevel()
284
+ trainer = trainers.BpeTrainer(
285
+ vocab_size=VOCAB_SIZE, min_frequency=2, show_progress=is_main(),
286
+ special_tokens=SPECIAL_TOKENS, initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
287
+ )
288
+ tok.train_from_iterator(tokenizer_training_iterator(), trainer=trainer)
289
+ bos_id, eos_id = tok.token_to_id(BOS_TOKEN), tok.token_to_id(EOS_TOKEN)
290
+ tok.post_processor = processors.TemplateProcessing(
291
+ single=f"{BOS_TOKEN} $A {EOS_TOKEN}",
292
+ pair=f"{BOS_TOKEN} $A {EOS_TOKEN} $B:1 {EOS_TOKEN}:1",
293
+ special_tokens=[(BOS_TOKEN, bos_id), (EOS_TOKEN, eos_id)],
294
+ )
295
+ tok.save(str(TOKENIZER_DIR / "tokenizer.json"))
296
+ fast = PreTrainedTokenizerFast(
297
+ tokenizer_file=str(TOKENIZER_DIR / "tokenizer.json"),
298
+ bos_token=BOS_TOKEN, eos_token=EOS_TOKEN, unk_token=UNK_TOKEN, pad_token=PAD_TOKEN,
299
+ )
300
+ fast.save_pretrained(str(TOKENIZER_DIR))
301
+
302
+ def tokenizer_training_iterator() -> Iterator[str]:
303
+ for src in DATA_SOURCES:
304
+ yield from stream_texts_from_source(src, 0, TOKENIZER_SAMPLE_DOCS_PER_SOURCE, TOKENIZER_CHAR_LIMIT)
305
+
306
+ def train_or_load_tokenizer() -> PreTrainedTokenizerFast:
307
+ TOKENIZER_DIR.mkdir(parents=True, exist_ok=True)
308
+ if not tokenizer_ready():
309
+ if is_distributed():
310
+ if is_main():
311
+ print("Entraînement tokenizer 32k…")
312
+ train_tokenizer_once()
313
+ dist.barrier()
314
+ else:
315
+ print("Entraînement tokenizer 32k…")
316
+ train_tokenizer_once()
317
+ return PreTrainedTokenizerFast.from_pretrained(str(TOKENIZER_DIR))
318
+
319
+
320
+ # ╔══════════════════════════════════════════════════════════════════════════════╗
321
+ # ║ MODÈLE + QLORA + OPTIMIZER + CHECKPOINT + EVAL (inchangés) ║
322
+ # ╚══════════════════════════════════════════════════════════════════════════════╝
323
+
324
+ # (Tout le reste du code est identique à la v2 que je t’ai donnée précédemment)
325
+ # Je le garde complet pour que tu puisses copier-coller directement.
326
+
327
+ @dataclass
328
+ class GPTConfig:
329
+ vocab_size: int = VOCAB_SIZE
330
+ block_size: int = BLOCK_SIZE
331
+ d_model: int = D_MODEL
332
+ n_heads: int = N_HEADS
333
+ n_layers: int = N_LAYERS
334
+ d_ff: int = D_FF
335
+ dropout: float = DROPOUT
336
+ use_checkpointing: bool = USE_CHECKPOINTING
337
+
338
+
339
+ class RMSNorm(nn.Module):
340
+ def __init__(self, dim: int, eps: float = 1e-6):
341
+ super().__init__()
342
+ self.weight = nn.Parameter(torch.ones(dim))
343
+ self.eps = eps
344
+ def forward(self, x):
345
+ return self.weight * x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
346
+
347
+
348
+ class RotaryEmbedding(nn.Module):
349
+ def __init__(self, dim: int, base: int = 10_000, max_seq: int = 4_096):
350
+ super().__init__()
351
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
352
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
353
+ t = torch.arange(max_seq).float()
354
+ freqs = torch.outer(t, inv_freq)
355
+ self.register_buffer("cos_cache", torch.repeat_interleave(freqs.cos(), 2, dim=-1), persistent=False)
356
+ self.register_buffer("sin_cache", torch.repeat_interleave(freqs.sin(), 2, dim=-1), persistent=False)
357
+ def forward(self, seq_len: int, dtype: torch.dtype):
358
+ return self.cos_cache[:seq_len].to(dtype), self.sin_cache[:seq_len].to(dtype)
359
+
360
+
361
+ def rotate_half(x):
362
+ x1, x2 = x[..., ::2], x[..., 1::2]
363
+ return torch.stack((-x2, x1), dim=-1).flatten(-2)
364
+
365
+ def apply_rope(x, cos, sin):
366
+ return x * cos.unsqueeze(0).unsqueeze(0) + rotate_half(x) * sin.unsqueeze(0).unsqueeze(0)
367
+
368
+
369
+ class CausalSelfAttention(nn.Module):
370
+ def __init__(self, cfg: GPTConfig):
371
+ super().__init__()
372
+ assert cfg.d_model % cfg.n_heads == 0
373
+ self.n_heads = cfg.n_heads
374
+ self.head_dim = cfg.d_model // cfg.n_heads
375
+ self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
376
+ self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
377
+ self.dropout_p = cfg.dropout
378
+ self.rope = RotaryEmbedding(self.head_dim)
379
+
380
+ def forward(self, x):
381
+ b, t, c = x.shape
382
+ q, k, v = self.qkv(x).split(c, dim=-1)
383
+ q = q.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
384
+ k = k.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
385
+ v = v.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
386
+ cos, sin = self.rope(t, x.dtype)
387
+ q, k = apply_rope(q, cos, sin), apply_rope(k, cos, sin)
388
+
389
+ if HAS_FLASH:
390
+ q = q.transpose(1, 2)
391
+ k = k.transpose(1, 2)
392
+ v = v.transpose(1, 2)
393
+ y = flash_attn_func(q, k, v, dropout_p=self.dropout_p if self.training else 0.0, causal=True)
394
+ y = y.reshape(b, t, c)
395
+ else:
396
+ y = F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout_p if self.training else 0.0, is_causal=True)
397
+ y = y.transpose(1, 2).contiguous().view(b, t, c)
398
+ return self.proj(y)
399
+
400
+
401
+ class SwiGLU(nn.Module):
402
+ def __init__(self, cfg: GPTConfig):
403
+ super().__init__()
404
+ self.w1 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
405
+ self.w2 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
406
+ self.w3 = nn.Linear(cfg.d_ff, cfg.d_model, bias=False)
407
+ def forward(self, x):
408
+ return self.w3(F.silu(self.w1(x)) * self.w2(x))
409
+
410
+
411
+ class Block(nn.Module):
412
+ def __init__(self, cfg: GPTConfig):
413
+ super().__init__()
414
+ self.ln1 = RMSNorm(cfg.d_model)
415
+ self.attn = CausalSelfAttention(cfg)
416
+ self.ln2 = RMSNorm(cfg.d_model)
417
+ self.ff = SwiGLU(cfg)
418
+ def forward(self, x):
419
+ x = x + self.attn(self.ln1(x))
420
+ x = x + self.ff(self.ln2(x))
421
+ return x
422
+
423
+
424
+ class GPT(nn.Module):
425
+ def __init__(self, cfg: GPTConfig):
426
+ super().__init__()
427
+ self.cfg = cfg
428
+ self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
429
+ self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)])
430
+ self.ln_f = RMSNorm(cfg.d_model)
431
+ self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
432
+ self.lm_head.weight = self.tok_emb.weight
433
+ self.apply(self._init_weights)
434
+
435
+ @staticmethod
436
+ def _init_weights(m):
437
+ if isinstance(m, (nn.Linear, nn.Embedding)):
438
+ nn.init.normal_(m.weight, 0.0, 0.02)
439
+ if isinstance(m, nn.Linear) and m.bias is not None:
440
+ nn.init.zeros_(m.bias)
441
+
442
+ def forward(self, input_ids, labels=None):
443
+ x = self.tok_emb(input_ids)
444
+ for block in self.blocks:
445
+ if self.cfg.use_checkpointing and self.training:
446
+ x = torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False)
447
+ else:
448
+ x = block(x)
449
+ logits = self.lm_head(self.ln_f(x))
450
+ loss = None
451
+ if labels is not None:
452
+ loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), labels.reshape(-1), ignore_index=-100)
453
+ return logits, loss
454
+
455
+
456
+ class LoRALinear(nn.Module):
457
+ def __init__(self, base_layer: nn.Linear, r: int = LORA_R, alpha: int = LORA_ALPHA, dropout: float = LORA_DROPOUT):
458
+ super().__init__()
459
+ self.base = base_layer
460
+ self.r = r
461
+ self.scale = alpha / r
462
+ in_f, out_f = base_layer.in_features, base_layer.out_features
463
+ try:
464
+ dev = next(base_layer.parameters()).device
465
+ except StopIteration:
466
+ dev = torch.device("cpu")
467
+ self.lora_A = nn.Linear(in_f, r, bias=False, device=dev)
468
+ self.lora_B = nn.Linear(r, out_f, bias=False, device=dev)
469
+ self.drop = nn.Dropout(dropout)
470
+ nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
471
+ nn.init.zeros_(self.lora_B.weight)
472
+ for p in self.base.parameters():
473
+ p.requires_grad = False
474
+
475
+ def forward(self, x):
476
+ return self.base(x) + self.lora_B(self.lora_A(self.drop(x))) * self.scale
477
+
478
+
479
+ def apply_qlora(model: GPT, device: torch.device) -> GPT:
480
+ if not USE_QLORA:
481
+ return model
482
+ replaced = 0
483
+ targets = []
484
+ for name, module in model.named_modules():
485
+ parts = name.split(".")
486
+ if parts[-1] in LORA_TARGET_MODULES and isinstance(module, nn.Linear):
487
+ targets.append((name, module))
488
+ for name, module in targets:
489
+ parts = name.split(".")
490
+ parent = model
491
+ for part in parts[:-1]:
492
+ parent = getattr(parent, part)
493
+ lora_layer = LoRALinear(module)
494
+ setattr(parent, parts[-1], lora_layer)
495
+ replaced += 1
496
+ if is_main():
497
+ print(f"QLoRA : {replaced} couches remplacées (device={device}, NF4={HAS_BNB})")
498
+ return model
499
+
500
+
501
+ def freeze_base_weights(model: GPT) -> None:
502
+ for name, p in model.named_parameters():
503
+ p.requires_grad = ("lora_A" in name or "lora_B" in name)
504
+
505
+
506
+ def build_optimizer(model: nn.Module) -> torch.optim.Optimizer:
507
+ decay, no_decay = [], []
508
+ for name, p in unwrap_model(model).named_parameters():
509
+ if not p.requires_grad: continue
510
+ (decay if p.ndim >= 2 and "weight" in name else no_decay).append(p)
511
+ groups = [
512
+ {"params": decay, "weight_decay": WEIGHT_DECAY},
513
+ {"params": no_decay, "weight_decay": 0.0},
514
+ ]
515
+ if HAS_BNB:
516
+ return bnb.optim.PagedAdamW8bit(groups, lr=LEARNING_RATE, betas=(0.9, 0.95), eps=1e-8)
517
+ return torch.optim.AdamW(groups, lr=LEARNING_RATE, betas=(0.9, 0.95), eps=1e-8, fused=torch.cuda.is_available())
518
+
519
+
520
+ def cosine_lr(step: int, total_steps: int) -> float:
521
+ if step < WARMUP_STEPS:
522
+ return LEARNING_RATE * step / max(1, WARMUP_STEPS)
523
+ p = min(1.0, (step - WARMUP_STEPS) / max(1, total_steps - WARMUP_STEPS))
524
+ return MIN_LR + 0.5 * (LEARNING_RATE - MIN_LR) * (1 + math.cos(math.pi * p))
525
+
526
+
527
+ def save_checkpoint(model, optimizer, epoch, step, best_loss, path):
528
+ raw = unwrap_model(model)
529
+ torch.save({
530
+ "model": normalize_state_dict_keys(raw.state_dict()),
531
+ "optimizer": optimizer.state_dict(),
532
+ "epoch": epoch, "step": step, "best_loss": best_loss,
533
+ "config": asdict(raw.cfg),
534
+ }, path)
535
+
536
+
537
+ def maybe_load_base_checkpoint(model, device):
538
+ if BASE_CHECKPOINT is None or not Path(BASE_CHECKPOINT).exists():
539
+ return
540
+ ckpt = torch.load(BASE_CHECKPOINT, map_location=device)
541
+ unwrap_model(model).load_state_dict(normalize_state_dict_keys(ckpt["model"]), strict=False)
542
+
543
+
544
+ def load_resume_checkpoint(model, optimizer, path, device):
545
+ ckpt = torch.load(path, map_location=device)
546
+ unwrap_model(model).load_state_dict(normalize_state_dict_keys(ckpt["model"]), strict=True)
547
+ try:
548
+ optimizer.load_state_dict(ckpt["optimizer"])
549
+ except Exception as e:
550
+ print(f"[warn] Optimizer state non repris: {e}")
551
+ return int(ckpt.get("epoch", 0)), int(ckpt.get("step", 0)), float(ckpt.get("best_loss", 1e9))
552
+
553
+
554
+ @torch.no_grad()
555
+ def evaluate(model, loader, device, max_batches=200) -> float:
556
+ model.eval()
557
+ losses = []
558
+ for i, batch in enumerate(loader):
559
+ if i >= max_batches: break
560
+ inp = batch["input_ids"].to(device, non_blocking=True)
561
+ lbl = batch["labels"].to(device, non_blocking=True)
562
+ with autocast_context(device):
563
+ _, loss = model(inp, lbl)
564
+ losses.append(loss.item())
565
+ model.train()
566
+ return sum(losses) / max(1, len(losses))
567
+
568
+
569
+ def make_loader(dataset, batch_size, num_workers, is_cuda):
570
+ kwargs = dict(batch_size=batch_size, num_workers=num_workers, pin_memory=is_cuda)
571
+ if num_workers > 0:
572
+ kwargs["persistent_workers"] = True
573
+ kwargs["prefetch_factor"] = PREFETCH_FACTOR
574
+ return torch.utils.data.DataLoader(dataset, **kwargs)
575
+
576
+
577
+ class PackedTextList(torch.utils.data.IterableDataset):
578
+ def __init__(self, texts, tokenizer, block_size, epoch_seed=0):
579
+ super().__init__()
580
+ self.texts = texts
581
+ self.tokenizer = tokenizer
582
+ self.block_size = block_size
583
+ self.epoch_seed = epoch_seed
584
+
585
+ def __iter__(self):
586
+ worker = torch.utils.data.get_worker_info()
587
+ rank, ws = get_rank(), get_world_size()
588
+ if worker is None:
589
+ shard_mod, shard_id = ws, rank
590
+ else:
591
+ shard_mod = worker.num_workers * ws
592
+ shard_id = rank * worker.num_workers + worker.id
593
+
594
+ rng = random.Random(self.epoch_seed)
595
+ indices = list(range(len(self.texts)))
596
+ rng.shuffle(indices)
597
+
598
+ bos, eos = self.tokenizer.bos_token_id, self.tokenizer.eos_token_id
599
+ buf: list[int] = []
600
+
601
+ for li, ti in enumerate(indices):
602
+ if li % shard_mod != shard_id:
603
+ continue
604
+ ids = self.tokenizer.encode(self.texts[ti], add_special_tokens=False)
605
+ if not ids: continue
606
+ buf.extend([bos] + ids + [eos])
607
+ while len(buf) >= self.block_size + 1:
608
+ chunk = buf[:self.block_size + 1]
609
+ buf = buf[self.block_size + 1:]
610
+ yield {
611
+ "input_ids": torch.tensor(chunk[:-1], dtype=torch.long),
612
+ "labels": torch.tensor(chunk[1:], dtype=torch.long),
613
+ }
614
+
615
+
616
+ # ╔══════════════════════════════════════════════════════════════════════════════╗
617
+ # ║ MAIN ║
618
+ # ╚══════════════════════════════════════════════════════════════════════════════╝
619
+
620
+ def main() -> None:
621
+ ddp_device = init_distributed()
622
+ set_seed(SEED + get_rank())
623
+ device = get_device(ddp_device)
624
+ is_cuda = device.type == "cuda"
625
+
626
+ cuda_idx = None
627
+ if is_cuda:
628
+ torch.backends.cuda.matmul.allow_tf32 = True
629
+ torch.backends.cudnn.allow_tf32 = True
630
+ torch.set_float32_matmul_precision("high")
631
+ cuda_idx = current_cuda_index(device)
632
+ _, total = torch.cuda.mem_get_info(cuda_idx)
633
+ vram_fraction = min(TARGET_VRAM_GIB * (1024**3) / total, 0.999)
634
+ torch.cuda.memory.set_per_process_memory_fraction(vram_fraction, device=cuda_idx)
635
+
636
+ if is_main():
637
+ print("=" * 72)
638
+ print(" GPT ~1B | H100 80 Go | QLoRA + BF16 + TF32 | MAX 100 GB (public)")
639
+ print("=" * 72)
640
+ print(f"Device : {device} | World: {get_world_size()} GPU(s)")
641
+ print(f"Flash-2 : {HAS_FLASH} | BNB 4-bit: {HAS_BNB} | QLoRA: {USE_QLORA}")
642
+ print(f"Grad ckpt: {USE_CHECKPOINTING} | Compile: {USE_COMPILE} ({COMPILE_MODE})")
643
+ if is_cuda:
644
+ free, total = torch.cuda.mem_get_info(cuda_idx)
645
+ print(f"GPU : {torch.cuda.get_device_name(cuda_idx)}")
646
+ print(f"VRAM : {total/1024**3:.1f} GiB | libre: {free/1024**3:.1f} GiB")
647
+
648
+ tokenizer = train_or_load_tokenizer()
649
+ cfg = GPTConfig(vocab_size=len(tokenizer))
650
+
651
+ if is_main():
652
+ CONFIG_FILE.write_text(json.dumps(asdict(cfg), indent=2, ensure_ascii=False), encoding="utf-8")
653
+
654
+ model = GPT(cfg).to(device)
655
+
656
+ if USE_QLORA:
657
+ model = apply_qlora(model, device)
658
+ freeze_base_weights(model)
659
+
660
+ maybe_load_base_checkpoint(model, device)
661
+
662
+ if USE_COMPILE and not USE_CHECKPOINTING and hasattr(torch, "compile"):
663
+ try:
664
+ model = torch.compile(model, mode=COMPILE_MODE)
665
+ if is_main():
666
+ print(f"torch.compile activé ({COMPILE_MODE})")
667
+ except Exception as e:
668
+ if is_main():
669
+ print(f"[warn] torch.compile échoué ({e}) — poursuite sans compile")
670
+
671
+ if is_distributed():
672
+ model = DDP(model, device_ids=[device.index])
673
+
674
+ optimizer = build_optimizer(model)
675
+
676
+ eval_texts = build_eval_texts()
677
+ eval_ds = PackedTextList(eval_texts, tokenizer, cfg.block_size, SEED + 999)
678
+ eval_loader = make_loader(eval_ds, BATCH_SIZE, EVAL_NUM_WORKERS, is_cuda)
679
+
680
+ init_texts = build_epoch_train_texts(0)
681
+ steps_per_epoch = max(1, len(init_texts) // BATCH_SIZE)
682
+ total_steps_est = steps_per_epoch * NUM_EPOCHS
683
+
684
+ start_epoch, start_step, best_eval = 0, 0, 1e9
685
+ if STATE_FILE.exists():
686
+ try:
687
+ if is_main(): print(f"Reprise depuis {STATE_FILE}")
688
+ start_epoch, start_step, best_eval = load_resume_checkpoint(model, optimizer, STATE_FILE, device)
689
+ except Exception as e:
690
+ if is_main():
691
+ bad = STATE_FILE.with_suffix(".corrupt.pt")
692
+ print(f"[warn] Checkpoint illisible: {e}")
693
+ try: STATE_FILE.rename(bad)
694
+ except: pass
695
+ start_epoch, start_step, best_eval = 0, 0, 1e9
696
+
697
+ if is_main():
698
+ raw = unwrap_model(model)
699
+ n_total = count_parameters(raw, False)
700
+ n_train = count_parameters(raw, True)
701
+ print(f"Paramètres totaux : {n_total/1e9:.3f}B")
702
+ print(f"Paramètres entraînés : {n_train/1e6:.1f}M ({100*n_train/max(1,n_total):.2f}%)")
703
+ print(f"Batch size : {BATCH_SIZE} | Grad accum: {GRAD_ACCUM_STEPS} | Effective: {BATCH_SIZE*GRAD_ACCUM_STEPS}")
704
+ print(f"Steps estimés: {total_steps_est} | Eval texts: {len(eval_texts)}")
705
+ print("\n── Conseil VRAM ────────────────────────────────────────────────")
706
+ print(" Surveille max_reserved à step 50.")
707
+ print(" Si OOM → baisse BATCH_SIZE ou active USE_CHECKPOINTING=True")
708
+ print("────────────────────────────────────────────────────────────────")
709
+
710
+ model.train()
711
+ optimizer.zero_grad(set_to_none=True)
712
+
713
+ global_step = start_step
714
+ t0 = time.time()
715
+ log_loss_sum = 0.0
716
+ log_loss_count = 0
717
+ tokens_since_log = 0
718
+ last_log = time.time()
719
+
720
+ if is_cuda:
721
+ torch.cuda.reset_peak_memory_stats(cuda_idx)
722
+
723
+ for epoch in range(start_epoch, NUM_EPOCHS):
724
+ if is_main():
725
+ print(f"\n{'='*20} Epoch {epoch+1}/{NUM_EPOCHS} {'='*20}")
726
+
727
+ train_texts = build_epoch_train_texts(epoch)
728
+ train_ds = PackedTextList(train_texts, tokenizer, cfg.block_size, SEED + epoch)
729
+ train_loader = make_loader(train_ds, BATCH_SIZE, TRAIN_NUM_WORKERS, is_cuda)
730
+
731
+ for micro_step, batch in enumerate(train_loader):
732
+ inp = batch["input_ids"].to(device, non_blocking=True)
733
+ lbl = batch["labels"].to(device, non_blocking=True)
734
+
735
+ with autocast_context(device):
736
+ _, loss = model(inp, lbl)
737
+
738
+ (loss / GRAD_ACCUM_STEPS).backward()
739
+
740
+ log_loss_sum += loss.item()
741
+ log_loss_count += 1
742
+ tokens_since_log += inp.numel()
743
+
744
+ if (micro_step + 1) % GRAD_ACCUM_STEPS != 0:
745
+ continue
746
+
747
+ lr = cosine_lr(global_step, total_steps_est)
748
+ for group in optimizer.param_groups:
749
+ group["lr"] = lr
750
+
751
+ torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
752
+ optimizer.step()
753
+ optimizer.zero_grad(set_to_none=True)
754
+ global_step += 1
755
+
756
+ if global_step % 50 == 0 and is_main():
757
+ now = time.time()
758
+ elapsed = max(1e-6, now - last_log)
759
+ tok_s = tokens_since_log / elapsed
760
+ avg_loss = log_loss_sum / max(1, log_loss_count)
761
+ print(f"ep {epoch+1}/{NUM_EPOCHS} | step={global_step:5d} | loss={avg_loss:.4f} | lr={lr:.2e} | {tok_s:,.0f} tok/s")
762
+ if is_cuda:
763
+ alloc = torch.cuda.memory_allocated(cuda_idx) / 1024**3
764
+ reserved = torch.cuda.memory_reserved(cuda_idx) / 1024**3
765
+ max_alloc = torch.cuda.max_memory_allocated(cuda_idx) / 1024**3
766
+ max_res = torch.cuda.max_memory_reserved(cuda_idx) / 1024**3
767
+ print(f"GPU mem | alloc={alloc:.2f} | reserved={reserved:.2f} | max_reserved={max_res:.2f} GiB")
768
+ last_log = now
769
+ tokens_since_log = 0
770
+ log_loss_sum = 0.0
771
+ log_loss_count = 0
772
+
773
+ if global_step % EVAL_EVERY == 0 and is_main():
774
+ val_loss = evaluate(model, eval_loader, device)
775
+ print(f"[eval] step {global_step:5d} | val_loss={val_loss:.4f}")
776
+ if val_loss < best_eval:
777
+ best_eval = val_loss
778
+ save_checkpoint(model, optimizer, epoch, global_step, best_eval, BEST_MODEL_FILE)
779
+ print(f"✓ Meilleur modèle → {BEST_MODEL_FILE}")
780
+
781
+ if global_step % SAVE_EVERY == 0 and is_main():
782
+ save_checkpoint(model, optimizer, epoch, global_step, best_eval, STATE_FILE)
783
+ save_checkpoint(model, optimizer, epoch, global_step, best_eval, MODEL_FILE)
784
+ print(f"✓ Checkpoint → {MODEL_FILE}")
785
+
786
+ if is_main():
787
+ save_checkpoint(model, optimizer, epoch + 1, global_step, best_eval, STATE_FILE)
788
+ ckpt = OUT_DIR / f"model_epoch_{epoch+1:02d}.pt"
789
+ save_checkpoint(model, optimizer, epoch + 1, global_step, best_eval, ckpt)
790
+ print(f"✓ Fin epoch {epoch+1}/{NUM_EPOCHS} → {ckpt}")
791
+
792
+ if is_main():
793
+ save_checkpoint(model, optimizer, NUM_EPOCHS, global_step, best_eval, MODEL_FILE)
794
+ save_checkpoint(model, optimizer, NUM_EPOCHS, global_step, best_eval, STATE_FILE)
795
+ total_min = (time.time() - t0) / 60
796
+ print(f"\nModèle final → {MODEL_FILE}")
797
+ print(f"Meilleur modèle → {BEST_MODEL_FILE}")
798
+ print(f"Temps total : {total_min:.1f} min | Steps: {global_step}")
799
+
800
+ if is_distributed():
801
+ dist.destroy_process_group()
802
+
803
+
804
+ if __name__ == "__main__":
805
+ main()
upload.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import re
4
+ import time
5
+ from pathlib import Path
6
+
7
+ from huggingface_hub import HfApi
8
+
9
+
10
+ IGNORE_NAMES = {
11
+ ".git",
12
+ "__pycache__",
13
+ ".ipynb_checkpoints",
14
+ ".DS_Store",
15
+ }
16
+
17
+ IGNORE_SUFFIXES = {
18
+ ".tmp",
19
+ ".lock",
20
+ ".swp",
21
+ ".swx",
22
+ ".part",
23
+ }
24
+
25
+
26
+ def sanitize_repo_name(name: str) -> str:
27
+ name = name.strip().replace(" ", "-")
28
+ name = re.sub(r"[^A-Za-z0-9._-]+", "-", name)
29
+ name = re.sub(r"-{2,}", "-", name).strip("-")
30
+ return name[:96] or "model"
31
+
32
+
33
+ def should_ignore(path: Path) -> bool:
34
+ if any(part in IGNORE_NAMES for part in path.parts):
35
+ return True
36
+ if path.suffix.lower() in IGNORE_SUFFIXES:
37
+ return True
38
+ return False
39
+
40
+
41
+ def folder_stats(folder: Path):
42
+ file_count = 0
43
+ total_size = 0
44
+
45
+ for p in folder.rglob("*"):
46
+ if not p.is_file():
47
+ continue
48
+ if should_ignore(p):
49
+ continue
50
+
51
+ try:
52
+ total_size += p.stat().st_size
53
+ file_count += 1
54
+ except FileNotFoundError:
55
+ continue
56
+
57
+ return file_count, total_size
58
+
59
+
60
+ def format_bytes(num_bytes: int) -> str:
61
+ value = float(num_bytes)
62
+ units = ["B", "KB", "MB", "GB", "TB"]
63
+ for unit in units:
64
+ if value < 1024 or unit == units[-1]:
65
+ return f"{value:.2f} {unit}"
66
+ value /= 1024
67
+
68
+
69
+ def format_speed(bytes_per_sec: float) -> str:
70
+ return f"{format_bytes(bytes_per_sec)}/s"
71
+
72
+
73
+ def main():
74
+ parser = argparse.ArgumentParser(
75
+ description="Upload one-shot d'un dossier modèle vers Hugging Face Hub."
76
+ )
77
+ parser.add_argument(
78
+ "--model_dir",
79
+ type=str,
80
+ required=True,
81
+ help="Chemin du dossier modèle à uploader",
82
+ )
83
+ parser.add_argument(
84
+ "--namespace",
85
+ type=str,
86
+ default="Medyassino",
87
+ help="Namespace ou username Hugging Face",
88
+ )
89
+ parser.add_argument(
90
+ "--repo_name",
91
+ type=str,
92
+ default=None,
93
+ help="Nom du repo distant. Par défaut: nom du dossier modèle",
94
+ )
95
+ parser.add_argument(
96
+ "--private",
97
+ action="store_true",
98
+ help="Créer le repo en privé. Sans ce flag, il sera public",
99
+ )
100
+ parser.add_argument(
101
+ "--large",
102
+ action="store_true",
103
+ help="Utiliser upload_large_folder pour les gros dossiers",
104
+ )
105
+ parser.add_argument(
106
+ "--token",
107
+ type=str,
108
+ default=None,
109
+ help="Token HF. Sinon utilise HF_TOKEN ou le login local",
110
+ )
111
+ parser.add_argument(
112
+ "--commit_message",
113
+ type=str,
114
+ default="Upload model from local folder",
115
+ help="Message de commit pour upload_folder",
116
+ )
117
+
118
+ args = parser.parse_args()
119
+
120
+ model_dir = Path(args.model_dir).expanduser().resolve()
121
+ if not model_dir.exists() or not model_dir.is_dir():
122
+ raise RuntimeError(f"Dossier modèle introuvable: {model_dir}")
123
+
124
+ repo_name = sanitize_repo_name(args.repo_name or model_dir.name)
125
+ repo_id = f"{args.namespace}/{repo_name}"
126
+
127
+ token = args.token or os.environ.get("HF_TOKEN")
128
+ api = HfApi(token=token)
129
+
130
+ try:
131
+ who = api.whoami()
132
+ except Exception as e:
133
+ raise RuntimeError(
134
+ "Authentification Hugging Face impossible. "
135
+ "Fais `hf auth login`, ou passe `--token`, ou définis HF_TOKEN."
136
+ ) from e
137
+
138
+ file_count, total_size = folder_stats(model_dir)
139
+
140
+ print(f"Authentifié comme: {who.get('name') or who.get('fullname') or who}")
141
+ print(f"Upload de: {model_dir}")
142
+ print(f"Repo cible: {repo_id}")
143
+ print(f"Visibilité : {'privé' if args.private else 'public'}")
144
+ print(f"Fichiers détectés: {file_count}")
145
+ print(f"Taille totale : {format_bytes(total_size)}")
146
+
147
+ api.create_repo(
148
+ repo_id=repo_id,
149
+ repo_type="model",
150
+ private=args.private,
151
+ exist_ok=True,
152
+ )
153
+
154
+ start_time = time.perf_counter()
155
+
156
+ if args.large:
157
+ api.upload_large_folder(
158
+ repo_id=repo_id,
159
+ repo_type="model",
160
+ folder_path=str(model_dir),
161
+ )
162
+ else:
163
+ api.upload_folder(
164
+ repo_id=repo_id,
165
+ repo_type="model",
166
+ folder_path=str(model_dir),
167
+ commit_message=args.commit_message,
168
+ ignore_patterns=[
169
+ "**/.git/**",
170
+ "**/__pycache__/**",
171
+ "**/.ipynb_checkpoints/**",
172
+ "**/*.tmp",
173
+ "**/*.lock",
174
+ "**/*.swp",
175
+ "**/*.part",
176
+ ],
177
+ )
178
+
179
+ elapsed = time.perf_counter() - start_time
180
+ avg_speed = (total_size / elapsed) if elapsed > 0 else 0.0
181
+
182
+ print()
183
+ print(f"Upload OK -> https://huggingface.co/{repo_id}")
184
+ print(f"Durée totale : {elapsed:.2f} s")
185
+ print(f"Vitesse moyenne : {format_speed(avg_speed)}")
186
+
187
+
188
+ if __name__ == "__main__":
189
+ main()
wikipedia_ar_h100/config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 32000,
3
+ "block_size": 1024,
4
+ "d_model": 1024,
5
+ "n_heads": 16,
6
+ "n_layers": 24,
7
+ "d_ff": 4096,
8
+ "dropout": 0.0,
9
+ "use_checkpointing": false
10
+ }
wikipedia_ar_h100/tokenizer_32k/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
wikipedia_ar_h100/tokenizer_32k/tokenizer_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "backend": "tokenizers",
3
+ "bos_token": "<bos>",
4
+ "eos_token": "<eos>",
5
+ "model_max_length": 1000000000000000019884624838656,
6
+ "pad_token": "<pad>",
7
+ "tokenizer_class": "TokenizersBackend",
8
+ "unk_token": "<unk>"
9
+ }
wikipedia_ar_h100/train_state.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:09c3c854a6aaa593570f779f0a4c7281db9cc027a3897437dfe3a4cb1075b92e
3
+ size 2322908797
wikipedia_ar_h100_agri_30gb/config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 32000,
3
+ "block_size": 1024,
4
+ "d_model": 1024,
5
+ "n_heads": 16,
6
+ "n_layers": 24,
7
+ "d_ff": 4096,
8
+ "dropout": 0.0,
9
+ "use_checkpointing": false
10
+ }
wikipedia_ar_h100_codealpaca/config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 32000,
3
+ "block_size": 1024,
4
+ "d_model": 1024,
5
+ "n_heads": 16,
6
+ "n_layers": 24,
7
+ "d_ff": 4096,
8
+ "dropout": 0.0,
9
+ "use_checkpointing": false
10
+ }
wikipedia_ar_h100_env_fr_ar_77gb/config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 32000,
3
+ "block_size": 1024,
4
+ "d_model": 1024,
5
+ "n_heads": 16,
6
+ "n_layers": 24,
7
+ "d_ff": 4096,
8
+ "dropout": 0.0,
9
+ "use_checkpointing": false
10
+ }
wikipedia_ar_h100_env_fr_ar_77gb/model_epoch_03.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2c04ad704abccb3f01a61fbec3d09442fd9edc3b8d656addc98ba6b25e01ac28
3
+ size 5225864649
wikipedia_ar_h100_multicode/config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 32000,
3
+ "block_size": 1024,
4
+ "d_model": 1024,
5
+ "n_heads": 16,
6
+ "n_layers": 24,
7
+ "d_ff": 4096,
8
+ "dropout": 0.0,
9
+ "use_checkpointing": false
10
+ }
wikipedia_ar_h100_multicode/train_state.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8234246d0aa1dbe4ced74d07266bebd381919bd0aada72b8aafa11621f0846bc
3
+ size 5225862591
wikipedia_ar_h100_multicode_10x2000/config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 32000,
3
+ "block_size": 1024,
4
+ "d_model": 1024,
5
+ "n_heads": 16,
6
+ "n_layers": 24,
7
+ "d_ff": 4096,
8
+ "dropout": 0.0,
9
+ "use_checkpointing": false
10
+ }
wikipedia_ar_h100_multicode_10x2000/model_round_06.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7c64f58ee916183c6f04c84c1e9fbd8b92991375ade718fe438385cea36c090
3
+ size 5225864649