| """
|
| Patches for transformers library integration issues.
|
|
|
| This module provides patches for various integration issues with the transformers library.
|
| Import this module before importing any modules that use transformers to apply the patches.
|
| """
|
| import sys
|
| import logging
|
| import importlib
|
| from typing import Any, Dict, Optional, List
|
|
|
| try:
|
| from codecarbon import create_emissions_callback
|
| except ImportError:
|
|
|
| def create_emissions_callback(*args, **kwargs):
|
| class DummyCallback:
|
| def __init__(self, *a, **k): pass
|
| return DummyCallback
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
| _original_modules = {}
|
|
|
| _patch_status = {}
|
|
|
|
|
| def _can_import_tensorflow() -> bool:
|
| """Check if tensorflow can be imported safely without initializing it fully."""
|
| try:
|
|
|
| import importlib.util
|
| tf_spec = importlib.util.find_spec("tensorflow")
|
| return tf_spec is not None
|
| except (ImportError, ModuleNotFoundError):
|
| return False
|
|
|
| def safe_import(module_name: str, error_ok: bool = False) -> Any:
|
| """Safely import a module without crashing if it doesn't exist."""
|
| try:
|
| return importlib.import_module(module_name)
|
| except (ImportError, ModuleNotFoundError):
|
| if not error_ok:
|
| logger.warning(f"Could not import {module_name}")
|
| return None
|
|
|
| def patch_transformers_integrations():
|
| """Patch transformers.integrations to provide proper implementations
|
| that don't require TensorFlow dependencies. This uses our clean architecture carbon tracking system."""
|
| try:
|
|
|
| import transformers
|
|
|
|
|
| integrations = safe_import('transformers.integrations', error_ok=True)
|
|
|
|
|
| if integrations and not hasattr(integrations, 'CodeCarbonCallback'):
|
| logger.info("Adding clean implementation for missing CodeCarbonCallback")
|
| integrations.CodeCarbonCallback = create_emissions_callback
|
| _patch_status['transformers.integrations.CodeCarbonCallback'] = True
|
|
|
|
|
| if not hasattr(integrations, 'TensorBoardCallback'):
|
| integrations.TensorBoardCallback = type('PlaceholderTensorBoardCallback', (), {
|
| '__init__': lambda self, *args, **kwargs: None,
|
| 'on_train_begin': lambda self, *args, **kwargs: None,
|
| 'on_train_end': lambda self, *args, **kwargs: None,
|
| })
|
| _patch_status['transformers.integrations.TensorBoardCallback'] = True
|
|
|
| return True
|
|
|
|
|
| elif not integrations:
|
| logger.info("Creating proxy transformers.integrations module")
|
|
|
|
|
| class IntegrationsModule:
|
| """Clean implementation of transformers.integrations."""
|
| def __init__(self):
|
| self.CodeCarbonCallback = create_emissions_callback
|
| self.TensorBoardCallback = type('PlaceholderTensorBoardCallback', (), {
|
| '__init__': lambda self, *args, **kwargs: None,
|
| 'on_train_begin': lambda self, *args, **kwargs: None,
|
| 'on_train_end': lambda self, *args, **kwargs: None,
|
| })
|
|
|
|
|
|
|
|
|
| integrations_module = IntegrationsModule()
|
| sys.modules['transformers.integrations'] = integrations_module
|
|
|
|
|
| if not hasattr(transformers, 'integrations'):
|
| transformers.integrations = integrations_module
|
|
|
| _patch_status['transformers.integrations'] = True
|
| return True
|
|
|
|
|
| else:
|
| logger.info("transformers.integrations.CodeCarbonCallback already exists")
|
| return True
|
|
|
| except ImportError as e:
|
| logger.error(f"Could not patch transformers.integrations: {e}")
|
| return False
|
|
|
| def patch_tensorflow_imports():
|
| """
|
| Patch tensorflow imports in transformers to avoid errors.
|
| """
|
| try:
|
|
|
| if _can_import_tensorflow():
|
| logger.info("TensorFlow is available, no need to patch imports")
|
| return True
|
|
|
| logger.info("TensorFlow not available, patching imports")
|
|
|
|
|
| class TFUtilsProxy:
|
| """Proxy for TF utilities."""
|
| def __getattr__(self, name):
|
| raise ImportError(
|
| "TensorFlow not installed. Cannot use TensorFlow models. "
|
| "To use PyTorch-only functionality, import from transformers directly."
|
| )
|
|
|
|
|
| if 'transformers.modeling_tf_utils' not in sys.modules:
|
| sys.modules['transformers.modeling_tf_utils'] = TFUtilsProxy()
|
| _patch_status['transformers.modeling_tf_utils'] = True
|
|
|
|
|
| related_modules = [
|
| 'transformers.models.bert.modeling_tf_bert',
|
| 'transformers.models.gpt2.modeling_tf_gpt2',
|
|
|
| ]
|
|
|
| for module_name in related_modules:
|
| if module_name not in sys.modules:
|
| sys.modules[module_name] = TFUtilsProxy()
|
| _patch_status[module_name] = True
|
|
|
| return True
|
|
|
| except Exception as e:
|
| logger.error(f"Error patching tensorflow imports: {e}")
|
| return False
|
|
|
| def apply_all_patches():
|
| """Apply all available patches."""
|
| patches_applied = 0
|
|
|
|
|
| if patch_tensorflow_imports():
|
| patches_applied += 1
|
|
|
|
|
| if patch_transformers_integrations():
|
| patches_applied += 1
|
|
|
| logger.info(f"Applied {patches_applied} patches successfully")
|
| return patches_applied > 0
|
|
|
| import transformers
|
| import logging
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
| def patch_transformers():
|
| """
|
| Apply custom patches to the transformers module.
|
| For instance, this patch logs every call to AutoTokenizer.from_pretrained,
|
| which can help debug tokenizer loading issues.
|
| """
|
| try:
|
| original_from_pretrained = transformers.AutoTokenizer.from_pretrained
|
|
|
| def patched_from_pretrained(*args, **kwargs):
|
| logger.info(f"AutoTokenizer.from_pretrained called with args: {args}, kwargs: {kwargs}")
|
| return original_from_pretrained(*args, **kwargs)
|
|
|
| transformers.AutoTokenizer.from_pretrained = patched_from_pretrained
|
| logger.info("Successfully patched transformers.AutoTokenizer.from_pretrained")
|
|
|
|
|
| except Exception as e:
|
| logger.error(f"Failed to patch transformers: {e}")
|
| raise e
|
|
|
|
|
| patch_transformers()
|
|
|
|
|
| try:
|
| apply_all_patches()
|
| except Exception as e:
|
| logger.warning(f"Error applying patches: {e}")
|
|
|
|
|
| if __name__ == "__main__":
|
|
|
| logging.basicConfig(level=logging.INFO)
|
|
|
|
|
| result = apply_all_patches()
|
| if result:
|
| print("✓ Successfully applied all patches")
|
| else:
|
| print("✗ Failed to apply patches")
|
|
|
|
|
| print("\nPatch status:")
|
| for patch, status in _patch_status.items():
|
| print(f" {'✓' if status else '✗'} {patch}")
|
|
|
| """
|
| Transformer patches to make the model work better with HuggingFace transformers.
|
| This file applies monkey patches to fix compatibility issues or add functionality.
|
| """
|
| import logging
|
| from typing import Dict, Any, Optional
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
| def apply_transformer_patches():
|
| """Apply monkey patches to transformers if needed"""
|
| try:
|
| import transformers
|
| logger.info(f"Applying patches to transformers v{transformers.__version__}")
|
|
|
|
|
|
|
| except ImportError:
|
| logger.warning("Transformers library not found, skipping patches")
|
|
|
|
|
| apply_transformer_patches()
|
|
|
| import logging
|
| logger = logging.getLogger(__name__)
|
|
|
| def apply_patch_to_layer(layer):
|
| original_forward = layer.forward
|
|
|
| def forward_with_debug(*args, **kwargs):
|
|
|
| logger.debug(f"Patch forward inputs shapes: {[getattr(t, 'shape', None) for t in args]}, "
|
| f"dtypes: {[getattr(t, 'dtype', None) for t in args]}")
|
| out = original_forward(*args, **kwargs)
|
|
|
| logger.debug(f"Patch forward output shape: {getattr(out, 'shape', None)}, "
|
| f"dtype: {getattr(out, 'dtype', None)}")
|
| return out
|
|
|
| layer.forward = forward_with_debug
|
|
|
| """
|
| Patches for the transformers library to ensure compatibility
|
| """
|
| import logging
|
| from types import FunctionType
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
| def apply_transformers_patches():
|
| """Apply patches to transformers library"""
|
| try:
|
| import torch
|
| import transformers
|
|
|
|
|
|
|
|
|
|
|
| if hasattr(transformers, 'AutoModel'):
|
| original_from_pretrained = transformers.AutoModel.from_pretrained
|
|
|
| def safe_from_pretrained(*args, **kwargs):
|
|
|
| if 'device_map' in kwargs and not isinstance(kwargs['device_map'], (str, dict)):
|
| logger.info("Fixing invalid device_map parameter")
|
| kwargs['device_map'] = "auto" if torch.cuda.is_available() else None
|
|
|
|
|
| if 'torch_dtype' not in kwargs:
|
| kwargs['torch_dtype'] = torch.float16 if torch.cuda.is_available() else torch.float32
|
|
|
| return original_from_pretrained(*args, **kwargs)
|
|
|
| transformers.AutoModel.from_pretrained = safe_from_pretrained
|
| logger.info("Applied patch to AutoModel.from_pretrained that preserves GPU usage")
|
|
|
| return True
|
| except Exception as e:
|
| logger.error(f"Failed to apply transformers patches: {e}")
|
| return False
|
|
|
|
|
| apply_transformers_patches()
|
|
|