File size: 2,548 Bytes
dc14955
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import json

file_path = r"c:\Projects\gridmind\scripts\gridmind_grpo_colab.ipynb"

with open(file_path, 'r', encoding='utf-8') as f:
    nb = json.load(f)

for cell in nb['cells']:
    if cell['cell_type'] != 'code':
        continue
        
    source = cell['source']
    source_text = "".join(source)
    
    # 1. Cell 1: Check for dependency installation cell
    if "!pip install trl" in source_text and "✔ All dependencies installed" in source_text:
        # Check if already added
        if "RuntimeError(\"❌ No GPU found!" not in source_text:
            cell['source'].extend([
                "import torch\n",
                "if not torch.cuda.is_available():\n",
                "    raise RuntimeError(\"❌ No GPU found! Go to Runtime → Change runtime type → Select T4 GPU\")\n",
                "print(f\"✔ GPU ready: {torch.cuda.get_device_name(0)}\")\n"
            ])
            print("Updated Cell 1")
            
    # 2. Step 4 cell
    if 'device_map="cuda" if torch.cuda.is_available() else "cpu"' in source_text:
        for i, line in enumerate(source):
            if 'device_map="cuda" if torch.cuda.is_available() else "cpu"' in line:
                source[i] = line.replace('device_map="cuda" if torch.cuda.is_available() else "cpu"', 'device_map="cuda"')
        print("Updated Step 4 cell")

    # 3. Step 7 cell
    if 'inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=400).to(model.device)' in source_text:
        for i, line in enumerate(source):
            if 'inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=400).to(model.device)' in line:
                source[i] = line.replace('.to(model.device)', '.to("cuda")')
        print("Updated Step 7 cell")

    # 4. Step 6 cell (GRPO config)
    if 'config = GRPOConfig(' in source_text:
        fp16_found = False
        for i, line in enumerate(source):
            if 'fp16=True,' in line:
                fp16_found = True
                break
        
        if not fp16_found:
            # Add fp16=True, after max_steps=60, or just inside config
            for i, line in enumerate(source):
                if 'config = GRPOConfig(' in line:
                    source.insert(i + 1, "    fp16=True,\n")
                    break
            print("Added fp16=True, to Step 6 cell")
        else:
            print("fp16=True, already present in Step 6 cell")

with open(file_path, 'w', encoding='utf-8') as f:
    json.dump(nb, f, indent=1)
    
print("All updates applied.")