File size: 14,390 Bytes
61ba51e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
#!/usr/bin/env python3
"""
Pre-validate all cached HuggingFace models to provide detailed feedback.

This script runs once during CI initialization (in prepare_runner.sh) to:
1. Scan snapshots in ~/.cache/huggingface/hub/ (with time/quantity limits)
2. Validate completeness (config/tokenizer/weights)
3. Output detailed failure reasons for debugging

NOTE: This script no longer writes shared validation markers. Each test run
independently validates its cache using per-run markers to avoid cross-runner
cache state pollution.
"""

import glob
import json
import os
import sys
import time
from pathlib import Path

# Add python directory to path to import sglang modules
REPO_ROOT = Path(__file__).parent.parent.parent.parent
sys.path.insert(0, str(REPO_ROOT / "python"))

from sglang.srt.model_loader.ci_weight_validation import (  # noqa: E402
    _validate_diffusion_model,
    validate_cache_with_detailed_reason,
)

# Limits to avoid spending too much time on validation
MAX_VALIDATION_TIME_SECONDS = 300  # Max 5 minutes total


def find_all_hf_snapshots():
    """
    Find all HuggingFace snapshots in cache.

    Returns:
        List of (model_name, snapshot_dir) tuples, sorted by mtime (newest first)
    """
    hf_home = os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
    hub_dir = os.path.join(hf_home, "hub")

    if not os.path.isdir(hub_dir):
        print(f"HF hub directory not found: {hub_dir}")
        return []

    snapshots = []

    # Pattern: models--org--model/snapshots/hash
    for model_dir in glob.glob(os.path.join(hub_dir, "models--*")):
        # Extract model name from directory (models--org--model -> org/model)
        dir_name = os.path.basename(model_dir)
        if not dir_name.startswith("models--"):
            continue

        # models--meta-llama--Llama-2-7b-hf -> meta-llama/Llama-2-7b-hf
        # Handle multi-part names: models--a--b--c -> a/b-c (join parts 1+ with /)
        parts = dir_name.split("--")
        if len(parts) < 3 or parts[0] != "models":
            # Invalid format, skip
            continue
        # Standard format: models--org--repo -> org/repo
        # Extended format: models--org--repo--extra -> org/repo-extra (join with -)
        model_name = parts[1] + "/" + "-".join(parts[2:])

        snapshots_dir = os.path.join(model_dir, "snapshots")
        if not os.path.isdir(snapshots_dir):
            continue

        # Find all snapshot hashes
        for snapshot_hash_dir in os.listdir(snapshots_dir):
            snapshot_path = os.path.join(snapshots_dir, snapshot_hash_dir)
            if os.path.isdir(snapshot_path):
                try:
                    mtime = os.path.getmtime(snapshot_path)
                    snapshots.append((model_name, snapshot_path, mtime))
                except OSError:
                    continue

    # Sort by mtime (newest first) - prioritize recently used models
    snapshots.sort(key=lambda x: x[2], reverse=True)

    # Return without mtime
    return [(name, path) for name, path, _ in snapshots]


