Spaces:
Sleeping
Sleeping
Commit
·
4b1bc71
1
Parent(s):
60663e2
Update app.py
Browse files
app.py
CHANGED
|
@@ -35,42 +35,98 @@ def load_model_from_pickle(pickle_path="best_model.pkl"):
|
|
| 35 |
if not os.path.exists(pickle_path):
|
| 36 |
return f"❌ Model file not found: {pickle_path}\n\nPlease ensure best_model.pkl is uploaded to the HuggingFace Space."
|
| 37 |
|
| 38 |
-
#
|
| 39 |
-
|
| 40 |
-
import
|
|
|
|
|
|
|
| 41 |
|
| 42 |
-
#
|
| 43 |
-
|
| 44 |
-
fake_cuda.is_available = lambda: True
|
| 45 |
-
fake_cuda.device_count = lambda: 1
|
| 46 |
-
fake_cuda.current_device = lambda: 0
|
| 47 |
-
fake_cuda.get_device_name = lambda x: "CPU (mocked as CUDA)"
|
| 48 |
-
fake_cuda.set_device = lambda x: None
|
| 49 |
-
fake_cuda.device = lambda x: types.SimpleNamespace(__enter__=lambda: None, __exit__=lambda *args: None)
|
| 50 |
-
fake_cuda.init = lambda: None
|
| 51 |
-
fake_cuda.is_initialized = lambda: True
|
| 52 |
-
fake_cuda._initialization_lock = types.SimpleNamespace(__enter__=lambda: None, __exit__=lambda *args: None)
|
| 53 |
-
|
| 54 |
-
# Save original
|
| 55 |
-
original_cuda = torch.cuda
|
| 56 |
|
| 57 |
try:
|
| 58 |
-
#
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
with open(pickle_path, 'rb') as f:
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
finally:
|
| 71 |
-
#
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
| 74 |
|
| 75 |
# Success! Model loaded with one of the strategies above
|
| 76 |
# Handle a few common package shapes.
|
|
|
|
| 35 |
if not os.path.exists(pickle_path):
|
| 36 |
return f"❌ Model file not found: {pickle_path}\n\nPlease ensure best_model.pkl is uploaded to the HuggingFace Space."
|
| 37 |
|
| 38 |
+
# METHOD 1: Set environment variable BEFORE any CUDA operations
|
| 39 |
+
# This prevents PyTorch from seeing ANY CUDA devices
|
| 40 |
+
import os as os_module
|
| 41 |
+
old_cuda_visible = os_module.environ.get('CUDA_VISIBLE_DEVICES', None)
|
| 42 |
+
os_module.environ['CUDA_VISIBLE_DEVICES'] = '-1' # Disable all CUDA devices
|
| 43 |
|
| 44 |
+
# Also set other CUDA-disabling flags
|
| 45 |
+
os_module.environ['CUDA_LAUNCH_BLOCKING'] = '0'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
try:
|
| 48 |
+
# METHOD 2: Use pickle with restricted globals to prevent CUDA imports
|
| 49 |
+
import io
|
| 50 |
+
import pickle
|
| 51 |
+
|
| 52 |
+
class CPUOnlyUnpickler(pickle.Unpickler):
|
| 53 |
+
def find_class(self, module, name):
|
| 54 |
+
# Allow transformers and torch imports
|
| 55 |
+
if module.startswith('transformers') or module.startswith('torch'):
|
| 56 |
+
return super().find_class(module, name)
|
| 57 |
+
return super().find_class(module, name)
|
| 58 |
+
|
| 59 |
+
def persistent_load(self, pid):
|
| 60 |
+
# Intercept torch storage and force CPU
|
| 61 |
+
if isinstance(pid, tuple) and len(pid) > 0:
|
| 62 |
+
if pid[0] == 'storage':
|
| 63 |
+
# Format: ('storage', storage_type, key, location, size)
|
| 64 |
+
storage_type = pid[1]
|
| 65 |
+
key = pid[2]
|
| 66 |
+
location = 'cpu' # Force CPU location
|
| 67 |
+
size = pid[4] if len(pid) > 4 else pid[3]
|
| 68 |
+
# Rebuild with CPU location
|
| 69 |
+
return super().persistent_load(('storage', storage_type, key, location, size))
|
| 70 |
+
return super().persistent_load(pid)
|
| 71 |
+
|
| 72 |
+
# Load using our custom unpickler
|
| 73 |
with open(pickle_path, 'rb') as f:
|
| 74 |
+
# First try: Custom unpickler with CUDA disabled
|
| 75 |
+
try:
|
| 76 |
+
model_package = CPUOnlyUnpickler(f).load()
|
| 77 |
+
except Exception as e1:
|
| 78 |
+
# Second try: Standard torch.load with map_location
|
| 79 |
+
f.seek(0) # Reset file pointer
|
| 80 |
+
try:
|
| 81 |
+
model_package = torch.load(
|
| 82 |
+
f,
|
| 83 |
+
map_location=torch.device('cpu'),
|
| 84 |
+
weights_only=False
|
| 85 |
+
)
|
| 86 |
+
except Exception as e2:
|
| 87 |
+
# Third try: Load with pickle directly and extract weights only
|
| 88 |
+
f.seek(0)
|
| 89 |
+
raw_package = pickle.load(f)
|
| 90 |
+
|
| 91 |
+
# Try to extract model from various package formats
|
| 92 |
+
if isinstance(raw_package, dict):
|
| 93 |
+
if 'model' in raw_package:
|
| 94 |
+
model_obj = raw_package['model']
|
| 95 |
+
elif 'state_dict' in raw_package:
|
| 96 |
+
return (f"❌ The pickle contains only state_dict. Please save the full model object.\n\n"
|
| 97 |
+
f"Use: torch.save({{'model': model, 'tokenizer': tokenizer, 'config': config}}, 'file.pkl')")
|
| 98 |
+
else:
|
| 99 |
+
return f"❌ Unknown pickle format. Keys found: {list(raw_package.keys())}"
|
| 100 |
+
|
| 101 |
+
# Move model to CPU recursively
|
| 102 |
+
def recursive_cpu(obj):
|
| 103 |
+
if hasattr(obj, 'cpu'):
|
| 104 |
+
return obj.cpu()
|
| 105 |
+
elif isinstance(obj, dict):
|
| 106 |
+
return {k: recursive_cpu(v) for k, v in obj.items()}
|
| 107 |
+
elif isinstance(obj, (list, tuple)):
|
| 108 |
+
return type(obj)(recursive_cpu(item) for item in obj)
|
| 109 |
+
return obj
|
| 110 |
+
|
| 111 |
+
model_package = {
|
| 112 |
+
'model': recursive_cpu(model_obj) if model_obj else None,
|
| 113 |
+
'tokenizer': raw_package.get('tokenizer'),
|
| 114 |
+
'config': raw_package.get('config', {})
|
| 115 |
+
}
|
| 116 |
+
else:
|
| 117 |
+
# Package is the model itself
|
| 118 |
+
model_package = {
|
| 119 |
+
'model': recursive_cpu(raw_package),
|
| 120 |
+
'tokenizer': None,
|
| 121 |
+
'config': {}
|
| 122 |
+
}
|
| 123 |
|
| 124 |
finally:
|
| 125 |
+
# Restore original CUDA_VISIBLE_DEVICES
|
| 126 |
+
if old_cuda_visible is not None:
|
| 127 |
+
os_module.environ['CUDA_VISIBLE_DEVICES'] = old_cuda_visible
|
| 128 |
+
else:
|
| 129 |
+
os_module.environ.pop('CUDA_VISIBLE_DEVICES', None)
|
| 130 |
|
| 131 |
# Success! Model loaded with one of the strategies above
|
| 132 |
# Handle a few common package shapes.
|