Prajwal782007 commited on
Commit
e9f731a
·
1 Parent(s): 7597057

feat: add script to automate TRL dependency updates in the GRPO Colab notebook

Browse files
Files changed (1) hide show
  1. scratch/fix_trl_version.py +9 -4
scratch/fix_trl_version.py CHANGED
@@ -5,13 +5,18 @@ file_path = r"c:\Projects\gridmind\scripts\gridmind_grpo_colab.ipynb"
5
  with open(file_path, 'r', encoding='utf-8') as f:
6
  nb = json.load(f)
7
 
 
8
  for cell in nb['cells']:
9
  if cell['cell_type'] == 'code':
10
- for i, line in enumerate(cell['source']):
11
- if '!pip install trl==0.8.6' in line:
12
- cell['source'][i] = line.replace('trl==0.8.6', 'trl>=0.14.0')
 
 
 
 
13
 
14
  with open(file_path, 'w', encoding='utf-8') as f:
15
  json.dump(nb, f, indent=1)
16
 
17
- print("Updated notebook to use trl>=0.14.0")
 
5
  with open(file_path, 'r', encoding='utf-8') as f:
6
  nb = json.load(f)
7
 
8
+ # Replace the first code cell's source
9
  for cell in nb['cells']:
10
  if cell['cell_type'] == 'code':
11
+ cell['source'] = [
12
+ "!pip install trl transformers accelerate datasets unsloth requests pandas matplotlib\n",
13
+ "import os\n",
14
+ "os.makedirs('results', exist_ok=True)\n",
15
+ "print(\"✔ All dependencies installed\")\n"
16
+ ]
17
+ break # Only replace the very first code cell
18
 
19
  with open(file_path, 'w', encoding='utf-8') as f:
20
  json.dump(nb, f, indent=1)
21
 
22
+ print("Updated Cell 1 successfully.")