Medyassino commited on
Commit
0b5f348
·
verified ·
1 Parent(s): f86198c

Add files using upload-large-folder tool

Browse files
modeleAIRAG/qa_terminal.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from pathlib import Path
4
+
5
+ from test_rag_doc_interne_100m import load_model, encode_texts, search
6
+
7
+
8
+ DEFAULT_CORPUS = [
9
+ "ARTICLE 12 - Les congés payés sont acquis à raison de 2,5 jours par mois travaillé.",
10
+ "Procédure de validation des notes de frais : transmettre via le portail RH avant le 5 du mois.",
11
+ "La politique RGPD impose un délai de 72h pour notifier une violation de données.",
12
+ "Le télétravail est autorisé jusqu'à 3 jours par semaine sur accord du manager.",
13
+ "Toute facture fournisseur doit être validée par le responsable budget avant paiement.",
14
+ "Formation obligatoire sécurité incendie : 1 fois par an, traçabilité dans le SIRH.",
15
+ "L'accord d'entreprise du 15/03/2024 fixe le taux de prime annuelle à 8% du salaire brut.",
16
+ ]
17
+
18
+
19
+ def load_corpus(path):
20
+ """
21
+ Formats acceptés :
22
+ - .txt : un passage par ligne
23
+ - .jsonl: champs possibles: positive, text, content, passage
24
+ """
25
+
26
+ if path is None:
27
+ return DEFAULT_CORPUS
28
+
29
+ path = Path(path)
30
+
31
+ if not path.exists():
32
+ raise FileNotFoundError(f"Corpus introuvable : {path}")
33
+
34
+ corpus = []
35
+
36
+ if path.suffix.lower() == ".txt":
37
+ with open(path, "r", encoding="utf-8") as f:
38
+ for line in f:
39
+ line = line.strip()
40
+ if line:
41
+ corpus.append(line)
42
+
43
+ elif path.suffix.lower() == ".jsonl":
44
+ with open(path, "r", encoding="utf-8") as f:
45
+ for line in f:
46
+ if not line.strip():
47
+ continue
48
+
49
+ obj = json.loads(line)
50
+
51
+ text = (
52
+ obj.get("positive")
53
+ or obj.get("text")
54
+ or obj.get("content")
55
+ or obj.get("passage")
56
+ )
57
+
58
+ if text:
59
+ corpus.append(text.strip())
60
+
61
+ else:
62
+ raise ValueError("Format corpus non supporté. Utilise .txt ou .jsonl")
63
+
64
+ if not corpus:
65
+ raise ValueError("Corpus vide.")
66
+
67
+ return corpus
68
+
69
+
70
+ def print_results(results, threshold, margin):
71
+ top1 = results[0]
72
+ top2_score = results[1]["score"] if len(results) > 1 else 0.0
73
+ diff = top1["score"] - top2_score
74
+
75
+ print("\n================ RÉPONSE ================")
76
+
77
+ if top1["score"] < threshold:
78
+ print("Aucun passage suffisamment pertinent trouvé.")
79
+ print(f"Score Top 1 : {top1['score']:.4f}")
80
+ else:
81
+ if diff < margin:
82
+ print("Résultat possible, mais incertain : Top 1 et Top 2 sont proches.")
83
+ print(f"Écart Top1 - Top2 : {diff:.4f}")
84
+
85
+ print(f"\nMeilleur passage | score={top1['score']:.4f}")
86
+ print(top1["text"])
87
+
88
+ print("\n================ TOP RÉSULTATS ================")
89
+
90
+ for i, r in enumerate(results, start=1):
91
+ print(f"\nTop {i} | score={r['score']:.4f}")
92
+ print(r["text"])
93
+
94
+
95
+ def main():
96
+ parser = argparse.ArgumentParser()
97
+
98
+ parser.add_argument(
99
+ "--save_dir",
100
+ type=str,
101
+ default="./checkpoints_rag_doc_100m",
102
+ help="Dossier du checkpoint.",
103
+ )
104
+
105
+ parser.add_argument(
106
+ "--corpus",
107
+ type=str,
108
+ default=None,
109
+ help="Corpus .txt ou .jsonl. Si absent, utilise le corpus de test.",
110
+ )
111
+
112
+ parser.add_argument(
113
+ "--top_k",
114
+ type=int,
115
+ default=5,
116
+ help="Nombre de passages à retourner.",
117
+ )
118
+
119
+ parser.add_argument(
120
+ "--threshold",
121
+ type=float,
122
+ default=0.45,
123
+ help="Score minimal pour accepter une réponse.",
124
+ )
125
+
126
+ parser.add_argument(
127
+ "--margin",
128
+ type=float,
129
+ default=0.03,
130
+ help="Écart minimal conseillé entre Top 1 et Top 2.",
131
+ )
132
+
133
+ args = parser.parse_args()
134
+
135
+ model, tokenizer, cfg, device = load_model(args.save_dir)
136
+
137
+ print(f"[INFO] Modèle chargé depuis : {args.save_dir}")
138
+ print(f"[INFO] Device : {device}")
139
+
140
+ corpus = load_corpus(args.corpus)
141
+
142
+ print(f"[INFO] Corpus chargé : {len(corpus)} passages")
143
+ print("[INFO] Encodage du corpus...")
144
+
145
+ corpus_embeddings = encode_texts(
146
+ model=model,
147
+ tokenizer=tokenizer,
148
+ texts=corpus,
149
+ device=device,
150
+ max_seq_len=cfg.max_seq_len,
151
+ )
152
+
153
+ print("\n==============================================")
154
+ print(" QA TERMINAL RAG")
155
+ print(" Tape ta question puis Entrée.")
156
+ print(" Commandes : exit, quit, q")
157
+ print("==============================================")
158
+
159
+ while True:
160
+ query = input("\nQuestion > ").strip()
161
+
162
+ if query.lower() in {"exit", "quit", "q"}:
163
+ print("Fin du QA.")
164
+ break
165
+
166
+ if not query:
167
+ continue
168
+
169
+ results = search(
170
+ query=query,
171
+ corpus=corpus,
172
+ corpus_embeddings=corpus_embeddings,
173
+ model=model,
174
+ tokenizer=tokenizer,
175
+ cfg=cfg,
176
+ device=device,
177
+ top_k=args.top_k,
178
+ )
179
+
180
+ print_results(results, args.threshold, args.margin)
181
+
182
+
183
+ if __name__ == "__main__":
184
+ main()
modeleAIRAG/test_rag_doc_interne_100m.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from transformers import AutoTokenizer
9
+
10
+
11
+ # =============================================================================
12
+ # CONFIG identique au modèle entraîné
13
+ # =============================================================================
14
+ @dataclass
15
+ class Config:
16
+ vocab_size: int = 32000
17
+ hidden_size: int = 768
18
+ num_hidden_layers: int = 12
19
+ num_attention_heads: int = 12
20
+ intermediate_size: int = 3072
21
+ max_position_embeddings: int = 512
22
+ hidden_dropout_prob: float = 0.1
23
+ attention_probs_dropout_prob: float = 0.1
24
+ layer_norm_eps: float = 1e-12
25
+ embedding_dim: int = 768
26
+ use_layer_scale: bool = True
27
+ layer_scale_init: float = 1e-5
28
+ use_grad_checkpointing: bool = False
29
+
30
+ max_seq_len: int = 384
31
+ save_dir: str = "./checkpoints_rag_doc_100m"
32
+
33
+
34
+ # =============================================================================
35
+ # ARCHITECTURE
36
+ # =============================================================================
37
+ class TransformerEncoderBlock(nn.Module):
38
+ def __init__(self, cfg):
39
+ super().__init__()
40
+ self.num_heads = cfg.num_attention_heads
41
+ self.head_dim = cfg.hidden_size // cfg.num_attention_heads
42
+
43
+ self.ln1 = nn.LayerNorm(cfg.hidden_size, eps=cfg.layer_norm_eps)
44
+ self.qkv = nn.Linear(cfg.hidden_size, 3 * cfg.hidden_size)
45
+ self.proj = nn.Linear(cfg.hidden_size, cfg.hidden_size)
46
+
47
+ self.attn_drop_p = cfg.attention_probs_dropout_prob
48
+
49
+ self.ln2 = nn.LayerNorm(cfg.hidden_size, eps=cfg.layer_norm_eps)
50
+ self.mlp = nn.Sequential(
51
+ nn.Linear(cfg.hidden_size, cfg.intermediate_size),
52
+ nn.GELU(),
53
+ nn.Linear(cfg.intermediate_size, cfg.hidden_size),
54
+ nn.Dropout(cfg.hidden_dropout_prob),
55
+ )
56
+
57
+ self.resid_drop = nn.Dropout(cfg.hidden_dropout_prob)
58
+ self.use_ls = cfg.use_layer_scale
59
+
60
+ if cfg.use_layer_scale:
61
+ self.gamma1 = nn.Parameter(cfg.layer_scale_init * torch.ones(cfg.hidden_size))
62
+ self.gamma2 = nn.Parameter(cfg.layer_scale_init * torch.ones(cfg.hidden_size))
63
+
64
+ def forward(self, x, attn_mask):
65
+ B, T, C = x.shape
66
+
67
+ h = self.ln1(x)
68
+ qkv = self.qkv(h).view(B, T, 3, self.num_heads, self.head_dim)
69
+ q, k, v = qkv.permute(2, 0, 3, 1, 4)
70
+
71
+ kpm = attn_mask[:, None, None, :].bool()
72
+
73
+ a = F.scaled_dot_product_attention(
74
+ q,
75
+ k,
76
+ v,
77
+ attn_mask=kpm,
78
+ dropout_p=0.0,
79
+ is_causal=False,
80
+ )
81
+
82
+ a = a.transpose(1, 2).contiguous().view(B, T, C)
83
+ a = self.resid_drop(self.proj(a))
84
+
85
+ if self.use_ls:
86
+ a = a * self.gamma1
87
+
88
+ x = x + a
89
+
90
+ m = self.mlp(self.ln2(x))
91
+
92
+ if self.use_ls:
93
+ m = m * self.gamma2
94
+
95
+ return x + m
96
+
97
+
98
+ class TextEncoder(nn.Module):
99
+ def __init__(self, cfg):
100
+ super().__init__()
101
+ self.cfg = cfg
102
+
103
+ self.tok_emb = nn.Embedding(
104
+ cfg.vocab_size,
105
+ cfg.hidden_size,
106
+ padding_idx=0,
107
+ )
108
+
109
+ self.pos_emb = nn.Embedding(
110
+ cfg.max_position_embeddings,
111
+ cfg.hidden_size,
112
+ )
113
+
114
+ self.emb_ln = nn.LayerNorm(cfg.hidden_size, eps=cfg.layer_norm_eps)
115
+ self.emb_drop = nn.Dropout(cfg.hidden_dropout_prob)
116
+
117
+ self.blocks = nn.ModuleList(
118
+ [TransformerEncoderBlock(cfg) for _ in range(cfg.num_hidden_layers)]
119
+ )
120
+
121
+ self.ln_f = nn.LayerNorm(cfg.hidden_size, eps=cfg.layer_norm_eps)
122
+
123
+ self.proj_head = nn.Sequential(
124
+ nn.Linear(cfg.hidden_size, cfg.hidden_size),
125
+ nn.Tanh(),
126
+ nn.Linear(cfg.hidden_size, cfg.embedding_dim),
127
+ )
128
+
129
+ def encode_backbone(self, ids, mask):
130
+ B, T = ids.shape
131
+
132
+ pos = torch.arange(T, device=ids.device).unsqueeze(0).expand(B, T)
133
+
134
+ x = self.tok_emb(ids) + self.pos_emb(pos)
135
+ x = self.emb_drop(self.emb_ln(x))
136
+
137
+ for blk in self.blocks:
138
+ x = blk(x, mask)
139
+
140
+ return self.ln_f(x)
141
+
142
+ def forward(self, ids, mask):
143
+ x = self.encode_backbone(ids, mask)
144
+
145
+ m = mask.unsqueeze(-1).float()
146
+ pooled = (x * m).sum(dim=1) / m.sum(dim=1).clamp(min=1e-6)
147
+
148
+ emb = self.proj_head(pooled)
149
+ return F.normalize(emb, p=2, dim=-1)
150
+
151
+
152
+ # =============================================================================
153
+ # FONCTIONS TEST
154
+ # =============================================================================
155
+ @torch.no_grad()
156
+ def encode_texts(model, tokenizer, texts, device, max_seq_len=384, batch_size=32):
157
+ model.eval()
158
+ all_embeddings = []
159
+
160
+ for i in range(0, len(texts), batch_size):
161
+ batch = texts[i:i + batch_size]
162
+
163
+ enc = tokenizer(
164
+ batch,
165
+ padding=True,
166
+ truncation=True,
167
+ max_length=max_seq_len,
168
+ return_tensors="pt",
169
+ ).to(device)
170
+
171
+ with torch.autocast(
172
+ device_type="cuda",
173
+ dtype=torch.bfloat16,
174
+ enabled=torch.cuda.is_available(),
175
+ ):
176
+ emb = model(enc["input_ids"], enc["attention_mask"])
177
+
178
+ all_embeddings.append(emb.float().cpu())
179
+
180
+ return torch.cat(all_embeddings, dim=0)
181
+
182
+
183
+ def load_model(save_dir):
184
+ save_dir = Path(save_dir)
185
+ ckpt_path = save_dir / "model_best.pt"
186
+
187
+ if not ckpt_path.exists():
188
+ raise FileNotFoundError(f"Checkpoint introuvable : {ckpt_path}")
189
+
190
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
191
+
192
+ tokenizer = AutoTokenizer.from_pretrained(save_dir)
193
+
194
+ ckpt = torch.load(ckpt_path, map_location=device)
195
+
196
+ saved_cfg = ckpt.get("config", {})
197
+ cfg = Config(**{k: v for k, v in saved_cfg.items() if hasattr(Config, k)})
198
+ cfg.vocab_size = tokenizer.vocab_size
199
+ cfg.use_grad_checkpointing = False
200
+
201
+ model = TextEncoder(cfg).to(device)
202
+ model.load_state_dict(ckpt["model_state"], strict=False)
203
+ model.eval()
204
+
205
+ return model, tokenizer, cfg, device
206
+
207
+
208
+ def search(query, corpus, corpus_embeddings, model, tokenizer, cfg, device, top_k=3):
209
+ q_emb = encode_texts(
210
+ model,
211
+ tokenizer,
212
+ [query],
213
+ device,
214
+ max_seq_len=cfg.max_seq_len,
215
+ )
216
+
217
+ scores = q_emb @ corpus_embeddings.T
218
+ top = torch.topk(scores.squeeze(0), k=min(top_k, len(corpus)))
219
+
220
+ results = []
221
+
222
+ for score, idx in zip(top.values, top.indices):
223
+ results.append(
224
+ {
225
+ "score": float(score),
226
+ "text": corpus[int(idx)],
227
+ }
228
+ )
229
+
230
+ return results
231
+
232
+
233
+ # =============================================================================
234
+ # MAIN
235
+ # =============================================================================
236
+ def main():
237
+ parser = argparse.ArgumentParser()
238
+
239
+ parser.add_argument(
240
+ "--save_dir",
241
+ type=str,
242
+ default="./checkpoints_rag_doc_100m",
243
+ help="Dossier contenant model_best.pt et le tokenizer.",
244
+ )
245
+
246
+ parser.add_argument(
247
+ "--top_k",
248
+ type=int,
249
+ default=3,
250
+ help="Nombre de résultats à retourner.",
251
+ )
252
+
253
+ args = parser.parse_args()
254
+
255
+ model, tokenizer, cfg, device = load_model(args.save_dir)
256
+
257
+ print(f"[INFO] Modèle chargé depuis : {args.save_dir}")
258
+ print(f"[INFO] Device : {device}")
259
+
260
+ corpus = [
261
+ "ARTICLE 12 - Les congés payés sont acquis à raison de 2,5 jours par mois travaillé.",
262
+ "Procédure de validation des notes de frais : transmettre via le portail RH avant le 5 du mois.",
263
+ "La politique RGPD impose un délai de 72h pour notifier une violation de données.",
264
+ "Le télétravail est autorisé jusqu'à 3 jours par semaine sur accord du manager.",
265
+ "Toute facture fournisseur doit être validée par le responsable budget avant paiement.",
266
+ "Formation obligatoire sécurité incendie : 1 fois par an, traçabilité dans le SIRH.",
267
+ "L'accord d'entreprise du 15/03/2024 fixe le taux de prime annuelle à 8% du salaire brut.",
268
+ ]
269
+
270
+ print("[INFO] Encodage du corpus...")
271
+ corpus_embeddings = encode_texts(
272
+ model,
273
+ tokenizer,
274
+ corpus,
275
+ device,
276
+ max_seq_len=cfg.max_seq_len,
277
+ )
278
+
279
+ queries = [
280
+ "Combien de jours de congés je gagne par mois ?",
281
+ "Comment déclarer mes notes de frais ?",
282
+ "Quel est le quota de télétravail ?",
283
+ "Quel est le délai de notification RGPD ?",
284
+ "Quel est le taux de prime annuelle ?",
285
+ ]
286
+
287
+ print("\n================ TEST RAG DOC INTERNE ================")
288
+
289
+ for q in queries:
290
+ print(f"\nQuestion : {q}")
291
+
292
+ results = search(
293
+ query=q,
294
+ corpus=corpus,
295
+ corpus_embeddings=corpus_embeddings,
296
+ model=model,
297
+ tokenizer=tokenizer,
298
+ cfg=cfg,
299
+ device=device,
300
+ top_k=args.top_k,
301
+ )
302
+
303
+ for rank, r in enumerate(results, start=1):
304
+ print(f" Top {rank} | score={r['score']:.4f}")
305
+ print(f" -> {r['text']}")
306
+
307
+
308
+ if __name__ == "__main__":
309
+ main()
modeleAIRAG/train1.py ADDED
@@ -0,0 +1,780 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ==============================================================================
3
+ RAG/NLP encoder ~100M params - SPÉCIALISÉ IT / TECH / CYBERSÉCURITÉ
4
+ Hardware : NVIDIA H100 80GB
5
+ Epochs : 20
6
+ ==============================================================================
7
+
8
+ Architecture :
9
+ - Encoder Transformer ~100M params (12 couches, hidden=768, 12 têtes)
10
+ - Tokenizer : camembert-base (32k FR) + extension domaine via BPE-suffixe
11
+ - Tête projection -> embeddings 768d L2-normalisés
12
+ - Loss : Symmetric MNRL + hard negatives (TF-IDF mining)
13
+ - MLM pré-entraînement (2 epochs) sur corpus IT FR
14
+ - EMA, LayerScale, BF16, SDPA (Flash Attention 2 sur H100)
15
+ - Gradient checkpointing ACTIVÉ (modèle 100M, batch large -> VRAM)
16
+
17
+ Datasets (IT / cybersécurité / dev / cloud / data) :
18
+ - mMARCO-FR (passages techniques)
19
+ - PIAF + FQuAD2 filtrés "tech"
20
+ - CodeSearchNet (docstrings -> code, FR/EN)
21
+ - StackExchange dumps (askubuntu, serverfault, security, stackoverflow)
22
+ - CVE / NVD descriptions (cybersécurité)
23
+ - OWASP / RFC-like (RFC corpus, MITRE ATT&CK)
24
+ - HuggingFace : "lhoestq/demo1", "code_search_net"
25
+ - Custom JSONL local optionnel (./data/custom_it.jsonl)
26
+
27
+ Usage :
28
+ pip install torch>=2.2 transformers>=4.40 datasets>=2.18 accelerate \\
29
+ sentencepiece tqdm numpy scikit-learn faiss-cpu
30
+ python train_rag_it_100m.py
31
+ """
32
+
33
+ import os
34
+ import math
35
+ import json
36
+ import random
37
+ import re
38
+ from dataclasses import dataclass, asdict
39
+ from pathlib import Path
40
+ from typing import List, Dict, Tuple, Optional
41
+
42
+ import numpy as np
43
+ import torch
44
+ import torch.nn as nn
45
+ import torch.nn.functional as F
46
+ import torch.utils.checkpoint as gc
47
+ from torch.utils.data import Dataset, DataLoader
48
+ from torch.optim import AdamW
49
+
50
+ from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
51
+ from datasets import load_dataset, Dataset as HFDataset
52
+ from tqdm.auto import tqdm
53
+
54
+ # =============================================================================
55
+ # 1. CONFIG — 100M params, IT/Tech
56
+ # =============================================================================
57
+ @dataclass
58
+ class Config:
59
+ # --- Modèle ~100M ---
60
+ vocab_size: int = 32000
61
+ hidden_size: int = 768
62
+ num_hidden_layers: int = 12
63
+ num_attention_heads: int = 12
64
+ intermediate_size: int = 3072
65
+ max_position_embeddings: int = 384 # docs IT plus longs
66
+ hidden_dropout_prob: float = 0.1
67
+ attention_probs_dropout_prob: float = 0.1
68
+ layer_norm_eps: float = 1e-12
69
+ embedding_dim: int = 768
70
+ use_layer_scale: bool = True
71
+ layer_scale_init: float = 1e-5
72
+ use_grad_checkpointing: bool = True # OBLIGATOIRE à 100M
73
+
74
+ tokenizer_name: str = "camembert-base"
75
+
76
+ # --- MLM pré-entraînement ---
77
+ do_mlm_pretrain: bool = True
78
+ mlm_epochs: int = 2
79
+ mlm_prob: float = 0.15
80
+ mlm_lr: float = 1e-4
81
+
82
+ # --- Contrastif ---
83
+ epochs: int = 20
84
+ batch_size: int = 96 # 100M + GC -> batch raisonnable
85
+ grad_accum_steps: int = 4 # batch effectif = 384
86
+ max_seq_len: int = 192 # docs IT plus longs
87
+ lr: float = 2e-5 # plus bas pour 100M + 20 epochs
88
+ weight_decay: float = 0.01
89
+ warmup_ratio: float = 0.04
90
+ grad_clip: float = 1.0
91
+ temperature: float = 0.02
92
+ num_workers: int = 6
93
+ seed: int = 42
94
+
95
+ # --- Hard negatives ---
96
+ use_hard_negatives: bool = True
97
+ n_hard_neg: int = 1
98
+ hard_neg_pool_size: int = 100_000
99
+
100
+ # --- EMA ---
101
+ use_ema: bool = True
102
+ ema_decay: float = 0.9995 # plus agressif pour 20 epochs
103
+
104
+ # --- Données ---
105
+ max_samples_per_dataset: int = 300_000
106
+ eval_max_size: int = 5_000
107
+
108
+ # --- Optim H100 ---
109
+ use_bf16: bool = True
110
+ use_compile: bool = True
111
+ compile_mode: str = "default"
112
+ log_every: int = 50
113
+ save_dir: str = "./checkpoints_rag_it_100m"
114
+ save_every_epochs: int = 2 # checkpoint tous les 2 epochs
115
+
116
+ # --- Domaine IT : custom data path ---
117
+ custom_jsonl_path: str = "./data/custom_it.jsonl"
118
+
119
+
120
+ CFG = Config()
121
+ Path(CFG.save_dir).mkdir(parents=True, exist_ok=True)
122
+ random.seed(CFG.seed); np.random.seed(CFG.seed)
123
+ torch.manual_seed(CFG.seed); torch.cuda.manual_seed_all(CFG.seed)
124
+ torch.backends.cuda.matmul.allow_tf32 = True
125
+ torch.backends.cudnn.allow_tf32 = True
126
+ torch.set_float32_matmul_precision("high")
127
+
128
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
129
+ print(f"[INFO] Device : {device}")
130
+ if torch.cuda.is_available():
131
+ print(f"[INFO] GPU : {torch.cuda.get_device_name(0)}")
132
+ print(f"[INFO] VRAM : {torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB")
133
+
134
+
135
+ # =============================================================================
136
+ # 2. ARCHITECTURE — 100M avec Gradient Checkpointing
137
+ # =============================================================================
138
+ class TransformerEncoderBlock(nn.Module):
139
+ def __init__(self, cfg: Config):
140
+ super().__init__()
141
+ self.num_heads = cfg.num_attention_heads
142
+ self.head_dim = cfg.hidden_size // cfg.num_attention_heads
143
+ self.ln1 = nn.LayerNorm(cfg.hidden_size, eps=cfg.layer_norm_eps)
144
+ self.qkv = nn.Linear(cfg.hidden_size, 3 * cfg.hidden_size, bias=True)
145
+ self.proj = nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=True)
146
+ self.attn_drop_p = cfg.attention_probs_dropout_prob
147
+
148
+ self.ln2 = nn.LayerNorm(cfg.hidden_size, eps=cfg.layer_norm_eps)
149
+ self.mlp = nn.Sequential(
150
+ nn.Linear(cfg.hidden_size, cfg.intermediate_size),
151
+ nn.GELU(),
152
+ nn.Linear(cfg.intermediate_size, cfg.hidden_size),
153
+ nn.Dropout(cfg.hidden_dropout_prob),
154
+ )
155
+ self.resid_drop = nn.Dropout(cfg.hidden_dropout_prob)
156
+ self.use_ls = cfg.use_layer_scale
157
+ if cfg.use_layer_scale:
158
+ self.gamma1 = nn.Parameter(cfg.layer_scale_init * torch.ones(cfg.hidden_size))
159
+ self.gamma2 = nn.Parameter(cfg.layer_scale_init * torch.ones(cfg.hidden_size))
160
+
161
+ def forward(self, x, attn_mask):
162
+ B, T, C = x.shape
163
+ h = self.ln1(x)
164
+ qkv = self.qkv(h).view(B, T, 3, self.num_heads, self.head_dim)
165
+ q, k, v = qkv.permute(2, 0, 3, 1, 4)
166
+ key_padding_mask = attn_mask[:, None, None, :].bool()
167
+ attn_out = F.scaled_dot_product_attention(
168
+ q, k, v, attn_mask=key_padding_mask,
169
+ dropout_p=self.attn_drop_p if self.training else 0.0,
170
+ is_causal=False,
171
+ )
172
+ attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, C)
173
+ attn_out = self.resid_drop(self.proj(attn_out))
174
+ if self.use_ls: attn_out = attn_out * self.gamma1
175
+ x = x + attn_out
176
+ mlp_out = self.mlp(self.ln2(x))
177
+ if self.use_ls: mlp_out = mlp_out * self.gamma2
178
+ return x + mlp_out
179
+
180
+
181
+ class TextEncoder(nn.Module):
182
+ def __init__(self, cfg: Config):
183
+ super().__init__()
184
+ self.cfg = cfg
185
+ self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.hidden_size, padding_idx=0)
186
+ self.pos_emb = nn.Embedding(cfg.max_position_embeddings, cfg.hidden_size)
187
+ self.emb_ln = nn.LayerNorm(cfg.hidden_size, eps=cfg.layer_norm_eps)
188
+ self.emb_drop = nn.Dropout(cfg.hidden_dropout_prob)
189
+ self.blocks = nn.ModuleList([TransformerEncoderBlock(cfg)
190
+ for _ in range(cfg.num_hidden_layers)])
191
+ self.ln_f = nn.LayerNorm(cfg.hidden_size, eps=cfg.layer_norm_eps)
192
+ self.proj_head = nn.Sequential(
193
+ nn.Linear(cfg.hidden_size, cfg.hidden_size),
194
+ nn.Tanh(),
195
+ nn.Linear(cfg.hidden_size, cfg.embedding_dim),
196
+ )
197
+ self.mlm_head = nn.Linear(cfg.hidden_size, cfg.vocab_size, bias=False)
198
+ self.mlm_head.weight = self.tok_emb.weight # tied
199
+ self.use_gc = cfg.use_grad_checkpointing
200
+ self.apply(self._init_weights)
201
+
202
+ @staticmethod
203
+ def _init_weights(m):
204
+ if isinstance(m, nn.Linear):
205
+ nn.init.normal_(m.weight, std=0.02)
206
+ if m.bias is not None: nn.init.zeros_(m.bias)
207
+ elif isinstance(m, nn.Embedding):
208
+ nn.init.normal_(m.weight, std=0.02)
209
+ elif isinstance(m, nn.LayerNorm):
210
+ nn.init.ones_(m.weight); nn.init.zeros_(m.bias)
211
+
212
+ def encode_backbone(self, input_ids, attention_mask):
213
+ B, T = input_ids.shape
214
+ positions = torch.arange(T, device=input_ids.device).unsqueeze(0).expand(B, T)
215
+ x = self.tok_emb(input_ids) + self.pos_emb(positions)
216
+ x = self.emb_drop(self.emb_ln(x))
217
+ for blk in self.blocks:
218
+ if self.use_gc and self.training:
219
+ x = gc.checkpoint(blk, x, attention_mask, use_reentrant=False)
220
+ else:
221
+ x = blk(x, attention_mask)
222
+ return self.ln_f(x)
223
+
224
+ def forward(self, input_ids, attention_mask):
225
+ x = self.encode_backbone(input_ids, attention_mask)
226
+ mask = attention_mask.unsqueeze(-1).float()
227
+ pooled = (x * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-6)
228
+ emb = self.proj_head(pooled)
229
+ return F.normalize(emb, p=2, dim=-1)
230
+
231
+ def forward_mlm(self, input_ids, attention_mask):
232
+ x = self.encode_backbone(input_ids, attention_mask)
233
+ return self.mlm_head(x)
234
+
235
+
236
+ def count_parameters(model: nn.Module) -> int:
237
+ return sum(p.numel() for n, p in model.named_parameters()
238
+ if p.requires_grad and "mlm_head" not in n)
239
+
240
+
241
+ # =============================================================================
242
+ # 3. EMA
243
+ # =============================================================================
244
+ class EMA:
245
+ def __init__(self, model: nn.Module, decay: float = 0.999):
246
+ self.decay = decay
247
+ self.shadow = {n: p.detach().clone()
248
+ for n, p in model.named_parameters() if p.requires_grad}
249
+
250
+ @torch.no_grad()
251
+ def update(self, model):
252
+ for n, p in model.named_parameters():
253
+ if p.requires_grad and n in self.shadow:
254
+ self.shadow[n].mul_(self.decay).add_(p.detach(), alpha=1.0 - self.decay)
255
+
256
+ @torch.no_grad()
257
+ def apply_to(self, model):
258
+ backup = {}
259
+ for n, p in model.named_parameters():
260
+ if n in self.shadow:
261
+ backup[n] = p.detach().clone()
262
+ p.copy_(self.shadow[n])
263
+ return backup
264
+
265
+ @torch.no_grad()
266
+ def restore(self, model, backup):
267
+ for n, p in model.named_parameters():
268
+ if n in backup: p.copy_(backup[n])
269
+
270
+
271
+ # =============================================================================
272
+ # 4. CHARGEMENT DES DATASETS — DOMAINE IT / TECH
273
+ # =============================================================================
274
+ IT_KEYWORDS = re.compile(
275
+ r"\b(api|cloud|docker|kubernetes|server|réseau|network|sécurité|security|"
276
+ r"vuln|attaque|attack|cve|owasp|sql|nosql|python|java|javascript|linux|"
277
+ r"windows|firewall|chiffr|crypto|http|tcp|ip|dns|vpn|tls|ssl|iam|oauth|"
278
+ r"jwt|microservice|devops|ci/cd|pipeline|kernel|conteneur|container|"
279
+ r"machine learning|deep learning|llm|nlp|rag|gpu|cuda|pytorch|tensorflow|"
280
+ r"hadoop|spark|sql|bdd|database|données|data|backup|sauvegarde)\b",
281
+ re.IGNORECASE,
282
+ )
283
+
284
+ def is_it_text(t: str) -> bool:
285
+ return bool(IT_KEYWORDS.search(t)) if t else False
286
+
287
+
288
+ def load_it_pairs(cfg: Config) -> List[Dict[str, str]]:
289
+ print("\n[DATA] Chargement des datasets IT/Tech...")
290
+ pairs: List[Dict[str, str]] = []
291
+
292
+ # 4.1 mMARCO FR filtré IT
293
+ try:
294
+ ds = load_dataset("unicamp-dl/mmarco", "french", split="train")
295
+ ds = ds.select(range(min(500_000, len(ds))))
296
+ kept = 0
297
+ for ex in tqdm(ds, desc="mMARCO-FR (IT-filter)"):
298
+ q = (ex.get("query") or "").strip()
299
+ p = (ex.get("positive") or ex.get("passage") or "").strip()
300
+ if q and p and (is_it_text(q) or is_it_text(p)):
301
+ pairs.append({"anchor": q, "positive": p})
302
+ kept += 1
303
+ if kept >= cfg.max_samples_per_dataset: break
304
+ except Exception as e:
305
+ print(f" [warn] mMARCO FR : {e}")
306
+
307
+ # 4.2 PIAF filtré IT
308
+ try:
309
+ ds = load_dataset("etalab-ia/piaf", split="train")
310
+ for ex in tqdm(ds, desc="PIAF (IT-filter)"):
311
+ q = (ex.get("question") or "").strip()
312
+ ctx = (ex.get("context") or "").strip()
313
+ if q and ctx and (is_it_text(q) or is_it_text(ctx)):
314
+ pairs.append({"anchor": q, "positive": ctx})
315
+ except Exception as e:
316
+ print(f" [warn] PIAF : {e}")
317
+
318
+ # 4.3 CodeSearchNet — docstring -> code (Python, JS, Go, Java)
319
+ for lang in ["python", "javascript", "java", "go"]:
320
+ try:
321
+ ds = load_dataset("code_search_net", lang, split="train",
322
+ trust_remote_code=True)
323
+ ds = ds.select(range(min(80_000, len(ds))))
324
+ for ex in tqdm(ds, desc=f"CodeSearchNet-{lang}"):
325
+ doc = (ex.get("func_documentation_string") or "").strip()
326
+ code = (ex.get("func_code_string") or "").strip()
327
+ if doc and code and len(doc) > 20 and len(code) > 30:
328
+ pairs.append({"anchor": doc, "positive": code[:1500]})
329
+ except Exception as e:
330
+ print(f" [warn] CodeSearchNet-{lang} : {e}")
331
+
332
+ # 4.4 StackExchange — Q/A techniques (security, serverfault, askubuntu)
333
+ for sub in ["security", "serverfault", "askubuntu", "stackoverflow"]:
334
+ try:
335
+ ds = load_dataset("flax-sentence-embeddings/stackexchange_xml",
336
+ sub, split="train", trust_remote_code=True)
337
+ ds = ds.select(range(min(60_000, len(ds))))
338
+ for ex in tqdm(ds, desc=f"SE-{sub}"):
339
+ title = (ex.get("title_body") or ex.get("title") or "").strip()
340
+ ans = (ex.get("upvoted_answer") or ex.get("answer") or "").strip()
341
+ if title and ans and len(ans) > 50:
342
+ pairs.append({"anchor": title, "positive": ans[:1500]})
343
+ except Exception as e:
344
+ print(f" [warn] SE-{sub} : {e}")
345
+
346
+ # 4.5 CVE / NVD descriptions (cybersécurité)
347
+ try:
348
+ ds = load_dataset("Iker/CVE-Description-and-Severity", split="train")
349
+ for ex in tqdm(ds, desc="CVE-NVD"):
350
+ cve_id = (ex.get("cve") or "").strip()
351
+ desc = (ex.get("description") or "").strip()
352
+ if cve_id and desc and len(desc) > 30:
353
+ # paire (cve_id + question implicite, description)
354
+ pairs.append({
355
+ "anchor": f"Quelle est la vulnérabilité {cve_id} ?",
356
+ "positive": desc[:1500],
357
+ })
358
+ except Exception as e:
359
+ print(f" [warn] CVE : {e}")
360
+
361
+ # 4.6 XNLI FR (entailment) - filtré IT
362
+ try:
363
+ ds = load_dataset("xnli", "fr", split="train")
364
+ ds = ds.filter(lambda x: x["label"] == 0)
365
+ for ex in tqdm(ds, desc="XNLI-FR (IT)"):
366
+ a = (ex.get("premise") or "").strip()
367
+ b = (ex.get("hypothesis") or "").strip()
368
+ if a and b and (is_it_text(a) or is_it_text(b)):
369
+ pairs.append({"anchor": a, "positive": b})
370
+ except Exception as e:
371
+ print(f" [warn] XNLI : {e}")
372
+
373
+ # 4.7 Custom JSONL local (corpus interne SecureRAG / OWASP / RFC)
374
+ if Path(cfg.custom_jsonl_path).exists():
375
+ print(f" [+] Lecture custom : {cfg.custom_jsonl_path}")
376
+ with open(cfg.custom_jsonl_path, "r", encoding="utf-8") as f:
377
+ for line in tqdm(f, desc="custom_it.jsonl"):
378
+ try:
379
+ ex = json.loads(line)
380
+ a = (ex.get("anchor") or ex.get("query") or "").strip()
381
+ p = (ex.get("positive") or ex.get("passage") or "").strip()
382
+ if a and p:
383
+ pairs.append({"anchor": a, "positive": p})
384
+ except Exception:
385
+ continue
386
+ else:
387
+ print(f" [info] Pas de fichier custom à {cfg.custom_jsonl_path}")
388
+
389
+ # Dédoublonnage
390
+ seen = set(); uniq = []
391
+ for p in pairs:
392
+ k = (p["anchor"][:200], p["positive"][:200])
393
+ if k not in seen:
394
+ seen.add(k); uniq.append(p)
395
+
396
+ random.shuffle(uniq)
397
+ print(f"[DATA] Total paires IT uniques : {len(uniq):,}")
398
+ return uniq
399
+
400
+
401
+ # =============================================================================
402
+ # 5. HARD NEGATIVE MINING
403
+ # =============================================================================
404
+ def mine_hard_negatives(pairs: List[Dict[str, str]], cfg: Config) -> List[Dict[str, str]]:
405
+ print("\n[HN] Mining hard negatives via TF-IDF...")
406
+ try:
407
+ from sklearn.feature_extraction.text import TfidfVectorizer
408
+ from sklearn.metrics.pairwise import linear_kernel
409
+ except ImportError:
410
+ print(" [warn] sklearn manquant"); return pairs
411
+
412
+ n = len(pairs)
413
+ pool_size = min(cfg.hard_neg_pool_size, n)
414
+ pool_idx = np.random.choice(n, size=pool_size, replace=False)
415
+ pool_pass = [pairs[i]["positive"] for i in pool_idx]
416
+
417
+ vec = TfidfVectorizer(max_features=80_000, ngram_range=(1, 2),
418
+ lowercase=True, strip_accents="unicode")
419
+ X_pool = vec.fit_transform(pool_pass)
420
+ enriched = []
421
+ batch = 2000
422
+ anchors = [p["anchor"] for p in pairs]
423
+
424
+ for start in tqdm(range(0, n, batch), desc="HN-mine"):
425
+ end = min(start + batch, n)
426
+ Xq = vec.transform(anchors[start:end])
427
+ sims = linear_kernel(Xq, X_pool)
428
+ for i_loc, i_glob in enumerate(range(start, end)):
429
+ true_pos = pairs[i_glob]["positive"]
430
+ order = np.argsort(-sims[i_loc])
431
+ picked = None
432
+ for j in order[:30]:
433
+ if pool_pass[j] != true_pos:
434
+ picked = pool_pass[j]; break
435
+ if picked is None: picked = pool_pass[order[0]]
436
+ enriched.append({
437
+ "anchor": pairs[i_glob]["anchor"],
438
+ "positive": pairs[i_glob]["positive"],
439
+ "hard_neg": picked,
440
+ })
441
+ return enriched
442
+
443
+
444
+ # =============================================================================
445
+ # 6. DATASET / COLLATE
446
+ # =============================================================================
447
+ class PairDataset(Dataset):
448
+ def __init__(self, items, with_hn): self.items, self.with_hn = items, with_hn
449
+ def __len__(self): return len(self.items)
450
+ def __getitem__(self, i):
451
+ ex = self.items[i]
452
+ if self.with_hn:
453
+ return ex["anchor"], ex["positive"], ex.get("hard_neg", ex["positive"])
454
+ return ex["anchor"], ex["positive"]
455
+
456
+
457
+ def make_collate_fn(tokenizer, max_len, with_hn):
458
+ def collate(batch):
459
+ a_list = [b[0] for b in batch]
460
+ p_list = [b[1] for b in batch]
461
+ a = tokenizer(a_list, padding=True, truncation=True,
462
+ max_length=max_len, return_tensors="pt")
463
+ p = tokenizer(p_list, padding=True, truncation=True,
464
+ max_length=max_len, return_tensors="pt")
465
+ if with_hn:
466
+ n_list = [b[2] for b in batch]
467
+ n = tokenizer(n_list, padding=True, truncation=True,
468
+ max_length=max_len, return_tensors="pt")
469
+ return a, p, n
470
+ return a, p
471
+ return collate
472
+
473
+
474
+ # =============================================================================
475
+ # 7. LOSS
476
+ # =============================================================================
477
+ def symmetric_mnrl_loss(emb_a, emb_p, emb_n=None, temperature=0.02):
478
+ N = emb_a.size(0)
479
+ labels = torch.arange(N, device=emb_a.device)
480
+ if emb_n is not None:
481
+ targets = torch.cat([emb_p, emb_n], dim=0)
482
+ sim_a = emb_a @ targets.t() / temperature
483
+ loss_a2p = F.cross_entropy(sim_a, labels)
484
+ else:
485
+ sim_a = emb_a @ emb_p.t() / temperature
486
+ loss_a2p = F.cross_entropy(sim_a, labels)
487
+ sim_p = emb_p @ emb_a.t() / temperature
488
+ loss_p2a = F.cross_entropy(sim_p, labels)
489
+ loss = 0.5 * (loss_a2p + loss_p2a)
490
+ with torch.no_grad():
491
+ acc = (sim_a[:, :N].argmax(dim=1) == labels).float().mean().item()
492
+ return loss, acc
493
+
494
+
495
+ # =============================================================================
496
+ # 8. MLM PRÉ-ENTRAÎNEMENT
497
+ # =============================================================================
498
+ def mlm_pretrain(model, tokenizer, texts, cfg: Config):
499
+ print(f"\n[MLM] Pré-entraînement sur {len(texts):,} textes IT...")
500
+
501
+ class MLMDataset(Dataset):
502
+ def __init__(self, t): self.t = t
503
+ def __len__(self): return len(self.t)
504
+ def __getitem__(self, i): return self.t[i]
505
+
506
+ def mlm_collate(batch):
507
+ enc = tokenizer(batch, padding=True, truncation=True,
508
+ max_length=cfg.max_seq_len, return_tensors="pt")
509
+ ids = enc["input_ids"].clone()
510
+ labels = ids.clone()
511
+ special = torch.zeros_like(ids, dtype=torch.bool)
512
+ for sid in tokenizer.all_special_ids:
513
+ special |= (ids == sid)
514
+ prob = torch.full(ids.shape, cfg.mlm_prob)
515
+ prob.masked_fill_(special, 0.0)
516
+ masked = torch.bernoulli(prob).bool()
517
+ labels[~masked] = -100
518
+ rand = torch.rand(ids.shape)
519
+ ids[masked & (rand < 0.8)] = tokenizer.mask_token_id
520
+ replace_rand = masked & (rand >= 0.8) & (rand < 0.9)
521
+ rand_tokens = torch.randint(0, tokenizer.vocab_size, ids.shape)
522
+ ids[replace_rand] = rand_tokens[replace_rand]
523
+ return ids, enc["attention_mask"], labels
524
+
525
+ loader = DataLoader(MLMDataset(texts), batch_size=cfg.batch_size,
526
+ shuffle=True, num_workers=cfg.num_workers,
527
+ collate_fn=mlm_collate, pin_memory=True,
528
+ drop_last=True, persistent_workers=True)
529
+
530
+ optim = AdamW(model.parameters(), lr=cfg.mlm_lr, weight_decay=0.01,
531
+ betas=(0.9, 0.98), eps=1e-6)
532
+ total_steps = len(loader) * cfg.mlm_epochs
533
+ sched = get_cosine_schedule_with_warmup(optim, int(total_steps * 0.04), total_steps)
534
+
535
+ model.train()
536
+ autocast_dtype = torch.bfloat16 if cfg.use_bf16 else torch.float16
537
+ for ep in range(cfg.mlm_epochs):
538
+ running = 0.0
539
+ pbar = tqdm(loader, desc=f"MLM ep{ep+1}/{cfg.mlm_epochs}")
540
+ for step, (ids, mask, labels) in enumerate(pbar, 1):
541
+ ids = ids.to(device, non_blocking=True)
542
+ mask = mask.to(device, non_blocking=True)
543
+ labels = labels.to(device, non_blocking=True)
544
+ optim.zero_grad(set_to_none=True)
545
+ with torch.autocast(device_type="cuda", dtype=autocast_dtype):
546
+ logits = model.forward_mlm(ids, mask)
547
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)),
548
+ labels.view(-1), ignore_index=-100)
549
+ loss.backward()
550
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
551
+ optim.step(); sched.step()
552
+ running += loss.item()
553
+ if step % 50 == 0:
554
+ pbar.set_postfix(loss=f"{running/step:.4f}",
555
+ ppl=f"{math.exp(min(20, running/step)):.1f}")
556
+ print("[MLM] Terminé.\n")
557
+
558
+
559
+ # =============================================================================
560
+ # 9. EVAL
561
+ # =============================================================================
562
+ @torch.no_grad()
563
+ def evaluate_retrieval(model, tokenizer, eval_pairs, cfg: Config):
564
+ model.eval()
565
+ autocast_dtype = torch.bfloat16 if cfg.use_bf16 else torch.float16
566
+ queries = [e["anchor"] for e in eval_pairs]
567
+ passages = [e["positive"] for e in eval_pairs]
568
+
569
+ def encode(texts):
570
+ embs = []
571
+ for i in range(0, len(texts), 64):
572
+ chunk = texts[i:i+64]
573
+ enc = tokenizer(chunk, padding=True, truncation=True,
574
+ max_length=cfg.max_seq_len, return_tensors="pt").to(device)
575
+ with torch.autocast(device_type="cuda", dtype=autocast_dtype):
576
+ e = model(enc["input_ids"], enc["attention_mask"])
577
+ embs.append(e.float())
578
+ return torch.cat(embs, dim=0)
579
+
580
+ Q = encode(queries); P = encode(passages)
581
+ sims = Q @ P.t()
582
+ N = sims.size(0)
583
+ targets = torch.arange(N, device=sims.device)
584
+ ranks = sims.argsort(dim=1, descending=True)
585
+ pos_in_rank = (ranks == targets.unsqueeze(1)).nonzero()[:, 1]
586
+ return {
587
+ "R@1": (pos_in_rank == 0).float().mean().item(),
588
+ "R@5": (pos_in_rank < 5).float().mean().item(),
589
+ "R@10": (pos_in_rank < 10).float().mean().item(),
590
+ "MRR": (1.0 / (pos_in_rank.float() + 1)).mean().item(),
591
+ }
592
+
593
+
594
+ # =============================================================================
595
+ # 10. TRAIN
596
+ # =============================================================================
597
+ def train():
598
+ tokenizer = AutoTokenizer.from_pretrained(CFG.tokenizer_name)
599
+ CFG.vocab_size = tokenizer.vocab_size
600
+ print(f"[TOK ] vocab_size = {CFG.vocab_size}")
601
+
602
+ items_all = load_it_pairs(CFG)
603
+ n_eval = min(CFG.eval_max_size, max(2000, int(len(items_all) * 0.005)))
604
+ eval_items = items_all[:n_eval]
605
+ train_items = items_all[n_eval:]
606
+ print(f"[DATA] train={len(train_items):,} eval={len(eval_items):,}")
607
+
608
+ if CFG.use_hard_negatives:
609
+ train_items = mine_hard_negatives(train_items, CFG)
610
+
611
+ collate = make_collate_fn(tokenizer, CFG.max_seq_len, CFG.use_hard_negatives)
612
+ train_loader = DataLoader(
613
+ PairDataset(train_items, CFG.use_hard_negatives),
614
+ batch_size=CFG.batch_size, shuffle=True,
615
+ num_workers=CFG.num_workers, collate_fn=collate,
616
+ pin_memory=True, drop_last=True, persistent_workers=True,
617
+ )
618
+
619
+ model = TextEncoder(CFG).to(device)
620
+ n_params = count_parameters(model)
621
+ print(f"[MODEL] Paramètres entraînables : {n_params/1e6:.2f} M")
622
+
623
+ if CFG.do_mlm_pretrain:
624
+ mlm_texts = []
625
+ for it in train_items[:400_000]:
626
+ mlm_texts.append(it["anchor"]); mlm_texts.append(it["positive"])
627
+ random.shuffle(mlm_texts)
628
+ mlm_pretrain(model, tokenizer, mlm_texts, CFG)
629
+
630
+ if CFG.use_compile and hasattr(torch, "compile"):
631
+ print(f"[MODEL] torch.compile(mode={CFG.compile_mode!r})")
632
+ model = torch.compile(model, mode=CFG.compile_mode)
633
+
634
+ raw_model = model._orig_mod if hasattr(model, "_orig_mod") else model
635
+ ema = EMA(raw_model, decay=CFG.ema_decay) if CFG.use_ema else None
636
+
637
+ no_decay = ["bias", "LayerNorm.weight", "ln1", "ln2", "ln_f", "emb_ln",
638
+ "gamma1", "gamma2"]
639
+ grouped = [
640
+ {"params": [p for n, p in model.named_parameters()
641
+ if "mlm_head" not in n and not any(nd in n for nd in no_decay)],
642
+ "weight_decay": CFG.weight_decay},
643
+ {"params": [p for n, p in model.named_parameters()
644
+ if "mlm_head" not in n and any(nd in n for nd in no_decay)],
645
+ "weight_decay": 0.0},
646
+ ]
647
+ optimizer = AdamW(grouped, lr=CFG.lr, betas=(0.9, 0.98), eps=1e-6)
648
+ steps_per_epoch = len(train_loader) // CFG.grad_accum_steps
649
+ total_steps = steps_per_epoch * CFG.epochs
650
+ warmup_steps = int(total_steps * CFG.warmup_ratio)
651
+ scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)
652
+ print(f"[OPTIM] total_steps={total_steps} warmup={warmup_steps}")
653
+
654
+ autocast_dtype = torch.bfloat16 if CFG.use_bf16 else torch.float16
655
+ best_mrr = 0.0
656
+ history = []
657
+
658
+ for epoch in range(1, CFG.epochs + 1):
659
+ model.train()
660
+ running_loss = running_acc = 0.0
661
+ n_seen = 0
662
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{CFG.epochs}")
663
+ optimizer.zero_grad(set_to_none=True)
664
+
665
+ for step, batch in enumerate(pbar, start=1):
666
+ if CFG.use_hard_negatives:
667
+ a, p, hn = batch
668
+ hn = {k: v.to(device, non_blocking=True) for k, v in hn.items()}
669
+ else:
670
+ a, p = batch; hn = None
671
+ a = {k: v.to(device, non_blocking=True) for k, v in a.items()}
672
+ p = {k: v.to(device, non_blocking=True) for k, v in p.items()}
673
+
674
+ with torch.autocast(device_type="cuda", dtype=autocast_dtype):
675
+ emb_a = model(a["input_ids"], a["attention_mask"])
676
+ emb_p = model(p["input_ids"], p["attention_mask"])
677
+ emb_n = (model(hn["input_ids"], hn["attention_mask"])
678
+ if hn is not None else None)
679
+ loss, acc = symmetric_mnrl_loss(emb_a, emb_p, emb_n, CFG.temperature)
680
+ loss = loss / CFG.grad_accum_steps
681
+
682
+ loss.backward()
683
+ if step % CFG.grad_accum_steps == 0:
684
+ torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.grad_clip)
685
+ optimizer.step(); scheduler.step()
686
+ optimizer.zero_grad(set_to_none=True)
687
+ if ema is not None: ema.update(raw_model)
688
+
689
+ running_loss += loss.item() * CFG.grad_accum_steps
690
+ running_acc += acc; n_seen += 1
691
+ if step % CFG.log_every == 0:
692
+ pbar.set_postfix(loss=f"{running_loss/n_seen:.4f}",
693
+ acc=f"{running_acc/n_seen:.3f}",
694
+ lr=f"{scheduler.get_last_lr()[0]:.2e}")
695
+
696
+ # Eval
697
+ backup = ema.apply_to(raw_model) if ema is not None else None
698
+ metrics = evaluate_retrieval(model, tokenizer, eval_items, CFG)
699
+ if backup is not None: ema.restore(raw_model, backup)
700
+ print(f"\n[EVAL] epoch {epoch} : R@1={metrics['R@1']:.3f} "
701
+ f"R@5={metrics['R@5']:.3f} R@10={metrics['R@10']:.3f} "
702
+ f"MRR={metrics['MRR']:.3f}")
703
+ history.append({"epoch": epoch, **metrics,
704
+ "train_loss": running_loss / max(1, n_seen)})
705
+
706
+ # Sauvegarde
707
+ is_best = metrics["MRR"] > best_mrr
708
+ if is_best: best_mrr = metrics["MRR"]
709
+ if ema is not None: backup = ema.apply_to(raw_model)
710
+ state = {k: v for k, v in raw_model.state_dict().items() if "mlm_head" not in k}
711
+
712
+ if epoch % CFG.save_every_epochs == 0 or is_best or epoch == CFG.epochs:
713
+ torch.save({"epoch": epoch, "model_state": state,
714
+ "config": asdict(CFG), "metrics": metrics},
715
+ Path(CFG.save_dir) / f"model_epoch{epoch}.pt")
716
+ if is_best:
717
+ torch.save({"epoch": epoch, "model_state": state,
718
+ "config": asdict(CFG), "metrics": metrics},
719
+ Path(CFG.save_dir) / "model_best.pt")
720
+ if ema is not None: ema.restore(raw_model, backup)
721
+ print(f"[SAVE] epoch {epoch} best={'oui' if is_best else 'non'}")
722
+
723
+ with open(Path(CFG.save_dir) / "history.json", "w", encoding="utf-8") as f:
724
+ json.dump(history, f, ensure_ascii=False, indent=2)
725
+ tokenizer.save_pretrained(CFG.save_dir)
726
+ print(f"\n[OK] Best MRR = {best_mrr:.3f} -> {CFG.save_dir}/model_best.pt")
727
+
728
+
729
+ # =============================================================================
730
+ # 11. DÉMO
731
+ # =============================================================================
732
+ @torch.no_grad()
733
+ def demo():
734
+ tokenizer = AutoTokenizer.from_pretrained(CFG.save_dir)
735
+ ckpt = torch.load(Path(CFG.save_dir) / "model_best.pt", map_location=device)
736
+ saved_cfg = ckpt["config"]
737
+ cfg2 = Config(**{k: v for k, v in saved_cfg.items() if hasattr(Config, k)})
738
+ cfg2.vocab_size = tokenizer.vocab_size
739
+ model = TextEncoder(cfg2).to(device).eval()
740
+ model.load_state_dict(ckpt["model_state"], strict=False)
741
+
742
+ corpus = [
743
+ "OWASP LLM Top 10 liste les vulnérabilités des modèles de langage.",
744
+ "La prompt injection consiste à manipuler les instructions d'un LLM.",
745
+ "Le H100 NVIDIA est un GPU IA avec 80 Go HBM3.",
746
+ "Docker permet de conteneuriser des applications.",
747
+ "Kubernetes orchestre des conteneurs à grande échelle.",
748
+ "Le chiffrement AES-256 est utilisé pour protéger les données.",
749
+ "Une attaque SQL injection exploite des requêtes mal échappées.",
750
+ "Le RAG combine retriever vectoriel et LLM générateur.",
751
+ ]
752
+ queries = [
753
+ "Quelles sont les vulnérabilités des LLM ?",
754
+ "Comment orchestrer des conteneurs ?",
755
+ "Quel GPU pour entraîner une IA ?",
756
+ ]
757
+ enc_corpus = tokenizer(corpus, padding=True, truncation=True,
758
+ max_length=cfg2.max_seq_len, return_tensors="pt").to(device)
759
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
760
+ c_emb = model(enc_corpus["input_ids"], enc_corpus["attention_mask"])
761
+
762
+ print("\n[DEMO IT-100M]")
763
+ for q in queries:
764
+ eq = tokenizer([q], padding=True, truncation=True,
765
+ max_length=cfg2.max_seq_len, return_tensors="pt").to(device)
766
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
767
+ q_emb = model(eq["input_ids"], eq["attention_mask"])
768
+ sims = (q_emb @ c_emb.t()).squeeze(0)
769
+ top = sims.topk(3)
770
+ print(f"\nQ : {q}")
771
+ for s, i in zip(top.values, top.indices):
772
+ print(f" ({s.item():.3f}) -> {corpus[i.item()]}")
773
+
774
+
775
+ if __name__ == "__main__":
776
+ train()
777
+ try:
778
+ demo()
779
+ except Exception as e:
780
+ print(f"[demo] {e}")
modeleAIRAG/train2.py ADDED
@@ -0,0 +1,921 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ==============================================================================
3
+ RAG/NLP encoder ~100M params - SPÉCIALISÉ DOCUMENTAIRE INTERNE ENTREPRISE
4
+ (RH, juridique, procédures, comptabilité, qualité, conformité, formation)
5
+ Hardware : NVIDIA H100 80GB
6
+ Epochs : 20
7
+ ==============================================================================
8
+
9
+ Spécificités vs version IT :
10
+ - max_seq_len = 384 (documents internes longs : procédures, contrats)
11
+ - Filtres lexicaux orientés "entreprise / documentation"
12
+ - Datasets : Common Crawl FR (filtré), Wikipédia FR (catégories doc),
13
+ FQuAD/PIAF (questions admin/juridique), MultiLegalPile-FR,
14
+ corpus interne JSONL (priorité absolue)
15
+ - Augmentation : "title -> contenu" et "section -> paragraphe"
16
+ - Loss : MNRL symétrique + 2 hard negatives par paire
17
+ - Pré-entraînement MLM sur corpus interne en priorité
18
+ - EMA decay 0.9995, LayerScale, BF16, SDPA, Gradient Checkpointing
19
+ - 20 epochs, batch effectif 384
20
+
21
+ Architecture identique 100M params (12L, 768d, 12H, FFN=3072).
22
+
23
+ Usage :
24
+ pip install torch>=2.2 transformers>=4.40 datasets>=2.18 accelerate \\
25
+ sentencepiece tqdm numpy scikit-learn faiss-cpu beautifulsoup4
26
+ python train_rag_doc_interne_100m.py
27
+
28
+ Préparation du corpus interne :
29
+ Place tes documents dans ./data/corpus_interne/ (PDF/DOCX/TXT/MD)
30
+ Ou directement un JSONL ./data/custom_doc.jsonl avec {"anchor","positive"}
31
+ """
32
+ import os
33
+ import math
34
+ import json
35
+ import random
36
+ import re
37
+ import glob
38
+ from dataclasses import dataclass, asdict
39
+ from pathlib import Path
40
+ from typing import List, Dict, Tuple, Optional
41
+
42
+ import numpy as np
43
+ import torch
44
+ import torch.nn as nn
45
+ import torch.nn.functional as F
46
+ import torch.utils.checkpoint as gc
47
+ from torch.utils.data import Dataset, DataLoader
48
+ from torch.optim import AdamW
49
+
50
+ from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
51
+ from datasets import load_dataset, Dataset as HFDataset
52
+ from tqdm.auto import tqdm
53
+
54
+ # =============================================================================
55
+ # 1. CONFIG — 100M, Documentaire interne
56
+ # =============================================================================
57
+ @dataclass
58
+ class Config:
59
+ # --- Modèle ~100M ---
60
+ vocab_size: int = 32000
61
+ hidden_size: int = 768
62
+ num_hidden_layers: int = 12
63
+ num_attention_heads: int = 12
64
+ intermediate_size: int = 3072
65
+ max_position_embeddings: int = 512 # docs longs
66
+ hidden_dropout_prob: float = 0.1
67
+ attention_probs_dropout_prob: float = 0.1
68
+ layer_norm_eps: float = 1e-12
69
+ embedding_dim: int = 768
70
+ use_layer_scale: bool = True
71
+ layer_scale_init: float = 1e-5
72
+ use_grad_checkpointing: bool = True
73
+
74
+ tokenizer_name: str = "camembert-base"
75
+
76
+ # --- MLM (priorité corpus interne) ---
77
+ do_mlm_pretrain: bool = True
78
+ mlm_epochs: int = 3 # +1 vs IT, doc interne plus rare
79
+ mlm_prob: float = 0.15
80
+ mlm_lr: float = 1e-4
81
+
82
+ # --- Contrastif ---
83
+ epochs: int = 20
84
+ batch_size: int = 64 # seq_len 384 -> batch + petit
85
+ grad_accum_steps: int = 6 # effectif = 384
86
+ max_seq_len: int = 384
87
+ lr: float = 2e-5
88
+ weight_decay: float = 0.01
89
+ warmup_ratio: float = 0.05
90
+ grad_clip: float = 1.0
91
+ temperature: float = 0.02
92
+ num_workers: int = 6
93
+ seed: int = 42
94
+
95
+ # --- Hard negatives (2 par paire pour doc interne) ---
96
+ use_hard_negatives: bool = True
97
+ n_hard_neg: int = 2 # plus fort
98
+ hard_neg_pool_size: int = 100_000
99
+
100
+ use_ema: bool = True
101
+ ema_decay: float = 0.9995
102
+
103
+ max_samples_per_dataset: int = 250_000
104
+ eval_max_size: int = 5_000
105
+
106
+ use_bf16: bool = True
107
+ use_compile: bool = True
108
+ compile_mode: str = "default"
109
+ log_every: int = 50
110
+ save_dir: str = "./checkpoints_rag_doc_100m"
111
+ save_every_epochs: int = 2
112
+
113
+ # --- Corpus interne ---
114
+ custom_jsonl_path: str = "./data/custom_doc.jsonl"
115
+ custom_corpus_dir: str = "./data/corpus_interne" # PDF/DOCX/TXT/MD
116
+ internal_oversample: int = 5 # x5 pour booster apprentissage interne
117
+
118
+
119
+ CFG = Config()
120
+ Path(CFG.save_dir).mkdir(parents=True, exist_ok=True)
121
+ random.seed(CFG.seed); np.random.seed(CFG.seed)
122
+ torch.manual_seed(CFG.seed); torch.cuda.manual_seed_all(CFG.seed)
123
+ torch.backends.cuda.matmul.allow_tf32 = True
124
+ torch.backends.cudnn.allow_tf32 = True
125
+ torch.set_float32_matmul_precision("high")
126
+
127
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
128
+ print(f"[INFO] Device : {device}")
129
+ if torch.cuda.is_available():
130
+ print(f"[INFO] GPU : {torch.cuda.get_device_name(0)}")
131
+ print(f"[INFO] VRAM : {torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB")
132
+
133
+
134
+ # =============================================================================
135
+ # 2. ARCHITECTURE
136
+ # =============================================================================
137
+ class TransformerEncoderBlock(nn.Module):
138
+ def __init__(self, cfg):
139
+ super().__init__()
140
+ self.num_heads = cfg.num_attention_heads
141
+ self.head_dim = cfg.hidden_size // cfg.num_attention_heads
142
+ self.ln1 = nn.LayerNorm(cfg.hidden_size, eps=cfg.layer_norm_eps)
143
+ self.qkv = nn.Linear(cfg.hidden_size, 3 * cfg.hidden_size, bias=True)
144
+ self.proj = nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=True)
145
+ self.attn_drop_p = cfg.attention_probs_dropout_prob
146
+ self.ln2 = nn.LayerNorm(cfg.hidden_size, eps=cfg.layer_norm_eps)
147
+ self.mlp = nn.Sequential(
148
+ nn.Linear(cfg.hidden_size, cfg.intermediate_size),
149
+ nn.GELU(),
150
+ nn.Linear(cfg.intermediate_size, cfg.hidden_size),
151
+ nn.Dropout(cfg.hidden_dropout_prob),
152
+ )
153
+ self.resid_drop = nn.Dropout(cfg.hidden_dropout_prob)
154
+ self.use_ls = cfg.use_layer_scale
155
+ if cfg.use_layer_scale:
156
+ self.gamma1 = nn.Parameter(cfg.layer_scale_init * torch.ones(cfg.hidden_size))
157
+ self.gamma2 = nn.Parameter(cfg.layer_scale_init * torch.ones(cfg.hidden_size))
158
+
159
+ def forward(self, x, attn_mask):
160
+ B, T, C = x.shape
161
+ h = self.ln1(x)
162
+ qkv = self.qkv(h).view(B, T, 3, self.num_heads, self.head_dim)
163
+ q, k, v = qkv.permute(2, 0, 3, 1, 4)
164
+ kpm = attn_mask[:, None, None, :].bool()
165
+ a = F.scaled_dot_product_attention(
166
+ q, k, v, attn_mask=kpm,
167
+ dropout_p=self.attn_drop_p if self.training else 0.0,
168
+ is_causal=False)
169
+ a = a.transpose(1, 2).contiguous().view(B, T, C)
170
+ a = self.resid_drop(self.proj(a))
171
+ if self.use_ls: a = a * self.gamma1
172
+ x = x + a
173
+ m = self.mlp(self.ln2(x))
174
+ if self.use_ls: m = m * self.gamma2
175
+ return x + m
176
+
177
+
178
+ class TextEncoder(nn.Module):
179
+ def __init__(self, cfg):
180
+ super().__init__()
181
+ self.cfg = cfg
182
+ self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.hidden_size, padding_idx=0)
183
+ self.pos_emb = nn.Embedding(cfg.max_position_embeddings, cfg.hidden_size)
184
+ self.emb_ln = nn.LayerNorm(cfg.hidden_size, eps=cfg.layer_norm_eps)
185
+ self.emb_drop = nn.Dropout(cfg.hidden_dropout_prob)
186
+ self.blocks = nn.ModuleList([TransformerEncoderBlock(cfg)
187
+ for _ in range(cfg.num_hidden_layers)])
188
+ self.ln_f = nn.LayerNorm(cfg.hidden_size, eps=cfg.layer_norm_eps)
189
+ self.proj_head = nn.Sequential(
190
+ nn.Linear(cfg.hidden_size, cfg.hidden_size),
191
+ nn.Tanh(),
192
+ nn.Linear(cfg.hidden_size, cfg.embedding_dim),
193
+ )
194
+ self.mlm_head = nn.Linear(cfg.hidden_size, cfg.vocab_size, bias=False)
195
+ self.mlm_head.weight = self.tok_emb.weight
196
+ self.use_gc = cfg.use_grad_checkpointing
197
+ self.apply(self._init_weights)
198
+
199
+ @staticmethod
200
+ def _init_weights(m):
201
+ if isinstance(m, nn.Linear):
202
+ nn.init.normal_(m.weight, std=0.02)
203
+ if m.bias is not None: nn.init.zeros_(m.bias)
204
+ elif isinstance(m, nn.Embedding):
205
+ nn.init.normal_(m.weight, std=0.02)
206
+ elif isinstance(m, nn.LayerNorm):
207
+ nn.init.ones_(m.weight); nn.init.zeros_(m.bias)
208
+
209
+ def encode_backbone(self, ids, mask):
210
+ B, T = ids.shape
211
+ pos = torch.arange(T, device=ids.device).unsqueeze(0).expand(B, T)
212
+ x = self.tok_emb(ids) + self.pos_emb(pos)
213
+ x = self.emb_drop(self.emb_ln(x))
214
+ for blk in self.blocks:
215
+ if self.use_gc and self.training:
216
+ x = gc.checkpoint(blk, x, mask, use_reentrant=False)
217
+ else:
218
+ x = blk(x, mask)
219
+ return self.ln_f(x)
220
+
221
+ def forward(self, ids, mask):
222
+ x = self.encode_backbone(ids, mask)
223
+ m = mask.unsqueeze(-1).float()
224
+ pooled = (x * m).sum(dim=1) / m.sum(dim=1).clamp(min=1e-6)
225
+ emb = self.proj_head(pooled)
226
+ return F.normalize(emb, p=2, dim=-1)
227
+
228
+ def forward_mlm(self, ids, mask):
229
+ return self.mlm_head(self.encode_backbone(ids, mask))
230
+
231
+
232
+ def count_parameters(model):
233
+ return sum(p.numel() for n, p in model.named_parameters()
234
+ if p.requires_grad and "mlm_head" not in n)
235
+
236
+
237
+ # =============================================================================
238
+ # 3. EMA
239
+ # =============================================================================
240
+ class EMA:
241
+ def __init__(self, model, decay=0.999):
242
+ self.decay = decay
243
+ self.shadow = {n: p.detach().clone()
244
+ for n, p in model.named_parameters() if p.requires_grad}
245
+
246
+ @torch.no_grad()
247
+ def update(self, model):
248
+ for n, p in model.named_parameters():
249
+ if p.requires_grad and n in self.shadow:
250
+ self.shadow[n].mul_(self.decay).add_(p.detach(), alpha=1.0 - self.decay)
251
+
252
+ @torch.no_grad()
253
+ def apply_to(self, model):
254
+ backup = {}
255
+ for n, p in model.named_parameters():
256
+ if n in self.shadow:
257
+ backup[n] = p.detach().clone(); p.copy_(self.shadow[n])
258
+ return backup
259
+
260
+ @torch.no_grad()
261
+ def restore(self, model, backup):
262
+ for n, p in model.named_parameters():
263
+ if n in backup: p.copy_(backup[n])
264
+
265
+
266
+ # =============================================================================
267
+ # 4. EXTRACTION CORPUS INTERNE (PDF / DOCX / TXT / MD)
268
+ # =============================================================================
269
+ def extract_text_from_file(path: Path) -> str:
270
+ """Extracteur multi-format. Retourne texte brut ou ''."""
271
+ suffix = path.suffix.lower()
272
+ try:
273
+ if suffix in {".txt", ".md"}:
274
+ return path.read_text(encoding="utf-8", errors="ignore")
275
+
276
+ if suffix == ".pdf":
277
+ try:
278
+ from pypdf import PdfReader
279
+ except ImportError:
280
+ from PyPDF2 import PdfReader
281
+ reader = PdfReader(str(path))
282
+ return "\n".join((p.extract_text() or "") for p in reader.pages)
283
+
284
+ if suffix == ".docx":
285
+ from docx import Document
286
+ doc = Document(str(path))
287
+ return "\n".join(p.text for p in doc.paragraphs)
288
+
289
+ if suffix in {".html", ".htm"}:
290
+ from bs4 import BeautifulSoup
291
+ soup = BeautifulSoup(path.read_text(encoding="utf-8", errors="ignore"),
292
+ "html.parser")
293
+ return soup.get_text(separator="\n")
294
+ except Exception as e:
295
+ print(f" [warn] extract {path.name} : {e}")
296
+ return ""
297
+
298
+
299
+ def chunk_document(text: str, chunk_size: int = 1500,
300
+ overlap: int = 200) -> List[Tuple[str, str]]:
301
+ """
302
+ Découpe un document en (titre/section, contenu) pour générer des paires.
303
+ Utilise les titres Markdown / numérotation pour détecter les sections.
304
+ """
305
+ text = re.sub(r"\n{3,}", "\n\n", text).strip()
306
+ if not text:
307
+ return []
308
+
309
+ # Détection sections (Markdown ##, numérotation 1., 1.1, ARTICLE, etc.)
310
+ section_re = re.compile(
311
+ r"(?m)^(#{1,4}\s+.+|" # markdown
312
+ r"\d+(?:\.\d+)*\.?\s+[A-ZÀ-Ÿa-zà-ÿ].+|" # numérotation
313
+ r"ARTICLE\s+\d+[\s\-:].+|" # juridique
314
+ r"CHAPITRE\s+\d+[\s\-:].+|" # juridique
315
+ r"[A-ZÀ-Ÿ][A-ZÀ-Ÿ\s]{8,}$)" # ALL CAPS section
316
+ )
317
+ sections = []
318
+ matches = list(section_re.finditer(text))
319
+ if matches:
320
+ for i, m in enumerate(matches):
321
+ title = m.group(0).strip()
322
+ start = m.end()
323
+ end = matches[i + 1].start() if i + 1 < len(matches) else len(text)
324
+ content = text[start:end].strip()
325
+ if title and content and len(content) > 80:
326
+ sections.append((title[:200], content))
327
+
328
+ # Si pas de sections détectées, fallback chunks fixes
329
+ if not sections:
330
+ for i in range(0, len(text), chunk_size - overlap):
331
+ chunk = text[i:i + chunk_size].strip()
332
+ if len(chunk) > 80:
333
+ # titre = première phrase
334
+ first_period = chunk.find(".")
335
+ title = chunk[:first_period if first_period > 20 else 80].strip()
336
+ sections.append((title, chunk))
337
+ return sections
338
+
339
+
340
+ def load_internal_corpus(cfg: Config) -> Tuple[List[Dict[str, str]], List[str]]:
341
+ """Lit ./data/corpus_interne/* et génère paires + textes pour MLM."""
342
+ pairs = []
343
+ raw_texts = []
344
+ corpus_dir = Path(cfg.custom_corpus_dir)
345
+ if not corpus_dir.exists():
346
+ print(f" [info] Dossier corpus interne absent : {corpus_dir}")
347
+ return pairs, raw_texts
348
+
349
+ files = []
350
+ for ext in ("*.pdf", "*.docx", "*.txt", "*.md", "*.html", "*.htm"):
351
+ files.extend(corpus_dir.rglob(ext))
352
+ print(f" [+] {len(files)} fichiers internes trouvés")
353
+
354
+ for fp in tqdm(files, desc="corpus_interne"):
355
+ text = extract_text_from_file(fp)
356
+ if not text or len(text) < 200:
357
+ continue
358
+ raw_texts.append(text)
359
+ sections = chunk_document(text)
360
+ for title, content in sections:
361
+ pairs.append({
362
+ "anchor": title,
363
+ "positive": content[:2500],
364
+ "_internal": True,
365
+ })
366
+ # Paire bonus : "où trouver X ?" -> contenu
367
+ pairs.append({
368
+ "anchor": f"Où trouver des informations sur : {title} ?",
369
+ "positive": content[:2500],
370
+ "_internal": True,
371
+ })
372
+ return pairs, raw_texts
373
+
374
+
375
+ # =============================================================================
376
+ # 5. CHARGEMENT DATASETS PUBLICS (DOC GÉNÉRIQUE FR)
377
+ # =============================================================================
378
+ DOC_KEYWORDS = re.compile(
379
+ r"\b(article|chapitre|procédure|politique|règlement|directive|note de service|"
380
+ r"manuel|guide|formation|RH|ressources humaines|congé|absence|salaire|paie|"
381
+ r"contrat|CDI|CDD|convention|accord|qualité|conformité|audit|ISO|RGPD|"
382
+ r"comité|conseil|assemblée|direction|département|service|budget|"
383
+ r"facture|comptabilité|comptable|TVA|achat|vente|client|fournisseur|"
384
+ r"juridique|légal|loi|décret|arrêté|jurisprudence|tribunal|"
385
+ r"sécurité|incident|risque|santé|hygiène|formation)\b",
386
+ re.IGNORECASE,
387
+ )
388
+
389
+ def is_doc_text(t: str) -> bool:
390
+ return bool(DOC_KEYWORDS.search(t)) if t else False
391
+
392
+
393
+ def load_doc_pairs(cfg: Config) -> List[Dict[str, str]]:
394
+ print("\n[DATA] Chargement des datasets DOC INTERNE...")
395
+ pairs: List[Dict[str, str]] = []
396
+
397
+ # 5.1 Corpus interne (priorité absolue, oversample)
398
+ internal_pairs, internal_texts = load_internal_corpus(cfg)
399
+ print(f" [+] Corpus interne : {len(internal_pairs):,} paires brutes")
400
+ pairs.extend(internal_pairs * cfg.internal_oversample)
401
+
402
+ # 5.2 PIAF + FQuAD (paires question / contexte FR génériques)
403
+ try:
404
+ ds = load_dataset("etalab-ia/piaf", split="train")
405
+ for ex in tqdm(ds, desc="PIAF"):
406
+ q = (ex.get("question") or "").strip()
407
+ ctx = (ex.get("context") or "").strip()
408
+ if q and ctx:
409
+ pairs.append({"anchor": q, "positive": ctx})
410
+ except Exception as e:
411
+ print(f" [warn] PIAF : {e}")
412
+
413
+ try:
414
+ ds = load_dataset("manu/fquad2_test", split="train")
415
+ for ex in tqdm(ds, desc="FQuAD2"):
416
+ q = (ex.get("question") or "").strip()
417
+ ctx = (ex.get("context") or "").strip()
418
+ if q and ctx:
419
+ pairs.append({"anchor": q, "positive": ctx})
420
+ except Exception as e:
421
+ print(f" [warn] FQuAD2 : {e}")
422
+
423
+ # 5.3 mMARCO FR filtré "documentaire"
424
+ try:
425
+ ds = load_dataset("unicamp-dl/mmarco", "french", split="train")
426
+ ds = ds.select(range(min(500_000, len(ds))))
427
+ kept = 0
428
+ for ex in tqdm(ds, desc="mMARCO-FR (DOC-filter)"):
429
+ q = (ex.get("query") or "").strip()
430
+ p = (ex.get("positive") or ex.get("passage") or "").strip()
431
+ if q and p and (is_doc_text(q) or is_doc_text(p)):
432
+ pairs.append({"anchor": q, "positive": p})
433
+ kept += 1
434
+ if kept >= cfg.max_samples_per_dataset: break
435
+ except Exception as e:
436
+ print(f" [warn] mMARCO : {e}")
437
+
438
+ # 5.4 Wikipedia FR — paires (résumé/lead -> section)
439
+ try:
440
+ ds = load_dataset("wikipedia", "20220301.fr", split="train",
441
+ trust_remote_code=True)
442
+ ds = ds.select(range(min(100_000, len(ds))))
443
+ for ex in tqdm(ds, desc="Wikipedia-FR"):
444
+ title = (ex.get("title") or "").strip()
445
+ text = (ex.get("text") or "").strip()
446
+ if not title or not text or len(text) < 300:
447
+ continue
448
+ # Première section comme positif du titre
449
+ first_chunk = text[:2000]
450
+ pairs.append({"anchor": title, "positive": first_chunk})
451
+ # Sections suivantes si présentes
452
+ paragraphs = text.split("\n\n")
453
+ for para in paragraphs[1:6]:
454
+ if len(para) > 200:
455
+ pairs.append({
456
+ "anchor": f"Que dit l'article '{title}' à propos de cela ?",
457
+ "positive": para[:2000],
458
+ })
459
+ except Exception as e:
460
+ print(f" [warn] Wikipedia FR : {e}")
461
+
462
+ # 5.5 MultiLegalPile FR (juridique)
463
+ try:
464
+ ds = load_dataset("joelniklaus/Multi_Legal_Pile", "fr_caselaw",
465
+ split="train", streaming=True)
466
+ count = 0
467
+ for ex in tqdm(ds, desc="MultiLegalPile-FR", total=50_000):
468
+ text = (ex.get("text") or "").strip()
469
+ if len(text) < 500: continue
470
+ # Première phrase = anchor, reste = positif
471
+ first_period = text.find(".")
472
+ if 30 < first_period < 250:
473
+ anchor = text[:first_period + 1]
474
+ positive = text[first_period + 1:first_period + 2001]
475
+ if len(positive) > 100:
476
+ pairs.append({"anchor": anchor, "positive": positive})
477
+ count += 1
478
+ if count >= 50_000: break
479
+ except Exception as e:
480
+ print(f" [warn] MultiLegalPile : {e}")
481
+
482
+ # 5.6 XNLI FR (entailment)
483
+ try:
484
+ ds = load_dataset("xnli", "fr", split="train")
485
+ ds = ds.filter(lambda x: x["label"] == 0)
486
+ ds = ds.select(range(min(80_000, len(ds))))
487
+ for ex in tqdm(ds, desc="XNLI-FR"):
488
+ a = (ex.get("premise") or "").strip()
489
+ b = (ex.get("hypothesis") or "").strip()
490
+ if a and b:
491
+ pairs.append({"anchor": a, "positive": b})
492
+ except Exception as e:
493
+ print(f" [warn] XNLI : {e}")
494
+
495
+ # 5.7 Custom JSONL
496
+ if Path(cfg.custom_jsonl_path).exists():
497
+ with open(cfg.custom_jsonl_path, "r", encoding="utf-8") as f:
498
+ for line in tqdm(f, desc="custom_doc.jsonl"):
499
+ try:
500
+ ex = json.loads(line)
501
+ a = (ex.get("anchor") or ex.get("query") or "").strip()
502
+ p = (ex.get("positive") or ex.get("passage") or "").strip()
503
+ if a and p:
504
+ pairs.append({"anchor": a, "positive": p, "_internal": True})
505
+ except Exception:
506
+ continue
507
+
508
+ # Dédup
509
+ seen = set(); uniq = []
510
+ for p in pairs:
511
+ k = (p["anchor"][:200], p["positive"][:200])
512
+ if k not in seen:
513
+ seen.add(k); uniq.append(p)
514
+ random.shuffle(uniq)
515
+ n_internal = sum(1 for p in uniq if p.get("_internal"))
516
+ print(f"[DATA] Total paires uniques : {len(uniq):,} (dont interne : {n_internal:,})")
517
+ return uniq
518
+
519
+
520
+ # =============================================================================
521
+ # 6. HARD NEGATIVE MINING (2 negs par paire)
522
+ # =============================================================================
523
+ def mine_hard_negatives_multi(pairs, cfg: Config):
524
+ print(f"\n[HN] Mining {cfg.n_hard_neg} hard negatives par paire...")
525
+ try:
526
+ from sklearn.feature_extraction.text import TfidfVectorizer
527
+ from sklearn.metrics.pairwise import linear_kernel
528
+ except ImportError:
529
+ print(" [warn] sklearn manquant"); return pairs
530
+
531
+ n = len(pairs)
532
+ pool_size = min(cfg.hard_neg_pool_size, n)
533
+ pool_idx = np.random.choice(n, size=pool_size, replace=False)
534
+ pool_pass = [pairs[i]["positive"] for i in pool_idx]
535
+ vec = TfidfVectorizer(max_features=80_000, ngram_range=(1, 2),
536
+ lowercase=True, strip_accents="unicode")
537
+ X_pool = vec.fit_transform(pool_pass)
538
+
539
+ enriched = []
540
+ batch = 2000
541
+ anchors = [p["anchor"] for p in pairs]
542
+ for start in tqdm(range(0, n, batch), desc="HN-mine"):
543
+ end = min(start + batch, n)
544
+ Xq = vec.transform(anchors[start:end])
545
+ sims = linear_kernel(Xq, X_pool)
546
+ for i_loc, i_glob in enumerate(range(start, end)):
547
+ true_pos = pairs[i_glob]["positive"]
548
+ order = np.argsort(-sims[i_loc])
549
+ picked = []
550
+ for j in order[:50]:
551
+ cand = pool_pass[j]
552
+ if cand != true_pos and cand not in picked:
553
+ picked.append(cand)
554
+ if len(picked) >= cfg.n_hard_neg: break
555
+ while len(picked) < cfg.n_hard_neg:
556
+ picked.append(pool_pass[random.randint(0, pool_size - 1)])
557
+ enriched.append({
558
+ "anchor": pairs[i_glob]["anchor"],
559
+ "positive": pairs[i_glob]["positive"],
560
+ "hard_negs": picked,
561
+ })
562
+ return enriched
563
+
564
+
565
+ # =============================================================================
566
+ # 7. DATASET / COLLATE (multi-hn)
567
+ # =============================================================================
568
+ class PairDataset(Dataset):
569
+ def __init__(self, items, n_hn): self.items, self.n_hn = items, n_hn
570
+ def __len__(self): return len(self.items)
571
+ def __getitem__(self, i):
572
+ ex = self.items[i]
573
+ if self.n_hn > 0:
574
+ negs = ex.get("hard_negs", [ex["positive"]] * self.n_hn)
575
+ return ex["anchor"], ex["positive"], negs[:self.n_hn]
576
+ return ex["anchor"], ex["positive"]
577
+
578
+
579
+ def make_collate_fn(tokenizer, max_len, n_hn):
580
+ def collate(batch):
581
+ a_l = [b[0] for b in batch]; p_l = [b[1] for b in batch]
582
+ a = tokenizer(a_l, padding=True, truncation=True,
583
+ max_length=max_len, return_tensors="pt")
584
+ p = tokenizer(p_l, padding=True, truncation=True,
585
+ max_length=max_len, return_tensors="pt")
586
+ if n_hn > 0:
587
+ # Flatten : [n0_p1, n0_p2, n1_p1, n1_p2, ...] -> on tokenize tout
588
+ all_negs = []
589
+ for b in batch:
590
+ all_negs.extend(b[2]) # n_hn négatifs par exemple
591
+ n = tokenizer(all_negs, padding=True, truncation=True,
592
+ max_length=max_len, return_tensors="pt")
593
+ return a, p, n
594
+ return a, p
595
+ return collate
596
+
597
+
598
+ # =============================================================================
599
+ # 8. LOSS — Symmetric MNRL avec multi-hard-negatives
600
+ # =============================================================================
601
+ def symmetric_mnrl_multi_hn(emb_a, emb_p, emb_neg=None, n_hn=0, temperature=0.02):
602
+ """
603
+ emb_neg : (N * n_hn, d) si fourni, sinon None.
604
+ Cibles a -> [P; N1; N2; ...] : N positifs + N*n_hn négatifs durs
605
+ """
606
+ N = emb_a.size(0)
607
+ labels = torch.arange(N, device=emb_a.device)
608
+ if emb_neg is not None and n_hn > 0:
609
+ targets = torch.cat([emb_p, emb_neg], dim=0)
610
+ sim_a = emb_a @ targets.t() / temperature
611
+ loss_a2p = F.cross_entropy(sim_a, labels)
612
+ else:
613
+ sim_a = emb_a @ emb_p.t() / temperature
614
+ loss_a2p = F.cross_entropy(sim_a, labels)
615
+ sim_p = emb_p @ emb_a.t() / temperature
616
+ loss_p2a = F.cross_entropy(sim_p, labels)
617
+ loss = 0.5 * (loss_a2p + loss_p2a)
618
+ with torch.no_grad():
619
+ acc = (sim_a[:, :N].argmax(dim=1) == labels).float().mean().item()
620
+ return loss, acc
621
+
622
+
623
+ # =============================================================================
624
+ # 9. MLM PRÉ-ENTRAÎNEMENT (priorité corpus interne)
625
+ # =============================================================================
626
+ def mlm_pretrain(model, tokenizer, internal_texts, public_texts, cfg: Config):
627
+ # 50% interne (oversampled) + 50% public pour spécialiser sans oublier
628
+ if internal_texts:
629
+ # On répète le corpus interne pour qu'il occupe ~50% du MLM
630
+ target_size = max(len(public_texts), 1)
631
+ repeats = max(1, target_size // max(len(internal_texts), 1))
632
+ internal_repeated = internal_texts * repeats
633
+ random.shuffle(internal_repeated)
634
+ public_texts = public_texts[:target_size]
635
+ all_texts = internal_repeated[:target_size] + public_texts
636
+ else:
637
+ all_texts = public_texts
638
+ random.shuffle(all_texts)
639
+ print(f"\n[MLM] Pré-entraînement sur {len(all_texts):,} textes "
640
+ f"(interne : {len(internal_texts):,})")
641
+
642
+ class MLMDataset(Dataset):
643
+ def __init__(self, t): self.t = t
644
+ def __len__(self): return len(self.t)
645
+ def __getitem__(self, i): return self.t[i]
646
+
647
+ def mlm_collate(batch):
648
+ enc = tokenizer(batch, padding=True, truncation=True,
649
+ max_length=cfg.max_seq_len, return_tensors="pt")
650
+ ids = enc["input_ids"].clone(); labels = ids.clone()
651
+ special = torch.zeros_like(ids, dtype=torch.bool)
652
+ for sid in tokenizer.all_special_ids: special |= (ids == sid)
653
+ prob = torch.full(ids.shape, cfg.mlm_prob)
654
+ prob.masked_fill_(special, 0.0)
655
+ masked = torch.bernoulli(prob).bool()
656
+ labels[~masked] = -100
657
+ rand = torch.rand(ids.shape)
658
+ ids[masked & (rand < 0.8)] = tokenizer.mask_token_id
659
+ rr = masked & (rand >= 0.8) & (rand < 0.9)
660
+ rt = torch.randint(0, tokenizer.vocab_size, ids.shape)
661
+ ids[rr] = rt[rr]
662
+ return ids, enc["attention_mask"], labels
663
+
664
+ loader = DataLoader(MLMDataset(all_texts), batch_size=cfg.batch_size,
665
+ shuffle=True, num_workers=cfg.num_workers,
666
+ collate_fn=mlm_collate, pin_memory=True,
667
+ drop_last=True, persistent_workers=True)
668
+ optim = AdamW(model.parameters(), lr=cfg.mlm_lr, weight_decay=0.01,
669
+ betas=(0.9, 0.98), eps=1e-6)
670
+ total_steps = len(loader) * cfg.mlm_epochs
671
+ sched = get_cosine_schedule_with_warmup(optim, int(total_steps * 0.04), total_steps)
672
+ model.train()
673
+ autocast_dtype = torch.bfloat16 if cfg.use_bf16 else torch.float16
674
+ for ep in range(cfg.mlm_epochs):
675
+ running = 0.0
676
+ pbar = tqdm(loader, desc=f"MLM ep{ep+1}/{cfg.mlm_epochs}")
677
+ for step, (ids, mask, labels) in enumerate(pbar, 1):
678
+ ids = ids.to(device, non_blocking=True)
679
+ mask = mask.to(device, non_blocking=True)
680
+ labels = labels.to(device, non_blocking=True)
681
+ optim.zero_grad(set_to_none=True)
682
+ with torch.autocast(device_type="cuda", dtype=autocast_dtype):
683
+ logits = model.forward_mlm(ids, mask)
684
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)),
685
+ labels.view(-1), ignore_index=-100)
686
+ loss.backward()
687
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
688
+ optim.step(); sched.step()
689
+ running += loss.item()
690
+ if step % 50 == 0:
691
+ pbar.set_postfix(loss=f"{running/step:.4f}",
692
+ ppl=f"{math.exp(min(20, running/step)):.1f}")
693
+ print("[MLM] Terminé.\n")
694
+
695
+
696
+ # =============================================================================
697
+ # 10. EVAL
698
+ # =============================================================================
699
+ @torch.no_grad()
700
+ def evaluate_retrieval(model, tokenizer, eval_pairs, cfg: Config):
701
+ model.eval()
702
+ autocast_dtype = torch.bfloat16 if cfg.use_bf16 else torch.float16
703
+ queries = [e["anchor"] for e in eval_pairs]
704
+ passages = [e["positive"] for e in eval_pairs]
705
+
706
+ def encode(texts):
707
+ embs = []
708
+ for i in range(0, len(texts), 32):
709
+ chunk = texts[i:i+32]
710
+ enc = tokenizer(chunk, padding=True, truncation=True,
711
+ max_length=cfg.max_seq_len, return_tensors="pt").to(device)
712
+ with torch.autocast(device_type="cuda", dtype=autocast_dtype):
713
+ e = model(enc["input_ids"], enc["attention_mask"])
714
+ embs.append(e.float())
715
+ return torch.cat(embs, dim=0)
716
+
717
+ Q = encode(queries); P = encode(passages)
718
+ sims = Q @ P.t()
719
+ N = sims.size(0)
720
+ targets = torch.arange(N, device=sims.device)
721
+ ranks = sims.argsort(dim=1, descending=True)
722
+ pos_in_rank = (ranks == targets.unsqueeze(1)).nonzero()[:, 1]
723
+ return {
724
+ "R@1": (pos_in_rank == 0).float().mean().item(),
725
+ "R@5": (pos_in_rank < 5).float().mean().item(),
726
+ "R@10": (pos_in_rank < 10).float().mean().item(),
727
+ "MRR": (1.0 / (pos_in_rank.float() + 1)).mean().item(),
728
+ }
729
+
730
+
731
+ # =============================================================================
732
+ # 11. TRAIN
733
+ # =============================================================================
734
+ def train():
735
+ tokenizer = AutoTokenizer.from_pretrained(CFG.tokenizer_name)
736
+ CFG.vocab_size = tokenizer.vocab_size
737
+ print(f"[TOK ] vocab_size = {CFG.vocab_size}")
738
+
739
+ items_all = load_doc_pairs(CFG)
740
+ n_eval = min(CFG.eval_max_size, max(2000, int(len(items_all) * 0.005)))
741
+ eval_items = items_all[:n_eval]
742
+ train_items = items_all[n_eval:]
743
+ print(f"[DATA] train={len(train_items):,} eval={len(eval_items):,}")
744
+
745
+ if CFG.use_hard_negatives:
746
+ train_items = mine_hard_negatives_multi(train_items, CFG)
747
+
748
+ n_hn = CFG.n_hard_neg if CFG.use_hard_negatives else 0
749
+ collate = make_collate_fn(tokenizer, CFG.max_seq_len, n_hn)
750
+ train_loader = DataLoader(
751
+ PairDataset(train_items, n_hn),
752
+ batch_size=CFG.batch_size, shuffle=True,
753
+ num_workers=CFG.num_workers, collate_fn=collate,
754
+ pin_memory=True, drop_last=True, persistent_workers=True,
755
+ )
756
+
757
+ model = TextEncoder(CFG).to(device)
758
+ n_params = count_parameters(model)
759
+ print(f"[MODEL] Paramètres entraînables : {n_params/1e6:.2f} M")
760
+
761
+ if CFG.do_mlm_pretrain:
762
+ # Sépare textes internes vs publics
763
+ internal_texts = []; public_texts = []
764
+ for it in train_items[:500_000]:
765
+ if it.get("_internal"):
766
+ internal_texts.append(it["anchor"])
767
+ internal_texts.append(it["positive"])
768
+ else:
769
+ public_texts.append(it["anchor"])
770
+ public_texts.append(it["positive"])
771
+ mlm_pretrain(model, tokenizer, internal_texts, public_texts, CFG)
772
+
773
+ if CFG.use_compile and hasattr(torch, "compile"):
774
+ model = torch.compile(model, mode=CFG.compile_mode)
775
+
776
+ raw_model = model._orig_mod if hasattr(model, "_orig_mod") else model
777
+ ema = EMA(raw_model, decay=CFG.ema_decay) if CFG.use_ema else None
778
+
779
+ no_decay = ["bias", "LayerNorm.weight", "ln1", "ln2", "ln_f", "emb_ln",
780
+ "gamma1", "gamma2"]
781
+ grouped = [
782
+ {"params": [p for n, p in model.named_parameters()
783
+ if "mlm_head" not in n and not any(nd in n for nd in no_decay)],
784
+ "weight_decay": CFG.weight_decay},
785
+ {"params": [p for n, p in model.named_parameters()
786
+ if "mlm_head" not in n and any(nd in n for nd in no_decay)],
787
+ "weight_decay": 0.0},
788
+ ]
789
+ optimizer = AdamW(grouped, lr=CFG.lr, betas=(0.9, 0.98), eps=1e-6)
790
+ steps_per_epoch = len(train_loader) // CFG.grad_accum_steps
791
+ total_steps = steps_per_epoch * CFG.epochs
792
+ warmup_steps = int(total_steps * CFG.warmup_ratio)
793
+ scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)
794
+ print(f"[OPTIM] total_steps={total_steps} warmup={warmup_steps}")
795
+
796
+ autocast_dtype = torch.bfloat16 if CFG.use_bf16 else torch.float16
797
+ best_mrr = 0.0
798
+ history = []
799
+
800
+ for epoch in range(1, CFG.epochs + 1):
801
+ model.train()
802
+ running_loss = running_acc = 0.0
803
+ n_seen = 0
804
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{CFG.epochs}")
805
+ optimizer.zero_grad(set_to_none=True)
806
+
807
+ for step, batch in enumerate(pbar, start=1):
808
+ if n_hn > 0:
809
+ a, p, neg = batch
810
+ neg = {k: v.to(device, non_blocking=True) for k, v in neg.items()}
811
+ else:
812
+ a, p = batch; neg = None
813
+ a = {k: v.to(device, non_blocking=True) for k, v in a.items()}
814
+ p = {k: v.to(device, non_blocking=True) for k, v in p.items()}
815
+
816
+ with torch.autocast(device_type="cuda", dtype=autocast_dtype):
817
+ emb_a = model(a["input_ids"], a["attention_mask"])
818
+ emb_p = model(p["input_ids"], p["attention_mask"])
819
+ emb_n = (model(neg["input_ids"], neg["attention_mask"])
820
+ if neg is not None else None)
821
+ loss, acc = symmetric_mnrl_multi_hn(
822
+ emb_a, emb_p, emb_n, n_hn=n_hn, temperature=CFG.temperature)
823
+ loss = loss / CFG.grad_accum_steps
824
+
825
+ loss.backward()
826
+ if step % CFG.grad_accum_steps == 0:
827
+ torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.grad_clip)
828
+ optimizer.step(); scheduler.step()
829
+ optimizer.zero_grad(set_to_none=True)
830
+ if ema is not None: ema.update(raw_model)
831
+
832
+ running_loss += loss.item() * CFG.grad_accum_steps
833
+ running_acc += acc; n_seen += 1
834
+ if step % CFG.log_every == 0:
835
+ pbar.set_postfix(loss=f"{running_loss/n_seen:.4f}",
836
+ acc=f"{running_acc/n_seen:.3f}",
837
+ lr=f"{scheduler.get_last_lr()[0]:.2e}")
838
+
839
+ backup = ema.apply_to(raw_model) if ema is not None else None
840
+ metrics = evaluate_retrieval(model, tokenizer, eval_items, CFG)
841
+ if backup is not None: ema.restore(raw_model, backup)
842
+ print(f"\n[EVAL] epoch {epoch} : R@1={metrics['R@1']:.3f} "
843
+ f"R@5={metrics['R@5']:.3f} R@10={metrics['R@10']:.3f} "
844
+ f"MRR={metrics['MRR']:.3f}")
845
+ history.append({"epoch": epoch, **metrics,
846
+ "train_loss": running_loss / max(1, n_seen)})
847
+
848
+ is_best = metrics["MRR"] > best_mrr
849
+ if is_best: best_mrr = metrics["MRR"]
850
+ if ema is not None: backup = ema.apply_to(raw_model)
851
+ state = {k: v for k, v in raw_model.state_dict().items() if "mlm_head" not in k}
852
+
853
+ if epoch % CFG.save_every_epochs == 0 or is_best or epoch == CFG.epochs:
854
+ torch.save({"epoch": epoch, "model_state": state,
855
+ "config": asdict(CFG), "metrics": metrics},
856
+ Path(CFG.save_dir) / f"model_epoch{epoch}.pt")
857
+ if is_best:
858
+ torch.save({"epoch": epoch, "model_state": state,
859
+ "config": asdict(CFG), "metrics": metrics},
860
+ Path(CFG.save_dir) / "model_best.pt")
861
+ if ema is not None: ema.restore(raw_model, backup)
862
+ print(f"[SAVE] epoch {epoch} best={'oui' if is_best else 'non'}")
863
+
864
+ with open(Path(CFG.save_dir) / "history.json", "w", encoding="utf-8") as f:
865
+ json.dump(history, f, ensure_ascii=False, indent=2)
866
+ tokenizer.save_pretrained(CFG.save_dir)
867
+ print(f"\n[OK] Best MRR = {best_mrr:.3f} -> {CFG.save_dir}/model_best.pt")
868
+
869
+
870
+ # =============================================================================
871
+ # 12. DÉMO
872
+ # =============================================================================
873
+ @torch.no_grad()
874
+ def demo():
875
+ tokenizer = AutoTokenizer.from_pretrained(CFG.save_dir)
876
+ ckpt = torch.load(Path(CFG.save_dir) / "model_best.pt", map_location=device)
877
+ saved_cfg = ckpt["config"]
878
+ cfg2 = Config(**{k: v for k, v in saved_cfg.items() if hasattr(Config, k)})
879
+ cfg2.vocab_size = tokenizer.vocab_size
880
+ model = TextEncoder(cfg2).to(device).eval()
881
+ model.load_state_dict(ckpt["model_state"], strict=False)
882
+
883
+ corpus = [
884
+ "ARTICLE 12 - Les congés payés sont acquis à raison de 2,5 jours par mois travaillé.",
885
+ "Procédure de validation des notes de frais : transmettre via le portail RH avant le 5 du mois.",
886
+ "La politique RGPD impose un délai de 72h pour notifier une violation de données.",
887
+ "Le télétravail est autorisé jusqu'à 3 jours par semaine sur accord du manager.",
888
+ "Toute facture fournisseur doit être validée par le responsable budget avant paiement.",
889
+ "Formation obligatoire sécurité incendie : 1 fois par an, traçabilité dans le SIRH.",
890
+ "L'accord d'entreprise du 15/03/2024 fixe le taux de prime annuelle à 8% du salaire brut.",
891
+ ]
892
+ queries = [
893
+ "Combien de jours de congés je gagne par mois ?",
894
+ "Comment déclarer mes notes de frais ?",
895
+ "Quel est le quota de télétravail ?",
896
+ "Quel taux de prime annuelle ?",
897
+ ]
898
+ enc = tokenizer(corpus, padding=True, truncation=True,
899
+ max_length=cfg2.max_seq_len, return_tensors="pt").to(device)
900
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
901
+ c_emb = model(enc["input_ids"], enc["attention_mask"])
902
+
903
+ print("\n[DEMO DOC-INTERNE-100M]")
904
+ for q in queries:
905
+ eq = tokenizer([q], padding=True, truncation=True,
906
+ max_length=cfg2.max_seq_len, return_tensors="pt").to(device)
907
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
908
+ q_emb = model(eq["input_ids"], eq["attention_mask"])
909
+ sims = (q_emb @ c_emb.t()).squeeze(0)
910
+ top = sims.topk(3)
911
+ print(f"\nQ : {q}")
912
+ for s, i in zip(top.values, top.indices):
913
+ print(f" ({s.item():.3f}) -> {corpus[i.item()]}")
914
+
915
+
916
+ if __name__ == "__main__":
917
+ train()
918
+ try:
919
+ demo()
920
+ except Exception as e:
921
+ print(f"[demo] {e}")
modeleAIRAG/train3_200m.py ADDED
@@ -0,0 +1,922 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ==============================================================================
3
+ RAG/NLP encoder ~100M params - SPÉCIALISÉ DOCUMENTAIRE INTERNE ENTREPRISE
4
+ (RH, juridique, procédures, comptabilité, qualité, conformité, formation)
5
+ Hardware : NVIDIA H100 80GB
6
+ Epochs : 20
7
+ ==============================================================================
8
+
9
+ Spécificités vs version IT :
10
+ - max_seq_len = 384 (documents internes longs : procédures, contrats)
11
+ - Filtres lexicaux orientés "entreprise / documentation"
12
+ - Datasets : Common Crawl FR (filtré), Wikipédia FR (catégories doc),
13
+ FQuAD/PIAF (questions admin/juridique), MultiLegalPile-FR,
14
+ corpus interne JSONL (priorité absolue)
15
+ - Augmentation : "title -> contenu" et "section -> paragraphe"
16
+ - Loss : MNRL symétrique + 2 hard negatives par paire
17
+ - Pré-entraînement MLM sur corpus interne en priorité
18
+ - EMA decay 0.9995, LayerScale, BF16, SDPA, Gradient Checkpointing
19
+ - 20 epochs, batch effectif 384
20
+
21
+ Architecture identique 100M params (12L, 768d, 12H, FFN=3072).
22
+
23
+ Usage :
24
+ pip install torch>=2.2 transformers>=4.40 datasets>=2.18 accelerate \\
25
+ sentencepiece tqdm numpy scikit-learn faiss-cpu beautifulsoup4
26
+ python train_rag_doc_interne_100m.py
27
+
28
+ Préparation du corpus interne :
29
+ Place tes documents dans ./data/corpus_interne/ (PDF/DOCX/TXT/MD)
30
+ Ou directement un JSONL ./data/custom_doc.jsonl avec {"anchor","positive"}
31
+ """
32
+ import os
33
+ import math
34
+ import json
35
+ import random
36
+ import re
37
+ import glob
38
+ from dataclasses import dataclass, asdict
39
+ from pathlib import Path
40
+ from typing import List, Dict, Tuple, Optional
41
+
42
+ import numpy as np
43
+ import torch
44
+ import torch.nn as nn
45
+ import torch.nn.functional as F
46
+ import torch.utils.checkpoint as gc
47
+ from torch.utils.data import Dataset, DataLoader
48
+ from torch.optim import AdamW
49
+
50
+ from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
51
+ from datasets import load_dataset, Dataset as HFDataset
52
+ from tqdm.auto import tqdm
53
+
54
+ # =============================================================================
55
+ # 1. CONFIG — 100M, Documentaire interne
56
+ # =============================================================================
57
+ @dataclass
58
+ class Config:
59
+ # --- Modèle ~100M ---
60
+ vocab_size: int = 32000
61
+ hidden_size: int = 1024
62
+ num_hidden_layers: int = 16
63
+ num_attention_heads: int = 16
64
+ intermediate_size: int = 4096
65
+ max_position_embeddings: int = 512 # docs longs
66
+ hidden_dropout_prob: float = 0.1
67
+ attention_probs_dropout_prob: float = 0.1
68
+ layer_norm_eps: float = 1e-12
69
+ embedding_dim: int = 1024
70
+ use_layer_scale: bool = True
71
+ layer_scale_init: float = 1e-5
72
+ use_grad_checkpointing: bool = True
73
+
74
+ tokenizer_name: str = "camembert-base"
75
+
76
+ # --- MLM (priorité corpus interne) ---
77
+ do_mlm_pretrain: bool = True
78
+ mlm_epochs: int = 2 # +1 vs IT, doc interne plus rare
79
+ mlm_prob: float = 0.15
80
+ mlm_lr: float = 8e-5
81
+
82
+ # --- Contrastif ---
83
+ epochs: int = 12
84
+ batch_size: int = 32 # seq_len 384 -> batch + petit
85
+ grad_accum_steps: int = 12 # effectif = 384
86
+ max_seq_len: int = 384
87
+ lr: float = 1.5e-5
88
+ weight_decay: float = 0.01
89
+ warmup_ratio: float = 0.06
90
+ grad_clip: float = 1.0
91
+ temperature: float = 0.02
92
+ num_workers: int = 6
93
+ seed: int = 42
94
+
95
+ # --- Hard negatives (2 par paire pour doc interne) ---
96
+ use_hard_negatives: bool = True
97
+ n_hard_neg: int = 2 # plus fort
98
+ hard_neg_pool_size: int = 200_000
99
+
100
+ use_ema: bool = True
101
+ ema_decay: float = 0.9995
102
+
103
+ max_samples_per_dataset: int = 250_000
104
+ eval_max_size: int = 5_000
105
+
106
+ use_bf16: bool = True
107
+ use_compile: bool = True
108
+ compile_mode: str = "default"
109
+ log_every: int = 50
110
+ save_dir: str = "./checkpoints_rag_doc_200m"
111
+ save_every_epochs: int = 2
112
+
113
+ # --- Corpus interne ---
114
+ custom_jsonl_path: str = "./data/custom_doc.jsonl"
115
+ custom_corpus_dir: str = "./data/corpus_interne" # PDF/DOCX/TXT/MD
116
+ internal_oversample: int = 8 # x5 pour booster apprentissage interne
117
+
118
+
119
+ CFG = Config()
120
+ Path(CFG.save_dir).mkdir(parents=True, exist_ok=True)
121
+ random.seed(CFG.seed); np.random.seed(CFG.seed)
122
+ torch.manual_seed(CFG.seed); torch.cuda.manual_seed_all(CFG.seed)
123
+ torch.backends.cuda.matmul.allow_tf32 = True
124
+ torch.backends.cudnn.allow_tf32 = True
125
+ torch.set_float32_matmul_precision("high")
126
+
127
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
128
+ print(f"[INFO] Device : {device}")
129
+ if torch.cuda.is_available():
130
+ print(f"[INFO] GPU : {torch.cuda.get_device_name(0)}")
131
+ print(f"[INFO] VRAM : {torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB")
132
+
133
+
134
+ # =============================================================================
135
+ # 2. ARCHITECTURE
136
+ # =============================================================================
137
+ class TransformerEncoderBlock(nn.Module):
138
+ def __init__(self, cfg):
139
+ super().__init__()
140
+ self.num_heads = cfg.num_attention_heads
141
+ self.head_dim = cfg.hidden_size // cfg.num_attention_heads
142
+ self.ln1 = nn.LayerNorm(cfg.hidden_size, eps=cfg.layer_norm_eps)
143
+ self.qkv = nn.Linear(cfg.hidden_size, 3 * cfg.hidden_size, bias=True)
144
+ self.proj = nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=True)
145
+ self.attn_drop_p = cfg.attention_probs_dropout_prob
146
+ self.ln2 = nn.LayerNorm(cfg.hidden_size, eps=cfg.layer_norm_eps)
147
+ self.mlp = nn.Sequential(
148
+ nn.Linear(cfg.hidden_size, cfg.intermediate_size),
149
+ nn.GELU(),
150
+ nn.Linear(cfg.intermediate_size, cfg.hidden_size),
151
+ nn.Dropout(cfg.hidden_dropout_prob),
152
+ )
153
+ self.resid_drop = nn.Dropout(cfg.hidden_dropout_prob)
154
+ self.use_ls = cfg.use_layer_scale
155
+ if cfg.use_layer_scale:
156
+ self.gamma1 = nn.Parameter(cfg.layer_scale_init * torch.ones(cfg.hidden_size))
157
+ self.gamma2 = nn.Parameter(cfg.layer_scale_init * torch.ones(cfg.hidden_size))
158
+
159
+ def forward(self, x, attn_mask):
160
+ B, T, C = x.shape
161
+ h = self.ln1(x)
162
+ qkv = self.qkv(h).view(B, T, 3, self.num_heads, self.head_dim)
163
+ q, k, v = qkv.permute(2, 0, 3, 1, 4)
164
+ kpm = attn_mask[:, None, None, :].bool()
165
+ a = F.scaled_dot_product_attention(
166
+ q, k, v, attn_mask=kpm,
167
+ dropout_p=self.attn_drop_p if self.training else 0.0,
168
+ is_causal=False)
169
+ a = a.transpose(1, 2).contiguous().view(B, T, C)
170
+ a = self.resid_drop(self.proj(a))
171
+ if self.use_ls: a = a * self.gamma1
172
+ x = x + a
173
+ m = self.mlp(self.ln2(x))
174
+ if self.use_ls: m = m * self.gamma2
175
+ return x + m
176
+
177
+
178
+ class TextEncoder(nn.Module):
179
+ def __init__(self, cfg):
180
+ super().__init__()
181
+ self.cfg = cfg
182
+ self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.hidden_size, padding_idx=0)
183
+ self.pos_emb = nn.Embedding(cfg.max_position_embeddings, cfg.hidden_size)
184
+ self.emb_ln = nn.LayerNorm(cfg.hidden_size, eps=cfg.layer_norm_eps)
185
+ self.emb_drop = nn.Dropout(cfg.hidden_dropout_prob)
186
+ self.blocks = nn.ModuleList([TransformerEncoderBlock(cfg)
187
+ for _ in range(cfg.num_hidden_layers)])
188
+ self.ln_f = nn.LayerNorm(cfg.hidden_size, eps=cfg.layer_norm_eps)
189
+ self.proj_head = nn.Sequential(
190
+ nn.Linear(cfg.hidden_size, cfg.hidden_size),
191
+ nn.Tanh(),
192
+ nn.Linear(cfg.hidden_size, cfg.embedding_dim),
193
+ )
194
+ self.mlm_head = nn.Linear(cfg.hidden_size, cfg.vocab_size, bias=False)
195
+ self.mlm_head.weight = self.tok_emb.weight
196
+ self.use_gc = cfg.use_grad_checkpointing
197
+ self.apply(self._init_weights)
198
+
199
+ @staticmethod
200
+ def _init_weights(m):
201
+ if isinstance(m, nn.Linear):
202
+ nn.init.normal_(m.weight, std=0.02)
203
+ if m.bias is not None: nn.init.zeros_(m.bias)
204
+ elif isinstance(m, nn.Embedding):
205
+ nn.init.normal_(m.weight, std=0.02)
206
+ elif isinstance(m, nn.LayerNorm):
207
+ nn.init.ones_(m.weight); nn.init.zeros_(m.bias)
208
+
209
+ def encode_backbone(self, ids, mask):
210
+ B, T = ids.shape
211
+ pos = torch.arange(T, device=ids.device).unsqueeze(0).expand(B, T)
212
+ x = self.tok_emb(ids) + self.pos_emb(pos)
213
+ x = self.emb_drop(self.emb_ln(x))
214
+ for blk in self.blocks:
215
+ if self.use_gc and self.training:
216
+ x = gc.checkpoint(blk, x, mask, use_reentrant=False)
217
+ else:
218
+ x = blk(x, mask)
219
+ return self.ln_f(x)
220
+
221
+ def forward(self, ids, mask):
222
+ x = self.encode_backbone(ids, mask)
223
+ m = mask.unsqueeze(-1).float()
224
+ pooled = (x * m).sum(dim=1) / m.sum(dim=1).clamp(min=1e-6)
225
+ emb = self.proj_head(pooled)
226
+ return F.normalize(emb, p=2, dim=-1)
227
+
228
+ def forward_mlm(self, ids, mask):
229
+ return self.mlm_head(self.encode_backbone(ids, mask))
230
+
231
+
232
+ def count_parameters(model):
233
+ return sum(p.numel() for n, p in model.named_parameters()
234
+ if p.requires_grad and "mlm_head" not in n)
235
+
236
+
237
+ # =============================================================================
238
+ # 3. EMA
239
+ # =============================================================================
240
+ class EMA:
241
+ def __init__(self, model, decay=0.999):
242
+ self.decay = decay
243
+ self.shadow = {n: p.detach().clone()
244
+ for n, p in model.named_parameters() if p.requires_grad}
245
+
246
+ @torch.no_grad()
247
+ def update(self, model):
248
+ for n, p in model.named_parameters():
249
+ if p.requires_grad and n in self.shadow:
250
+ self.shadow[n].mul_(self.decay).add_(p.detach(), alpha=1.0 - self.decay)
251
+
252
+ @torch.no_grad()
253
+ def apply_to(self, model):
254
+ backup = {}
255
+ for n, p in model.named_parameters():
256
+ if n in self.shadow:
257
+ backup[n] = p.detach().clone(); p.copy_(self.shadow[n])
258
+ return backup
259
+
260
+ @torch.no_grad()
261
+ def restore(self, model, backup):
262
+ for n, p in model.named_parameters():
263
+ if n in backup: p.copy_(backup[n])
264
+
265
+
266
+ # =============================================================================
267
+ # 4. EXTRACTION CORPUS INTERNE (PDF / DOCX / TXT / MD)
268
+ # =============================================================================
269
+ def extract_text_from_file(path: Path) -> str:
270
+ """Extracteur multi-format. Retourne texte brut ou ''."""
271
+ suffix = path.suffix.lower()
272
+ try:
273
+ if suffix in {".txt", ".md"}:
274
+ return path.read_text(encoding="utf-8", errors="ignore")
275
+
276
+ if suffix == ".pdf":
277
+ try:
278
+ from pypdf import PdfReader
279
+ except ImportError:
280
+ from PyPDF2 import PdfReader
281
+ reader = PdfReader(str(path))
282
+ return "\n".join((p.extract_text() or "") for p in reader.pages)
283
+
284
+ if suffix == ".docx":
285
+ from docx import Document
286
+ doc = Document(str(path))
287
+ return "\n".join(p.text for p in doc.paragraphs)
288
+
289
+ if suffix in {".html", ".htm"}:
290
+ from bs4 import BeautifulSoup
291
+ soup = BeautifulSoup(path.read_text(encoding="utf-8", errors="ignore"),
292
+ "html.parser")
293
+ return soup.get_text(separator="\n")
294
+ except Exception as e:
295
+ print(f" [warn] extract {path.name} : {e}")
296
+ return ""
297
+
298
+
299
+ def chunk_document(text: str, chunk_size: int = 1500,
300
+ overlap: int = 200) -> List[Tuple[str, str]]:
301
+ """
302
+ Découpe un document en (titre/section, contenu) pour générer des paires.
303
+ Utilise les titres Markdown / numérotation pour détecter les sections.
304
+ """
305
+ text = re.sub(r"\n{3,}", "\n\n", text).strip()
306
+ if not text:
307
+ return []
308
+
309
+ # Détection sections (Markdown ##, numérotation 1., 1.1, ARTICLE, etc.)
310
+ section_re = re.compile(
311
+ r"(?m)^(#{1,4}\s+.+|" # markdown
312
+ r"\d+(?:\.\d+)*\.?\s+[A-ZÀ-Ÿa-zà-ÿ].+|" # numérotation
313
+ r"ARTICLE\s+\d+[\s\-:].+|" # juridique
314
+ r"CHAPITRE\s+\d+[\s\-:].+|" # juridique
315
+ r"[A-ZÀ-Ÿ][A-ZÀ-Ÿ\s]{8,}$)" # ALL CAPS section
316
+ )
317
+ sections = []
318
+ matches = list(section_re.finditer(text))
319
+ if matches:
320
+ for i, m in enumerate(matches):
321
+ title = m.group(0).strip()
322
+ start = m.end()
323
+ end = matches[i + 1].start() if i + 1 < len(matches) else len(text)
324
+ content = text[start:end].strip()
325
+ if title and content and len(content) > 80:
326
+ sections.append((title[:200], content))
327
+
328
+ # Si pas de sections détectées, fallback chunks fixes
329
+ if not sections:
330
+ for i in range(0, len(text), chunk_size - overlap):
331
+ chunk = text[i:i + chunk_size].strip()
332
+ if len(chunk) > 80:
333
+ # titre = première phrase
334
+ first_period = chunk.find(".")
335
+ title = chunk[:first_period if first_period > 20 else 80].strip()
336
+ sections.append((title, chunk))
337
+ return sections
338
+
339
+
340
+ def load_internal_corpus(cfg: Config) -> Tuple[List[Dict[str, str]], List[str]]:
341
+ """Lit ./data/corpus_interne/* et génère paires + textes pour MLM."""
342
+ pairs = []
343
+ raw_texts = []
344
+ corpus_dir = Path(cfg.custom_corpus_dir)
345
+ if not corpus_dir.exists():
346
+ print(f" [info] Dossier corpus interne absent : {corpus_dir}")
347
+ return pairs, raw_texts
348
+
349
+ files = []
350
+ for ext in ("*.pdf", "*.docx", "*.txt", "*.md", "*.html", "*.htm"):
351
+ files.extend(corpus_dir.rglob(ext))
352
+ print(f" [+] {len(files)} fichiers internes trouvés")
353
+
354
+ for fp in tqdm(files, desc="corpus_interne"):
355
+ text = extract_text_from_file(fp)
356
+ if not text or len(text) < 200:
357
+ continue
358
+ raw_texts.append(text)
359
+ sections = chunk_document(text)
360
+ for title, content in sections:
361
+ pairs.append({
362
+ "anchor": title,
363
+ "positive": content[:2500],
364
+ "_internal": True,
365
+ })
366
+ # Paire bonus : "où trouver X ?" -> contenu
367
+ pairs.append({
368
+ "anchor": f"Où trouver des informations sur : {title} ?",
369
+ "positive": content[:2500],
370
+ "_internal": True,
371
+ })
372
+ return pairs, raw_texts
373
+
374
+
375
+ # =============================================================================
376
+ # 5. CHARGEMENT DATASETS PUBLICS (DOC GÉNÉRIQUE FR)
377
+ # =============================================================================
378
+ DOC_KEYWORDS = re.compile(
379
+ r"\b(article|chapitre|procédure|politique|règlement|directive|note de service|"
380
+ r"manuel|guide|formation|RH|ressources humaines|congé|absence|salaire|paie|"
381
+ r"contrat|CDI|CDD|convention|accord|qualité|conformité|audit|ISO|RGPD|"
382
+ r"comité|conseil|assemblée|direction|département|service|budget|"
383
+ r"facture|comptabilité|comptable|TVA|achat|vente|client|fournisseur|"
384
+ r"juridique|légal|loi|décret|arrêté|jurisprudence|tribunal|"
385
+ r"sécurité|incident|risque|santé|hygiène|formation)\b",
386
+ re.IGNORECASE,
387
+ )
388
+
389
+ def is_doc_text(t: str) -> bool:
390
+ return bool(DOC_KEYWORDS.search(t)) if t else False
391
+
392
+
393
+ def load_doc_pairs(cfg: Config) -> List[Dict[str, str]]:
394
+ print("\n[DATA] Chargement des datasets DOC INTERNE...")
395
+ pairs: List[Dict[str, str]] = []
396
+
397
+ # 5.1 Corpus interne (priorité absolue, oversample)
398
+ internal_pairs, internal_texts = load_internal_corpus(cfg)
399
+ print(f" [+] Corpus interne : {len(internal_pairs):,} paires brutes")
400
+ pairs.extend(internal_pairs * cfg.internal_oversample)
401
+
402
+ # 5.2 PIAF + FQuAD (paires question / contexte FR génériques)
403
+ try:
404
+ ds = load_dataset("etalab-ia/piaf", split="train")
405
+ for ex in tqdm(ds, desc="PIAF"):
406
+ q = (ex.get("question") or "").strip()
407
+ ctx = (ex.get("context") or "").strip()
408
+ if q and ctx:
409
+ pairs.append({"anchor": q, "positive": ctx})
410
+ except Exception as e:
411
+ print(f" [warn] PIAF : {e}")
412
+
413
+ try:
414
+ ds = load_dataset("manu/fquad2_test", split="train")
415
+ for ex in tqdm(ds, desc="FQuAD2"):
416
+ q = (ex.get("question") or "").strip()
417
+ ctx = (ex.get("context") or "").strip()
418
+ if q and ctx:
419
+ pairs.append({"anchor": q, "positive": ctx})
420
+ except Exception as e:
421
+ print(f" [warn] FQuAD2 : {e}")
422
+
423
+ # 5.3 mMARCO FR filtré "documentaire"
424
+ try:
425
+ ds = load_dataset("unicamp-dl/mmarco", "french", split="train")
426
+ ds = ds.select(range(min(500_000, len(ds))))
427
+ kept = 0
428
+ for ex in tqdm(ds, desc="mMARCO-FR (DOC-filter)"):
429
+ q = (ex.get("query") or "").strip()
430
+ p = (ex.get("positive") or ex.get("passage") or "").strip()
431
+ if q and p and (is_doc_text(q) or is_doc_text(p)):
432
+ pairs.append({"anchor": q, "positive": p})
433
+ kept += 1
434
+ if kept >= cfg.max_samples_per_dataset: break
435
+ except Exception as e:
436
+ print(f" [warn] mMARCO : {e}")
437
+
438
+ # 5.4 Wikipedia FR — paires (résumé/lead -> section)
439
+ try:
440
+ ds = load_dataset("wikipedia", "20220301.fr", split="train",
441
+ trust_remote_code=True)
442
+ ds = ds.select(range(min(100_000, len(ds))))
443
+ for ex in tqdm(ds, desc="Wikipedia-FR"):
444
+ title = (ex.get("title") or "").strip()
445
+ text = (ex.get("text") or "").strip()
446
+ if not title or not text or len(text) < 300:
447
+ continue
448
+ # Première section comme positif du titre
449
+ first_chunk = text[:2000]
450
+ pairs.append({"anchor": title, "positive": first_chunk})
451
+ # Sections suivantes si présentes
452
+ paragraphs = text.split("\n\n")
453
+ for para in paragraphs[1:6]:
454
+ if len(para) > 200:
455
+ pairs.append({
456
+ "anchor": f"Que dit l'article '{title}' à propos de cela ?",
457
+ "positive": para[:2000],
458
+ })
459
+ except Exception as e:
460
+ print(f" [warn] Wikipedia FR : {e}")
461
+
462
+ # 5.5 MultiLegalPile FR (juridique)
463
+ try:
464
+ ds = load_dataset("joelniklaus/Multi_Legal_Pile", "fr_caselaw",
465
+ split="train", streaming=True)
466
+ count = 0
467
+ for ex in tqdm(ds, desc="MultiLegalPile-FR", total=50_000):
468
+ text = (ex.get("text") or "").strip()
469
+ if len(text) < 500: continue
470
+ # Première phrase = anchor, reste = positif
471
+ first_period = text.find(".")
472
+ if 30 < first_period < 250:
473
+ anchor = text[:first_period + 1]
474
+ positive = text[first_period + 1:first_period + 2001]
475
+ if len(positive) > 100:
476
+ pairs.append({"anchor": anchor, "positive": positive})
477
+ count += 1
478
+ if count >= 50_000: break
479
+ except Exception as e:
480
+ print(f" [warn] MultiLegalPile : {e}")
481
+
482
+ # 5.6 XNLI FR (entailment)
483
+ try:
484
+ ds = load_dataset("xnli", "fr", split="train")
485
+ ds = ds.filter(lambda x: x["label"] == 0)
486
+ ds = ds.select(range(min(80_000, len(ds))))
487
+ for ex in tqdm(ds, desc="XNLI-FR"):
488
+ a = (ex.get("premise") or "").strip()
489
+ b = (ex.get("hypothesis") or "").strip()
490
+ if a and b:
491
+ pairs.append({"anchor": a, "positive": b})
492
+ except Exception as e:
493
+ print(f" [warn] XNLI : {e}")
494
+
495
+ # 5.7 Custom JSONL
496
+ if Path(cfg.custom_jsonl_path).exists():
497
+ with open(cfg.custom_jsonl_path, "r", encoding="utf-8") as f:
498
+ for line in tqdm(f, desc="custom_doc.jsonl"):
499
+ try:
500
+ ex = json.loads(line)
501
+ a = (ex.get("anchor") or ex.get("query") or "").strip()
502
+ p = (ex.get("positive") or ex.get("passage") or "").strip()
503
+ if a and p:
504
+ pairs.append({"anchor": a, "positive": p, "_internal": True})
505
+ except Exception:
506
+ continue
507
+
508
+ # Dédup
509
+ seen = set(); uniq = []
510
+ for p in pairs:
511
+ k = (p["anchor"][:200], p["positive"][:200])
512
+ if k not in seen:
513
+ seen.add(k); uniq.append(p)
514
+ random.shuffle(uniq)
515
+ n_internal = sum(1 for p in uniq if p.get("_internal"))
516
+ print(f"[DATA] Total paires uniques : {len(uniq):,} (dont interne : {n_internal:,})")
517
+ return uniq
518
+
519
+
520
+ # =============================================================================
521
+ # 6. HARD NEGATIVE MINING (2 negs par paire)
522
+ # =============================================================================
523
+ def mine_hard_negatives_multi(pairs, cfg: Config):
524
+ print(f"\n[HN] Mining {cfg.n_hard_neg} hard negatives par paire...")
525
+ try:
526
+ from sklearn.feature_extraction.text import TfidfVectorizer
527
+ from sklearn.metrics.pairwise import linear_kernel
528
+ except ImportError:
529
+ print(" [warn] sklearn manquant"); return pairs
530
+
531
+ n = len(pairs)
532
+ pool_size = min(cfg.hard_neg_pool_size, n)
533
+ pool_idx = np.random.choice(n, size=pool_size, replace=False)
534
+ pool_pass = [pairs[i]["positive"] for i in pool_idx]
535
+ vec = TfidfVectorizer(max_features=80_000, ngram_range=(1, 2),
536
+ lowercase=True, strip_accents="unicode")
537
+ X_pool = vec.fit_transform(pool_pass)
538
+
539
+ enriched = []
540
+ batch = 2000
541
+ anchors = [p["anchor"] for p in pairs]
542
+ for start in tqdm(range(0, n, batch), desc="HN-mine"):
543
+ end = min(start + batch, n)
544
+ Xq = vec.transform(anchors[start:end])
545
+ sims = linear_kernel(Xq, X_pool)
546
+ for i_loc, i_glob in enumerate(range(start, end)):
547
+ true_pos = pairs[i_glob]["positive"]
548
+ order = np.argsort(-sims[i_loc])
549
+ picked = []
550
+ for j in order[:50]:
551
+ cand = pool_pass[j]
552
+ if cand != true_pos and cand not in picked:
553
+ picked.append(cand)
554
+ if len(picked) >= cfg.n_hard_neg: break
555
+ while len(picked) < cfg.n_hard_neg:
556
+ picked.append(pool_pass[random.randint(0, pool_size - 1)])
557
+ enriched.append({
558
+ "anchor": pairs[i_glob]["anchor"],
559
+ "positive": pairs[i_glob]["positive"],
560
+ "hard_negs": picked,
561
+ "_internal": pairs[i_glob].get("_internal", False),
562
+ })
563
+ return enriched
564
+
565
+
566
+ # =============================================================================
567
+ # 7. DATASET / COLLATE (multi-hn)
568
+ # =============================================================================
569
+ class PairDataset(Dataset):
570
+ def __init__(self, items, n_hn): self.items, self.n_hn = items, n_hn
571
+ def __len__(self): return len(self.items)
572
+ def __getitem__(self, i):
573
+ ex = self.items[i]
574
+ if self.n_hn > 0:
575
+ negs = ex.get("hard_negs", [ex["positive"]] * self.n_hn)
576
+ return ex["anchor"], ex["positive"], negs[:self.n_hn]
577
+ return ex["anchor"], ex["positive"]
578
+
579
+
580
+ def make_collate_fn(tokenizer, max_len, n_hn):
581
+ def collate(batch):
582
+ a_l = [b[0] for b in batch]; p_l = [b[1] for b in batch]
583
+ a = tokenizer(a_l, padding=True, truncation=True,
584
+ max_length=max_len, return_tensors="pt")
585
+ p = tokenizer(p_l, padding=True, truncation=True,
586
+ max_length=max_len, return_tensors="pt")
587
+ if n_hn > 0:
588
+ # Flatten : [n0_p1, n0_p2, n1_p1, n1_p2, ...] -> on tokenize tout
589
+ all_negs = []
590
+ for b in batch:
591
+ all_negs.extend(b[2]) # n_hn négatifs par exemple
592
+ n = tokenizer(all_negs, padding=True, truncation=True,
593
+ max_length=max_len, return_tensors="pt")
594
+ return a, p, n
595
+ return a, p
596
+ return collate
597
+
598
+
599
+ # =============================================================================
600
+ # 8. LOSS — Symmetric MNRL avec multi-hard-negatives
601
+ # =============================================================================
602
+ def symmetric_mnrl_multi_hn(emb_a, emb_p, emb_neg=None, n_hn=0, temperature=0.02):
603
+ """
604
+ emb_neg : (N * n_hn, d) si fourni, sinon None.
605
+ Cibles a -> [P; N1; N2; ...] : N positifs + N*n_hn négatifs durs
606
+ """
607
+ N = emb_a.size(0)
608
+ labels = torch.arange(N, device=emb_a.device)
609
+ if emb_neg is not None and n_hn > 0:
610
+ targets = torch.cat([emb_p, emb_neg], dim=0)
611
+ sim_a = emb_a @ targets.t() / temperature
612
+ loss_a2p = F.cross_entropy(sim_a, labels)
613
+ else:
614
+ sim_a = emb_a @ emb_p.t() / temperature
615
+ loss_a2p = F.cross_entropy(sim_a, labels)
616
+ sim_p = emb_p @ emb_a.t() / temperature
617
+ loss_p2a = F.cross_entropy(sim_p, labels)
618
+ loss = 0.5 * (loss_a2p + loss_p2a)
619
+ with torch.no_grad():
620
+ acc = (sim_a[:, :N].argmax(dim=1) == labels).float().mean().item()
621
+ return loss, acc
622
+
623
+
624
+ # =============================================================================
625
+ # 9. MLM PRÉ-ENTRAÎNEMENT (priorité corpus interne)
626
+ # =============================================================================
627
+ def mlm_pretrain(model, tokenizer, internal_texts, public_texts, cfg: Config):
628
+ # 50% interne (oversampled) + 50% public pour spécialiser sans oublier
629
+ if internal_texts:
630
+ # On répète le corpus interne pour qu'il occupe ~50% du MLM
631
+ target_size = max(len(public_texts), 1)
632
+ repeats = max(1, target_size // max(len(internal_texts), 1))
633
+ internal_repeated = internal_texts * repeats
634
+ random.shuffle(internal_repeated)
635
+ public_texts = public_texts[:target_size]
636
+ all_texts = internal_repeated[:target_size] + public_texts
637
+ else:
638
+ all_texts = public_texts
639
+ random.shuffle(all_texts)
640
+ print(f"\n[MLM] Pré-entraînement sur {len(all_texts):,} textes "
641
+ f"(interne : {len(internal_texts):,})")
642
+
643
+ class MLMDataset(Dataset):
644
+ def __init__(self, t): self.t = t
645
+ def __len__(self): return len(self.t)
646
+ def __getitem__(self, i): return self.t[i]
647
+
648
+ def mlm_collate(batch):
649
+ enc = tokenizer(batch, padding=True, truncation=True,
650
+ max_length=cfg.max_seq_len, return_tensors="pt")
651
+ ids = enc["input_ids"].clone(); labels = ids.clone()
652
+ special = torch.zeros_like(ids, dtype=torch.bool)
653
+ for sid in tokenizer.all_special_ids: special |= (ids == sid)
654
+ prob = torch.full(ids.shape, cfg.mlm_prob)
655
+ prob.masked_fill_(special, 0.0)
656
+ masked = torch.bernoulli(prob).bool()
657
+ labels[~masked] = -100
658
+ rand = torch.rand(ids.shape)
659
+ ids[masked & (rand < 0.8)] = tokenizer.mask_token_id
660
+ rr = masked & (rand >= 0.8) & (rand < 0.9)
661
+ rt = torch.randint(0, tokenizer.vocab_size, ids.shape)
662
+ ids[rr] = rt[rr]
663
+ return ids, enc["attention_mask"], labels
664
+
665
+ loader = DataLoader(MLMDataset(all_texts), batch_size=cfg.batch_size,
666
+ shuffle=True, num_workers=cfg.num_workers,
667
+ collate_fn=mlm_collate, pin_memory=True,
668
+ drop_last=True, persistent_workers=True)
669
+ optim = AdamW(model.parameters(), lr=cfg.mlm_lr, weight_decay=0.01,
670
+ betas=(0.9, 0.98), eps=1e-6)
671
+ total_steps = len(loader) * cfg.mlm_epochs
672
+ sched = get_cosine_schedule_with_warmup(optim, int(total_steps * 0.04), total_steps)
673
+ model.train()
674
+ autocast_dtype = torch.bfloat16 if cfg.use_bf16 else torch.float16
675
+ for ep in range(cfg.mlm_epochs):
676
+ running = 0.0
677
+ pbar = tqdm(loader, desc=f"MLM ep{ep+1}/{cfg.mlm_epochs}")
678
+ for step, (ids, mask, labels) in enumerate(pbar, 1):
679
+ ids = ids.to(device, non_blocking=True)
680
+ mask = mask.to(device, non_blocking=True)
681
+ labels = labels.to(device, non_blocking=True)
682
+ optim.zero_grad(set_to_none=True)
683
+ with torch.autocast(device_type="cuda", dtype=autocast_dtype):
684
+ logits = model.forward_mlm(ids, mask)
685
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)),
686
+ labels.view(-1), ignore_index=-100)
687
+ loss.backward()
688
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
689
+ optim.step(); sched.step()
690
+ running += loss.item()
691
+ if step % 50 == 0:
692
+ pbar.set_postfix(loss=f"{running/step:.4f}",
693
+ ppl=f"{math.exp(min(20, running/step)):.1f}")
694
+ print("[MLM] Terminé.\n")
695
+
696
+
697
+ # =============================================================================
698
+ # 10. EVAL
699
+ # =============================================================================
700
+ @torch.no_grad()
701
+ def evaluate_retrieval(model, tokenizer, eval_pairs, cfg: Config):
702
+ model.eval()
703
+ autocast_dtype = torch.bfloat16 if cfg.use_bf16 else torch.float16
704
+ queries = [e["anchor"] for e in eval_pairs]
705
+ passages = [e["positive"] for e in eval_pairs]
706
+
707
+ def encode(texts):
708
+ embs = []
709
+ for i in range(0, len(texts), 32):
710
+ chunk = texts[i:i+32]
711
+ enc = tokenizer(chunk, padding=True, truncation=True,
712
+ max_length=cfg.max_seq_len, return_tensors="pt").to(device)
713
+ with torch.autocast(device_type="cuda", dtype=autocast_dtype):
714
+ e = model(enc["input_ids"], enc["attention_mask"])
715
+ embs.append(e.float())
716
+ return torch.cat(embs, dim=0)
717
+
718
+ Q = encode(queries); P = encode(passages)
719
+ sims = Q @ P.t()
720
+ N = sims.size(0)
721
+ targets = torch.arange(N, device=sims.device)
722
+ ranks = sims.argsort(dim=1, descending=True)
723
+ pos_in_rank = (ranks == targets.unsqueeze(1)).nonzero()[:, 1]
724
+ return {
725
+ "R@1": (pos_in_rank == 0).float().mean().item(),
726
+ "R@5": (pos_in_rank < 5).float().mean().item(),
727
+ "R@10": (pos_in_rank < 10).float().mean().item(),
728
+ "MRR": (1.0 / (pos_in_rank.float() + 1)).mean().item(),
729
+ }
730
+
731
+
732
+ # =============================================================================
733
+ # 11. TRAIN
734
+ # =============================================================================
735
+ def train():
736
+ tokenizer = AutoTokenizer.from_pretrained(CFG.tokenizer_name)
737
+ CFG.vocab_size = tokenizer.vocab_size
738
+ print(f"[TOK ] vocab_size = {CFG.vocab_size}")
739
+
740
+ items_all = load_doc_pairs(CFG)
741
+ n_eval = min(CFG.eval_max_size, max(2000, int(len(items_all) * 0.005)))
742
+ eval_items = items_all[:n_eval]
743
+ train_items = items_all[n_eval:]
744
+ print(f"[DATA] train={len(train_items):,} eval={len(eval_items):,}")
745
+
746
+ if CFG.use_hard_negatives:
747
+ train_items = mine_hard_negatives_multi(train_items, CFG)
748
+
749
+ n_hn = CFG.n_hard_neg if CFG.use_hard_negatives else 0
750
+ collate = make_collate_fn(tokenizer, CFG.max_seq_len, n_hn)
751
+ train_loader = DataLoader(
752
+ PairDataset(train_items, n_hn),
753
+ batch_size=CFG.batch_size, shuffle=True,
754
+ num_workers=CFG.num_workers, collate_fn=collate,
755
+ pin_memory=True, drop_last=True, persistent_workers=True,
756
+ )
757
+
758
+ model = TextEncoder(CFG).to(device)
759
+ n_params = count_parameters(model)
760
+ print(f"[MODEL] Paramètres entraînables : {n_params/1e6:.2f} M")
761
+
762
+ if CFG.do_mlm_pretrain:
763
+ # Sépare textes internes vs publics
764
+ internal_texts = []; public_texts = []
765
+ for it in train_items[:500_000]:
766
+ if it.get("_internal"):
767
+ internal_texts.append(it["anchor"])
768
+ internal_texts.append(it["positive"])
769
+ else:
770
+ public_texts.append(it["anchor"])
771
+ public_texts.append(it["positive"])
772
+ mlm_pretrain(model, tokenizer, internal_texts, public_texts, CFG)
773
+
774
+ if CFG.use_compile and hasattr(torch, "compile"):
775
+ model = torch.compile(model, mode=CFG.compile_mode)
776
+
777
+ raw_model = model._orig_mod if hasattr(model, "_orig_mod") else model
778
+ ema = EMA(raw_model, decay=CFG.ema_decay) if CFG.use_ema else None
779
+
780
+ no_decay = ["bias", "LayerNorm.weight", "ln1", "ln2", "ln_f", "emb_ln",
781
+ "gamma1", "gamma2"]
782
+ grouped = [
783
+ {"params": [p for n, p in model.named_parameters()
784
+ if "mlm_head" not in n and not any(nd in n for nd in no_decay)],
785
+ "weight_decay": CFG.weight_decay},
786
+ {"params": [p for n, p in model.named_parameters()
787
+ if "mlm_head" not in n and any(nd in n for nd in no_decay)],
788
+ "weight_decay": 0.0},
789
+ ]
790
+ optimizer = AdamW(grouped, lr=CFG.lr, betas=(0.9, 0.98), eps=1e-6)
791
+ steps_per_epoch = len(train_loader) // CFG.grad_accum_steps
792
+ total_steps = steps_per_epoch * CFG.epochs
793
+ warmup_steps = int(total_steps * CFG.warmup_ratio)
794
+ scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)
795
+ print(f"[OPTIM] total_steps={total_steps} warmup={warmup_steps}")
796
+
797
+ autocast_dtype = torch.bfloat16 if CFG.use_bf16 else torch.float16
798
+ best_mrr = 0.0
799
+ history = []
800
+
801
+ for epoch in range(1, CFG.epochs + 1):
802
+ model.train()
803
+ running_loss = running_acc = 0.0
804
+ n_seen = 0
805
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{CFG.epochs}")
806
+ optimizer.zero_grad(set_to_none=True)
807
+
808
+ for step, batch in enumerate(pbar, start=1):
809
+ if n_hn > 0:
810
+ a, p, neg = batch
811
+ neg = {k: v.to(device, non_blocking=True) for k, v in neg.items()}
812
+ else:
813
+ a, p = batch; neg = None
814
+ a = {k: v.to(device, non_blocking=True) for k, v in a.items()}
815
+ p = {k: v.to(device, non_blocking=True) for k, v in p.items()}
816
+
817
+ with torch.autocast(device_type="cuda", dtype=autocast_dtype):
818
+ emb_a = model(a["input_ids"], a["attention_mask"])
819
+ emb_p = model(p["input_ids"], p["attention_mask"])
820
+ emb_n = (model(neg["input_ids"], neg["attention_mask"])
821
+ if neg is not None else None)
822
+ loss, acc = symmetric_mnrl_multi_hn(
823
+ emb_a, emb_p, emb_n, n_hn=n_hn, temperature=CFG.temperature)
824
+ loss = loss / CFG.grad_accum_steps
825
+
826
+ loss.backward()
827
+ if step % CFG.grad_accum_steps == 0:
828
+ torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.grad_clip)
829
+ optimizer.step(); scheduler.step()
830
+ optimizer.zero_grad(set_to_none=True)
831
+ if ema is not None: ema.update(raw_model)
832
+
833
+ running_loss += loss.item() * CFG.grad_accum_steps
834
+ running_acc += acc; n_seen += 1
835
+ if step % CFG.log_every == 0:
836
+ pbar.set_postfix(loss=f"{running_loss/n_seen:.4f}",
837
+ acc=f"{running_acc/n_seen:.3f}",
838
+ lr=f"{scheduler.get_last_lr()[0]:.2e}")
839
+
840
+ backup = ema.apply_to(raw_model) if ema is not None else None
841
+ metrics = evaluate_retrieval(model, tokenizer, eval_items, CFG)
842
+ if backup is not None: ema.restore(raw_model, backup)
843
+ print(f"\n[EVAL] epoch {epoch} : R@1={metrics['R@1']:.3f} "
844
+ f"R@5={metrics['R@5']:.3f} R@10={metrics['R@10']:.3f} "
845
+ f"MRR={metrics['MRR']:.3f}")
846
+ history.append({"epoch": epoch, **metrics,
847
+ "train_loss": running_loss / max(1, n_seen)})
848
+
849
+ is_best = metrics["MRR"] > best_mrr
850
+ if is_best: best_mrr = metrics["MRR"]
851
+ if ema is not None: backup = ema.apply_to(raw_model)
852
+ state = {k: v for k, v in raw_model.state_dict().items() if "mlm_head" not in k}
853
+
854
+ if epoch % CFG.save_every_epochs == 0 or is_best or epoch == CFG.epochs:
855
+ torch.save({"epoch": epoch, "model_state": state,
856
+ "config": asdict(CFG), "metrics": metrics},
857
+ Path(CFG.save_dir) / f"model_epoch{epoch}.pt")
858
+ if is_best:
859
+ torch.save({"epoch": epoch, "model_state": state,
860
+ "config": asdict(CFG), "metrics": metrics},
861
+ Path(CFG.save_dir) / "model_best.pt")
862
+ if ema is not None: ema.restore(raw_model, backup)
863
+ print(f"[SAVE] epoch {epoch} best={'oui' if is_best else 'non'}")
864
+
865
+ with open(Path(CFG.save_dir) / "history.json", "w", encoding="utf-8") as f:
866
+ json.dump(history, f, ensure_ascii=False, indent=2)
867
+ tokenizer.save_pretrained(CFG.save_dir)
868
+ print(f"\n[OK] Best MRR = {best_mrr:.3f} -> {CFG.save_dir}/model_best.pt")
869
+
870
+
871
+ # =============================================================================
872
+ # 12. DÉMO
873
+ # =============================================================================
874
+ @torch.no_grad()
875
+ def demo():
876
+ tokenizer = AutoTokenizer.from_pretrained(CFG.save_dir)
877
+ ckpt = torch.load(Path(CFG.save_dir) / "model_best.pt", map_location=device)
878
+ saved_cfg = ckpt["config"]
879
+ cfg2 = Config(**{k: v for k, v in saved_cfg.items() if hasattr(Config, k)})
880
+ cfg2.vocab_size = tokenizer.vocab_size
881
+ model = TextEncoder(cfg2).to(device).eval()
882
+ model.load_state_dict(ckpt["model_state"], strict=False)
883
+
884
+ corpus = [
885
+ "ARTICLE 12 - Les congés payés sont acquis à raison de 2,5 jours par mois travaillé.",
886
+ "Procédure de validation des notes de frais : transmettre via le portail RH avant le 5 du mois.",
887
+ "La politique RGPD impose un délai de 72h pour notifier une violation de données.",
888
+ "Le télétravail est autorisé jusqu'à 3 jours par semaine sur accord du manager.",
889
+ "Toute facture fournisseur doit être validée par le responsable budget avant paiement.",
890
+ "Formation obligatoire sécurité incendie : 1 fois par an, traçabilité dans le SIRH.",
891
+ "L'accord d'entreprise du 15/03/2024 fixe le taux de prime annuelle à 8% du salaire brut.",
892
+ ]
893
+ queries = [
894
+ "Combien de jours de congés je gagne par mois ?",
895
+ "Comment déclarer mes notes de frais ?",
896
+ "Quel est le quota de télétravail ?",
897
+ "Quel taux de prime annuelle ?",
898
+ ]
899
+ enc = tokenizer(corpus, padding=True, truncation=True,
900
+ max_length=cfg2.max_seq_len, return_tensors="pt").to(device)
901
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
902
+ c_emb = model(enc["input_ids"], enc["attention_mask"])
903
+
904
+ print("\n[DEMO DOC-INTERNE-100M]")
905
+ for q in queries:
906
+ eq = tokenizer([q], padding=True, truncation=True,
907
+ max_length=cfg2.max_seq_len, return_tensors="pt").to(device)
908
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
909
+ q_emb = model(eq["input_ids"], eq["attention_mask"])
910
+ sims = (q_emb @ c_emb.t()).squeeze(0)
911
+ top = sims.topk(3)
912
+ print(f"\nQ : {q}")
913
+ for s, i in zip(top.values, top.indices):
914
+ print(f" ({s.item():.3f}) -> {corpus[i.item()]}")
915
+
916
+
917
+ if __name__ == "__main__":
918
+ train()
919
+ try:
920
+ demo()
921
+ except Exception as e:
922
+ print(f"[demo] {e}")
rag_boolq_400m/checkpoints/training_info.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "latest_checkpoint": "rag_boolq_400m/checkpoints/clm_epoch_28.pt",
3
+ "latest_mtime": 1777473272.7815104,
4
+ "latest_mtime_iso": "2026-04-29T14:34:32.781510+00:00",
5
+ "size_bytes": 1643359381,
6
+ "epoch": 28
7
+ }
rag_boolq_400m/local_finetuned/README.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # RAG Custom v6.2 POWER
2
+
3
+ Profil: power_400m
4
+ Paramètres: 190.66M
5
+ Sauvegarde locale complète.
rag_boolq_400m/local_finetuned/config.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "6.2",
3
+ "profile": "power_400m",
4
+ "total_params_M": 190.66,
5
+ "encoder_config": {
6
+ "vocab_size": 36000,
7
+ "max_len": 640,
8
+ "d_model": 640,
9
+ "n_heads": 10,
10
+ "n_layers": 8,
11
+ "dim_ff": 2560,
12
+ "dropout": 0.1
13
+ },
14
+ "decoder_config": {
15
+ "vocab_size": 36000,
16
+ "max_len": 640,
17
+ "d_model": 768,
18
+ "n_heads": 12,
19
+ "n_layers": 14,
20
+ "dim_ff": 3072,
21
+ "dropout": 0.1
22
+ },
23
+ "project_dir": "/workspace/rag_boolq_400m",
24
+ "local_finetuned_dir": "/workspace/rag_boolq_400m/local_finetuned",
25
+ "generation": {
26
+ "max_new_tokens": 160,
27
+ "temperature": 0.72,
28
+ "top_k": 60,
29
+ "top_p": 0.92,
30
+ "beam_size": 3
31
+ },
32
+ "retrieval": {
33
+ "use_hybrid": true,
34
+ "rag_top_k": 12,
35
+ "sim_threshold": 0.045,
36
+ "min_support": 0.28
37
+ },
38
+ "metrics": {
39
+ "retrieval": {
40
+ "recall@12": 0.933,
41
+ "n": 120
42
+ },
43
+ "demo": {
44
+ "demo_pass": 4,
45
+ "demo_total": 5,
46
+ "demo_pct": 80.0
47
+ }
48
+ },
49
+ "saved_at": "2026-04-29 14:38:20"
50
+ }
rag_boolq_400m/local_finetuned/tokenizer/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
rag_boolq_400m/local_finetuned/tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "backend": "tokenizers",
3
+ "bos_token": "[BOS]",
4
+ "cls_token": "[CLS]",
5
+ "eos_token": "[EOS]",
6
+ "mask_token": "[MASK]",
7
+ "max_length": 640,
8
+ "model_max_length": 640,
9
+ "pad_token": "[PAD]",
10
+ "sep_token": "[SEP]",
11
+ "stride": 0,
12
+ "tokenizer_class": "TokenizersBackend",
13
+ "truncation_side": "right",
14
+ "truncation_strategy": "longest_first",
15
+ "unk_token": "[UNK]"
16
+ }
rag_boolq_400m/local_finetuned/tokenizer/training_info.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "latest_checkpoint": null,
3
+ "latest_mtime": null,
4
+ "latest_mtime_iso": null,
5
+ "size_bytes": null,
6
+ "epoch": null
7
+ }
rag_boolq_400m/local_finetuned/training_info.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "latest_checkpoint": "rag_boolq_400m/local_finetuned/decoder_finetuned.pt",
3
+ "latest_mtime": 1777473500.6976762,
4
+ "latest_mtime_iso": "2026-04-29T14:38:20.697676+00:00",
5
+ "size_bytes": 620134204,
6
+ "epoch": null
7
+ }
rag_boolq_400m/local_finetuned/training_summary.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "6.2",
3
+ "profile": "power_400m",
4
+ "total_params_M": 190.66,
5
+ "encoder_config": {
6
+ "vocab_size": 36000,
7
+ "max_len": 640,
8
+ "d_model": 640,
9
+ "n_heads": 10,
10
+ "n_layers": 8,
11
+ "dim_ff": 2560,
12
+ "dropout": 0.1
13
+ },
14
+ "decoder_config": {
15
+ "vocab_size": 36000,
16
+ "max_len": 640,
17
+ "d_model": 768,
18
+ "n_heads": 12,
19
+ "n_layers": 14,
20
+ "dim_ff": 3072,
21
+ "dropout": 0.1
22
+ },
23
+ "project_dir": "/workspace/rag_boolq_400m",
24
+ "local_finetuned_dir": "/workspace/rag_boolq_400m/local_finetuned",
25
+ "generation": {
26
+ "max_new_tokens": 160,
27
+ "temperature": 0.72,
28
+ "top_k": 60,
29
+ "top_p": 0.92,
30
+ "beam_size": 3
31
+ },
32
+ "retrieval": {
33
+ "use_hybrid": true,
34
+ "rag_top_k": 12,
35
+ "sim_threshold": 0.045,
36
+ "min_support": 0.28
37
+ },
38
+ "metrics": {
39
+ "retrieval": {
40
+ "recall@12": 0.933,
41
+ "n": 120
42
+ },
43
+ "demo": {
44
+ "demo_pass": 4,
45
+ "demo_total": 5,
46
+ "demo_pct": 80.0
47
+ }
48
+ },
49
+ "saved_at": "2026-04-29 14:38:20"
50
+ }
rag_boolq_400m/models/custom_bpe_v6_2.json ADDED
The diff for this file is too large to render. See raw diff
 
rag_boolq_400m/models/tokenizer_fast/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
rag_boolq_400m/models/tokenizer_fast/tokenizer_config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "backend": "tokenizers",
3
+ "bos_token": "[BOS]",
4
+ "cls_token": "[CLS]",
5
+ "eos_token": "[EOS]",
6
+ "mask_token": "[MASK]",
7
+ "max_length": 640,
8
+ "model_max_length": 640,
9
+ "pad_token": "[PAD]",
10
+ "sep_token": "[SEP]",
11
+ "stride": 0,
12
+ "tokenizer_class": "TokenizersBackend",
13
+ "truncation_side": "right",
14
+ "truncation_strategy": "longest_first",
15
+ "unk_token": "[UNK]"
16
+ }
rag_boolq_400m/models/tokenizer_fast/training_info.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "latest_checkpoint": null,
3
+ "latest_mtime": null,
4
+ "latest_mtime_iso": null,
5
+ "size_bytes": null,
6
+ "epoch": null
7
+ }
rag_boolq_400m/models/training_info.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "latest_checkpoint": "rag_boolq_400m/models/decoder_v6_2.pt",
3
+ "latest_mtime": 1777473498.0736282,
4
+ "latest_mtime_iso": "2026-04-29T14:38:18.073628+00:00",
5
+ "size_bytes": 620133245,
6
+ "epoch": null
7
+ }
rag_boolq_400m/summary_v6_2.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "6.2",
3
+ "profile": "power_400m",
4
+ "vocab": 36000,
5
+ "max_len": 640,
6
+ "chunks": 108751,
7
+ "datasets": 23,
8
+ "total_params_M": 190.66,
9
+ "encoder_params_M": 63.29,
10
+ "decoder_params_M": 127.37,
11
+ "epochs": {
12
+ "mlm": 18,
13
+ "retriever": 16,
14
+ "clm": 28
15
+ },
16
+ "grad_accum": 12,
17
+ "retrieval": {
18
+ "recall@12": 0.933,
19
+ "n": 120
20
+ },
21
+ "demo": {
22
+ "demo_pass": 4,
23
+ "demo_total": 5,
24
+ "demo_pct": 80.0
25
+ },
26
+ "local_finetuned_dir": "/workspace/rag_boolq_400m/local_finetuned",
27
+ "project_dir": "/workspace/rag_boolq_400m"
28
+ }
rag_v6_2_400m_domains/summary_v6_2.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "6.2",
3
+ "profile": "power_400m",
4
+ "vocab": 48000,
5
+ "max_len": 1024,
6
+ "chunks": 108849,
7
+ "datasets": 37,
8
+ "dataset_groups": [
9
+ "single"
10
+ ],
11
+ "max_texts_per_dataset": 0,
12
+ "max_total_docs": 0,
13
+ "total_params_M": 450.67,
14
+ "encoder_params_M": 123.35,
15
+ "decoder_params_M": 327.32,
16
+ "epochs": {
17
+ "mlm": 18,
18
+ "retriever": 16,
19
+ "clm": 28
20
+ },
21
+ "grad_accum": 12,
22
+ "retrieval": {
23
+ "recall@12": 0.867,
24
+ "n": 120
25
+ },
26
+ "demo": {
27
+ "demo_pass": 4,
28
+ "demo_total": 5,
29
+ "demo_pct": 80.0
30
+ },
31
+ "local_finetuned_dir": "/workspace/rag_v6_2_400m_domains/local_finetuned",
32
+ "project_dir": "/workspace/rag_v6_2_400m_domains"
33
+ }
security/cyber_unified.py ADDED
@@ -0,0 +1,1370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ train_all_models_10datasets.py
6
+
7
+ Script unique pour entraîner localement :
8
+
9
+ 1. SecurityLLM -> LoRA SFT sur 10 datasets cyber
10
+ 2. Llama-Phishsense-1B -> LoRA SFT sur 10 datasets cyber/phishing
11
+ 3. CySecBERT -> classifier phishing
12
+ 4. SecBERT -> classifier phishing
13
+
14
+ Par défaut :
15
+ - 10 datasets SFT pour les LLM
16
+ - 3 epochs pour les LLM
17
+ - 3 epochs pour BERT/SecBERT
18
+ - entraînement séquentiel pour éviter de saturer RAM/GPU
19
+
20
+ Structure attendue :
21
+
22
+ security/
23
+ ├── train_all_models_10datasets.py
24
+ ├── models/
25
+ │ ├── SecurityLLM/
26
+ │ ├── Llama-Phishsense-1B/
27
+ │ ├── CySecBERT/
28
+ │ └── SecBERT/
29
+ ├── datasets/
30
+ │ └── cybersecurity-rules/
31
+ └── outputs/
32
+ """
33
+
34
+ import os
35
+ import gc
36
+ import json
37
+ import argparse
38
+ import inspect
39
+ from pathlib import Path
40
+ from typing import Dict, Any, List, Tuple, Optional
41
+
42
+ import numpy as np
43
+ import torch
44
+
45
+ from datasets import load_dataset, Dataset, concatenate_datasets
46
+
47
+ from transformers import (
48
+ AutoTokenizer,
49
+ AutoModelForCausalLM,
50
+ AutoModelForSequenceClassification,
51
+ TrainingArguments,
52
+ Trainer,
53
+ DataCollatorForLanguageModeling,
54
+ )
55
+
56
+ from peft import (
57
+ LoraConfig,
58
+ get_peft_model,
59
+ TaskType,
60
+ PeftModel,
61
+ )
62
+
63
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
64
+
65
+
66
+ # ============================================================
67
+ # Chemins locaux
68
+ # ============================================================
69
+
70
+ BASE_DIR = Path(__file__).resolve().parent
71
+
72
+ DEFAULT_MODELS = {
73
+ "securityllm": BASE_DIR / "models" / "SecurityLLM",
74
+ "phishsense": BASE_DIR / "models" / "Llama-Phishsense-1B",
75
+ "cysecbert": BASE_DIR / "models" / "CySecBERT",
76
+ "secbert": BASE_DIR / "models" / "SecBERT",
77
+ }
78
+
79
+ DEFAULT_OUTPUT_DIR = BASE_DIR / "outputs"
80
+
81
+
82
+ # ============================================================
83
+ # 10 datasets pour les LLM
84
+ # ============================================================
85
+
86
+ MULTI_CYBER_DATASETS = [
87
+ {
88
+ "name": "local_cybersecurity_rules",
89
+ "dataset": str(BASE_DIR / "datasets" / "cybersecurity-rules"),
90
+ "max_samples": 0,
91
+ },
92
+ {
93
+ "name": "phishing_email_dataset",
94
+ "dataset": "zefang-liu/phishing-email-dataset",
95
+ "max_samples": 0,
96
+ },
97
+ {
98
+ "name": "trendyol_cybersecurity_instruction",
99
+ "dataset": "Trendyol/Trendyol-Cybersecurity-Instruction-Tuning-Dataset",
100
+ "max_samples": 20000,
101
+ },
102
+ {
103
+ "name": "cybersecurity_32k_instruction",
104
+ "dataset": "Vanessasml/cybersecurity_32k_instruction_input_output",
105
+ "max_samples": 12000,
106
+ },
107
+ {
108
+ "name": "cybersecurity_sharegpt",
109
+ "dataset": "ChaoticNeutrals/Cybersecurity-ShareGPT",
110
+ "max_samples": 12000,
111
+ },
112
+ {
113
+ "name": "cybersecurity_eval",
114
+ "dataset": "CyberNative/CyberSecurityEval",
115
+ "max_samples": 1000,
116
+ },
117
+ {
118
+ "name": "cybersecurity_corpus",
119
+ "dataset": "zeroshot/cybersecurity-corpus",
120
+ "max_samples": 1000,
121
+ },
122
+ {
123
+ "name": "practical_ai_for_cybersecurity",
124
+ "dataset": "Falah/Practical_AI_for_Cybersecurity",
125
+ "max_samples": 1000,
126
+ },
127
+ {
128
+ "name": "cybersecurity_llm_cve",
129
+ "dataset": "Bouquets/Cybersecurity-LLM-CVE",
130
+ "max_samples": 12000,
131
+ },
132
+ {
133
+ "name": "cve_llm_training",
134
+ "dataset": "morpheuslord/cve-llm-training",
135
+ "max_samples": 12000,
136
+ },
137
+ ]
138
+
139
+ DEFAULT_PHISHING_DATASET = "zefang-liu/phishing-email-dataset"
140
+
141
+
142
+ # ============================================================
143
+ # Utilitaires généraux
144
+ # ============================================================
145
+
146
+ def log(title: str):
147
+ print("\n" + "=" * 100)
148
+ print(title)
149
+ print("=" * 100)
150
+
151
+
152
+ def set_seed(seed: int = 42):
153
+ np.random.seed(seed)
154
+ torch.manual_seed(seed)
155
+ if torch.cuda.is_available():
156
+ torch.cuda.manual_seed_all(seed)
157
+
158
+
159
+ def cleanup_memory():
160
+ gc.collect()
161
+ if torch.cuda.is_available():
162
+ torch.cuda.empty_cache()
163
+
164
+
165
+ def check_path(path: Path, name: str):
166
+ if not path.exists():
167
+ raise FileNotFoundError(f"{name} introuvable : {path}")
168
+
169
+
170
+ def make_training_args(**kwargs):
171
+ """
172
+ Compatibilité avec plusieurs versions transformers.
173
+ Certaines versions utilisent evaluation_strategy, d'autres eval_strategy.
174
+ """
175
+ sig = inspect.signature(TrainingArguments.__init__)
176
+ allowed = set(sig.parameters.keys())
177
+
178
+ clean = {}
179
+
180
+ for k, v in kwargs.items():
181
+ if k in allowed:
182
+ clean[k] = v
183
+
184
+ if "evaluation_strategy" in kwargs and "eval_strategy" in allowed:
185
+ clean["eval_strategy"] = kwargs["evaluation_strategy"]
186
+
187
+ return TrainingArguments(**clean)
188
+
189
+
190
+ def reduce_dataset(ds: Dataset, max_samples: int = 0) -> Dataset:
191
+ if max_samples and max_samples > 0 and len(ds) > max_samples:
192
+ return ds.select(range(max_samples))
193
+ return ds
194
+
195
+
196
+ # ============================================================
197
+ # Chargement dataset local ou HF
198
+ # ============================================================
199
+
200
+ def load_local_or_hf_dataset(dataset_ref: str, split: str = "train") -> Dataset:
201
+ """
202
+ Charge :
203
+ - dossier local contenant .jsonl/.json/.csv/.parquet
204
+ - fichier local
205
+ - dataset Hugging Face
206
+ """
207
+
208
+ path = Path(dataset_ref)
209
+
210
+ if path.exists():
211
+ if path.is_file():
212
+ suffix = path.suffix.lower()
213
+ files = [str(path)]
214
+
215
+ if suffix in [".json", ".jsonl"]:
216
+ return load_dataset("json", data_files=files, split=split)
217
+ if suffix == ".csv":
218
+ return load_dataset("csv", data_files=files, split=split)
219
+ if suffix == ".parquet":
220
+ return load_dataset("parquet", data_files=files, split=split)
221
+
222
+ raise RuntimeError(f"Format fichier non supporté : {path}")
223
+
224
+ jsonl_files = list(path.rglob("*.jsonl"))
225
+ json_files = list(path.rglob("*.json"))
226
+ csv_files = list(path.rglob("*.csv"))
227
+ parquet_files = list(path.rglob("*.parquet"))
228
+
229
+ if jsonl_files:
230
+ return load_dataset(
231
+ "json",
232
+ data_files=[str(f) for f in jsonl_files],
233
+ split=split,
234
+ )
235
+
236
+ if json_files:
237
+ return load_dataset(
238
+ "json",
239
+ data_files=[str(f) for f in json_files],
240
+ split=split,
241
+ )
242
+
243
+ if csv_files:
244
+ return load_dataset(
245
+ "csv",
246
+ data_files=[str(f) for f in csv_files],
247
+ split=split,
248
+ )
249
+
250
+ if parquet_files:
251
+ return load_dataset(
252
+ "parquet",
253
+ data_files=[str(f) for f in parquet_files],
254
+ split=split,
255
+ )
256
+
257
+ raise RuntimeError(f"Aucun fichier dataset lisible trouvé dans : {path}")
258
+
259
+ return load_dataset(dataset_ref, split=split)
260
+
261
+
262
+ # ============================================================
263
+ # Conversion multi-formats vers SFT text
264
+ # ============================================================
265
+
266
+ def safe_str(x) -> str:
267
+ if x is None:
268
+ return ""
269
+ return str(x).strip()
270
+
271
+
272
+ def row_to_unified_sft_text(row: Dict[str, Any]) -> str:
273
+ """
274
+ Convertit plusieurs formats HF en format SFT.
275
+
276
+ Formats supportés :
277
+ - messages
278
+ - instruction/input/output
279
+ - system/user/assistant
280
+ - question/answer
281
+ - prompt/response
282
+ - text/label
283
+ - CVE-like
284
+ - fallback toutes colonnes
285
+ """
286
+
287
+ # 1. Format messages
288
+ if "messages" in row and row["messages"]:
289
+ try:
290
+ messages = row["messages"]
291
+ parts = []
292
+
293
+ for msg in messages:
294
+ if isinstance(msg, dict):
295
+ role = safe_str(msg.get("role", "user")).upper()
296
+ content = safe_str(msg.get("content", ""))
297
+ if content:
298
+ parts.append(f"{role}:\n{content}")
299
+
300
+ if parts:
301
+ return "\n\n".join(parts)
302
+ except Exception:
303
+ pass
304
+
305
+ # 2. Format system/user/assistant
306
+ system = safe_str(row.get("system", ""))
307
+ user = safe_str(row.get("user", ""))
308
+ assistant = safe_str(row.get("assistant", ""))
309
+
310
+ if user and assistant:
311
+ if not system:
312
+ system = "Tu es un assistant cybersécurité défensif."
313
+
314
+ return f"""### System:
315
+ {system}
316
+
317
+ ### User:
318
+ {user}
319
+
320
+ ### Assistant:
321
+ {assistant}"""
322
+
323
+ # 3. Format instruction/input/output
324
+ instruction = safe_str(row.get("instruction", ""))
325
+ input_text = safe_str(row.get("input", ""))
326
+ output = safe_str(row.get("output", ""))
327
+
328
+ if instruction and output:
329
+ user_content = instruction
330
+ if input_text:
331
+ user_content += "\n\nContexte :\n" + input_text
332
+
333
+ return f"""### System:
334
+ Tu es un assistant cybersécurité défensif.
335
+ Tu privilégies l'analyse, la détection, la remédiation et la prévention.
336
+
337
+ ### User:
338
+ {user_content}
339
+
340
+ ### Assistant:
341
+ {output}"""
342
+
343
+ # 4. Format prompt / response / completion
344
+ prompt_keys = ["prompt", "Prompt", "query", "Query", "question", "Question", "problem"]
345
+ answer_keys = ["response", "Response", "completion", "Completion", "answer", "Answer", "solution"]
346
+
347
+ prompt = ""
348
+ answer = ""
349
+
350
+ for k in prompt_keys:
351
+ if k in row and safe_str(row.get(k)):
352
+ prompt = safe_str(row.get(k))
353
+ break
354
+
355
+ for k in answer_keys:
356
+ if k in row and safe_str(row.get(k)):
357
+ answer = safe_str(row.get(k))
358
+ break
359
+
360
+ if prompt and answer:
361
+ return f"""### System:
362
+ Tu es un assistant cybersécurité défensif.
363
+
364
+ ### User:
365
+ {prompt}
366
+
367
+ ### Assistant:
368
+ {answer}"""
369
+
370
+ # 5. Format CVE-like
371
+ cve_keys = ["cve", "CVE", "cve_id", "CVE_ID", "id"]
372
+ desc_keys = ["description", "Description", "details", "Details", "summary"]
373
+
374
+ cve_id = ""
375
+ desc = ""
376
+
377
+ for k in cve_keys:
378
+ if k in row and safe_str(row.get(k)):
379
+ cve_id = safe_str(row.get(k))
380
+ break
381
+
382
+ for k in desc_keys:
383
+ if k in row and safe_str(row.get(k)):
384
+ desc = safe_str(row.get(k))
385
+ break
386
+
387
+ if cve_id or desc:
388
+ raw = "\n".join([f"{k}: {v}" for k, v in row.items() if v is not None])
389
+
390
+ return f"""### System:
391
+ Tu es un assistant cybersécurité défensif spécialisé en vulnérabilités.
392
+
393
+ ### User:
394
+ Analyse cette vulnérabilité et donne un résumé défensif, impact, priorité et remédiations.
395
+
396
+ {raw}
397
+
398
+ ### Assistant:
399
+ """
400
+
401
+ # 6. Format phishing / classification
402
+ text_keys = [
403
+ "text",
404
+ "Text",
405
+ "email",
406
+ "Email",
407
+ "Email Text",
408
+ "body",
409
+ "Body",
410
+ "message",
411
+ "Message",
412
+ "content",
413
+ "Content",
414
+ "url",
415
+ "URL",
416
+ "text_combined",
417
+ "sentence",
418
+ ]
419
+
420
+ label_keys = [
421
+ "label",
422
+ "Label",
423
+ "class",
424
+ "Class",
425
+ "category",
426
+ "Category",
427
+ "is_phishing",
428
+ "phishing",
429
+ "status",
430
+ "type",
431
+ ]
432
+
433
+ text = ""
434
+ label = ""
435
+
436
+ for k in text_keys:
437
+ if k in row and safe_str(row.get(k)):
438
+ text = safe_str(row.get(k))
439
+ break
440
+
441
+ for k in label_keys:
442
+ if k in row and row.get(k) is not None:
443
+ label = safe_str(row.get(k))
444
+ break
445
+
446
+ if text:
447
+ return f"""### System:
448
+ Tu es un assistant défensif spécialisé en cybersécurité.
449
+
450
+ ### User:
451
+ Analyse ce contenu dans un contexte cybersécurité.
452
+ Donne un verdict, les indices, le risque et les actions recommandées.
453
+
454
+ {text}
455
+
456
+ ### Assistant:
457
+ Label brut du dataset : {label}
458
+
459
+ Analyse défensive :
460
+ - Verdict :
461
+ - Risque :
462
+ - Indices :
463
+ - Actions recommandées :
464
+ """
465
+
466
+ # 7. Fallback général
467
+ raw = "\n".join([f"{k}: {v}" for k, v in row.items() if v is not None])
468
+
469
+ return f"""### System:
470
+ Tu es un assistant cybersécurité défensif.
471
+
472
+ ### User:
473
+ Analyse ce contenu cyber :
474
+
475
+ {raw}
476
+
477
+ ### Assistant:
478
+ """
479
+
480
+
481
+ def load_one_sft_dataset(
482
+ dataset_ref: str,
483
+ name: str,
484
+ split: str = "train",
485
+ max_samples: int = 0,
486
+ ) -> Optional[Dataset]:
487
+ print(f"\n[+] Chargement dataset SFT : {name}")
488
+ print(f" Source : {dataset_ref}")
489
+
490
+ try:
491
+ ds = load_local_or_hf_dataset(str(dataset_ref), split=split)
492
+ except Exception as e:
493
+ print(f"[ERREUR] Dataset ignoré : {name}")
494
+ print(f"Raison : {repr(e)}")
495
+ return None
496
+
497
+ try:
498
+ ds = reduce_dataset(ds, max_samples=max_samples)
499
+ print("[OK] Lignes :", len(ds))
500
+ print("[OK] Colonnes :", ds.column_names)
501
+ print("[OK] Exemple brut :", ds[0])
502
+ except Exception as e:
503
+ print(f"[ERREUR] Lecture impossible : {name}")
504
+ print(f"Raison : {repr(e)}")
505
+ return None
506
+
507
+ def mapper(row):
508
+ return {"text": row_to_unified_sft_text(row)}
509
+
510
+ try:
511
+ ds = ds.map(mapper, remove_columns=ds.column_names)
512
+ return ds
513
+ except Exception as e:
514
+ print(f"[ERREUR] Conversion SFT impossible : {name}")
515
+ print(f"Raison : {repr(e)}")
516
+ return None
517
+
518
+
519
+ def load_multi_sft_dataset(
520
+ dataset_configs: List[Dict[str, Any]],
521
+ split: str = "train",
522
+ global_max_samples: int = 0,
523
+ ) -> Dataset:
524
+ datasets_list = []
525
+
526
+ for cfg in dataset_configs:
527
+ ds = load_one_sft_dataset(
528
+ dataset_ref=cfg["dataset"],
529
+ name=cfg["name"],
530
+ split=split,
531
+ max_samples=cfg.get("max_samples", 0),
532
+ )
533
+
534
+ if ds is not None and len(ds) > 0:
535
+ datasets_list.append(ds)
536
+
537
+ if not datasets_list:
538
+ raise RuntimeError("Aucun dataset SFT n'a pu être chargé.")
539
+
540
+ merged = concatenate_datasets(datasets_list)
541
+ merged = merged.shuffle(seed=42)
542
+
543
+ if global_max_samples and global_max_samples > 0 and len(merged) > global_max_samples:
544
+ merged = merged.select(range(global_max_samples))
545
+
546
+ print("\n[OK] Dataset SFT fusionné.")
547
+ print("[OK] Total lignes :", len(merged))
548
+ print("[OK] Exemple final :", merged[0])
549
+
550
+ return merged
551
+
552
+
553
+ def tokenize_text_sft_dataset(
554
+ ds: Dataset,
555
+ tokenizer,
556
+ max_length: int,
557
+ ) -> Dataset:
558
+ def mapper(row):
559
+ encoded = tokenizer(
560
+ row["text"],
561
+ truncation=True,
562
+ max_length=max_length,
563
+ padding=False,
564
+ )
565
+ encoded["labels"] = encoded["input_ids"].copy()
566
+ return encoded
567
+
568
+ return ds.map(mapper, remove_columns=ds.column_names)
569
+
570
+
571
+ # ============================================================
572
+ # LoRA pour LLM
573
+ # ============================================================
574
+
575
+ def infer_lora_targets(model) -> List[str]:
576
+ """
577
+ Détection automatique des modules LoRA.
578
+ Compatible Llama/Mistral/Zephyr-like et plusieurs architectures.
579
+ """
580
+
581
+ common = [
582
+ "q_proj",
583
+ "k_proj",
584
+ "v_proj",
585
+ "o_proj",
586
+ "gate_proj",
587
+ "up_proj",
588
+ "down_proj",
589
+ "query",
590
+ "key",
591
+ "value",
592
+ "dense",
593
+ "fc1",
594
+ "fc2",
595
+ ]
596
+
597
+ found = set()
598
+
599
+ for name, module in model.named_modules():
600
+ last = name.split(".")[-1]
601
+ if last in common:
602
+ found.add(last)
603
+
604
+ found = sorted(found)
605
+
606
+ if not found:
607
+ raise RuntimeError(
608
+ "Impossible de détecter automatiquement les target_modules LoRA."
609
+ )
610
+
611
+ print("[+] Modules LoRA détectés :", found)
612
+ return found
613
+
614
+
615
+ def train_llm_lora_multi_dataset(
616
+ model_path: Path,
617
+ dataset_configs: List[Dict[str, Any]],
618
+ output_dir: Path,
619
+ split: str,
620
+ global_max_samples: int,
621
+ epochs: float,
622
+ batch_size: int,
623
+ grad_accum: int,
624
+ lr: float,
625
+ max_length: int,
626
+ save_steps: int,
627
+ logging_steps: int,
628
+ lora_r: int,
629
+ lora_alpha: int,
630
+ lora_dropout: float,
631
+ skip_existing: bool,
632
+ ):
633
+ log(f"ENTRAÎNEMENT LLM LoRA MULTI-DATASETS : {model_path.name}")
634
+
635
+ check_path(model_path, f"Modèle {model_path.name}")
636
+
637
+ if skip_existing and output_dir.exists() and (output_dir / "adapter_config.json").exists():
638
+ print(f"[SKIP] Adapter LoRA déjà présent : {output_dir}")
639
+ return
640
+
641
+ output_dir.mkdir(parents=True, exist_ok=True)
642
+
643
+ print("[+] Chargement tokenizer...")
644
+ tokenizer = AutoTokenizer.from_pretrained(
645
+ str(model_path),
646
+ local_files_only=True,
647
+ trust_remote_code=True,
648
+ )
649
+
650
+ if tokenizer.pad_token is None:
651
+ tokenizer.pad_token = tokenizer.eos_token
652
+
653
+ print("[+] Chargement modèle...")
654
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
655
+
656
+ model = AutoModelForCausalLM.from_pretrained(
657
+ str(model_path),
658
+ local_files_only=True,
659
+ trust_remote_code=True,
660
+ torch_dtype=dtype,
661
+ device_map="auto" if torch.cuda.is_available() else None,
662
+ )
663
+
664
+ if not torch.cuda.is_available():
665
+ model.to("cpu")
666
+
667
+ model.config.use_cache = False
668
+
669
+ if hasattr(model, "gradient_checkpointing_enable"):
670
+ model.gradient_checkpointing_enable()
671
+
672
+ target_modules = infer_lora_targets(model)
673
+
674
+ lora_config = LoraConfig(
675
+ r=lora_r,
676
+ lora_alpha=lora_alpha,
677
+ lora_dropout=lora_dropout,
678
+ bias="none",
679
+ task_type=TaskType.CAUSAL_LM,
680
+ target_modules=target_modules,
681
+ )
682
+
683
+ print("[+] Application LoRA...")
684
+ model = get_peft_model(model, lora_config)
685
+ model.print_trainable_parameters()
686
+
687
+ print("[+] Chargement + fusion des 10 datasets...")
688
+ ds = load_multi_sft_dataset(
689
+ dataset_configs=dataset_configs,
690
+ split=split,
691
+ global_max_samples=global_max_samples,
692
+ )
693
+
694
+ print("[+] Tokenisation...")
695
+ tokenized = tokenize_text_sft_dataset(
696
+ ds,
697
+ tokenizer=tokenizer,
698
+ max_length=max_length,
699
+ )
700
+
701
+ data_collator = DataCollatorForLanguageModeling(
702
+ tokenizer=tokenizer,
703
+ mlm=False,
704
+ )
705
+
706
+ use_fp16 = torch.cuda.is_available()
707
+ use_bf16 = False
708
+
709
+ if torch.cuda.is_available():
710
+ try:
711
+ use_bf16 = torch.cuda.is_bf16_supported()
712
+ use_fp16 = not use_bf16
713
+ except Exception:
714
+ use_bf16 = False
715
+ use_fp16 = True
716
+
717
+ training_args = make_training_args(
718
+ output_dir=str(output_dir),
719
+ num_train_epochs=epochs,
720
+ per_device_train_batch_size=batch_size,
721
+ gradient_accumulation_steps=grad_accum,
722
+ learning_rate=lr,
723
+ fp16=use_fp16,
724
+ bf16=use_bf16,
725
+ logging_steps=logging_steps,
726
+ save_steps=save_steps,
727
+ save_total_limit=2,
728
+ report_to="none",
729
+ optim="adamw_torch",
730
+ warmup_ratio=0.03,
731
+ lr_scheduler_type="cosine",
732
+ remove_unused_columns=False,
733
+ )
734
+
735
+ trainer = Trainer(
736
+ model=model,
737
+ args=training_args,
738
+ train_dataset=tokenized,
739
+ data_collator=data_collator,
740
+ )
741
+
742
+ print("[+] Début entraînement LoRA...")
743
+ trainer.train()
744
+
745
+ print("[+] Sauvegarde adapter LoRA :", output_dir)
746
+ model.save_pretrained(str(output_dir))
747
+ tokenizer.save_pretrained(str(output_dir))
748
+
749
+ del trainer
750
+ del model
751
+ del tokenizer
752
+ cleanup_memory()
753
+
754
+ print("[OK] Entraînement LoRA terminé :", output_dir)
755
+
756
+
757
+ # ============================================================
758
+ # BERT classification
759
+ # ============================================================
760
+
761
+
762
+ def detect_text_label_columns(ds: Dataset) -> Tuple[Optional[str], Optional[str]]:
763
+ """
764
+ Détection robuste des colonnes texte/label.
765
+ Compatible avec zefang-liu/phishing-email-dataset :
766
+ - Email Text
767
+ - Email Type
768
+ """
769
+
770
+ cols = ds.column_names
771
+ lower_map = {c.lower().strip(): c for c in cols}
772
+
773
+ text_candidates = [
774
+ "text",
775
+ "Text",
776
+ "email",
777
+ "Email",
778
+ "Email Text",
779
+ "email text",
780
+ "body",
781
+ "Body",
782
+ "message",
783
+ "Message",
784
+ "content",
785
+ "Content",
786
+ "url",
787
+ "URL",
788
+ "text_combined",
789
+ "sentence",
790
+ ]
791
+
792
+ label_candidates = [
793
+ "label",
794
+ "Label",
795
+ "class",
796
+ "Class",
797
+ "category",
798
+ "Category",
799
+ "Email Type",
800
+ "email type",
801
+ "type",
802
+ "Type",
803
+ "is_phishing",
804
+ "phishing",
805
+ "status",
806
+ "target",
807
+ ]
808
+
809
+ def find_column(candidates):
810
+ # Match exact insensible à la casse
811
+ for cand in candidates:
812
+ key = cand.lower().strip()
813
+ if key in lower_map:
814
+ return lower_map[key]
815
+
816
+ # Match partiel
817
+ for col in cols:
818
+ col_l = col.lower().strip()
819
+ for cand in candidates:
820
+ cand_l = cand.lower().strip()
821
+ if cand_l in col_l or col_l in cand_l:
822
+ return col
823
+
824
+ return None
825
+
826
+ text_col = find_column(text_candidates)
827
+ label_col = find_column(label_candidates)
828
+
829
+ return text_col, label_col
830
+
831
+ def normalize_labels(
832
+ ds: Dataset,
833
+ label_col: str,
834
+ ) -> Tuple[Dataset, Dict[str, int], Dict[int, str]]:
835
+ labels_raw = [str(x) for x in ds[label_col]]
836
+ unique = sorted(list(set(labels_raw)))
837
+
838
+ label2id = {label: i for i, label in enumerate(unique)}
839
+ id2label = {i: label for label, i in label2id.items()}
840
+
841
+ def mapper(row):
842
+ row["labels"] = label2id[str(row[label_col])]
843
+ return row
844
+
845
+ ds = ds.map(mapper)
846
+ return ds, label2id, id2label
847
+
848
+
849
+ def compute_metrics(eval_pred):
850
+ logits, labels = eval_pred
851
+ preds = np.argmax(logits, axis=-1)
852
+
853
+ precision, recall, f1, _ = precision_recall_fscore_support(
854
+ labels,
855
+ preds,
856
+ average="weighted",
857
+ zero_division=0,
858
+ )
859
+
860
+ acc = accuracy_score(labels, preds)
861
+
862
+ return {
863
+ "accuracy": acc,
864
+ "precision": precision,
865
+ "recall": recall,
866
+ "f1": f1,
867
+ }
868
+
869
+
870
+ def train_bert_classifier(
871
+ model_path: Path,
872
+ dataset_ref: str,
873
+ output_dir: Path,
874
+ split: str,
875
+ max_samples: int,
876
+ epochs: float,
877
+ batch_size: int,
878
+ lr: float,
879
+ max_length: int,
880
+ logging_steps: int,
881
+ skip_existing: bool,
882
+ ):
883
+ log(f"ENTRAÎNEMENT BERT CLASSIFIER : {model_path.name}")
884
+
885
+ check_path(model_path, f"Modèle {model_path.name}")
886
+
887
+ if skip_existing and output_dir.exists() and (output_dir / "config.json").exists():
888
+ print(f"[SKIP] Classifier déjà présent : {output_dir}")
889
+ return
890
+
891
+ output_dir.mkdir(parents=True, exist_ok=True)
892
+
893
+ print("[+] Chargement dataset classification :", dataset_ref)
894
+
895
+ ds = load_local_or_hf_dataset(str(dataset_ref), split=split)
896
+ ds = reduce_dataset(ds, max_samples=max_samples)
897
+
898
+ print("[+] Nombre d'exemples :", len(ds))
899
+ print("[+] Colonnes :", ds.column_names)
900
+ print("[+] Exemple brut :", ds[0])
901
+
902
+ text_col, label_col = detect_text_label_columns(ds)
903
+
904
+ if not text_col or not label_col:
905
+ raise ValueError(
906
+ "Impossible de détecter les colonnes texte/label.\n"
907
+ f"Colonnes disponibles : {ds.column_names}"
908
+ )
909
+
910
+ print("[+] Colonne texte :", text_col)
911
+ print("[+] Colonne label :", label_col)
912
+
913
+ ds, label2id, id2label = normalize_labels(ds, label_col)
914
+
915
+ split_ds = ds.train_test_split(test_size=0.15, seed=42)
916
+ train_ds = split_ds["train"]
917
+ eval_ds = split_ds["test"]
918
+
919
+ print("[+] Train size :", len(train_ds))
920
+ print("[+] Eval size :", len(eval_ds))
921
+ print("[+] Labels :", label2id)
922
+
923
+ tokenizer = AutoTokenizer.from_pretrained(
924
+ str(model_path),
925
+ local_files_only=True,
926
+ trust_remote_code=True,
927
+ )
928
+
929
+ def tok(batch):
930
+ return tokenizer(
931
+ batch[text_col],
932
+ truncation=True,
933
+ padding="max_length",
934
+ max_length=max_length,
935
+ )
936
+
937
+ train_ds = train_ds.map(tok, batched=True)
938
+ eval_ds = eval_ds.map(tok, batched=True)
939
+
940
+ keep = ["input_ids", "attention_mask", "labels"]
941
+
942
+ train_ds = train_ds.remove_columns(
943
+ [c for c in train_ds.column_names if c not in keep]
944
+ )
945
+
946
+ eval_ds = eval_ds.remove_columns(
947
+ [c for c in eval_ds.column_names if c not in keep]
948
+ )
949
+
950
+ model = AutoModelForSequenceClassification.from_pretrained(
951
+ str(model_path),
952
+ local_files_only=True,
953
+ trust_remote_code=True,
954
+ num_labels=len(label2id),
955
+ label2id=label2id,
956
+ id2label=id2label,
957
+ ignore_mismatched_sizes=True,
958
+ )
959
+
960
+ use_fp16 = torch.cuda.is_available()
961
+
962
+ training_args = make_training_args(
963
+ output_dir=str(output_dir),
964
+ num_train_epochs=epochs,
965
+ per_device_train_batch_size=batch_size,
966
+ per_device_eval_batch_size=batch_size,
967
+ learning_rate=lr,
968
+ fp16=use_fp16,
969
+ logging_steps=logging_steps,
970
+ evaluation_strategy="epoch",
971
+ save_strategy="epoch",
972
+ save_total_limit=2,
973
+ report_to="none",
974
+ load_best_model_at_end=True,
975
+ metric_for_best_model="f1",
976
+ )
977
+
978
+ trainer = Trainer(
979
+ model=model,
980
+ args=training_args,
981
+ train_dataset=train_ds,
982
+ eval_dataset=eval_ds,
983
+ compute_metrics=compute_metrics,
984
+ )
985
+
986
+ print("[+] Début entraînement classifier...")
987
+ trainer.train()
988
+
989
+ print("[+] Évaluation finale...")
990
+ metrics = trainer.evaluate()
991
+ print(metrics)
992
+
993
+ print("[+] Sauvegarde classifier :", output_dir)
994
+ trainer.save_model(str(output_dir))
995
+ tokenizer.save_pretrained(str(output_dir))
996
+
997
+ with open(output_dir / "label_mapping.json", "w", encoding="utf-8") as f:
998
+ json.dump(
999
+ {
1000
+ "label2id": label2id,
1001
+ "id2label": id2label,
1002
+ "text_col": text_col,
1003
+ "label_col": label_col,
1004
+ "metrics": metrics,
1005
+ },
1006
+ f,
1007
+ ensure_ascii=False,
1008
+ indent=2,
1009
+ )
1010
+
1011
+ del trainer
1012
+ del model
1013
+ del tokenizer
1014
+ cleanup_memory()
1015
+
1016
+ print("[OK] Entraînement BERT terminé :", output_dir)
1017
+
1018
+
1019
+ # ============================================================
1020
+ # Tests après entraînement
1021
+ # ============================================================
1022
+
1023
+ def test_lora_adapter(
1024
+ base_model: Path,
1025
+ adapter_dir: Path,
1026
+ prompt: str,
1027
+ max_new_tokens: int = 250,
1028
+ ):
1029
+ log(f"TEST LoRA : {adapter_dir.name}")
1030
+
1031
+ if not adapter_dir.exists():
1032
+ print("[SKIP] Adapter introuvable :", adapter_dir)
1033
+ return
1034
+
1035
+ tokenizer = AutoTokenizer.from_pretrained(
1036
+ str(base_model),
1037
+ local_files_only=True,
1038
+ trust_remote_code=True,
1039
+ )
1040
+
1041
+ if tokenizer.pad_token is None:
1042
+ tokenizer.pad_token = tokenizer.eos_token
1043
+
1044
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
1045
+
1046
+ base = AutoModelForCausalLM.from_pretrained(
1047
+ str(base_model),
1048
+ local_files_only=True,
1049
+ trust_remote_code=True,
1050
+ torch_dtype=dtype,
1051
+ device_map="auto" if torch.cuda.is_available() else None,
1052
+ )
1053
+
1054
+ model = PeftModel.from_pretrained(base, str(adapter_dir))
1055
+ model.eval()
1056
+
1057
+ full_prompt = f"""### System:
1058
+ Tu es un assistant cybersécurité défensif.
1059
+
1060
+ ### User:
1061
+ {prompt}
1062
+
1063
+ ### Assistant:
1064
+ """
1065
+
1066
+ inputs = tokenizer(full_prompt, return_tensors="pt")
1067
+ device = next(model.parameters()).device
1068
+ inputs = {k: v.to(device) for k, v in inputs.items()}
1069
+
1070
+ with torch.no_grad():
1071
+ out = model.generate(
1072
+ **inputs,
1073
+ max_new_tokens=max_new_tokens,
1074
+ temperature=0.2,
1075
+ do_sample=True,
1076
+ pad_token_id=tokenizer.eos_token_id,
1077
+ )
1078
+
1079
+ print(tokenizer.decode(out[0], skip_special_tokens=True))
1080
+
1081
+ del model
1082
+ del base
1083
+ del tokenizer
1084
+ cleanup_memory()
1085
+
1086
+
1087
+ def test_bert_classifier(model_dir: Path, text: str):
1088
+ log(f"TEST BERT CLASSIFIER : {model_dir.name}")
1089
+
1090
+ if not model_dir.exists():
1091
+ print("[SKIP] Classifier introuvable :", model_dir)
1092
+ return
1093
+
1094
+ tokenizer = AutoTokenizer.from_pretrained(str(model_dir))
1095
+ model = AutoModelForSequenceClassification.from_pretrained(str(model_dir))
1096
+ model.eval()
1097
+
1098
+ device = "cuda" if torch.cuda.is_available() else "cpu"
1099
+ model.to(device)
1100
+
1101
+ inputs = tokenizer(
1102
+ text,
1103
+ return_tensors="pt",
1104
+ truncation=True,
1105
+ padding=True,
1106
+ max_length=256,
1107
+ )
1108
+
1109
+ inputs = {k: v.to(device) for k, v in inputs.items()}
1110
+
1111
+ with torch.no_grad():
1112
+ out = model(**inputs)
1113
+ probs = torch.softmax(out.logits, dim=-1)[0].detach().cpu().numpy()
1114
+
1115
+ id2label = model.config.id2label
1116
+
1117
+ for idx, prob in enumerate(probs):
1118
+ print(f"{id2label[idx]}: {prob:.4f}")
1119
+
1120
+ del model
1121
+ del tokenizer
1122
+ cleanup_memory()
1123
+
1124
+
1125
+ # ============================================================
1126
+ # Orchestration
1127
+ # ============================================================
1128
+
1129
+ def train_selected(args):
1130
+ set_seed(args.seed)
1131
+
1132
+ models = {
1133
+ "securityllm": Path(args.security_model),
1134
+ "phishsense": Path(args.phish_model),
1135
+ "cysecbert": Path(args.cysecbert_model),
1136
+ "secbert": Path(args.secbert_model),
1137
+ }
1138
+
1139
+ outputs = {
1140
+ "securityllm": Path(args.output_dir) / "securityllm-10datasets-lora",
1141
+ "phishsense": Path(args.output_dir) / "phishsense-10datasets-lora",
1142
+ "cysecbert": Path(args.output_dir) / "cysecbert-phishing-classifier",
1143
+ "secbert": Path(args.output_dir) / "secbert-phishing-classifier",
1144
+ }
1145
+
1146
+ if args.train == "all":
1147
+ selected = ["securityllm", "phishsense", "cysecbert", "secbert"]
1148
+ else:
1149
+ selected = [args.train]
1150
+
1151
+ print("[+] Modèles sélectionnés :", selected)
1152
+
1153
+ if "securityllm" in selected:
1154
+ train_llm_lora_multi_dataset(
1155
+ model_path=models["securityllm"],
1156
+ dataset_configs=MULTI_CYBER_DATASETS,
1157
+ output_dir=outputs["securityllm"],
1158
+ split=args.split,
1159
+ global_max_samples=args.max_samples,
1160
+ epochs=args.llm_epochs,
1161
+ batch_size=args.llm_batch_size,
1162
+ grad_accum=args.grad_accum,
1163
+ lr=args.llm_lr,
1164
+ max_length=args.llm_max_length,
1165
+ save_steps=args.save_steps,
1166
+ logging_steps=args.logging_steps,
1167
+ lora_r=args.lora_r,
1168
+ lora_alpha=args.lora_alpha,
1169
+ lora_dropout=args.lora_dropout,
1170
+ skip_existing=args.skip_existing,
1171
+ )
1172
+
1173
+ if "phishsense" in selected:
1174
+ train_llm_lora_multi_dataset(
1175
+ model_path=models["phishsense"],
1176
+ dataset_configs=MULTI_CYBER_DATASETS,
1177
+ output_dir=outputs["phishsense"],
1178
+ split=args.split,
1179
+ global_max_samples=args.max_samples,
1180
+ epochs=args.llm_epochs,
1181
+ batch_size=args.llm_batch_size,
1182
+ grad_accum=args.grad_accum,
1183
+ lr=args.llm_lr,
1184
+ max_length=args.llm_max_length,
1185
+ save_steps=args.save_steps,
1186
+ logging_steps=args.logging_steps,
1187
+ lora_r=args.lora_r,
1188
+ lora_alpha=args.lora_alpha,
1189
+ lora_dropout=args.lora_dropout,
1190
+ skip_existing=args.skip_existing,
1191
+ )
1192
+
1193
+ if "cysecbert" in selected:
1194
+ train_bert_classifier(
1195
+ model_path=models["cysecbert"],
1196
+ dataset_ref=args.phishing_dataset,
1197
+ output_dir=outputs["cysecbert"],
1198
+ split=args.split,
1199
+ max_samples=args.bert_max_samples,
1200
+ epochs=args.bert_epochs,
1201
+ batch_size=args.bert_batch_size,
1202
+ lr=args.bert_lr,
1203
+ max_length=args.bert_max_length,
1204
+ logging_steps=args.logging_steps,
1205
+ skip_existing=args.skip_existing,
1206
+ )
1207
+
1208
+ if "secbert" in selected:
1209
+ train_bert_classifier(
1210
+ model_path=models["secbert"],
1211
+ dataset_ref=args.phishing_dataset,
1212
+ output_dir=outputs["secbert"],
1213
+ split=args.split,
1214
+ max_samples=args.bert_max_samples,
1215
+ epochs=args.bert_epochs,
1216
+ batch_size=args.bert_batch_size,
1217
+ lr=args.bert_lr,
1218
+ max_length=args.bert_max_length,
1219
+ logging_steps=args.logging_steps,
1220
+ skip_existing=args.skip_existing,
1221
+ )
1222
+
1223
+ print("\n[OK] Pipeline terminé.")
1224
+
1225
+
1226
+ def run_tests(args):
1227
+ outputs = {
1228
+ "securityllm": Path(args.output_dir) / "securityllm-10datasets-lora",
1229
+ "phishsense": Path(args.output_dir) / "phishsense-10datasets-lora",
1230
+ "cysecbert": Path(args.output_dir) / "cysecbert-phishing-classifier",
1231
+ "secbert": Path(args.output_dir) / "secbert-phishing-classifier",
1232
+ }
1233
+
1234
+ test_lora_adapter(
1235
+ base_model=Path(args.security_model),
1236
+ adapter_dir=outputs["securityllm"],
1237
+ prompt="Explique une règle Sigma permettant de détecter PowerShell EncodedCommand de manière défensive.",
1238
+ )
1239
+
1240
+ test_lora_adapter(
1241
+ base_model=Path(args.phish_model),
1242
+ adapter_dir=outputs["phishsense"],
1243
+ prompt="Analyse cet email : Votre compte sera suspendu. Cliquez ici pour confirmer votre mot de passe.",
1244
+ )
1245
+
1246
+ test_bert_classifier(
1247
+ model_dir=outputs["cysecbert"],
1248
+ text="Your account will be suspended. Click here to verify your password.",
1249
+ )
1250
+
1251
+ test_bert_classifier(
1252
+ model_dir=outputs["secbert"],
1253
+ text="Your account will be suspended. Click here to verify your password.",
1254
+ )
1255
+
1256
+
1257
+ # ============================================================
1258
+ # Main CLI
1259
+ # ============================================================
1260
+
1261
+ def main():
1262
+ parser = argparse.ArgumentParser(
1263
+ description="Entraîner tous les modèles cyber locaux avec 10 datasets et 3 epochs."
1264
+ )
1265
+
1266
+ parser.add_argument(
1267
+ "--train",
1268
+ default="all",
1269
+ choices=["all", "securityllm", "phishsense", "cysecbert", "secbert"],
1270
+ help="Quel modèle entraîner.",
1271
+ )
1272
+
1273
+ parser.add_argument(
1274
+ "--test-after",
1275
+ action="store_true",
1276
+ help="Tester les modèles/adapters après entraînement.",
1277
+ )
1278
+
1279
+ parser.add_argument(
1280
+ "--skip-existing",
1281
+ action="store_true",
1282
+ help="Ignorer un entraînement si la sortie existe déjà.",
1283
+ )
1284
+
1285
+ parser.add_argument("--seed", type=int, default=42)
1286
+
1287
+ # Modèles locaux
1288
+ parser.add_argument(
1289
+ "--security-model",
1290
+ default=str(DEFAULT_MODELS["securityllm"]),
1291
+ )
1292
+
1293
+ parser.add_argument(
1294
+ "--phish-model",
1295
+ default=str(DEFAULT_MODELS["phishsense"]),
1296
+ )
1297
+
1298
+ parser.add_argument(
1299
+ "--cysecbert-model",
1300
+ default=str(DEFAULT_MODELS["cysecbert"]),
1301
+ )
1302
+
1303
+ parser.add_argument(
1304
+ "--secbert-model",
1305
+ default=str(DEFAULT_MODELS["secbert"]),
1306
+ )
1307
+
1308
+ # Dataset classification BERT
1309
+ parser.add_argument(
1310
+ "--phishing-dataset",
1311
+ default=DEFAULT_PHISHING_DATASET,
1312
+ )
1313
+
1314
+ parser.add_argument("--split", default="train")
1315
+
1316
+ # Sorties
1317
+ parser.add_argument(
1318
+ "--output-dir",
1319
+ default=str(DEFAULT_OUTPUT_DIR),
1320
+ )
1321
+
1322
+ # Limitation globale LLM
1323
+ parser.add_argument(
1324
+ "--max-samples",
1325
+ type=int,
1326
+ default=0,
1327
+ help="Limiter le nombre total d'exemples SFT fusionnés. 0 = pas de limite globale.",
1328
+ )
1329
+
1330
+ # Limitation BERT
1331
+ parser.add_argument(
1332
+ "--bert-max-samples",
1333
+ type=int,
1334
+ default=0,
1335
+ help="Limiter le nombre d'exemples pour BERT. 0 = pas de limite.",
1336
+ )
1337
+
1338
+ # Paramètres LLM LoRA
1339
+ parser.add_argument("--llm-epochs", type=float, default=3.0)
1340
+ parser.add_argument("--llm-batch-size", type=int, default=1)
1341
+ parser.add_argument("--grad-accum", type=int, default=8)
1342
+ parser.add_argument("--llm-lr", type=float, default=2e-4)
1343
+ parser.add_argument("--llm-max-length", type=int, default=1024)
1344
+
1345
+ parser.add_argument("--lora-r", type=int, default=16)
1346
+ parser.add_argument("--lora-alpha", type=int, default=32)
1347
+ parser.add_argument("--lora-dropout", type=float, default=0.05)
1348
+
1349
+ # Paramètres BERT
1350
+ parser.add_argument("--bert-epochs", type=float, default=3.0)
1351
+ parser.add_argument("--bert-batch-size", type=int, default=8)
1352
+ parser.add_argument("--bert-lr", type=float, default=2e-5)
1353
+ parser.add_argument("--bert-max-length", type=int, default=256)
1354
+
1355
+ # Logs / sauvegarde
1356
+ parser.add_argument("--logging-steps", type=int, default=10)
1357
+ parser.add_argument("--save-steps", type=int, default=200)
1358
+
1359
+ args = parser.parse_args()
1360
+
1361
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
1362
+
1363
+ train_selected(args)
1364
+
1365
+ if args.test_after:
1366
+ run_tests(args)
1367
+
1368
+
1369
+ if __name__ == "__main__":
1370
+ main()
security/sec.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ from huggingface_hub import snapshot_download
7
+ from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
8
+ from datasets import load_dataset
9
+
10
+
11
+ BASE_DIR = Path(__file__).resolve().parent
12
+ MODELS_DIR = BASE_DIR / "models"
13
+ DATASETS_DIR = BASE_DIR / "datasets"
14
+
15
+
16
+ REPOS = {
17
+ "SecurityLLM": {
18
+ "repo_id": "ZySec-AI/SecurityLLM",
19
+ "repo_type": "model",
20
+ "local_dir": MODELS_DIR / "SecurityLLM",
21
+ "kind": "causal_lm",
22
+ },
23
+ "Llama-Phishsense-1B": {
24
+ "repo_id": "AcuteShrewdSecurity/Llama-Phishsense-1B",
25
+ "repo_type": "model",
26
+ "local_dir": MODELS_DIR / "Llama-Phishsense-1B",
27
+ "kind": "causal_lm",
28
+ },
29
+ "CySecBERT": {
30
+ "repo_id": "markusbayer/CySecBERT",
31
+ "repo_type": "model",
32
+ "local_dir": MODELS_DIR / "CySecBERT",
33
+ "kind": "bert",
34
+ },
35
+ "SecBERT": {
36
+ "repo_id": "jackaduma/SecBERT",
37
+ "repo_type": "model",
38
+ "local_dir": MODELS_DIR / "SecBERT",
39
+ "kind": "bert",
40
+ },
41
+ "cybersecurity-rules": {
42
+ "repo_id": "jcordon5/cybersecurity-rules",
43
+ "repo_type": "dataset",
44
+ "local_dir": DATASETS_DIR / "cybersecurity-rules",
45
+ "kind": "dataset",
46
+ },
47
+ }
48
+
49
+
50
+ def download_all():
51
+ MODELS_DIR.mkdir(exist_ok=True)
52
+ DATASETS_DIR.mkdir(exist_ok=True)
53
+
54
+ for name, item in REPOS.items():
55
+ print(f"\n[+] Téléchargement : {name}")
56
+ print(f" Repo : {item['repo_id']}")
57
+ print(f" Dossier: {item['local_dir']}")
58
+
59
+ snapshot_download(
60
+ repo_id=item["repo_id"],
61
+ repo_type=item["repo_type"],
62
+ local_dir=str(item["local_dir"]),
63
+ resume_download=True,
64
+ )
65
+
66
+ print(f"[OK] {name} téléchargé.")
67
+
68
+
69
+ def check_files():
70
+ print("\n[+] Vérification des fichiers locaux")
71
+
72
+ for name, item in REPOS.items():
73
+ path = item["local_dir"]
74
+
75
+ print(f"\n--- {name} ---")
76
+ print(f"Dossier : {path}")
77
+
78
+ if not path.exists():
79
+ print("[ERREUR] Dossier introuvable.")
80
+ continue
81
+
82
+ files = list(path.glob("*"))
83
+
84
+ if not files:
85
+ print("[ERREUR] Dossier vide.")
86
+ continue
87
+
88
+ for file in files[:20]:
89
+ print(" ", file.name)
90
+
91
+ if item["kind"] in ["causal_lm", "bert"]:
92
+ config = path / "config.json"
93
+ if config.exists():
94
+ print("[OK] config.json trouvé.")
95
+ else:
96
+ print("[ATTENTION] config.json absent.")
97
+
98
+ print("[OK] Vérification terminée.")
99
+
100
+
101
+ def set_offline_mode():
102
+ os.environ["HF_HUB_OFFLINE"] = "1"
103
+ os.environ["TRANSFORMERS_OFFLINE"] = "1"
104
+ os.environ["HF_DATASETS_OFFLINE"] = "1"
105
+
106
+
107
+ def test_causal_lm(name, path, prompt):
108
+ print(f"\n[+] Test modèle génératif : {name}")
109
+
110
+ tokenizer = AutoTokenizer.from_pretrained(
111
+ str(path),
112
+ local_files_only=True,
113
+ trust_remote_code=True,
114
+ )
115
+
116
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
117
+
118
+ model = AutoModelForCausalLM.from_pretrained(
119
+ str(path),
120
+ local_files_only=True,
121
+ trust_remote_code=True,
122
+ torch_dtype=dtype,
123
+ device_map="auto" if torch.cuda.is_available() else None,
124
+ )
125
+
126
+ if not torch.cuda.is_available():
127
+ model.to("cpu")
128
+
129
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
130
+
131
+ device = next(model.parameters()).device
132
+ inputs = {k: v.to(device) for k, v in inputs.items()}
133
+
134
+ with torch.no_grad():
135
+ output = model.generate(
136
+ **inputs,
137
+ max_new_tokens=120,
138
+ temperature=0.2,
139
+ do_sample=True,
140
+ pad_token_id=tokenizer.eos_token_id,
141
+ )
142
+
143
+ text = tokenizer.decode(output[0], skip_special_tokens=True)
144
+
145
+ print("\n===== SORTIE MODÈLE =====")
146
+ print(text)
147
+ print("=========================")
148
+
149
+
150
+ def test_bert(name, path):
151
+ print(f"\n[+] Test BERT : {name}")
152
+
153
+ tokenizer = AutoTokenizer.from_pretrained(
154
+ str(path),
155
+ local_files_only=True,
156
+ trust_remote_code=True,
157
+ )
158
+
159
+ model = AutoModel.from_pretrained(
160
+ str(path),
161
+ local_files_only=True,
162
+ trust_remote_code=True,
163
+ )
164
+
165
+ device = "cuda" if torch.cuda.is_available() else "cpu"
166
+ model.to(device)
167
+ model.eval()
168
+
169
+ text = "Suspicious PowerShell encoded command execution detected."
170
+
171
+ inputs = tokenizer(
172
+ text,
173
+ return_tensors="pt",
174
+ truncation=True,
175
+ padding=True,
176
+ max_length=512,
177
+ )
178
+
179
+ inputs = {k: v.to(device) for k, v in inputs.items()}
180
+
181
+ with torch.no_grad():
182
+ outputs = model(**inputs)
183
+
184
+ embedding = outputs.last_hidden_state[:, 0, :]
185
+
186
+ print("[OK] Modèle chargé.")
187
+ print("Texte :", text)
188
+ print("Shape embedding :", tuple(embedding.shape))
189
+
190
+
191
+ def find_dataset_files(path):
192
+ parquet_files = list(path.rglob("*.parquet"))
193
+ json_files = list(path.rglob("*.json")) + list(path.rglob("*.jsonl"))
194
+ csv_files = list(path.rglob("*.csv"))
195
+
196
+ if parquet_files:
197
+ return "parquet", [str(f) for f in parquet_files]
198
+ if json_files:
199
+ return "json", [str(f) for f in json_files]
200
+ if csv_files:
201
+ return "csv", [str(f) for f in csv_files]
202
+
203
+ return None, []
204
+
205
+
206
+ def test_dataset(path):
207
+ print("\n[+] Test dataset cybersecurity-rules")
208
+
209
+ dataset_type, files = find_dataset_files(path)
210
+
211
+ if dataset_type is None:
212
+ print("[ERREUR] Aucun fichier parquet/json/jsonl/csv trouvé.")
213
+ print("Fichiers présents :")
214
+ for f in list(path.rglob("*"))[:30]:
215
+ print(" ", f)
216
+ return
217
+
218
+ print(f"[OK] Type détecté : {dataset_type}")
219
+ print(f"[OK] Nombre de fichiers : {len(files)}")
220
+
221
+ ds = load_dataset(dataset_type, data_files=files, split="train")
222
+
223
+ print("[OK] Dataset chargé.")
224
+ print("Nombre de lignes :", len(ds))
225
+
226
+ print("\n===== PREMIÈRE LIGNE =====")
227
+ print(ds[0])
228
+ print("==========================")
229
+
230
+
231
+ def test_all():
232
+ set_offline_mode()
233
+
234
+ check_files()
235
+
236
+ # Test SecurityLLM
237
+ test_causal_lm(
238
+ "SecurityLLM",
239
+ REPOS["SecurityLLM"]["local_dir"],
240
+ "Tu es un analyste SOC. Donne une procédure défensive pour analyser une alerte SSH brute force.",
241
+ )
242
+
243
+ # Test Llama-Phishsense-1B
244
+ test_causal_lm(
245
+ "Llama-Phishsense-1B",
246
+ REPOS["Llama-Phishsense-1B"]["local_dir"],
247
+ "Analyse ce message pour phishing : Votre compte sera suspendu. Cliquez ici pour confirmer votre mot de passe.",
248
+ )
249
+
250
+ # Test CySecBERT
251
+ test_bert(
252
+ "CySecBERT",
253
+ REPOS["CySecBERT"]["local_dir"],
254
+ )
255
+
256
+ # Test SecBERT
257
+ test_bert(
258
+ "SecBERT",
259
+ REPOS["SecBERT"]["local_dir"],
260
+ )
261
+
262
+ # Test dataset
263
+ test_dataset(
264
+ REPOS["cybersecurity-rules"]["local_dir"],
265
+ )
266
+
267
+
268
+ def test_one(name):
269
+ set_offline_mode()
270
+
271
+ if name not in REPOS:
272
+ print("[ERREUR] Nom inconnu.")
273
+ print("Noms possibles :", ", ".join(REPOS.keys()))
274
+ return
275
+
276
+ item = REPOS[name]
277
+
278
+ if item["kind"] == "causal_lm":
279
+ test_causal_lm(
280
+ name,
281
+ item["local_dir"],
282
+ "Donne une analyse cybersécurité défensive courte.",
283
+ )
284
+
285
+ elif item["kind"] == "bert":
286
+ test_bert(name, item["local_dir"])
287
+
288
+ elif item["kind"] == "dataset":
289
+ test_dataset(item["local_dir"])
290
+
291
+
292
+ def main():
293
+ parser = argparse.ArgumentParser()
294
+
295
+ parser.add_argument(
296
+ "--download",
297
+ action="store_true",
298
+ help="Télécharger tous les modèles et datasets en local.",
299
+ )
300
+
301
+ parser.add_argument(
302
+ "--check",
303
+ action="store_true",
304
+ help="Vérifier les fichiers téléchargés.",
305
+ )
306
+
307
+ parser.add_argument(
308
+ "--test-all",
309
+ action="store_true",
310
+ help="Tester tous les modèles localement.",
311
+ )
312
+
313
+ parser.add_argument(
314
+ "--test-one",
315
+ type=str,
316
+ help="Tester un seul modèle : SecurityLLM, Llama-Phishsense-1B, CySecBERT, SecBERT, cybersecurity-rules",
317
+ )
318
+
319
+ args = parser.parse_args()
320
+
321
+ if args.download:
322
+ download_all()
323
+
324
+ if args.check:
325
+ check_files()
326
+
327
+ if args.test_all:
328
+ test_all()
329
+
330
+ if args.test_one:
331
+ test_one(args.test_one)
332
+
333
+ if not any([args.download, args.check, args.test_all, args.test_one]):
334
+ parser.print_help()
335
+
336
+
337
+ if __name__ == "__main__":
338
+ main()