0xZohar commited on
Commit
8526d06
·
verified ·
1 Parent(s): 5cb5de0

Fix: Safe device selection for CPU/GPU compatibility

Browse files

- clip_retrieval.py: Add _resolve_device() with CUDA detection & CPU fallback
- demo.py: Auto-select GPT engine (EngineFast for CUDA, Engine for CPU)
- engine.py: Normalize device_map to string for CLIP text encoder

Resolves: CUDA availability issues on HF Spaces CPU instances
Support: Works on both GPU and CPU tiers

File: code/clip_retrieval.py

Files changed (1) hide show
  1. code/clip_retrieval.py +73 -26
code/clip_retrieval.py CHANGED
@@ -48,9 +48,8 @@ class CLIPRetriever:
48
  self.cache_dir = cache_dir or os.path.join(data_root, "clip_features")
49
  self.model_name = model_name
50
 
51
- # ZeroGPU: Always use cuda (ZeroGPU manages allocation automatically)
52
- # DO NOT check torch.cuda.is_available() as it returns False at startup
53
- self.device = "cuda"
54
 
55
  # State
56
  self.model = None
@@ -62,6 +61,30 @@ class CLIPRetriever:
62
  self._load_cache()
63
  self._load_model()
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  def _load_cache(self):
66
  """Load precomputed features and metadata"""
67
  features_path = os.path.join(self.cache_dir, "features.npy")
@@ -103,29 +126,53 @@ class CLIPRetriever:
103
  print(f"Loading CLIP model: {self.model_name} on {self.device}")
104
  print(f"Cache directory: {HF_CACHE_DIR}")
105
 
106
- try:
107
- self.model = CLIPModel.from_pretrained(
108
- self.model_name,
109
- cache_dir=HF_CACHE_DIR,
110
- use_safetensors=True # Force safetensors to bypass CVE-2025-32434
111
- ).to(self.device)
112
-
113
- self.processor = CLIPProcessor.from_pretrained(
114
- self.model_name,
115
- cache_dir=HF_CACHE_DIR,
116
- use_safetensors=True # Force safetensors to bypass CVE-2025-32434
117
- )
118
-
119
- self.model.eval()
120
- print("✅ CLIP model loaded successfully")
121
-
122
- except Exception as e:
123
- print(f"❌ CLIP model loading failed: {e}")
124
- raise RuntimeError(
125
- f"Failed to load CLIP model from {self.model_name}\n"
126
- f"Cache directory: {HF_CACHE_DIR}\n"
127
- f"Error: {e}"
128
- ) from e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  def _encode_text(self, text: str) -> np.ndarray:
131
  """
 
48
  self.cache_dir = cache_dir or os.path.join(data_root, "clip_features")
49
  self.model_name = model_name
50
 
51
+ # Resolve runtime device with safe CPU fallback (HF Spaces cpu/basic instances)
52
+ self.device = self._resolve_device(device)
 
53
 
54
  # State
55
  self.model = None
 
61
  self._load_cache()
62
  self._load_model()
63
 
64
+ def _resolve_device(self, device_override: Optional[str]) -> str:
65
+ """
66
+ Decide which device to use for the CLIP encoder.
67
+
68
+ Priority:
69
+ 1) Explicit argument
70
+ 2) Environment override: CLIP_DEVICE
71
+ 3) CUDA if available
72
+ 4) CPU fallback (avoids HF Spaces "no NVIDIA driver" failures)
73
+ """
74
+ if device_override:
75
+ return device_override
76
+
77
+ env_device = os.getenv("CLIP_DEVICE")
78
+ if env_device:
79
+ print(f"🔧 Using device from CLIP_DEVICE env: {env_device}")
80
+ return env_device
81
+
82
+ if torch.cuda.is_available():
83
+ return "cuda"
84
+
85
+ print("ℹ️ CUDA not available; defaulting CLIP to CPU")
86
+ return "cpu"
87
+
88
  def _load_cache(self):
89
  """Load precomputed features and metadata"""
90
  features_path = os.path.join(self.cache_dir, "features.npy")
 
126
  print(f"Loading CLIP model: {self.model_name} on {self.device}")
127
  print(f"Cache directory: {HF_CACHE_DIR}")
128
 
129
+ # Try preferred device first, then fall back to CPU if GPU is unavailable
130
+ preferred_device = self.device
131
+ device_attempts = [preferred_device]
132
+ if preferred_device != "cpu":
133
+ device_attempts.append("cpu")
134
+
135
+ last_error = None
136
+
137
+ for target_device in device_attempts:
138
+ try:
139
+ torch_dtype = torch.float16 if target_device.startswith("cuda") else torch.float32
140
+
141
+ model = CLIPModel.from_pretrained(
142
+ self.model_name,
143
+ cache_dir=HF_CACHE_DIR,
144
+ use_safetensors=True, # Force safetensors to bypass CVE-2025-32434
145
+ torch_dtype=torch_dtype,
146
+ ).to(target_device)
147
+
148
+ processor = CLIPProcessor.from_pretrained(
149
+ self.model_name,
150
+ cache_dir=HF_CACHE_DIR,
151
+ use_safetensors=True # Force safetensors to bypass CVE-2025-32434
152
+ )
153
+
154
+ self.model = model
155
+ self.processor = processor
156
+ self.device = target_device
157
+ self.model.eval()
158
+
159
+ if target_device != preferred_device:
160
+ print(f"ℹ️ CLIP loaded on {target_device} (fallback from {preferred_device})")
161
+ else:
162
+ print("✅ CLIP model loaded successfully")
163
+ return
164
+
165
+ except Exception as e:
166
+ last_error = e
167
+ print(f"⚠️ CLIP load failed on {target_device}: {e}")
168
+ continue
169
+
170
+ # If we reach here, all attempts failed
171
+ raise RuntimeError(
172
+ f"Failed to load CLIP model from {self.model_name}\n"
173
+ f"Cache directory: {HF_CACHE_DIR}\n"
174
+ f"Error: {last_error}"
175
+ ) from last_error
176
 
177
  def _encode_text(self, text: str) -> np.ndarray:
178
  """