File size: 25,718 Bytes
a9f5d41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
#!/usr/bin/env python3
"""
CogNet-1B โ€” Lanceur d'entraรฎnement Python pur
===============================================
Remplace acil_submit.sh โ€” tout est en Python !
Dรฉtecte les GPUs automatiquement, prรฉpare les donnรฉes,
lance l'entraรฎnement multi-GPU avec torchrun si nรฉcessaire.

Usage:
    # Simple โ€” tout automatique
    python run.py

    # Avec options
    python run.py --max-steps 100000 --batch-size 4 --hf-token hf_xxxx

    # Reprendre un checkpoint
    python run.py --resume ./checkpoints_1b/cognet_1b_latest.pt

    # Seulement prรฉparer les donnรฉes
    python run.py --prep-only

    # Sur un cluster avec SLURM (soumission auto)
    python run.py --slurm --time 72:00:00 --gpus 4
"""

import argparse
import os
import signal
import subprocess
import sys
import time
import json
import shutil
from datetime import datetime
from pathlib import Path

# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
#  Configuration par dรฉfaut
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•

DEFAULTS = {
    'model_size': '1b',
    'batch_size': 4,
    'grad_accum': 8,
    'seq_len': 512,
    'max_lr': 1e-4,
    'min_lr': 1e-5,
    'warmup_steps': 2000,
    'max_steps': 100000,
    'ckpt_dir': './checkpoints_1b',
    'data_dir': './data_1b',
    'save_every': 2000,
    'eval_every': 500,
    'log_every': 50,
    'weight_decay': 0.1,
    'grad_clip': 1.0,
}

WORKSPACE = os.path.dirname(os.path.abspath(__file__))
TRAIN_SCRIPT = os.path.join(WORKSPACE, 'train_ultra.py')


# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
#  Dรฉtection GPU
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•

def detect_gpus():
    """Dรฉtecte le nombre de GPUs disponibles."""
    try:
        result = subprocess.run(
            ['nvidia-smi', '--query-gpu=name,memory.total', '--format=csv,noheader,nounits'],
            capture_output=True, text=True, timeout=10
        )
        if result.returncode != 0:
            return 0, []
        lines = [l.strip() for l in result.stdout.strip().split('\n') if l.strip()]
        gpus = []
        for line in lines:
            parts = line.split(',')
            name = parts[0].strip()
            vram = float(parts[1].strip()) if len(parts) > 1 else 0
            gpus.append({'name': name, 'vram_mb': vram})
        return len(gpus), gpus
    except Exception:
        # Fallback: torch
        try:
            import torch
            count = torch.cuda.device_count()
            gpus = []
            for i in range(count):
                name = torch.cuda.get_device_name(i)
                vram = torch.cuda.get_device_properties(i).total_mem / 1e6  # MB
                gpus.append({'name': name, 'vram_mb': vram})
            return count, gpus
        except Exception:
            return 0, []


def get_gpu_type(gpus):
    """Retourne le type de GPU (A100, H100, etc.)."""
    if not gpus:
        return 'CPU'
    name = gpus[0]['name'].upper()
    if 'H100' in name:
        return 'H100'
    elif 'A100' in name:
        return 'A100'
    elif 'A6000' in name:
        return 'A6000'
    elif '4090' in name:
        return 'RTX4090'
    elif '3090' in name:
        return 'RTX3090'
    elif 'V100' in name:
        return 'V100'
    return gpus[0]['name']


# NOTE: Les estimations de temps seront calculรฉes dynamiquement
# par le vrai benchmark au dรฉbut du training dans train_ultra.py.
# Plus aucune estimation fabriquรฉe ici.


# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
#  Prรฉparation des donnรฉes (Python)
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•

