Spaces:
Sleeping
Sleeping
File size: 2,548 Bytes
dc14955 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 | import json
file_path = r"c:\Projects\gridmind\scripts\gridmind_grpo_colab.ipynb"
with open(file_path, 'r', encoding='utf-8') as f:
nb = json.load(f)
for cell in nb['cells']:
if cell['cell_type'] != 'code':
continue
source = cell['source']
source_text = "".join(source)
# 1. Cell 1: Check for dependency installation cell
if "!pip install trl" in source_text and "✔ All dependencies installed" in source_text:
# Check if already added
if "RuntimeError(\"❌ No GPU found!" not in source_text:
cell['source'].extend([
"import torch\n",
"if not torch.cuda.is_available():\n",
" raise RuntimeError(\"❌ No GPU found! Go to Runtime → Change runtime type → Select T4 GPU\")\n",
"print(f\"✔ GPU ready: {torch.cuda.get_device_name(0)}\")\n"
])
print("Updated Cell 1")
# 2. Step 4 cell
if 'device_map="cuda" if torch.cuda.is_available() else "cpu"' in source_text:
for i, line in enumerate(source):
if 'device_map="cuda" if torch.cuda.is_available() else "cpu"' in line:
source[i] = line.replace('device_map="cuda" if torch.cuda.is_available() else "cpu"', 'device_map="cuda"')
print("Updated Step 4 cell")
# 3. Step 7 cell
if 'inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=400).to(model.device)' in source_text:
for i, line in enumerate(source):
if 'inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=400).to(model.device)' in line:
source[i] = line.replace('.to(model.device)', '.to("cuda")')
print("Updated Step 7 cell")
# 4. Step 6 cell (GRPO config)
if 'config = GRPOConfig(' in source_text:
fp16_found = False
for i, line in enumerate(source):
if 'fp16=True,' in line:
fp16_found = True
break
if not fp16_found:
# Add fp16=True, after max_steps=60, or just inside config
for i, line in enumerate(source):
if 'config = GRPOConfig(' in line:
source.insert(i + 1, " fp16=True,\n")
break
print("Added fp16=True, to Step 6 cell")
else:
print("fp16=True, already present in Step 6 cell")
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(nb, f, indent=1)
print("All updates applied.")
|