Spaces:
Sleeping
Sleeping
Commit
·
6fb2aa6
1
Parent(s):
2790442
Default to Gemma router and limit prefetch
Browse files
app.py
CHANGED
|
@@ -73,18 +73,20 @@ def _prefetch_repo(repo_id: str) -> None:
|
|
| 73 |
print(f"Prefetch skipped for {repo_id}: {exc}")
|
| 74 |
|
| 75 |
|
| 76 |
-
def _start_prefetch_workers():
|
| 77 |
global PREFETCH_EXECUTOR
|
| 78 |
if PREFETCH_DISABLED or not HF_HUB_AVAILABLE:
|
| 79 |
return
|
| 80 |
if PREFETCH_EXECUTOR is not None:
|
| 81 |
return
|
| 82 |
-
|
|
|
|
|
|
|
| 83 |
PREFETCH_EXECUTOR = ThreadPoolExecutor(max_workers=worker_count, thread_name_prefix="prefetch")
|
| 84 |
submitted = set()
|
| 85 |
-
for model_name
|
| 86 |
-
repos = {
|
| 87 |
-
tokenizer_repo =
|
| 88 |
if tokenizer_repo:
|
| 89 |
repos.add(tokenizer_repo)
|
| 90 |
for repo in repos:
|
|
@@ -95,13 +97,6 @@ def _start_prefetch_workers():
|
|
| 95 |
|
| 96 |
|
| 97 |
MODELS = {
|
| 98 |
-
"Router-Qwen3-32B-AWQ": {
|
| 99 |
-
"repo_id": "Alovestocode/router-qwen3-32b-merged-awq", # AWQ quantized model
|
| 100 |
-
"tokenizer_repo": "Alovestocode/router-qwen3-32b-merged", # Tokenizer from original repo
|
| 101 |
-
"description": "Router checkpoint on Qwen3 32B merged, optimized with AWQ quantization via vLLM.",
|
| 102 |
-
"params_b": 32.0,
|
| 103 |
-
"quantization": "awq", # vLLM will auto-detect AWQ
|
| 104 |
-
},
|
| 105 |
"Router-Gemma3-27B-AWQ": {
|
| 106 |
"repo_id": "Alovestocode/router-gemma3-merged-awq", # AWQ quantized model
|
| 107 |
"tokenizer_repo": "Alovestocode/router-gemma3-merged", # Tokenizer from original repo
|
|
@@ -109,9 +104,38 @@ MODELS = {
|
|
| 109 |
"params_b": 27.0,
|
| 110 |
"quantization": "awq", # vLLM will auto-detect AWQ
|
| 111 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
}
|
| 113 |
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
# Try to import LLM Compressor (for quantization - optional, vLLM has native AWQ support)
|
| 117 |
# Note: llm-compressor is only needed for quantizing models, not for loading pre-quantized AWQ models
|
|
@@ -1070,7 +1094,7 @@ def build_ui():
|
|
| 1070 |
model_choice = gr.Dropdown(
|
| 1071 |
label="Router Checkpoint",
|
| 1072 |
choices=list(MODELS.keys()),
|
| 1073 |
-
value=
|
| 1074 |
allow_custom_value=False,
|
| 1075 |
)
|
| 1076 |
difficulty = gr.Radio(
|
|
@@ -1130,16 +1154,7 @@ def build_ui():
|
|
| 1130 |
|
| 1131 |
|
| 1132 |
def _prefetch_from_env() -> None:
|
| 1133 |
-
|
| 1134 |
-
if entries:
|
| 1135 |
-
names = [item.strip() for item in entries.split(",") if item.strip()]
|
| 1136 |
-
else:
|
| 1137 |
-
single = os.environ.get("ROUTER_PREFETCH_MODEL")
|
| 1138 |
-
names = [single] if single else []
|
| 1139 |
-
|
| 1140 |
-
if names == ["ALL"] or names == ["all"]:
|
| 1141 |
-
names = list(MODELS.keys())
|
| 1142 |
-
|
| 1143 |
for name in names:
|
| 1144 |
if name not in MODELS:
|
| 1145 |
print(f"Prefetch skipped, unknown model: {name}")
|
|
|
|
| 73 |
print(f"Prefetch skipped for {repo_id}: {exc}")
|
| 74 |
|
| 75 |
|
| 76 |
+
def _start_prefetch_workers(model_names: list[str]):
|
| 77 |
global PREFETCH_EXECUTOR
|
| 78 |
if PREFETCH_DISABLED or not HF_HUB_AVAILABLE:
|
| 79 |
return
|
| 80 |
if PREFETCH_EXECUTOR is not None:
|
| 81 |
return
|
| 82 |
+
if not model_names:
|
| 83 |
+
return
|
| 84 |
+
worker_count = max(1, min(PREFETCH_THREADS, len(model_names) * 2))
|
| 85 |
PREFETCH_EXECUTOR = ThreadPoolExecutor(max_workers=worker_count, thread_name_prefix="prefetch")
|
| 86 |
submitted = set()
|
| 87 |
+
for model_name in model_names:
|
| 88 |
+
repos = {MODELS[model_name]["repo_id"]}
|
| 89 |
+
tokenizer_repo = MODELS[model_name].get("tokenizer_repo")
|
| 90 |
if tokenizer_repo:
|
| 91 |
repos.add(tokenizer_repo)
|
| 92 |
for repo in repos:
|
|
|
|
| 97 |
|
| 98 |
|
| 99 |
MODELS = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
"Router-Gemma3-27B-AWQ": {
|
| 101 |
"repo_id": "Alovestocode/router-gemma3-merged-awq", # AWQ quantized model
|
| 102 |
"tokenizer_repo": "Alovestocode/router-gemma3-merged", # Tokenizer from original repo
|
|
|
|
| 104 |
"params_b": 27.0,
|
| 105 |
"quantization": "awq", # vLLM will auto-detect AWQ
|
| 106 |
},
|
| 107 |
+
"Router-Qwen3-32B-AWQ": {
|
| 108 |
+
"repo_id": "Alovestocode/router-qwen3-32b-merged-awq", # AWQ quantized model
|
| 109 |
+
"tokenizer_repo": "Alovestocode/router-qwen3-32b-merged", # Tokenizer from original repo
|
| 110 |
+
"description": "Router checkpoint on Qwen3 32B merged, optimized with AWQ quantization via vLLM.",
|
| 111 |
+
"params_b": 32.0,
|
| 112 |
+
"quantization": "awq", # vLLM will auto-detect AWQ
|
| 113 |
+
},
|
| 114 |
}
|
| 115 |
|
| 116 |
+
DEFAULT_MODEL = os.environ.get("DEFAULT_ROUTER_MODEL", "Router-Gemma3-27B-AWQ")
|
| 117 |
+
if DEFAULT_MODEL not in MODELS:
|
| 118 |
+
DEFAULT_MODEL = next(iter(MODELS)) if MODELS else None
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def _resolve_prefetch_model_names(include_default: bool) -> list[str]:
|
| 122 |
+
entries = os.environ.get("ROUTER_PREFETCH_MODELS")
|
| 123 |
+
if entries:
|
| 124 |
+
names = [item.strip() for item in entries.split(",") if item.strip()]
|
| 125 |
+
else:
|
| 126 |
+
single = os.environ.get("ROUTER_PREFETCH_MODEL")
|
| 127 |
+
names = [single] if single else []
|
| 128 |
+
|
| 129 |
+
if names == ["ALL"] or names == ["all"]:
|
| 130 |
+
names = list(MODELS.keys())
|
| 131 |
+
|
| 132 |
+
valid = [name for name in names if name in MODELS]
|
| 133 |
+
if not valid and include_default and DEFAULT_MODEL:
|
| 134 |
+
valid = [DEFAULT_MODEL]
|
| 135 |
+
return valid
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
_start_prefetch_workers(_resolve_prefetch_model_names(include_default=True))
|
| 139 |
|
| 140 |
# Try to import LLM Compressor (for quantization - optional, vLLM has native AWQ support)
|
| 141 |
# Note: llm-compressor is only needed for quantizing models, not for loading pre-quantized AWQ models
|
|
|
|
| 1094 |
model_choice = gr.Dropdown(
|
| 1095 |
label="Router Checkpoint",
|
| 1096 |
choices=list(MODELS.keys()),
|
| 1097 |
+
value=DEFAULT_MODEL,
|
| 1098 |
allow_custom_value=False,
|
| 1099 |
)
|
| 1100 |
difficulty = gr.Radio(
|
|
|
|
| 1154 |
|
| 1155 |
|
| 1156 |
def _prefetch_from_env() -> None:
|
| 1157 |
+
names = _resolve_prefetch_model_names(include_default=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1158 |
for name in names:
|
| 1159 |
if name not in MODELS:
|
| 1160 |
print(f"Prefetch skipped, unknown model: {name}")
|