Tomas
Add initial project setup with model configuration, requirements, and upload script
58af2e6
unverified
| import os | |
| import sys | |
| import importlib.util | |
| from pathlib import Path | |
| def locate_trl_module(): | |
| """Find the location of the TRL module in the Python path.""" | |
| try: | |
| spec = importlib.util.find_spec('trl') | |
| if spec is None: | |
| print("TRL module not found in the Python path") | |
| return None | |
| trl_path = Path(spec.origin).parent | |
| print(f"Found TRL module at: {trl_path}") | |
| return trl_path | |
| except Exception as e: | |
| print(f"Error locating TRL module: {e}") | |
| return None | |
| def patch_sft_trainer(): | |
| """Patch the SFTTrainer to avoid using torchvision's NMS operator.""" | |
| trl_path = locate_trl_module() | |
| if trl_path is None: | |
| return False | |
| # Path to the trainer.py file which likely contains the NMS reference | |
| trainer_path = trl_path / "trainer" / "sft_trainer.py" | |
| if not trainer_path.exists(): | |
| print(f"Could not find the SFT trainer file at: {trainer_path}") | |
| return False | |
| print(f"Found SFT trainer file at: {trainer_path}") | |
| # Read the file content | |
| with open(trainer_path, "r") as f: | |
| content = f.read() | |
| # Check if 'torchvision' is in the file | |
| if "torchvision" not in content: | |
| print("No torchvision imports found in the SFT trainer file.") | |
| return False | |
| # Create backup | |
| backup_path = trainer_path.with_suffix(".py.bak") | |
| print(f"Creating backup at: {backup_path}") | |
| with open(backup_path, "w") as f: | |
| f.write(content) | |
| # Replace imports - common patterns | |
| patched_content = content | |
| # Pattern 1: Direct import of nms | |
| patched_content = patched_content.replace( | |
| "from torchvision.ops import nms", | |
| "# from torchvision.ops import nms # Commented out to fix NMS error" | |
| ) | |
| # Pattern 2: Import torchvision | |
| patched_content = patched_content.replace( | |
| "import torchvision", | |
| "# import torchvision # Commented out to fix NMS error" | |
| ) | |
| # Pattern 3: Import from torchvision.ops | |
| patched_content = patched_content.replace( | |
| "from torchvision.ops", | |
| "# from torchvision.ops # Commented out to fix NMS error" | |
| ) | |
| # Add our custom NMS implementation | |
| custom_nms = """ | |
| # Custom NMS implementation to avoid torchvision dependency | |
| def nms(boxes, scores, iou_threshold): | |
| """ | |
| Performs non-maximum suppression (NMS) on the boxes according to their | |
| intersection-over-union (IoU). | |
| Args: | |
| boxes (Tensor[N, 4]): boxes to perform NMS on | |
| scores (Tensor[N]): scores for each one of the boxes | |
| iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold | |
| Returns: | |
| Tensor: int64 tensor with the indices of the elements that have been kept | |
| """ | |
| import torch | |
| # Sort boxes by scores | |
| _, order = scores.sort(0, descending=True) | |
| keep = [] | |
| while order.numel() > 0: | |
| if order.numel() == 1: | |
| keep.append(order.item()) | |
| break | |
| i = order[0].item() | |
| keep.append(i) | |
| # Compute IoU of the remaining boxes with the largest box | |
| xx1 = torch.max(boxes[i, 0], boxes[order[1:], 0]) | |
| yy1 = torch.max(boxes[i, 1], boxes[order[1:], 1]) | |
| xx2 = torch.min(boxes[i, 2], boxes[order[1:], 2]) | |
| yy2 = torch.min(boxes[i, 3], boxes[order[1:], 3]) | |
| w = torch.clamp(xx2 - xx1, min=0.0) | |
| h = torch.clamp(yy2 - yy1, min=0.0) | |
| inter = w * h | |
| # IoU = intersection / (area1 + area2 - intersection) | |
| box_area = (boxes[i, 2] - boxes[i, 0]) * (boxes[i, 3] - boxes[i, 1]) | |
| other_area = (boxes[order[1:], 2] - boxes[order[1:], 0]) * (boxes[order[1:], 3] - boxes[order[1:], 1]) | |
| iou = inter / (box_area + other_area - inter) | |
| # Keep boxes with IoU less than threshold | |
| inds = torch.where(iou <= iou_threshold)[0] | |
| order = order[inds + 1] | |
| return torch.tensor(keep, dtype=torch.int64) | |
| """ | |
| # Add our custom implementation somewhere near the imports | |
| import_end = patched_content.find("\n\n", patched_content.find("import ")) | |
| if import_end == -1: | |
| import_end = patched_content.find("\n", patched_content.find("import ")) | |
| patched_content = patched_content[:import_end] + custom_nms + patched_content[import_end:] | |
| # Write the patched file | |
| with open(trainer_path, "w") as f: | |
| f.write(patched_content) | |
| print(f"Successfully patched {trainer_path}") | |
| print("The SFTTrainer should now work without requiring torchvision's NMS operator") | |
| return True | |
| if __name__ == "__main__": | |
| success = patch_sft_trainer() | |
| if success: | |
| print("\nPatch applied successfully. You can now run the fine-tuning script.") | |
| else: | |
| print("\nFailed to apply the patch. Please check the error messages above.") |