Alikestocode commited on
Commit
e829b15
·
1 Parent(s): 63c8de5

Parallelize AWQ model prefetching

Browse files
Files changed (1) hide show
  1. app.py +53 -0
app.py CHANGED
@@ -10,6 +10,13 @@ import spaces
10
  import torch
11
  from transformers import AutoTokenizer, TextIteratorStreamer, pipeline
12
  from threading import Thread
 
 
 
 
 
 
 
13
 
14
  # Enable optimizations
15
  torch.backends.cuda.matmul.allow_tf32 = True
@@ -43,6 +50,52 @@ except ImportError:
43
  # Optional flag to disable vLLM (defaults to true on MIG due to device detection instability)
44
  DISABLE_VLLM = os.environ.get("DISABLE_VLLM", "1" if MIG_VISIBLE else "0") == "1"
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  # Try to import LLM Compressor (for quantization - optional, vLLM has native AWQ support)
47
  # Note: llm-compressor is only needed for quantizing models, not for loading pre-quantized AWQ models
48
  # vLLM can load AWQ models natively without llm-compressor
 
10
  import torch
11
  from transformers import AutoTokenizer, TextIteratorStreamer, pipeline
12
  from threading import Thread
13
+ from concurrent.futures import ThreadPoolExecutor
14
+
15
+ try:
16
+ from huggingface_hub import snapshot_download
17
+ HF_HUB_AVAILABLE = True
18
+ except ImportError: # pragma: no cover
19
+ HF_HUB_AVAILABLE = False
20
 
21
  # Enable optimizations
22
  torch.backends.cuda.matmul.allow_tf32 = True
 
50
  # Optional flag to disable vLLM (defaults to true on MIG due to device detection instability)
51
  DISABLE_VLLM = os.environ.get("DISABLE_VLLM", "1" if MIG_VISIBLE else "0") == "1"
52
 
53
+ # ---------------------------------------------------------------------------
54
+ # Parallel prefetch of model weights/tokenizers to reduce first-load latency
55
+ # ---------------------------------------------------------------------------
56
+ PREFETCH_DISABLED = os.environ.get("DISABLE_PREFETCH", "0") == "1"
57
+ PREFETCH_THREADS = int(os.environ.get("PREFETCH_THREADS", "4"))
58
+ PREFETCH_EXECUTOR = None
59
+
60
+
61
+ def _prefetch_repo(repo_id: str) -> None:
62
+ if not HF_HUB_AVAILABLE:
63
+ return
64
+ try:
65
+ snapshot_download(
66
+ repo_id=repo_id,
67
+ etag_timeout=10,
68
+ resume_download=True,
69
+ local_files_only=False,
70
+ )
71
+ print(f"Prefetched repo: {repo_id}")
72
+ except Exception as exc: # pragma: no cover
73
+ print(f"Prefetch skipped for {repo_id}: {exc}")
74
+
75
+
76
+ def _start_prefetch_workers():
77
+ global PREFETCH_EXECUTOR
78
+ if PREFETCH_DISABLED or not HF_HUB_AVAILABLE:
79
+ return
80
+ if PREFETCH_EXECUTOR is not None:
81
+ return
82
+ worker_count = max(1, min(PREFETCH_THREADS, len(MODELS) * 2))
83
+ PREFETCH_EXECUTOR = ThreadPoolExecutor(max_workers=worker_count, thread_name_prefix="prefetch")
84
+ submitted = set()
85
+ for model_name, cfg in MODELS.items():
86
+ repos = {cfg["repo_id"]}
87
+ tokenizer_repo = cfg.get("tokenizer_repo")
88
+ if tokenizer_repo:
89
+ repos.add(tokenizer_repo)
90
+ for repo in repos:
91
+ if repo in submitted:
92
+ continue
93
+ submitted.add(repo)
94
+ PREFETCH_EXECUTOR.submit(_prefetch_repo, repo)
95
+
96
+
97
+ _start_prefetch_workers()
98
+
99
  # Try to import LLM Compressor (for quantization - optional, vLLM has native AWQ support)
100
  # Note: llm-compressor is only needed for quantizing models, not for loading pre-quantized AWQ models
101
  # vLLM can load AWQ models natively without llm-compressor