wenjiao's picture
refactor repo code
b15b21e
"""Queue ETA estimation for the leaderboard.
Provides a slot-simulation based ETA calculation:
ETA(pos) = running_remaining + ⌈pos / concurrencyβŒ‰ Γ— task_duration
where *task_duration* depends on model size:
- ≀ 30B params β†’ 3 hours
- > 30B params β†’ 5 hours
*running_remaining* is approximated as the average estimated task duration
of currently active (Running/Waiting/Triggered) entries, because we don't
know exactly when each one started.
"""
import json
import logging
import math
import os
from datetime import datetime, timedelta
logger = logging.getLogger(__name__)
# ── Task duration estimates (hours) ──────────────────────────────────────
_SMALL_MODEL_HOURS = 3 # ≀ 30B
_LARGE_MODEL_HOURS = 5 # > 30B
_SIZE_THRESHOLD_B = 30.0 # billions of params
def estimate_task_hours(model_params: float | None) -> float:
"""Return estimated task duration in hours based on model size."""
if model_params is None or model_params <= 0:
return _SMALL_MODEL_HOURS # default when unknown
if model_params > _SIZE_THRESHOLD_B:
return _LARGE_MODEL_HOURS
return _SMALL_MODEL_HOURS
def compute_single_eta(
status_path: str,
model_params: float | None,
concurrency: int = 2,
) -> float:
"""Compute ETA for a newly submitted model (it goes to the end of the queue).
Used by submit.py to show the user an immediate estimate.
"""
active_count = 0
pending_count = 0
active_hours_sum = 0.0
if os.path.isdir(status_path):
for root, _dirs, files in os.walk(status_path):
for fname in files:
if not fname.endswith(".json"):
continue
fpath = os.path.join(root, fname)
try:
with open(fpath) as f:
data = json.load(f)
except (json.JSONDecodeError, OSError):
continue
status = data.get("status", "")
script = data.get("script", "")
if status in ("Running", "Waiting", "Triggered"):
active_count += 1
active_hours_sum += estimate_task_hours(_get_params(data))
elif status == "Pending" and script in ("auto_quant", "auto_eval"):
pending_count += 1
# The model's own status file has already been written by _upload_to_hub,
# so it is already included in pending_count.
queue_pos = max(pending_count, 1)
task_hours = estimate_task_hours(model_params)
if active_count > 0:
running_remaining = active_hours_sum / active_count
else:
running_remaining = 0.0
eta = running_remaining + math.ceil(queue_pos / concurrency) * task_hours
return round(eta, 1)
def format_eta(hours: float) -> str:
"""Format ETA hours into a human-readable string like '~6h' or '~1d 2h'."""
if hours <= 0:
return "< 1h"
total_h = int(math.ceil(hours))
if total_h < 24:
return f"~{total_h}h"
days = total_h // 24
remaining_h = total_h % 24
if remaining_h == 0:
return f"~{days}d"
return f"~{days}d {remaining_h}h"
def _get_params(data: dict) -> float | None:
"""Extract model parameter count (in billions) from a status entry."""
for key in ("model_params", "params"):
val = data.get(key)
if val is not None:
try:
return float(val)
except (TypeError, ValueError):
pass
return None