File size: 10,512 Bytes
534a3ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8526d06
 
534a3ae
 
 
 
 
 
 
 
 
 
 
8526d06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
534a3ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b6a559
534a3ae
5b6a559
 
 
 
534a3ae
b501d11
 
 
534a3ae
5b6a559
534a3ae
8526d06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
edbbf00
 
 
 
8526d06
5cea805
8526d06
 
 
 
 
edbbf00
5cea805
8526d06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
534a3ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
"""
Design Generation Module

Provides fast text-to-design generation using neural processing.
Enables end-to-end text-to-LEGO functionality.

Usage:
    from clip_retrieval import CLIPRetriever

    retriever = CLIPRetriever()
    result = retriever.get_best_match("red sports car")
    ldr_path = result["ldr_path"]
"""

import os
import json
import numpy as np
import torch
from transformers import CLIPProcessor, CLIPModel
from typing import Dict, List, Optional
from cube3d.config import HF_CACHE_DIR


class CLIPRetriever:
    """
    Neural design generation engine

    Loads precomputed design features and provides fast text-to-design generation.
    """

    def __init__(
        self,
        data_root: str = "data/1313个筛选车结构和对照渲染图",
        cache_dir: Optional[str] = None,
        model_name: str = "openai/clip-vit-base-patch32",
        device: Optional[str] = None
    ):
        """
        Initialize design generator

        Args:
            data_root: Path to data directory
            cache_dir: Path to feature cache directory (auto-detected if None)
            model_name: Neural model to use (will use HF cache if preloaded)
            device: Device for neural model ("cuda", "cpu", or None for auto)
        """
        self.data_root = data_root
        self.cache_dir = cache_dir or os.path.join(data_root, "clip_features")
        self.model_name = model_name

        # Resolve runtime device with safe CPU fallback (HF Spaces cpu/basic instances)
        self.device = self._resolve_device(device)

        # State
        self.model = None
        self.processor = None
        self.features = None
        self.metadata = None

        # Load cache and model
        self._load_cache()
        self._load_model()

    def _resolve_device(self, device_override: Optional[str]) -> str:
        """
        Decide which device to use for the CLIP encoder.

        Priority:
        1) Explicit argument
        2) Environment override: CLIP_DEVICE
        3) CUDA if available
        4) CPU fallback (avoids HF Spaces "no NVIDIA driver" failures)
        """
        if device_override:
            return device_override

        env_device = os.getenv("CLIP_DEVICE")
        if env_device:
            print(f"🔧 Using device from CLIP_DEVICE env: {env_device}")
            return env_device

        if torch.cuda.is_available():
            return "cuda"

        print("ℹ️  CUDA not available; defaulting CLIP to CPU")
        return "cpu"

    def _load_cache(self):
        """Load precomputed features and metadata"""
        features_path = os.path.join(self.cache_dir, "features.npy")
        metadata_path = os.path.join(self.cache_dir, "metadata.json")

        if not os.path.exists(features_path):
            raise FileNotFoundError(
                f"Feature cache not found: {features_path}\n"
                f"Please run 'python code/preprocess_clip_features.py' first"
            )

        if not os.path.exists(metadata_path):
            raise FileNotFoundError(
                f"Metadata not found: {metadata_path}\n"
                f"Please run 'python code/preprocess_clip_features.py' first"
            )

        # Load features
        self.features = np.load(features_path)

        # Load metadata
        with open(metadata_path, "r", encoding="utf-8") as f:
            self.metadata = json.load(f)

        print(f"Loaded {self.features.shape[0]} precomputed features")
        print(f"Feature dimension: {self.features.shape[1]}")

    def _load_model(self):
        """Load CLIP model using /data persistent cache

        Simplified loading strategy:
        - Use HF_CACHE_DIR (/data/.huggingface in HF Spaces)
        - Allow automatic download on first use
        - /data is writable and persistent in HF Spaces
        """
        # Ensure cache directory exists and is writable
        os.makedirs(HF_CACHE_DIR, exist_ok=True)

        print(f"Loading CLIP model: {self.model_name} on {self.device}")
        print(f"Cache directory: {HF_CACHE_DIR}")

        # Try preferred device first, then fall back to CPU if GPU is unavailable
        preferred_device = self.device
        device_attempts = [preferred_device]
        if preferred_device != "cpu":
            device_attempts.append("cpu")

        last_error = None

        for target_device in device_attempts:
            try:
                torch_dtype = torch.float16 if target_device.startswith("cuda") else torch.float32

                model = CLIPModel.from_pretrained(
                    self.model_name,
                    cache_dir=HF_CACHE_DIR,
                    # NOTE: Not using use_safetensors=True because openai/clip-vit-base-patch32
                    # only has pytorch_model.bin in main branch (model.safetensors exists in
                    # revision d15b5f2 but not merged). Using pytorch_model.bin is safe for
                    # official OpenAI model with local_files_only=True (prevents malicious replacements)
                    torch_dtype=torch_dtype,
                    local_files_only=True  # Use pre-downloaded model from build
                ).to(target_device)

                processor = CLIPProcessor.from_pretrained(
                    self.model_name,
                    cache_dir=HF_CACHE_DIR,
                    # Processor doesn't have weight files, use_safetensors not applicable
                    local_files_only=True  # Use pre-downloaded model from build
                )

                self.model = model
                self.processor = processor
                self.device = target_device
                self.model.eval()

                if target_device != preferred_device:
                    print(f"ℹ️  CLIP loaded on {target_device} (fallback from {preferred_device})")
                else:
                    print("✅ CLIP model loaded successfully")
                return

            except Exception as e:
                last_error = e
                print(f"⚠️  CLIP load failed on {target_device}: {e}")
                continue

        # If we reach here, all attempts failed
        raise RuntimeError(
            f"Failed to load CLIP model from {self.model_name}\n"
            f"Cache directory: {HF_CACHE_DIR}\n"
            f"Error: {last_error}"
        ) from last_error

    def _encode_text(self, text: str) -> np.ndarray:
        """
        Encode text query to CLIP feature vector

        Args:
            text: Text query

        Returns:
            Normalized feature vector (shape: [512])
        """
        # Preprocess text
        inputs = self.processor(text=[text], return_tensors="pt", padding=True)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        # Extract features
        with torch.no_grad():
            text_features = self.model.get_text_features(**inputs)
            # Normalize (important for cosine similarity)
            text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)

        return text_features.cpu().numpy().flatten()

    def search(self, query: str, top_k: int = 5) -> List[Dict]:
        """
        Generate design candidates from text query

        Args:
            query: Text description (e.g., "red sports car")
            top_k: Number of design variants to generate

        Returns:
            List of dictionaries containing:
                - car_id: Car ID
                - image_path: Path to rendering image
                - ldr_path: Path to LDR file
                - confidence: Generation confidence score (0-1)
                - rank: Design variant number (1-based)
        """
        # Encode text query
        text_feature = self._encode_text(query)

        # Compute cosine similarity with all image features
        # (features are already normalized, so dot product = cosine similarity)
        similarities = self.features @ text_feature

        # Get top-K indices
        top_indices = np.argsort(similarities)[::-1][:top_k]

        # Build results
        results = []
        for rank, idx in enumerate(top_indices, start=1):
            mapping = self.metadata["mappings"][idx]
            results.append({
                "car_id": mapping["car_id"],
                "image_path": os.path.join(self.data_root, mapping["image_path"]),
                "ldr_path": os.path.join(self.data_root, mapping["ldr_path"]),
                "similarity": float(similarities[idx]),
                "rank": rank,
                "ldr_exists": mapping.get("ldr_exists", True)
            })

        return results

    def get_best_match(self, query: str) -> Dict:
        """
        Get the single best matching result

        Args:
            query: Text description

        Returns:
            Dictionary with best match information
        """
        results = self.search(query, top_k=1)
        return results[0] if results else None

    def get_ldr_path_from_text(self, query: str) -> str:
        """
        Convenience method: directly get LDR path from text query

        Args:
            query: Text description

        Returns:
            Absolute path to best matching LDR file
        """
        best_match = self.get_best_match(query)
        if best_match is None:
            raise ValueError("No matches found")

        return best_match["ldr_path"]


# Singleton instance for global access
_global_retriever: Optional[CLIPRetriever] = None


def get_retriever(**kwargs) -> CLIPRetriever:
    """
    Get or create global retriever instance

    This ensures the model is only loaded once.

    Args:
        **kwargs: Passed to CLIPRetriever constructor

    Returns:
        CLIPRetriever instance
    """
    global _global_retriever

    if _global_retriever is None:
        _global_retriever = CLIPRetriever(**kwargs)

    return _global_retriever


if __name__ == "__main__":
    # Simple test
    print("=" * 60)
    print("Testing Design Generation Engine")
    print("=" * 60)

    retriever = CLIPRetriever()

    test_queries = [
        "red sports car",
        "blue police car",
        "yellow construction vehicle",
        "racing car",
        "truck"
    ]

    for query in test_queries:
        print(f"\nQuery: '{query}'")
        results = retriever.search(query, top_k=3)

        for result in results:
            print(f"  Rank {result['rank']}: car_{result['car_id']} "
                  f"(confidence: {result['similarity']:.3f})")