Spaces:
Running
Running
| """Fix broken TRL 0.24.0 optional dep imports for Transformers 5.x on RunPod.""" | |
| import re | |
| TRL_BASE = "/usr/local/lib/python3.11/dist-packages/trl" | |
| def disable_availability_block(filepath, guard_fn_name): | |
| """Replace `if is_X_available():` blocks with `if False:` to skip broken imports.""" | |
| try: | |
| with open(filepath, "r") as f: | |
| content = f.read() | |
| original = content | |
| # Replace the specific guard with if False | |
| content = re.sub( | |
| rf'if {re.escape(guard_fn_name)}\(\):', | |
| 'if False: # disabled: optional dep not available', | |
| content, | |
| ) | |
| if content != original: | |
| with open(filepath, "w") as f: | |
| f.write(content) | |
| print(f"FIXED: {filepath} (disabled {guard_fn_name})") | |
| else: | |
| print(f"SKIP: {filepath} (guard not found)") | |
| except FileNotFoundError: | |
| print(f"NOT FOUND: {filepath}") | |
| # judges.py: disable llm_blender (broken TRANSFORMERS_CACHE on transformers 5.x) | |
| disable_availability_block(f"{TRL_BASE}/trainer/judges.py", "is_llm_blender_available") | |
| # vllm_client.py: disable vllm (not installed) | |
| disable_availability_block(f"{TRL_BASE}/extras/vllm_client.py", "is_vllm_available") | |
| # Verify | |
| print("\nVerifying...") | |
| try: | |
| from unsloth import FastLanguageModel, PatchFastRL | |
| PatchFastRL("GRPO", FastLanguageModel) | |
| print("Unsloth patched TRL") | |
| from trl import GRPOConfig, GRPOTrainer | |
| print("SUCCESS: TRL GRPOTrainer imported OK - ready to train!") | |
| except Exception as e: | |
| print(f"FAILED: {e}") | |
| import torch | |
| print(f"CUDA: {torch.cuda.is_available()}, GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'N/A'}") | |