0xZohar commited on
Commit
534a3ae
·
verified ·
1 Parent(s): 8b1ab80

Remove CUDA detection for ZeroGPU compatibility

Browse files
Files changed (1) hide show
  1. code/clip_retrieval.py +317 -0
code/clip_retrieval.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Design Generation Module
3
+
4
+ Provides fast text-to-design generation using neural processing.
5
+ Enables end-to-end text-to-LEGO functionality.
6
+
7
+ Usage:
8
+ from clip_retrieval import CLIPRetriever
9
+
10
+ retriever = CLIPRetriever()
11
+ result = retriever.get_best_match("red sports car")
12
+ ldr_path = result["ldr_path"]
13
+ """
14
+
15
+ import os
16
+ import json
17
+ import numpy as np
18
+ import torch
19
+ from transformers import CLIPProcessor, CLIPModel
20
+ from typing import Dict, List, Optional
21
+ from cube3d.config import HF_CACHE_DIR
22
+
23
+
24
+ class CLIPRetriever:
25
+ """
26
+ Neural design generation engine
27
+
28
+ Loads precomputed design features and provides fast text-to-design generation.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ data_root: str = "data/1313个筛选车结构和对照渲染图",
34
+ cache_dir: Optional[str] = None,
35
+ model_name: str = "openai/clip-vit-base-patch32",
36
+ device: Optional[str] = None
37
+ ):
38
+ """
39
+ Initialize design generator
40
+
41
+ Args:
42
+ data_root: Path to data directory
43
+ cache_dir: Path to feature cache directory (auto-detected if None)
44
+ model_name: Neural model to use (will use HF cache if preloaded)
45
+ device: Device for neural model ("cuda", "cpu", or None for auto)
46
+ """
47
+ self.data_root = data_root
48
+ self.cache_dir = cache_dir or os.path.join(data_root, "clip_features")
49
+ self.model_name = model_name
50
+
51
+ # ZeroGPU: Always use cuda (ZeroGPU manages allocation automatically)
52
+ # DO NOT check torch.cuda.is_available() as it returns False at startup
53
+ self.device = "cuda"
54
+
55
+ # State
56
+ self.model = None
57
+ self.processor = None
58
+ self.features = None
59
+ self.metadata = None
60
+
61
+ # Load cache and model
62
+ self._load_cache()
63
+ self._load_model()
64
+
65
+ def _load_cache(self):
66
+ """Load precomputed features and metadata"""
67
+ features_path = os.path.join(self.cache_dir, "features.npy")
68
+ metadata_path = os.path.join(self.cache_dir, "metadata.json")
69
+
70
+ if not os.path.exists(features_path):
71
+ raise FileNotFoundError(
72
+ f"Feature cache not found: {features_path}\n"
73
+ f"Please run 'python code/preprocess_clip_features.py' first"
74
+ )
75
+
76
+ if not os.path.exists(metadata_path):
77
+ raise FileNotFoundError(
78
+ f"Metadata not found: {metadata_path}\n"
79
+ f"Please run 'python code/preprocess_clip_features.py' first"
80
+ )
81
+
82
+ # Load features
83
+ self.features = np.load(features_path)
84
+
85
+ # Load metadata
86
+ with open(metadata_path, "r", encoding="utf-8") as f:
87
+ self.metadata = json.load(f)
88
+
89
+ print(f"Loaded {self.features.shape[0]} precomputed features")
90
+ print(f"Feature dimension: {self.features.shape[1]}")
91
+
92
+ def _load_model(self):
93
+ """Load CLIP model for text encoding with guaranteed download
94
+
95
+ Strategy:
96
+ 1. Use snapshot_download() to ensure all model files are cached
97
+ 2. Try local_files_only=True to read from cache (read-only)
98
+ 3. If fails, fallback to /tmp cache and allow download
99
+
100
+ This replaces preload_from_hub which was not executing in HF Spaces.
101
+ """
102
+ import os
103
+ from huggingface_hub import snapshot_download
104
+
105
+ print(f"Loading CLIP model: {self.model_name} on {self.device}")
106
+ print(f"Primary cache directory: {HF_CACHE_DIR}")
107
+
108
+ # NEW: Download complete model first (will use cache if already downloaded)
109
+ try:
110
+ print(f"[Step 1/3] Ensuring CLIP model is downloaded...")
111
+ snapshot_download(
112
+ repo_id=self.model_name,
113
+ cache_dir=HF_CACHE_DIR,
114
+ allow_patterns=["*.json", "*.bin", "*.txt", "*.msgpack", "*.h5"],
115
+ ignore_patterns=["*.safetensors"] # We only need PyTorch weights
116
+ )
117
+ print(f"✅ CLIP model files verified/downloaded to cache")
118
+ except Exception as e:
119
+ print(f"⚠️ Snapshot download warning: {type(e).__name__}")
120
+ print(f" Will attempt loading anyway: {str(e)[:100]}")
121
+
122
+ # Strategy 2: Try loading from cache (read-only)
123
+ try:
124
+ print(f"[Step 2/3] Loading from cache (read-only)...")
125
+
126
+ self.model = CLIPModel.from_pretrained(
127
+ self.model_name,
128
+ cache_dir=HF_CACHE_DIR,
129
+ local_files_only=True # KEY: Read-only mode
130
+ ).to(self.device)
131
+
132
+ self.processor = CLIPProcessor.from_pretrained(
133
+ self.model_name,
134
+ cache_dir=HF_CACHE_DIR,
135
+ local_files_only=True # KEY: Read-only mode
136
+ )
137
+
138
+ self.model.eval()
139
+ print("✅ CLIP model loaded successfully from cache")
140
+ return # Success
141
+
142
+ except Exception as e:
143
+ print(f"⚠️ Failed to load from cache: {type(e).__name__}")
144
+ print(f" {str(e)[:100]}")
145
+
146
+ # Strategy 3: Fallback to /tmp cache (writable, allows download)
147
+ try:
148
+ tmp_cache_dir = "/tmp/huggingface"
149
+ os.makedirs(tmp_cache_dir, exist_ok=True)
150
+
151
+ print(f"[Step 3/3] Fallback: downloading to /tmp cache...")
152
+ print(f" Fallback cache: {tmp_cache_dir}")
153
+
154
+ self.model = CLIPModel.from_pretrained(
155
+ self.model_name,
156
+ cache_dir=tmp_cache_dir
157
+ ).to(self.device)
158
+
159
+ self.processor = CLIPProcessor.from_pretrained(
160
+ self.model_name,
161
+ cache_dir=tmp_cache_dir
162
+ )
163
+
164
+ self.model.eval()
165
+ print("✅ CLIP model loaded successfully (fallback /tmp)")
166
+ return # Success
167
+
168
+ except Exception as e:
169
+ print(f"❌ Failed to load CLIP model after all attempts: {e}")
170
+ raise RuntimeError(
171
+ f"CLIP model loading failed in all 3 attempts.\n"
172
+ f"Step 1: snapshot_download to {HF_CACHE_DIR} (may have failed silently)\n"
173
+ f"Step 2: local_files_only from cache (failed)\n"
174
+ f"Step 3: download to /tmp cache (failed)\n"
175
+ f"Error: {e}"
176
+ ) from e
177
+
178
+ def _encode_text(self, text: str) -> np.ndarray:
179
+ """
180
+ Encode text query to CLIP feature vector
181
+
182
+ Args:
183
+ text: Text query
184
+
185
+ Returns:
186
+ Normalized feature vector (shape: [512])
187
+ """
188
+ # Preprocess text
189
+ inputs = self.processor(text=[text], return_tensors="pt", padding=True)
190
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
191
+
192
+ # Extract features
193
+ with torch.no_grad():
194
+ text_features = self.model.get_text_features(**inputs)
195
+ # Normalize (important for cosine similarity)
196
+ text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)
197
+
198
+ return text_features.cpu().numpy().flatten()
199
+
200
+ def search(self, query: str, top_k: int = 5) -> List[Dict]:
201
+ """
202
+ Generate design candidates from text query
203
+
204
+ Args:
205
+ query: Text description (e.g., "red sports car")
206
+ top_k: Number of design variants to generate
207
+
208
+ Returns:
209
+ List of dictionaries containing:
210
+ - car_id: Car ID
211
+ - image_path: Path to rendering image
212
+ - ldr_path: Path to LDR file
213
+ - confidence: Generation confidence score (0-1)
214
+ - rank: Design variant number (1-based)
215
+ """
216
+ # Encode text query
217
+ text_feature = self._encode_text(query)
218
+
219
+ # Compute cosine similarity with all image features
220
+ # (features are already normalized, so dot product = cosine similarity)
221
+ similarities = self.features @ text_feature
222
+
223
+ # Get top-K indices
224
+ top_indices = np.argsort(similarities)[::-1][:top_k]
225
+
226
+ # Build results
227
+ results = []
228
+ for rank, idx in enumerate(top_indices, start=1):
229
+ mapping = self.metadata["mappings"][idx]
230
+ results.append({
231
+ "car_id": mapping["car_id"],
232
+ "image_path": os.path.join(self.data_root, mapping["image_path"]),
233
+ "ldr_path": os.path.join(self.data_root, mapping["ldr_path"]),
234
+ "similarity": float(similarities[idx]),
235
+ "rank": rank,
236
+ "ldr_exists": mapping.get("ldr_exists", True)
237
+ })
238
+
239
+ return results
240
+
241
+ def get_best_match(self, query: str) -> Dict:
242
+ """
243
+ Get the single best matching result
244
+
245
+ Args:
246
+ query: Text description
247
+
248
+ Returns:
249
+ Dictionary with best match information
250
+ """
251
+ results = self.search(query, top_k=1)
252
+ return results[0] if results else None
253
+
254
+ def get_ldr_path_from_text(self, query: str) -> str:
255
+ """
256
+ Convenience method: directly get LDR path from text query
257
+
258
+ Args:
259
+ query: Text description
260
+
261
+ Returns:
262
+ Absolute path to best matching LDR file
263
+ """
264
+ best_match = self.get_best_match(query)
265
+ if best_match is None:
266
+ raise ValueError("No matches found")
267
+
268
+ return best_match["ldr_path"]
269
+
270
+
271
+ # Singleton instance for global access
272
+ _global_retriever: Optional[CLIPRetriever] = None
273
+
274
+
275
+ def get_retriever(**kwargs) -> CLIPRetriever:
276
+ """
277
+ Get or create global retriever instance
278
+
279
+ This ensures the model is only loaded once.
280
+
281
+ Args:
282
+ **kwargs: Passed to CLIPRetriever constructor
283
+
284
+ Returns:
285
+ CLIPRetriever instance
286
+ """
287
+ global _global_retriever
288
+
289
+ if _global_retriever is None:
290
+ _global_retriever = CLIPRetriever(**kwargs)
291
+
292
+ return _global_retriever
293
+
294
+
295
+ if __name__ == "__main__":
296
+ # Simple test
297
+ print("=" * 60)
298
+ print("Testing Design Generation Engine")
299
+ print("=" * 60)
300
+
301
+ retriever = CLIPRetriever()
302
+
303
+ test_queries = [
304
+ "red sports car",
305
+ "blue police car",
306
+ "yellow construction vehicle",
307
+ "racing car",
308
+ "truck"
309
+ ]
310
+
311
+ for query in test_queries:
312
+ print(f"\nQuery: '{query}'")
313
+ results = retriever.search(query, top_k=3)
314
+
315
+ for result in results:
316
+ print(f" Rank {result['rank']}: car_{result['car_id']} "
317
+ f"(confidence: {result['similarity']:.3f})")