def prepare_data_python(data_dir, hf_token='', skip=False):
    """Lance la prรฉparation des donnรฉes via train_ultra.py."""
    if skip:
        print('[DATA] Skip (--skip-data-prep)')
        return True

    merged = os.path.join(data_dir, 'train_merged.pt')
    if os.path.exists(merged):
        size_mb = os.path.getsize(merged) / 1e6
        print(f'[DATA] Dรฉjร  prรฉparรฉ: {merged} ({size_mb:.0f} MB)')
        return True

    print('[DATA] Prรฉparation des datasets (HF + AICL + synthetic)...')
    env = os.environ.copy()
    if hf_token:
        env['HF_TOKEN'] = hf_token

    cmd = [sys.executable, TRAIN_SCRIPT, '--max-steps', '0', '--skip-data-prep']
    # Note: --max-steps 0 avec --skip-data-prep ne fait rien
    # On doit lancer sans --skip-data-prep pour que la data prep se fasse
    cmd = [sys.executable, TRAIN_SCRIPT, '--max-steps', '0']

    try:
        result = subprocess.run(cmd, env=env, cwd=WORKSPACE, timeout=7200)  # 2h max
        if result.returncode != 0:
            print(f'[DATA] ERREUR: data prep a รฉchouรฉ (code {result.returncode})')
            return False
    except subprocess.TimeoutExpired:
        print('[DATA] ERREUR: data prep a timeout (2h)')
        return False
    except Exception as e:
        print(f'[DATA] ERREUR: {e}')
        return False

    if os.path.exists(merged):
        size_mb = os.path.getsize(merged) / 1e6
        print(f'[DATA] Prรฉparation terminรฉe: {merged} ({size_mb:.0f} MB)')
        return True

    print('[DATA] ERREUR: fichier merged non trouvรฉ aprรจs prรฉparation')
    return False


# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
#  Vรฉrification des dรฉpendances
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•

def check_dependencies():
    """Vรฉrifie que les dรฉpendances Python sont installรฉes."""
    required = ['torch', 'datasets', 'huggingface_hub', 'tokenizers']
    missing = []

    for pkg in required:
        try:
            __import__(pkg)
        except ImportError:
            missing.append(pkg)

    # Vรฉrification optionnelle
    optional_missing = []
    try:
        import bitsandbytes
    except ImportError:
        optional_missing.append('bitsandbytes (optionnel: 8-bit optimizer)')

    return missing, optional_missing


def install_dependencies(packages):
    """Installe les packages manquants."""
    for pkg in packages:
        print(f'[INSTALL] Installation de {pkg}...')
        subprocess.run([sys.executable, '-m', 'pip', 'install', pkg, '-q'], check=False)


# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
#  Lancement de l'entraรฎnement
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•

def launch_training(args, num_gpus):
    """Lance l'entraรฎnement โ€” torchrun si multi-GPU, sinon python direct."""

    # Construction des arguments communs
    common_args = [
        '--model-size', str(args.model_size),
        '--batch-size', str(args.batch_size),
        '--grad-accum', str(args.grad_accum),
        '--seq-len', str(args.seq_len),
        '--max-lr', str(args.max_lr),
        '--min-lr', str(args.min_lr),
        '--warmup-steps', str(args.warmup_steps),
        '--max-steps', str(args.max_steps),
        '--ckpt-dir', str(args.ckpt_dir),
        '--save-every', str(args.save_every),
        '--eval-every', str(args.eval_every),
        '--log-every', str(args.log_every),
        '--weight-decay', str(args.weight_decay),
        '--grad-clip', str(args.grad_clip),
    ]

    # Optimisations V2 โ€” toutes activรฉes par dรฉfaut
    if args.bf16:
        common_args.append('--bf16')
    if args.compile:
        common_args.append('--compile')
    if args.cuda_prefetch:
        common_args.append('--cuda-prefetch')
    if args.seq_warmup:
        common_args.append('--seq-warmup')
    if args.async_ckpt:
        common_args.append('--async-ckpt')
    if args.use_8bit:
        common_args.append('--8bit-optim')

    # Resume
    if args.resume:
        common_args.extend(['--resume', args.resume])

    # Skip data prep (dรฉjร  fait)
    common_args.append('--skip-data-prep')

    # Environnement
    env = os.environ.copy()
    if args.hf_token:
        env['HF_TOKEN'] = args.hf_token
    env['COGNET_WORKSPACE'] = WORKSPACE
    env['AICL_REPEAT'] = str(args.aicl_repeat)

    # CUDA optimizations
    env['CUDA_DEVICE_MAX_CONNECTIONS'] = '1'
    env['TORCH_NCCL_AVOID_RECORD_STREAMS'] = '1'
    if 'NCCL_P2P_LEVEL' not in env:
        env['NCCL_P2P_LEVEL'] = 'NVL'

    # Multi-GPU โ†’ torchrun
    if num_gpus > 1 and args.use_fsdp:
        common_args.append('--use-fsdp')

        cmd = [
            sys.executable, '-m', 'torch.distributed.run',
            '--standalone',
            f'--nproc_per_node={num_gpus}',
            TRAIN_SCRIPT,
        ] + common_args

        print(f'\n[TRAIN] Lancement FSDP avec {num_gpus} GPUs via torchrun...')
        print(f'[TRAIN] Commande: {" ".join(cmd[:8])}... ({" ".join(common_args[:6])}...)')

    # Single GPU โ†’ python direct
    else:
        if args.compile_step:
            common_args.append('--compile-step')

        cmd = [sys.executable, TRAIN_SCRIPT] + common_args

        print(f'\n[TRAIN] Lancement single GPU...')
        print(f'[TRAIN] Commande: {" ".join(cmd[:4])}... ({" ".join(common_args[:6])}...)')

    # Lancement
    start_time = time.time()
    try:
        process = subprocess.Popen(
            cmd, env=env, cwd=WORKSPACE,
            stdout=sys.stdout, stderr=sys.stderr,
        )

        # Gestion des signaux pour propager au sous-processus
        def forward_signal(signum, frame):
            process.send_signal(signum)

        signal.signal(signal.SIGTERM, forward_signal)
        signal.signal(signal.SIGINT, forward_signal)

        # Attendre la fin
        return_code = process.wait()
        elapsed = time.time() - start_time

        if return_code == 0:
            print(f'\n[TRAIN] Entraรฎnement terminรฉ avec succรจs! ({elapsed/3600:.1f}h)')
        else:
            print(f'\n[TRAIN] Entraรฎnement terminรฉ avec code {return_code} ({elapsed/3600:.1f}h)')

        return return_code == 0

    except KeyboardInterrupt:
        print('\n[TRAIN] Interruption clavier โ€” checkpoint sauvegardรฉ par train_ultra.py')
        return True
    except Exception as e:
        print(f'\n[TRAIN] ERREUR: {e}')
        return False


# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
#  Soumission SLURM (optionnel)
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•

def submit_slurm(args, num_gpus):
    """Soumet le job via SLURM โ€” mais le script reste en Python!"""
    slurm_script = f"""#!/bin/bash
#SBATCH --job-name=cognet-1b
#SBATCH --partition=gpu
#SBATCH --nodes=1
#SBATCH --ntasks-per-node={num_gpus}
#SBATCH --cpus-per-task=8
#SBATCH --mem=256G
#SBATCH --gres=gpu:{num_gpus}
#SBATCH --time={args.time}
#SBATCH --output=logs/cognet-%j.out
#SBATCH --error=logs/cognet-%j.err

cd {WORKSPACE}
{sys.executable} run.py {" ".join(get_run_args_for_slurm(args))}
"""
    script_path = os.path.join(WORKSPACE, '_slurm_submit.sh')
    os.makedirs(os.path.join(WORKSPACE, 'logs'), exist_ok=True)

    with open(script_path, 'w') as f:
        f.write(slurm_script)

    print(f'[SLURM] Soumission du job...')
    result = subprocess.run(['sbatch', script_path], capture_output=True, text=True)
    if result.returncode == 0:
        job_id = result.stdout.strip().split()[-1]
        print(f'[SLURM] Job soumis: {job_id}')
        print(f'[SLURM] Logs: logs/cognet-{job_id}.out')
    else:
        print(f'[SLURM] ERREUR: {result.stderr}')
    os.remove(script_path)


def get_run_args_for_slurm(args):
    """Retourne les arguments Python pour la soumission SLURM."""
    arg_list = []
    if args.hf_token:
        arg_list.extend(['--hf-token', args.hf_token])
    arg_list.extend(['--max-steps', str(args.max_steps)])
    arg_list.extend(['--batch-size', str(args.batch_size)])
    arg_list.extend(['--grad-accum', str(args.grad_accum)])
    arg_list.extend(['--seq-len', str(args.seq_len)])
    if args.no_compile:
        arg_list.append('--no-compile')
    if args.no_fsdp:
        arg_list.append('--no-fsdp')
    return arg_list


# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
#  Vรฉrification des checkpoints
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•

