File size: 1,785 Bytes
acc4df8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""Patch T5 specroute gen_scripts to add P100 GPU detection and BSZ."""
import re, os

BASE = '/Users/nnminh322/Desktop/personal/Continual/improve_gainlora'

T5_SCRIPTS = [
    os.path.join(BASE, 'gen_script_superni_order1_t5_specroute.sh'),
    os.path.join(BASE, 'gen_script_superni_order2_t5_specroute.sh'),
    os.path.join(BASE, 'gen_script_long_order3_t5_specroute.sh'),
    os.path.join(BASE, 'gen_script_long_order4_t5_specroute.sh'),
]

GPU_OLD = (
    'else\n'
    '    GPU_MODE="a100"\n'
    '    GPU_IDS="${1:-0}"\n'
    '    FP16_FLAG=""\n'
    '    echo "[GPU] Strategy: A100 (single GPU, fp32)"\n'
    'fi'
)

GPU_NEW = (
    'elif [ "$GPU_MEM" -gt 16000 ]; then\n'
    '    GPU_MODE="p100"\n'
    '    GPU_IDS="${1:-0}"\n'
    '    FP16_FLAG="--gradient_checkpointing"\n'
    '    echo "[GPU] Strategy: P100 16GB (fp32 + gradient_checkpointing)"\n'
    'else\n'
    '    GPU_MODE="a100"\n'
    '    GPU_IDS="${1:-0}"\n'
    '    FP16_FLAG=""\n'
    '    echo "[GPU] Strategy: A100 (single GPU, fp32)"\n'
    'fi'
)

BSZ_PAT = re.compile(
    r'(elif \[ "\$GPU_MODE" = "t4_1gpu" \]; then\n    BSZ=\d+; GA=\d+; EVAL_BSZ=\d+\n)'
    r'(else\n    BSZ=\d+; GA=\d+; EVAL_BSZ=\d+\n)'
)

def add_p100(m):
    return (
        m.group(1)
        + 'elif [ "$GPU_MODE" = "p100" ]; then\n    BSZ=8; GA=4; EVAL_BSZ=4\n'
        + m.group(2)
    )

for name in T5_SCRIPTS:
    if not os.path.exists(name):
        print(f'SKIP (not found): {name}')
        continue
    with open(name) as f:
        c = f.read()
    n_detect = c.count(GPU_OLD)
    c = c.replace(GPU_OLD, GPU_NEW, 1)
    n_bsz = len(BSZ_PAT.findall(c))
    c = BSZ_PAT.sub(add_p100, c)
    with open(name, 'w') as f:
        f.write(c)
    print(f'{name}: gpu_detect={n_detect} bsz_blocks={n_bsz}')

print('Done.')