def is_transformers_text_model(snapshot_dir):
    """
    Check if a snapshot is a transformers text model.

    Only excludes (returns False) for models with STRONG evidence of being
    diffusers/generation pipelines. Uses conservative heuristics to avoid
    false negatives on multimodal LLMs with tokenizers.

    Args:
        snapshot_dir: Path to snapshot directory

    Returns:
        True if this looks like a transformers text model, False otherwise (N/A)
    """
    # Check for diffusers pipeline markers (strong evidence)
    diffusers_markers = [
        "model_index.json",  # Diffusers pipeline config
        "scheduler",  # Scheduler directory (diffusers)
    ]
    if any(
        os.path.exists(os.path.join(snapshot_dir, marker))
        for marker in diffusers_markers
    ):
        return False

    config_path = os.path.join(snapshot_dir, "config.json")
    if not os.path.exists(config_path):
        # No config.json - likely not a transformers model
        return False

    try:
        with open(config_path, "r", encoding="utf-8") as f:
            config = json.load(f)

        # Check for explicit diffusers/generation model types (conservative keywords)
        model_type = config.get("_class_name") or config.get("model_type")
        if model_type:
            model_type_lower = str(model_type).lower()
            # Only exclude clear diffusion/generation models
            if any(
                keyword in model_type_lower
                for keyword in [
                    "diffusion",
                    "unet",
                    "vae",
                    "controlnet",
                    "stable-diffusion",
                    "latent-diffusion",
                ]
            ):
                return False

        # Check architectures for explicit generation/diffusion classes
        architectures = config.get("architectures", [])
        if architectures:
            arch_str = " ".join(architectures).lower()
            # Conservative: only exclude obvious diffusion/generation architectures
            # Use word boundaries to avoid false positives (e.g., "dit" in "conditional")
            for keyword in [
                "diffusion",
                "unet2d",
                "unet3d",
                "vaedecoder",  # More specific than "vae"
                "vaeencoder",
                "controlnet",
                "autoencoder",
                "ditmodel",  # Diffusion Transformer - use more specific pattern
                "pixart",  # PixArt diffusion model
            ]:
                if keyword in arch_str:
                    return False

        # Check for standalone vision encoder/image processor (no text component)
        # Only if model name explicitly indicates non-text usage
        model_name = config.get("_name_or_path", "").lower()

        if any(
            keyword in model_name
            for keyword in [
                "image-edit-",  # Pure image editing (e.g., Qwen-Image-Edit)
                "-image-editing",
                "dit-",  # DiT generation models
                "pixart-",  # PixArt generation models
            ]
        ):
            # Additional check: does it have tokenizer? If yes, might be multimodal LLM
            has_tokenizer = any(
                os.path.exists(os.path.join(snapshot_dir, fname))
                for fname in ["tokenizer.json", "tokenizer.model", "tiktoken.model"]
            )
            if not has_tokenizer:
                # Image-edit model without tokenizer -> likely pure vision pipeline
                return False

        # Default: assume it's a transformers text/multimodal model
        # Even if it lacks tokenizer, let validation report the actual error
        # (better false positive than false negative for text models)
        return True

    except (json.JSONDecodeError, OSError, KeyError):
        # Can't parse config - assume it's transformers and let validation report failure
        return True


def scan_weight_files(snapshot_dir):
    """
    Scan for weight files in a snapshot.

    Returns:
        List of weight file paths, or empty list if scan fails
    """
    weight_files = []

    # First, look for index files
    index_patterns = ["*.safetensors.index.json", "pytorch_model.bin.index.json"]
    index_files = []
    for pattern in index_patterns:
        index_files.extend(glob.glob(os.path.join(snapshot_dir, pattern)))

    # If we have safetensors index, collect shards from it
    for index_file in index_files:
        if index_file.endswith(".safetensors.index.json"):
            try:
                with open(index_file, "r", encoding="utf-8") as f:
                    index_data = json.load(f)
                weight_map = index_data.get("weight_map", {})
                for weight_file in set(weight_map.values()):
                    weight_path = os.path.join(snapshot_dir, weight_file)
                    if os.path.exists(weight_path):
                        weight_files.append(weight_path)
            except Exception as e:
                print(
                    f"  Warning: Failed to parse index {os.path.basename(index_file)}: {e}"
                )

    # If no index found or no shards from index, do recursive glob
    if not weight_files:
        matched = glob.glob(
            os.path.join(snapshot_dir, "**/*.safetensors"), recursive=True
        )
        MAX_WEIGHT_FILES = 1000
        if len(matched) > MAX_WEIGHT_FILES:
            print(
                f"  Warning: Too many safetensors files ({len(matched)} > {MAX_WEIGHT_FILES})"
            )
            return []

        for f in matched:
            if os.path.exists(f):  # Filter out broken symlinks
                weight_files.append(f)

    return weight_files


def validate_snapshot(model_name, snapshot_dir, weight_files, validated_cache):
    """
    Validate a snapshot and return detailed status.

    Uses in-process cache to avoid duplicate validation within the same run.

    Args:
        model_name: Model identifier
        snapshot_dir: Path to snapshot directory
        weight_files: List of weight files to validate
        validated_cache: Dict to track already-validated snapshots in this run

    Returns:
        Tuple of (result, reason):
        - (True, None) if validation passed
        - (False, reason_str) if validation failed
        - (None, None) if skipped (already validated in this run)
    """
    # Fast path: check in-process cache first
    if snapshot_dir in validated_cache:
        return None, None  # Already validated in this run, skip

    try:
        # Perform validation with detailed reason
        is_complete, reason = validate_cache_with_detailed_reason(
            snapshot_dir=snapshot_dir,
            weight_files=weight_files,
            model_name_or_path=model_name,
        )

        # Cache result to avoid re-validation in this run
        validated_cache[snapshot_dir] = (is_complete, reason)

        return is_complete, reason

    except Exception as e:
        error_msg = f"Validation raised exception: {e}"
        return False, error_msg