def check_existing_checkpoints(ckpt_dir):
    """Affiche les checkpoints existants."""
    ckpt_path = Path(ckpt_dir)
    if not ckpt_path.exists():
        return None

    latest = ckpt_path / 'cognet_1b_latest.pt'
    best = ckpt_path / 'cognet_1b_best.pt'
    final = ckpt_path / 'cognet_1b_final.pt'

    info = {}
    if latest.exists():
        try:
            data = torch.load(str(latest), map_location='cpu', weights_only=False)
            info['latest_step'] = data.get('step', 0)
            info['latest_loss'] = data.get('loss', float('inf'))
            info['latest_path'] = str(latest)
        except Exception:
            pass
    if best.exists():
        try:
            data = torch.load(str(best), map_location='cpu', weights_only=False)
            info['best_step'] = data.get('step', 0)
            info['best_loss'] = data.get('best_loss', float('inf'))
            info['best_path'] = str(best)
        except Exception:
            pass
    if final.exists():
        info['final_path'] = str(final)

    return info


# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
#  Main
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•

def main():
    parser = argparse.ArgumentParser(
        description='CogNet-1B โ€” Lanceur Python (remplace acil_submit.sh)',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Exemples:
  python run.py                                    # Tout automatique
  python run.py --max-steps 50000                  # 50k steps
  python run.py --hf-token hf_xxx                  # Avec token HF
  python run.py --resume ./checkpoints_1b/cognet_1b_latest.pt  # Reprendre
  python run.py --prep-only                        # Seulement data prep
  python run.py --slurm --gpus 4 --time 72:00:00   # SLURM auto
  python run.py --no-fsdp                          # Single GPU
        """
    )

    # Config
    parser.add_argument('--model-size', type=str, default=DEFAULTS['model_size'], choices=['1b', '350m'])
    parser.add_argument('--batch-size', type=int, default=DEFAULTS['batch_size'])
    parser.add_argument('--grad-accum', type=int, default=DEFAULTS['grad_accum'])
    parser.add_argument('--seq-len', type=int, default=DEFAULTS['seq_len'])
    parser.add_argument('--max-lr', type=float, default=DEFAULTS['max_lr'])
    parser.add_argument('--min-lr', type=float, default=DEFAULTS['min_lr'])
    parser.add_argument('--warmup-steps', type=int, default=DEFAULTS['warmup_steps'])
    parser.add_argument('--max-steps', type=int, default=DEFAULTS['max_steps'])
    parser.add_argument('--ckpt-dir', type=str, default=DEFAULTS['ckpt_dir'])
    parser.add_argument('--data-dir', type=str, default=DEFAULTS['data_dir'])
    parser.add_argument('--save-every', type=int, default=DEFAULTS['save_every'])
    parser.add_argument('--eval-every', type=int, default=DEFAULTS['eval_every'])
    parser.add_argument('--log-every', type=int, default=DEFAULTS['log_every'])
    parser.add_argument('--weight-decay', type=float, default=DEFAULTS['weight_decay'])
    parser.add_argument('--grad-clip', type=float, default=DEFAULTS['grad_clip'])

    # Token & repos
    parser.add_argument('--hf-token', type=str, default=os.environ.get('HF_TOKEN', ''),
                        help='HuggingFace API token')
    parser.add_argument('--aicl-repeat', type=int, default=10,
                        help='Nombre de rรฉpรฉtitions des donnรฉes AICL')

    # Optimizations (activรฉes par dรฉfaut)
    parser.add_argument('--no-compile', action='store_true', help='Dรฉsactiver torch.compile')
    parser.add_argument('--no-fsdp', action='store_true', help='Dรฉsactiver FSDP (single GPU)')
    parser.add_argument('--no-cuda-prefetch', action='store_true', help='Dรฉsactiver CUDA prefetch')
    parser.add_argument('--no-seq-warmup', action='store_true', help='Dรฉsactiver seq length warmup')
    parser.add_argument('--no-async-ckpt', action='store_true', help='Dรฉsactiver async checkpointing')
    parser.add_argument('--no-bf16', action='store_true', help='Dรฉsactiver BF16 (utiliser FP16)')
    parser.add_argument('--8bit', action='store_true', help='Activer 8-bit optimizer (bitsandbytes)')
    parser.add_argument('--compile-step', action='store_true', help='Compiler forward+backward ensemble')

    # Resume
    parser.add_argument('--resume', type=str, default=None, help='Chemin du checkpoint ร  reprendre')

    # Modes spรฉciaux
    parser.add_argument('--prep-only', action='store_true', help='Seulement prรฉparer les donnรฉes')
    parser.add_argument('--skip-data-prep', action='store_true', help='Sauter la prรฉparation des donnรฉes')
    parser.add_argument('--check-only', action='store_true', help='Seulement vรฉrifier le setup')

    # SLURM
    parser.add_argument('--slurm', action='store_true', help='Soumettre via SLURM')
    parser.add_argument('--gpus', type=int, default=None, help='Nombre de GPUs pour SLURM')
    parser.add_argument('--time', type=str, default='72:00:00', help='Temps SLURM')

    args = parser.parse_args()

    # Dรฉriver les flags boolรฉens (inversรฉs car les flags sont "no-*")
    args.bf16 = not args.no_bf16
    args.compile = not args.no_compile
    args.use_fsdp = not args.no_fsdp
    args.cuda_prefetch = not args.no_cuda_prefetch
    args.seq_warmup = not args.no_seq_warmup
    args.async_ckpt = not args.no_async_ckpt
    args.use_8bit = getattr(args, '8bit', False)

    # โ•โ•โ• Banniรจre โ•โ•โ•
    print()
    print('โ•”โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•—')
    print('โ•‘       CogNet-1B โ€” Lanceur Python V2                     โ•‘')
    print('โ•‘       Les performances seront mesurรฉes par benchmark     โ•‘')
    print('โ•šโ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•')
    print()

    # โ•โ•โ• Dรฉtection GPU โ•โ•โ•
    num_gpus, gpus = detect_gpus()
    gpu_type = get_gpu_type(gpus)

    print(f'[GPU] {num_gpus} GPU(s) dรฉtectรฉ(s):')
    for i, gpu in enumerate(gpus):
        print(f'  GPU {i}: {gpu["name"]} ({gpu["vram_mb"]:.0f} MB VRAM)')
    print(f'  Type: {gpu_type}')

    if num_gpus == 0:
        print('[GPU] ATTENTION: Aucun GPU dรฉtectรฉ โ€” entraรฎnement sur CPU (trรจs lent!)')
        print('[GPU] Vรฉrifiez que nvidia-smi fonctionne et que CUDA est installรฉ')

    # โ•โ•โ• Vรฉrification dรฉpendances โ•โ•โ•
    missing, optional = check_dependencies()
    if missing:
        print(f'\n[DEPS] Packages manquants: {", ".join(missing)}')
        response = input('[DEPS] Installer automatiquement? (o/n) [o] ').strip().lower()
        if response in ('', 'o', 'oui', 'y', 'yes'):
            install_dependencies(missing)
        else:
            print('[DEPS] Installation annulรฉe. Installez manuellement:')
            print(f'  pip install {" ".join(missing)}')
            sys.exit(1)

    if optional:
        print(f'[DEPS] Optionnels non installรฉs: {", ".join(optional)}')

    # โ•โ•โ• Vรฉrification du script d'entraรฎnement โ•โ•โ•
    if not os.path.exists(TRAIN_SCRIPT):
        print(f'[ERREUR] Script d\'entraรฎnement introuvable: {TRAIN_SCRIPT}')
        sys.exit(1)

    if not os.path.exists(os.path.join(WORKSPACE, 'cognet_1b_optimized.py')):
        print(f'[ERREUR] Modรจle optimisรฉ introuvable: cognet_1b_optimized.py')
        sys.exit(1)

    # โ•โ•โ• Checkpoints existants โ•โ•โ•
    ckpt_info = check_existing_checkpoints(args.ckpt_dir)
    if ckpt_info:
        print(f'\n[CKPT] Checkpoints existants dans {args.ckpt_dir}:')
        if 'latest_step' in ckpt_info:
            print(f'  Latest: step {ckpt_info["latest_step"]}, loss={ckpt_info["latest_loss"]:.4f}')
        if 'best_step' in ckpt_info:
            print(f'  Best:   step {ckpt_info["best_step"]}, loss={ckpt_info["best_loss"]:.4f}')

    else:
        print(f'\n[CKPT] Aucun checkpoint existant')

    # โ•โ•โ• Estimation du temps โ•โ•โ•
    # NOTE: Le vrai benchmark sera fait par train_ultra.py au dรฉbut du training.
    # Pas d'estimation fabriquรฉe ici โ€” les chiffres rรฉels seront mesurรฉs.
    if num_gpus > 0 and not args.check_only:
        effective_batch = args.batch_size * args.grad_accum * num_gpus
        print(f'\n[BENCH] Les performances seront mesurรฉes par un vrai benchmark au dรฉmarrage.')
        print(f'  GPU: {num_gpus}x {gpu_type}')
        print(f'  Batch effectif: {effective_batch} ({args.batch_size} x {args.grad_accum} x {num_gpus} GPUs)')
        print(f'  Le temps restant sera calculรฉ ร  partir de la vitesse mesurรฉe.')

    # โ•โ•โ• Config finale โ•โ•โ•
    print(f'\n[CONFIG] Configuration finale:')
    print(f'  Model:    CogNet-{args.model_size.upper()} (16 blocks, 8 channels, 384 ch_dim, 8192 ff)')
    print(f'  Vocab:    136 (CharTokenizer)')
    print(f'  Seq len:  {args.seq_len}')
    print(f'  Batch:    {args.batch_size} x grad_accum={args.grad_accum} x GPUs={num_gpus} = {args.batch_size * args.grad_accum * num_gpus}')
    print(f'  LR:       {args.min_lr} โ†’ {args.max_lr}')
    print(f'  Steps:    {args.max_steps:,}')
    print(f'  HF token: {"SET" if args.hf_token else "NOT SET"}')
    print(f'  BF16:     {args.bf16}')
    print(f'  Compile:  {args.compile}')
    print(f'  FSDP:     {args.use_fsdp} ({num_gpus} GPUs)')
    print(f'  Prefetch: {args.cuda_prefetch}')
    print(f'  SeqWarm:  {args.seq_warmup}')
    print(f'  AsyncCkpt:{args.async_ckpt}')
    print(f'  8-bit:    {args.use_8bit}')

    # โ•โ•โ• Check-only โ•โ•โ•
    if args.check_only:
        print('\n[CHECK] Vรฉrification terminรฉe โ€” tout est prรชt!')
        return

    # โ•โ•โ• SLURM โ•โ•โ•
    if args.slurm:
        gpu_count = args.gpus or num_gpus or 4
        submit_slurm(args, gpu_count)
        return

    # โ•โ•โ• Data prep โ•โ•โ•
    if args.prep_only:
        ok = prepare_data_python(args.data_dir, args.hf_token, skip=False)
        print('\n[DATA] Prรฉparation terminรฉe!' if ok else '\n[DATA] ร‰CHEC!')
        return

    if not args.skip_data_prep:
        ok = prepare_data_python(args.data_dir, args.hf_token)
        if not ok:
            print('[DATA] ร‰CHEC de la prรฉparation des donnรฉes!')
            response = input('[DATA] Continuer quand mรชme? (o/n) [n] ').strip().lower()
            if response not in ('o', 'oui', 'y', 'yes'):
                sys.exit(1)

    # โ•โ•โ• Entraรฎnement โ•โ•โ•
    print('\n' + '=' * 60)
    print('  Dร‰MARRAGE DE L\'ENTRAรŽNEMENT')
    print('=' * 60)
    print(f'  Dรฉbut: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')
    print('=' * 60 + '\n')

    success = launch_training(args, num_gpus)

    print('\n' + '=' * 60)
    if success:
        print('  ENTRAรŽNEMENT TERMINร‰ AVEC SUCCรˆS')
    else:
        print('  ENTRAรŽNEMENT TERMINร‰ AVEC ERREURS')
    print(f'  Fin: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')
    print('=' * 60)

    # Vรฉrifier le rรฉsultat final
    ckpt_info = check_existing_checkpoints(args.ckpt_dir)
    if ckpt_info and 'best_path' in ckpt_info:
        print(f'\n  Meilleur checkpoint: {ckpt_info["best_path"]}')
        if 'best_loss' in ckpt_info:
            print(f'  Meilleure loss: {ckpt_info["best_loss"]:.4f}')

    if not success:
        sys.exit(1)


if __name__ == '__main__':
    main()