File size: 3,610 Bytes
3a08af1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a09ebea
 
 
3a08af1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""Queue ETA estimation for the leaderboard.

Provides a slot-simulation based ETA calculation:

    ETA(pos) = running_remaining + ⌈pos / concurrencyβŒ‰ Γ— task_duration

where *task_duration* depends on model size:
    - ≀ 30B params β†’ 3 hours
    - >  30B params β†’ 5 hours

*running_remaining* is approximated as the average estimated task duration
of currently active (Running/Waiting/Triggered) entries, because we don't
know exactly when each one started.
"""

import json
import logging
import math
import os
from datetime import datetime, timedelta

logger = logging.getLogger(__name__)

# ── Task duration estimates (hours) ──────────────────────────────────────
_SMALL_MODEL_HOURS = 3    # ≀ 30B
_LARGE_MODEL_HOURS = 5    # > 30B
_SIZE_THRESHOLD_B = 30.0  # billions of params


def estimate_task_hours(model_params: float | None) -> float:
    """Return estimated task duration in hours based on model size."""
    if model_params is None or model_params <= 0:
        return _SMALL_MODEL_HOURS  # default when unknown
    if model_params > _SIZE_THRESHOLD_B:
        return _LARGE_MODEL_HOURS
    return _SMALL_MODEL_HOURS


def compute_single_eta(
    status_path: str,
    model_params: float | None,
    concurrency: int = 2,
) -> float:
    """Compute ETA for a newly submitted model (it goes to the end of the queue).

    Used by submit.py to show the user an immediate estimate.
    """
    active_count = 0
    pending_count = 0
    active_hours_sum = 0.0

    if os.path.isdir(status_path):
        for root, _dirs, files in os.walk(status_path):
            for fname in files:
                if not fname.endswith(".json"):
                    continue
                fpath = os.path.join(root, fname)
                try:
                    with open(fpath) as f:
                        data = json.load(f)
                except (json.JSONDecodeError, OSError):
                    continue

                status = data.get("status", "")
                script = data.get("script", "")

                if status in ("Running", "Waiting", "Triggered"):
                    active_count += 1
                    active_hours_sum += estimate_task_hours(_get_params(data))
                elif status == "Pending" and script in ("auto_quant", "auto_eval"):
                    pending_count += 1

    # The model's own status file has already been written by _upload_to_hub,
    # so it is already included in pending_count.
    queue_pos = max(pending_count, 1)
    task_hours = estimate_task_hours(model_params)

    if active_count > 0:
        running_remaining = active_hours_sum / active_count
    else:
        running_remaining = 0.0

    eta = running_remaining + math.ceil(queue_pos / concurrency) * task_hours
    return round(eta, 1)


def format_eta(hours: float) -> str:
    """Format ETA hours into a human-readable string like '~6h' or '~1d 2h'."""
    if hours <= 0:
        return "< 1h"
    total_h = int(math.ceil(hours))
    if total_h < 24:
        return f"~{total_h}h"
    days = total_h // 24
    remaining_h = total_h % 24
    if remaining_h == 0:
        return f"~{days}d"
    return f"~{days}d {remaining_h}h"


def _get_params(data: dict) -> float | None:
    """Extract model parameter count (in billions) from a status entry."""
    for key in ("model_params", "params"):
        val = data.get(key)
        if val is not None:
            try:
                return float(val)
            except (TypeError, ValueError):
                pass
    return None