0xZohar commited on
Commit
5b6a559
·
verified ·
1 Parent(s): 7a4cc65

Fix CLIP loading: Use /data cache for clip_retrieval.py

Browse files
Files changed (1) hide show
  1. code/clip_retrieval.py +12 -65
code/clip_retrieval.py CHANGED
@@ -90,88 +90,35 @@ class CLIPRetriever:
90
  print(f"Feature dimension: {self.features.shape[1]}")
91
 
92
  def _load_model(self):
93
- """Load CLIP model for text encoding with guaranteed download
94
 
95
- Strategy:
96
- 1. Use snapshot_download() to ensure all model files are cached
97
- 2. Try local_files_only=True to read from cache (read-only)
98
- 3. If fails, fallback to /tmp cache and allow download
99
-
100
- This replaces preload_from_hub which was not executing in HF Spaces.
101
  """
102
- import os
103
- from huggingface_hub import snapshot_download
104
-
105
  print(f"Loading CLIP model: {self.model_name} on {self.device}")
106
- print(f"Primary cache directory: {HF_CACHE_DIR}")
107
-
108
- # NEW: Download complete model first (will use cache if already downloaded)
109
- try:
110
- print(f"[Step 1/3] Ensuring CLIP model is downloaded...")
111
- snapshot_download(
112
- repo_id=self.model_name,
113
- cache_dir=HF_CACHE_DIR,
114
- allow_patterns=["*.json", "*.bin", "*.txt", "*.msgpack", "*.h5"],
115
- ignore_patterns=["*.safetensors"] # We only need PyTorch weights
116
- )
117
- print(f"✅ CLIP model files verified/downloaded to cache")
118
- except Exception as e:
119
- print(f"⚠️ Snapshot download warning: {type(e).__name__}")
120
- print(f" Will attempt loading anyway: {str(e)[:100]}")
121
-
122
- # Strategy 2: Try loading from cache (read-only)
123
- try:
124
- print(f"[Step 2/3] Loading from cache (read-only)...")
125
-
126
- self.model = CLIPModel.from_pretrained(
127
- self.model_name,
128
- cache_dir=HF_CACHE_DIR,
129
- local_files_only=True # KEY: Read-only mode
130
- ).to(self.device)
131
-
132
- self.processor = CLIPProcessor.from_pretrained(
133
- self.model_name,
134
- cache_dir=HF_CACHE_DIR,
135
- local_files_only=True # KEY: Read-only mode
136
- )
137
 
138
- self.model.eval()
139
- print("✅ CLIP model loaded successfully from cache")
140
- return # Success
141
-
142
- except Exception as e:
143
- print(f"⚠️ Failed to load from cache: {type(e).__name__}")
144
- print(f" {str(e)[:100]}")
145
-
146
- # Strategy 3: Fallback to /tmp cache (writable, allows download)
147
  try:
148
- tmp_cache_dir = "/tmp/huggingface"
149
- os.makedirs(tmp_cache_dir, exist_ok=True)
150
-
151
- print(f"[Step 3/3] Fallback: downloading to /tmp cache...")
152
- print(f" Fallback cache: {tmp_cache_dir}")
153
-
154
  self.model = CLIPModel.from_pretrained(
155
  self.model_name,
156
- cache_dir=tmp_cache_dir
157
  ).to(self.device)
158
 
159
  self.processor = CLIPProcessor.from_pretrained(
160
  self.model_name,
161
- cache_dir=tmp_cache_dir
162
  )
163
 
164
  self.model.eval()
165
- print("✅ CLIP model loaded successfully (fallback /tmp)")
166
- return # Success
167
 
168
  except Exception as e:
169
- print(f"❌ Failed to load CLIP model after all attempts: {e}")
170
  raise RuntimeError(
171
- f"CLIP model loading failed in all 3 attempts.\n"
172
- f"Step 1: snapshot_download to {HF_CACHE_DIR} (may have failed silently)\n"
173
- f"Step 2: local_files_only from cache (failed)\n"
174
- f"Step 3: download to /tmp cache (failed)\n"
175
  f"Error: {e}"
176
  ) from e
177
 
 
90
  print(f"Feature dimension: {self.features.shape[1]}")
91
 
92
  def _load_model(self):
93
+ """Load CLIP model using /data persistent cache
94
 
95
+ Simplified loading strategy:
96
+ - Use HF_CACHE_DIR (/data/.huggingface in HF Spaces)
97
+ - Allow automatic download on first use
98
+ - /data is writable and persistent in HF Spaces
 
 
99
  """
 
 
 
100
  print(f"Loading CLIP model: {self.model_name} on {self.device}")
101
+ print(f"Cache directory: {HF_CACHE_DIR}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
 
 
 
 
 
 
 
 
 
103
  try:
 
 
 
 
 
 
104
  self.model = CLIPModel.from_pretrained(
105
  self.model_name,
106
+ cache_dir=HF_CACHE_DIR
107
  ).to(self.device)
108
 
109
  self.processor = CLIPProcessor.from_pretrained(
110
  self.model_name,
111
+ cache_dir=HF_CACHE_DIR
112
  )
113
 
114
  self.model.eval()
115
+ print("✅ CLIP model loaded successfully")
 
116
 
117
  except Exception as e:
118
+ print(f"❌ CLIP model loading failed: {e}")
119
  raise RuntimeError(
120
+ f"Failed to load CLIP model from {self.model_name}\n"
121
+ f"Cache directory: {HF_CACHE_DIR}\n"
 
 
122
  f"Error: {e}"
123
  ) from e
124