apishift-env / scripts /fix_trl.py
yaswanth169's picture
Initial APIShift env push
3040bf7 verified
"""Fix broken TRL 0.24.0 optional dep imports for Transformers 5.x on RunPod."""
import re
TRL_BASE = "/usr/local/lib/python3.11/dist-packages/trl"
def disable_availability_block(filepath, guard_fn_name):
"""Replace `if is_X_available():` blocks with `if False:` to skip broken imports."""
try:
with open(filepath, "r") as f:
content = f.read()
original = content
# Replace the specific guard with if False
content = re.sub(
rf'if {re.escape(guard_fn_name)}\(\):',
'if False: # disabled: optional dep not available',
content,
)
if content != original:
with open(filepath, "w") as f:
f.write(content)
print(f"FIXED: {filepath} (disabled {guard_fn_name})")
else:
print(f"SKIP: {filepath} (guard not found)")
except FileNotFoundError:
print(f"NOT FOUND: {filepath}")
# judges.py: disable llm_blender (broken TRANSFORMERS_CACHE on transformers 5.x)
disable_availability_block(f"{TRL_BASE}/trainer/judges.py", "is_llm_blender_available")
# vllm_client.py: disable vllm (not installed)
disable_availability_block(f"{TRL_BASE}/extras/vllm_client.py", "is_vllm_available")
# Verify
print("\nVerifying...")
try:
from unsloth import FastLanguageModel, PatchFastRL
PatchFastRL("GRPO", FastLanguageModel)
print("Unsloth patched TRL")
from trl import GRPOConfig, GRPOTrainer
print("SUCCESS: TRL GRPOTrainer imported OK - ready to train!")
except Exception as e:
print(f"FAILED: {e}")
import torch
print(f"CUDA: {torch.cuda.is_available()}, GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'N/A'}")