import os import sys import importlib.util from pathlib import Path def locate_trl_module(): """Find the location of the TRL module in the Python path.""" try: spec = importlib.util.find_spec('trl') if spec is None: print("TRL module not found in the Python path") return None trl_path = Path(spec.origin).parent print(f"Found TRL module at: {trl_path}") return trl_path except Exception as e: print(f"Error locating TRL module: {e}") return None def patch_sft_trainer(): """Patch the SFTTrainer to avoid using torchvision's NMS operator.""" trl_path = locate_trl_module() if trl_path is None: return False # Path to the trainer.py file which likely contains the NMS reference trainer_path = trl_path / "trainer" / "sft_trainer.py" if not trainer_path.exists(): print(f"Could not find the SFT trainer file at: {trainer_path}") return False print(f"Found SFT trainer file at: {trainer_path}") # Read the file content with open(trainer_path, "r") as f: content = f.read() # Check if 'torchvision' is in the file if "torchvision" not in content: print("No torchvision imports found in the SFT trainer file.") return False # Create backup backup_path = trainer_path.with_suffix(".py.bak") print(f"Creating backup at: {backup_path}") with open(backup_path, "w") as f: f.write(content) # Replace imports - common patterns patched_content = content # Pattern 1: Direct import of nms patched_content = patched_content.replace( "from torchvision.ops import nms", "# from torchvision.ops import nms # Commented out to fix NMS error" ) # Pattern 2: Import torchvision patched_content = patched_content.replace( "import torchvision", "# import torchvision # Commented out to fix NMS error" ) # Pattern 3: Import from torchvision.ops patched_content = patched_content.replace( "from torchvision.ops", "# from torchvision.ops # Commented out to fix NMS error" ) # Add our custom NMS implementation custom_nms = """ # Custom NMS implementation to avoid torchvision dependency def nms(boxes, scores, iou_threshold): """ Performs non-maximum suppression (NMS) on the boxes according to their intersection-over-union (IoU). Args: boxes (Tensor[N, 4]): boxes to perform NMS on scores (Tensor[N]): scores for each one of the boxes iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold Returns: Tensor: int64 tensor with the indices of the elements that have been kept """ import torch # Sort boxes by scores _, order = scores.sort(0, descending=True) keep = [] while order.numel() > 0: if order.numel() == 1: keep.append(order.item()) break i = order[0].item() keep.append(i) # Compute IoU of the remaining boxes with the largest box xx1 = torch.max(boxes[i, 0], boxes[order[1:], 0]) yy1 = torch.max(boxes[i, 1], boxes[order[1:], 1]) xx2 = torch.min(boxes[i, 2], boxes[order[1:], 2]) yy2 = torch.min(boxes[i, 3], boxes[order[1:], 3]) w = torch.clamp(xx2 - xx1, min=0.0) h = torch.clamp(yy2 - yy1, min=0.0) inter = w * h # IoU = intersection / (area1 + area2 - intersection) box_area = (boxes[i, 2] - boxes[i, 0]) * (boxes[i, 3] - boxes[i, 1]) other_area = (boxes[order[1:], 2] - boxes[order[1:], 0]) * (boxes[order[1:], 3] - boxes[order[1:], 1]) iou = inter / (box_area + other_area - inter) # Keep boxes with IoU less than threshold inds = torch.where(iou <= iou_threshold)[0] order = order[inds + 1] return torch.tensor(keep, dtype=torch.int64) """ # Add our custom implementation somewhere near the imports import_end = patched_content.find("\n\n", patched_content.find("import ")) if import_end == -1: import_end = patched_content.find("\n", patched_content.find("import ")) patched_content = patched_content[:import_end] + custom_nms + patched_content[import_end:] # Write the patched file with open(trainer_path, "w") as f: f.write(patched_content) print(f"Successfully patched {trainer_path}") print("The SFTTrainer should now work without requiring torchvision's NMS operator") return True if __name__ == "__main__": success = patch_sft_trainer() if success: print("\nPatch applied successfully. You can now run the fine-tuning script.") else: print("\nFailed to apply the patch. Please check the error messages above.")