Spaces:
Running
Running
File size: 6,179 Bytes
6f2b9f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
"""
Robust model loader with compatibility fixes for scikit-learn version mismatches.
"""
import joblib
import pickle
import sys
import warnings
class SklearnCompatibilityUnpickler(pickle.Unpickler):
"""Custom unpickler that handles scikit-learn compatibility issues."""
def find_class(self, module, name):
# Handle EuclideanDistance compatibility issue
if module == 'sklearn.metrics._dist_metrics' and name == 'EuclideanDistance':
try:
# Try to import and patch the module
import sklearn.metrics._dist_metrics as dist_metrics
# Check if EuclideanDistance exists
if not hasattr(dist_metrics, 'EuclideanDistance'):
# Try to create it from available classes
if hasattr(dist_metrics, 'EuclideanDistance32'):
# Create a class that acts like EuclideanDistance
class EuclideanDistanceWrapper(dist_metrics.EuclideanDistance32):
pass
dist_metrics.EuclideanDistance = EuclideanDistanceWrapper
elif hasattr(dist_metrics, 'EuclideanDistance64'):
class EuclideanDistanceWrapper(dist_metrics.EuclideanDistance64):
pass
dist_metrics.EuclideanDistance = EuclideanDistanceWrapper
else:
# Last resort: try to find it in neighbors module
try:
from sklearn.neighbors._dist_metrics import EuclideanDistance as ED
dist_metrics.EuclideanDistance = ED
except:
# Create a minimal stub class
class EuclideanDistanceStub:
def __init__(self, *args, **kwargs):
pass
dist_metrics.EuclideanDistance = EuclideanDistanceStub
return getattr(dist_metrics, 'EuclideanDistance')
except Exception as e:
warnings.warn(f"Could not patch EuclideanDistance: {e}")
# Fallback: return a stub class
class EuclideanDistanceStub:
def __init__(self, *args, **kwargs):
pass
return EuclideanDistanceStub
# For all other classes, use default behavior
return super().find_class(module, name)
def load_model_with_compatibility(model_path):
"""
Load a joblib model with compatibility fixes.
Args:
model_path: Path to the .joblib model file
Returns:
Loaded model object
"""
try:
# First, try to patch the module before loading
try:
import sklearn.metrics._dist_metrics as dist_metrics
if not hasattr(dist_metrics, 'EuclideanDistance'):
if hasattr(dist_metrics, 'EuclideanDistance32'):
dist_metrics.EuclideanDistance = dist_metrics.EuclideanDistance32
elif hasattr(dist_metrics, 'EuclideanDistance64'):
dist_metrics.EuclideanDistance = dist_metrics.EuclideanDistance64
except:
pass
# Try standard loading first
try:
return joblib.load(model_path)
except (AttributeError, ModuleNotFoundError) as e:
if 'EuclideanDistance' in str(e):
# Try with custom unpickler
warnings.warn("Using compatibility mode to load model...")
try:
# Use joblib's internal file handling but with custom unpickler
import joblib.numpy_pickle
# Open the file
with open(model_path, 'rb') as f:
# Try to use joblib's format detection
unpickler = SklearnCompatibilityUnpickler(f)
try:
return unpickler.load()
except:
# If that doesn't work, try monkey-patching more aggressively
# Re-import after patching
import importlib
import sklearn.metrics._dist_metrics
importlib.reload(sklearn.metrics._dist_metrics)
# Patch again after reload
dist_metrics = sklearn.metrics._dist_metrics
if not hasattr(dist_metrics, 'EuclideanDistance'):
if hasattr(dist_metrics, 'EuclideanDistance32'):
# Create a proper alias
dist_metrics.EuclideanDistance = type('EuclideanDistance',
(dist_metrics.EuclideanDistance32,), {})
# Try loading again
return joblib.load(model_path)
except Exception as e2:
raise RuntimeError(f"Failed to load model even with compatibility mode: {e2}")
else:
raise
except Exception as e:
raise RuntimeError(f"Error loading model from {model_path}: {e}")
def load_sklearn_model_safe(model_path, scaler_path=None):
"""
Safely load sklearn model and scaler with compatibility fixes.
Args:
model_path: Path to model .joblib file
scaler_path: Path to scaler .joblib file (optional)
Returns:
Tuple of (model, scaler) or (model, None) if scaler_path not provided
"""
model = load_model_with_compatibility(model_path)
scaler = None
if scaler_path:
try:
scaler = load_model_with_compatibility(scaler_path)
except Exception as e:
warnings.warn(f"Could not load scaler: {e}")
return model, scaler
|