Spaces:
Sleeping
Sleeping
Commit ·
af8efb9
1
Parent(s): ffe65ee
chore: add script to update trl dependency version in gridmind_grpo_colab.ipynb
Browse files- scratch/fix_trl_version.py +17 -0
scratch/fix_trl_version.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
|
| 3 |
+
file_path = r"c:\Projects\gridmind\scripts\gridmind_grpo_colab.ipynb"
|
| 4 |
+
|
| 5 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 6 |
+
nb = json.load(f)
|
| 7 |
+
|
| 8 |
+
for cell in nb['cells']:
|
| 9 |
+
if cell['cell_type'] == 'code':
|
| 10 |
+
for i, line in enumerate(cell['source']):
|
| 11 |
+
if '!pip install trl==0.8.6' in line:
|
| 12 |
+
cell['source'][i] = line.replace('trl==0.8.6', 'trl>=0.14.0')
|
| 13 |
+
|
| 14 |
+
with open(file_path, 'w', encoding='utf-8') as f:
|
| 15 |
+
json.dump(nb, f, indent=1)
|
| 16 |
+
|
| 17 |
+
print("Updated notebook to use trl>=0.14.0")
|