File size: 6,179 Bytes
6f2b9f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""
Robust model loader with compatibility fixes for scikit-learn version mismatches.
"""
import joblib
import pickle
import sys
import warnings

class SklearnCompatibilityUnpickler(pickle.Unpickler):
    """Custom unpickler that handles scikit-learn compatibility issues."""
    
    def find_class(self, module, name):
        # Handle EuclideanDistance compatibility issue
        if module == 'sklearn.metrics._dist_metrics' and name == 'EuclideanDistance':
            try:
                # Try to import and patch the module
                import sklearn.metrics._dist_metrics as dist_metrics
                
                # Check if EuclideanDistance exists
                if not hasattr(dist_metrics, 'EuclideanDistance'):
                    # Try to create it from available classes
                    if hasattr(dist_metrics, 'EuclideanDistance32'):
                        # Create a class that acts like EuclideanDistance
                        class EuclideanDistanceWrapper(dist_metrics.EuclideanDistance32):
                            pass
                        dist_metrics.EuclideanDistance = EuclideanDistanceWrapper
                    elif hasattr(dist_metrics, 'EuclideanDistance64'):
                        class EuclideanDistanceWrapper(dist_metrics.EuclideanDistance64):
                            pass
                        dist_metrics.EuclideanDistance = EuclideanDistanceWrapper
                    else:
                        # Last resort: try to find it in neighbors module
                        try:
                            from sklearn.neighbors._dist_metrics import EuclideanDistance as ED
                            dist_metrics.EuclideanDistance = ED
                        except:
                            # Create a minimal stub class
                            class EuclideanDistanceStub:
                                def __init__(self, *args, **kwargs):
                                    pass
                            dist_metrics.EuclideanDistance = EuclideanDistanceStub
                
                return getattr(dist_metrics, 'EuclideanDistance')
            except Exception as e:
                warnings.warn(f"Could not patch EuclideanDistance: {e}")
                # Fallback: return a stub class
                class EuclideanDistanceStub:
                    def __init__(self, *args, **kwargs):
                        pass
                return EuclideanDistanceStub
        
        # For all other classes, use default behavior
        return super().find_class(module, name)


def load_model_with_compatibility(model_path):
    """
    Load a joblib model with compatibility fixes.
    
    Args:
        model_path: Path to the .joblib model file
        
    Returns:
        Loaded model object
    """
    try:
        # First, try to patch the module before loading
        try:
            import sklearn.metrics._dist_metrics as dist_metrics
            if not hasattr(dist_metrics, 'EuclideanDistance'):
                if hasattr(dist_metrics, 'EuclideanDistance32'):
                    dist_metrics.EuclideanDistance = dist_metrics.EuclideanDistance32
                elif hasattr(dist_metrics, 'EuclideanDistance64'):
                    dist_metrics.EuclideanDistance = dist_metrics.EuclideanDistance64
        except:
            pass
        
        # Try standard loading first
        try:
            return joblib.load(model_path)
        except (AttributeError, ModuleNotFoundError) as e:
            if 'EuclideanDistance' in str(e):
                # Try with custom unpickler
                warnings.warn("Using compatibility mode to load model...")
                try:
                    # Use joblib's internal file handling but with custom unpickler
                    import joblib.numpy_pickle
                    
                    # Open the file
                    with open(model_path, 'rb') as f:
                        # Try to use joblib's format detection
                        unpickler = SklearnCompatibilityUnpickler(f)
                        try:
                            return unpickler.load()
                        except:
                            # If that doesn't work, try monkey-patching more aggressively
                            # Re-import after patching
                            import importlib
                            import sklearn.metrics._dist_metrics
                            importlib.reload(sklearn.metrics._dist_metrics)
                            
                            # Patch again after reload
                            dist_metrics = sklearn.metrics._dist_metrics
                            if not hasattr(dist_metrics, 'EuclideanDistance'):
                                if hasattr(dist_metrics, 'EuclideanDistance32'):
                                    # Create a proper alias
                                    dist_metrics.EuclideanDistance = type('EuclideanDistance', 
                                                                         (dist_metrics.EuclideanDistance32,), {})
                            
                            # Try loading again
                            return joblib.load(model_path)
                except Exception as e2:
                    raise RuntimeError(f"Failed to load model even with compatibility mode: {e2}")
            else:
                raise
    except Exception as e:
        raise RuntimeError(f"Error loading model from {model_path}: {e}")


def load_sklearn_model_safe(model_path, scaler_path=None):
    """
    Safely load sklearn model and scaler with compatibility fixes.
    
    Args:
        model_path: Path to model .joblib file
        scaler_path: Path to scaler .joblib file (optional)
        
    Returns:
        Tuple of (model, scaler) or (model, None) if scaler_path not provided
    """
    model = load_model_with_compatibility(model_path)
    scaler = None
    
    if scaler_path:
        try:
            scaler = load_model_with_compatibility(scaler_path)
        except Exception as e:
            warnings.warn(f"Could not load scaler: {e}")
    
    return model, scaler