def main():
    start_time = time.time()

    print("=" * 70)
    print("CI_OFFLINE: Pre-validating cached HuggingFace models")
    print("=" * 70)
    print(f"Max time: {MAX_VALIDATION_TIME_SECONDS}s")
    print()

    print("Scanning HuggingFace cache for models...")
    snapshots = find_all_hf_snapshots()

    if not snapshots:
        print("No cached models found, skipping validation")
        print("=" * 70)
        return

    print(f"Found {len(snapshots)} snapshot(s) in cache")
    print()

    validated_count = 0
    failed_count = 0
    skipped_count = 0
    processed_count = 0

    # In-process cache to avoid re-validating same snapshot in this run
    validated_cache = {}

    for model_name, snapshot_dir in snapshots:
        # Check time limit
        elapsed = time.time() - start_time
        if elapsed > MAX_VALIDATION_TIME_SECONDS:
            print()
            print(
                f"Time limit reached ({elapsed:.1f}s > {MAX_VALIDATION_TIME_SECONDS}s)"
            )
            print(
                f"Stopping validation, {len(snapshots) - processed_count} snapshots remaining"
            )
            break

        snapshot_hash = os.path.basename(snapshot_dir)
        print(
            f"[{processed_count + 1}/{len(snapshots)}] {model_name} ({snapshot_hash[:8]}...)"
        )
        processed_count += 1

        # Determine model type by checking for model_index.json (diffusers pipeline marker)
        model_index_path = os.path.join(snapshot_dir, "model_index.json")
        is_diffusion_model = os.path.exists(model_index_path)

        if is_diffusion_model:
            # This is a diffusers pipeline - use diffusion validation
            try:
                is_valid, reason = _validate_diffusion_model(snapshot_dir)

                if is_valid:
                    print("  PASS (diffusion) - Cache complete & valid")
                    validated_count += 1
                else:
                    print(f"  FAIL (diffusion) - {reason}")
                    failed_count += 1

            except Exception as e:
                print(f"  FAIL (diffusion) - Validation raised exception: {e}")
                failed_count += 1

            continue

        # Transformers model - use standard validation
        # First check if this looks like a transformers text model
        if not is_transformers_text_model(snapshot_dir):
            # Not a recognized model type, skip
            print(
                "  SKIP (unknown type) - Not a diffusers pipeline or transformers model"
            )
            skipped_count += 1
            continue

        # Scan weight files
        weight_files = scan_weight_files(snapshot_dir)

        if not weight_files:
            print("  SKIP (no weights) - empty or incomplete download")
            skipped_count += 1
            continue

        # Validate
        try:
            result, reason = validate_snapshot(
                model_name, snapshot_dir, weight_files, validated_cache
            )

            if result is True:
                print("  PASS - Cache complete & valid")
                validated_count += 1
            elif result is False:
                # Print detailed failure reason
                if reason:
                    print(f"  FAIL (incomplete) - {reason}")
                else:
                    print("  FAIL (incomplete) - cache validation failed")
                failed_count += 1
            else:  # None (skipped)
                print("  SKIP (already validated in this run)")
                skipped_count += 1

        except Exception as e:
            print(f"  FAIL (error) - Validation raised exception: {e}")
            failed_count += 1

    elapsed_total = time.time() - start_time

    print()
    print("=" * 70)
    print(f"Validation summary (completed in {elapsed_total:.1f}s):")
    print(f"  PASS (complete & valid):      {validated_count}")
    print(f"  FAIL (incomplete/corrupted):  {failed_count}")
    print(f"  SKIP (no weights/duplicate):  {skipped_count}")
    print(f"  Total processed:              {processed_count}/{len(snapshots)}")
    print("=" * 70)


if __name__ == "__main__":
    main()