Wildnerve-tlm01_Hybrid_Model / transformer_patches.py
WildnerveAI's picture
Upload 3 files
b671111 verified
"""
Patches for transformers library integration issues.
This module provides patches for various integration issues with the transformers library.
Import this module before importing any modules that use transformers to apply the patches.
"""
import sys
import logging
import importlib
from typing import Any, Dict, Optional, List
try:
from codecarbon import create_emissions_callback
except ImportError:
# minimal fallback for missing codecarbon
def create_emissions_callback(*args, **kwargs):
class DummyCallback:
def __init__(self, *a, **k): pass
return DummyCallback
logger = logging.getLogger(__name__)
# Dictionary to store original modules
_original_modules = {}
# Dictionary to store patch status
_patch_status = {}
# Determine if we can import tensorflow safely
def _can_import_tensorflow() -> bool:
"""Check if tensorflow can be imported safely without initializing it fully."""
try:
# Use importlib.util instead of direct import to avoid initializing TF
import importlib.util
tf_spec = importlib.util.find_spec("tensorflow")
return tf_spec is not None
except (ImportError, ModuleNotFoundError):
return False
def safe_import(module_name: str, error_ok: bool = False) -> Any:
"""Safely import a module without crashing if it doesn't exist."""
try:
return importlib.import_module(module_name)
except (ImportError, ModuleNotFoundError):
if not error_ok:
logger.warning(f"Could not import {module_name}")
return None
def patch_transformers_integrations():
"""Patch transformers.integrations to provide proper implementations
that don't require TensorFlow dependencies. This uses our clean architecture carbon tracking system."""
try:
# Try to import transformers first
import transformers
# Try to safely import integrations modules
integrations = safe_import('transformers.integrations', error_ok=True)
# If the module exists but has no CodeCarbonCallback, add ours
if integrations and not hasattr(integrations, 'CodeCarbonCallback'):
logger.info("Adding clean implementation for missing CodeCarbonCallback")
integrations.CodeCarbonCallback = create_emissions_callback
_patch_status['transformers.integrations.CodeCarbonCallback'] = True
# Also add placeholder for TensorBoardCallback if missing
if not hasattr(integrations, 'TensorBoardCallback'):
integrations.TensorBoardCallback = type('PlaceholderTensorBoardCallback', (), {
'__init__': lambda self, *args, **kwargs: None,
'on_train_begin': lambda self, *args, **kwargs: None,
'on_train_end': lambda self, *args, **kwargs: None,
})
_patch_status['transformers.integrations.TensorBoardCallback'] = True
return True
# If module couldn't be imported, create a proxy module
elif not integrations:
logger.info("Creating proxy transformers.integrations module")
# Create a clean integration module
class IntegrationsModule:
"""Clean implementation of transformers.integrations."""
def __init__(self):
self.CodeCarbonCallback = create_emissions_callback
self.TensorBoardCallback = type('PlaceholderTensorBoardCallback', (), {
'__init__': lambda self, *args, **kwargs: None,
'on_train_begin': lambda self, *args, **kwargs: None,
'on_train_end': lambda self, *args, **kwargs: None,
})
# Add other callbacks as needed here
# Create the module and install it
integrations_module = IntegrationsModule()
sys.modules['transformers.integrations'] = integrations_module
# Make sure transformers has access to it
if not hasattr(transformers, 'integrations'):
transformers.integrations = integrations_module
_patch_status['transformers.integrations'] = True
return True
# If module already has a CodeCarbonCallback, nothing to patch
else:
logger.info("transformers.integrations.CodeCarbonCallback already exists")
return True
except ImportError as e:
logger.error(f"Could not patch transformers.integrations: {e}")
return False
def patch_tensorflow_imports():
"""
Patch tensorflow imports in transformers to avoid errors.
"""
try:
# Only apply if tensorflow isn't available
if _can_import_tensorflow():
logger.info("TensorFlow is available, no need to patch imports")
return True
logger.info("TensorFlow not available, patching imports")
# Create a proxy module for transformers.modeling_tf_utils
class TFUtilsProxy:
"""Proxy for TF utilities."""
def __getattr__(self, name):
raise ImportError(
"TensorFlow not installed. Cannot use TensorFlow models. "
"To use PyTorch-only functionality, import from transformers directly."
)
# Install the proxy
if 'transformers.modeling_tf_utils' not in sys.modules:
sys.modules['transformers.modeling_tf_utils'] = TFUtilsProxy()
_patch_status['transformers.modeling_tf_utils'] = True
# Also patch related imports
related_modules = [
'transformers.models.bert.modeling_tf_bert',
'transformers.models.gpt2.modeling_tf_gpt2',
# Add other TF modules here
]
for module_name in related_modules:
if module_name not in sys.modules:
sys.modules[module_name] = TFUtilsProxy()
_patch_status[module_name] = True
return True
except Exception as e:
logger.error(f"Error patching tensorflow imports: {e}")
return False
def apply_all_patches():
"""Apply all available patches."""
patches_applied = 0
# Apply TensorFlow import patches
if patch_tensorflow_imports():
patches_applied += 1
# Apply transformers integrations patches
if patch_transformers_integrations():
patches_applied += 1
logger.info(f"Applied {patches_applied} patches successfully")
return patches_applied > 0
import transformers
import logging
logger = logging.getLogger(__name__)
def patch_transformers():
"""
Apply custom patches to the transformers module.
For instance, this patch logs every call to AutoTokenizer.from_pretrained,
which can help debug tokenizer loading issues.
"""
try:
original_from_pretrained = transformers.AutoTokenizer.from_pretrained
def patched_from_pretrained(*args, **kwargs):
logger.info(f"AutoTokenizer.from_pretrained called with args: {args}, kwargs: {kwargs}")
return original_from_pretrained(*args, **kwargs)
transformers.AutoTokenizer.from_pretrained = patched_from_pretrained
logger.info("Successfully patched transformers.AutoTokenizer.from_pretrained")
# Additional patches can be applied here.
except Exception as e:
logger.error(f"Failed to patch transformers: {e}")
raise e
# Apply patches automatically on module import to ensure stability at runtime.
patch_transformers()
# Apply patches when the module is imported
try:
apply_all_patches()
except Exception as e:
logger.warning(f"Error applying patches: {e}")
# Still continue execution in case of errors
if __name__ == "__main__":
# Set up logging when run directly
logging.basicConfig(level=logging.INFO)
# Apply and report patches
result = apply_all_patches()
if result:
print("✓ Successfully applied all patches")
else:
print("✗ Failed to apply patches")
# Print patch status
print("\nPatch status:")
for patch, status in _patch_status.items():
print(f" {'✓' if status else '✗'} {patch}")
"""
Transformer patches to make the model work better with HuggingFace transformers.
This file applies monkey patches to fix compatibility issues or add functionality.
"""
import logging
from typing import Dict, Any, Optional
logger = logging.getLogger(__name__)
def apply_transformer_patches():
"""Apply monkey patches to transformers if needed"""
try:
import transformers
logger.info(f"Applying patches to transformers v{transformers.__version__}")
# No patches needed currently, but you can add them here if needed in future
except ImportError:
logger.warning("Transformers library not found, skipping patches")
# Apply patches when imported
apply_transformer_patches()
import logging
logger = logging.getLogger(__name__)
def apply_patch_to_layer(layer):
original_forward = layer.forward
def forward_with_debug(*args, **kwargs):
# log input shapes and dtypes
logger.debug(f"Patch forward inputs shapes: {[getattr(t, 'shape', None) for t in args]}, "
f"dtypes: {[getattr(t, 'dtype', None) for t in args]}")
out = original_forward(*args, **kwargs)
# log output shape and dtype
logger.debug(f"Patch forward output shape: {getattr(out, 'shape', None)}, "
f"dtype: {getattr(out, 'dtype', None)}")
return out
layer.forward = forward_with_debug
"""
Patches for the transformers library to ensure compatibility
"""
import logging
from types import FunctionType
logger = logging.getLogger(__name__)
def apply_transformers_patches():
"""Apply patches to transformers library"""
try:
import torch
import transformers
# Only apply safe patches that don't interfere with GPU usage
# Don't replace torch.device with a CPU-only version!
# Fix AutoModel.from_pretrained to handle device mapping safely
if hasattr(transformers, 'AutoModel'):
original_from_pretrained = transformers.AutoModel.from_pretrained
def safe_from_pretrained(*args, **kwargs):
# Keep any device_map parameter but handle it safely
if 'device_map' in kwargs and not isinstance(kwargs['device_map'], (str, dict)):
logger.info("Fixing invalid device_map parameter")
kwargs['device_map'] = "auto" if torch.cuda.is_available() else None
# Use cuda for faster performance if available
if 'torch_dtype' not in kwargs:
kwargs['torch_dtype'] = torch.float16 if torch.cuda.is_available() else torch.float32
return original_from_pretrained(*args, **kwargs)
transformers.AutoModel.from_pretrained = safe_from_pretrained
logger.info("Applied patch to AutoModel.from_pretrained that preserves GPU usage")
return True
except Exception as e:
logger.error(f"Failed to apply transformers patches: {e}")
return False
# Apply patches when module is imported
apply_transformers_patches()