Upload server_runtime.py
Browse files- server_runtime.py +43 -15
server_runtime.py
CHANGED
|
@@ -145,6 +145,9 @@ def create_hf_space_app(config: RuntimeConfig) -> FastAPI:
|
|
| 145 |
join_timeout = float(os.getenv("HF_GENERATION_JOIN_TIMEOUT_SECONDS", "180"))
|
| 146 |
max_input_tokens = int(os.getenv("HF_MAX_INPUT_TOKENS", str(config.max_input_tokens)))
|
| 147 |
max_new_tokens_limit = int(os.getenv("HF_MAX_NEW_TOKENS", str(config.max_new_tokens)))
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
base_dir = os.path.dirname(os.path.abspath(__file__))
|
| 150 |
|
|
@@ -366,32 +369,57 @@ def create_hf_space_app(config: RuntimeConfig) -> FastAPI:
|
|
| 366 |
nonlocal model, tokenizer, worker_tasks, max_workers, device
|
| 367 |
|
| 368 |
logger.info("Loading model %s on %s", config.model_name, device)
|
| 369 |
-
tokenizer_kwargs: Dict[str, Any] = {
|
|
|
|
|
|
|
|
|
|
| 370 |
if config.tokenizer_use_fast is not None:
|
| 371 |
tokenizer_kwargs["use_fast"] = config.tokenizer_use_fast
|
| 372 |
-
tokenizer = AutoTokenizer.from_pretrained(config.model_name, **tokenizer_kwargs)
|
| 373 |
model_load_kwargs: Dict[str, Any] = {
|
| 374 |
"trust_remote_code": True,
|
| 375 |
"device_map": "auto" if device == "cuda" else None,
|
|
|
|
| 376 |
}
|
| 377 |
if device == "cuda":
|
| 378 |
model_load_kwargs["dtype"] = "auto"
|
| 379 |
else:
|
| 380 |
model_load_kwargs["torch_dtype"] = torch.float32
|
| 381 |
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
**
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
|
| 396 |
if device != "cuda":
|
| 397 |
model = model.to("cpu")
|
|
|
|
| 145 |
join_timeout = float(os.getenv("HF_GENERATION_JOIN_TIMEOUT_SECONDS", "180"))
|
| 146 |
max_input_tokens = int(os.getenv("HF_MAX_INPUT_TOKENS", str(config.max_input_tokens)))
|
| 147 |
max_new_tokens_limit = int(os.getenv("HF_MAX_NEW_TOKENS", str(config.max_new_tokens)))
|
| 148 |
+
model_load_retries = max(1, int(os.getenv("HF_MODEL_LOAD_RETRIES", "4")))
|
| 149 |
+
model_load_retry_delay = max(1.0, float(os.getenv("HF_MODEL_LOAD_RETRY_DELAY_SECONDS", "8")))
|
| 150 |
+
local_files_only = _is_truthy(os.getenv("HF_LOCAL_FILES_ONLY", "0"))
|
| 151 |
|
| 152 |
base_dir = os.path.dirname(os.path.abspath(__file__))
|
| 153 |
|
|
|
|
| 369 |
nonlocal model, tokenizer, worker_tasks, max_workers, device
|
| 370 |
|
| 371 |
logger.info("Loading model %s on %s", config.model_name, device)
|
| 372 |
+
tokenizer_kwargs: Dict[str, Any] = {
|
| 373 |
+
"trust_remote_code": True,
|
| 374 |
+
"local_files_only": local_files_only,
|
| 375 |
+
}
|
| 376 |
if config.tokenizer_use_fast is not None:
|
| 377 |
tokenizer_kwargs["use_fast"] = config.tokenizer_use_fast
|
|
|
|
| 378 |
model_load_kwargs: Dict[str, Any] = {
|
| 379 |
"trust_remote_code": True,
|
| 380 |
"device_map": "auto" if device == "cuda" else None,
|
| 381 |
+
"local_files_only": local_files_only,
|
| 382 |
}
|
| 383 |
if device == "cuda":
|
| 384 |
model_load_kwargs["dtype"] = "auto"
|
| 385 |
else:
|
| 386 |
model_load_kwargs["torch_dtype"] = torch.float32
|
| 387 |
|
| 388 |
+
last_load_error: Optional[Exception] = None
|
| 389 |
+
for attempt in range(1, model_load_retries + 1):
|
| 390 |
+
try:
|
| 391 |
+
tokenizer = AutoTokenizer.from_pretrained(config.model_name, **tokenizer_kwargs)
|
| 392 |
+
try:
|
| 393 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 394 |
+
config.model_name,
|
| 395 |
+
**model_load_kwargs,
|
| 396 |
+
)
|
| 397 |
+
except TypeError:
|
| 398 |
+
# Backward compatibility for older transformers that do not accept `dtype`.
|
| 399 |
+
if "dtype" in model_load_kwargs:
|
| 400 |
+
model_load_kwargs["torch_dtype"] = model_load_kwargs.pop("dtype")
|
| 401 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 402 |
+
config.model_name,
|
| 403 |
+
**model_load_kwargs,
|
| 404 |
+
)
|
| 405 |
+
break
|
| 406 |
+
except Exception as exc:
|
| 407 |
+
last_load_error = exc
|
| 408 |
+
logger.warning(
|
| 409 |
+
"Model load attempt %d/%d failed: %s",
|
| 410 |
+
attempt,
|
| 411 |
+
model_load_retries,
|
| 412 |
+
str(exc),
|
| 413 |
+
)
|
| 414 |
+
if attempt < model_load_retries:
|
| 415 |
+
await asyncio.sleep(model_load_retry_delay)
|
| 416 |
+
else:
|
| 417 |
+
logger.error(
|
| 418 |
+
"Model loading failed after %d attempts (local_files_only=%s)",
|
| 419 |
+
model_load_retries,
|
| 420 |
+
str(local_files_only),
|
| 421 |
+
)
|
| 422 |
+
raise last_load_error
|
| 423 |
|
| 424 |
if device != "cuda":
|
| 425 |
model = model.to("cpu")
|