File size: 4,444 Bytes
a54a5ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""
ZeroGPU quota tracking.

HF Pro gives 1500s (25 min) of H200 time per day. There's no official Python
API to query remaining quota directly — we track it locally per-call.

This module:
- Records GPU time consumed via the `@spaces.GPU` decorated functions
- Provides estimates for upcoming operations
- Resets daily (per-day file on disk)

Note: this is approximate. The authoritative source is HF's quota error if
you go over. Our tracking is for UX (showing "~18/25 min used today").
"""

from __future__ import annotations

import json
import time
from dataclasses import dataclass
from pathlib import Path

from .workspace import WORKSPACE

QUOTA_FILE = WORKSPACE / "quota_log.json"
DAILY_QUOTA_SECONDS = 1500   # 25 minutes for HF Pro
OVERAGE_RATE_PER_SECOND = 1.0 / 600.0   # $1 per 600s (10 min)


@dataclass
class QuotaState:
    date: str                # YYYY-MM-DD (UTC)
    used_seconds: float
    operations: list[dict]   # log of recent operations

    def remaining_seconds(self) -> float:
        return max(0.0, DAILY_QUOTA_SECONDS - self.used_seconds)

    def usage_fraction(self) -> float:
        return min(1.0, self.used_seconds / DAILY_QUOTA_SECONDS)

    def overage_cost_usd(self) -> float:
        overage = max(0.0, self.used_seconds - DAILY_QUOTA_SECONDS)
        return overage * OVERAGE_RATE_PER_SECOND


def _today_utc() -> str:
    return time.strftime("%Y-%m-%d", time.gmtime())


def _load_or_new() -> QuotaState:
    today = _today_utc()
    if QUOTA_FILE.exists():
        try:
            with QUOTA_FILE.open("r") as f:
                data = json.load(f)
            if data.get("date") == today:
                return QuotaState(
                    date=data["date"],
                    used_seconds=float(data["used_seconds"]),
                    operations=data.get("operations", []),
                )
        except (json.JSONDecodeError, KeyError, ValueError):
            pass
    # Fresh day or corrupted file
    return QuotaState(date=today, used_seconds=0.0, operations=[])


def _save(state: QuotaState) -> None:
    with QUOTA_FILE.open("w") as f:
        json.dump(
            {
                "date": state.date,
                "used_seconds": state.used_seconds,
                "operations": state.operations[-50:],  # keep last 50
            },
            f,
            indent=2,
        )


def get_state() -> QuotaState:
    """Get current quota state, refreshed for today."""
    return _load_or_new()


def record_usage(operation: str, seconds: float) -> QuotaState:
    """Record a completed GPU operation. Returns updated state."""
    state = _load_or_new()
    state.used_seconds += seconds
    state.operations.append({
        "op": operation,
        "seconds": round(seconds, 2),
        "timestamp": time.time(),
    })
    _save(state)
    return state


# ---------------------------------------------------------------------------
# Per-operation estimates (used by UI to warn before expensive operations)
# ---------------------------------------------------------------------------

ESTIMATES = {
    # Stage 1
    "generate_trellis2_fast": 30,
    "generate_trellis2_balanced": 60,
    "generate_trellis2_hero": 90,
    "generate_hunyuan3d": 60,
    # Stage 2 — baking (nvdiffrast is fast)
    "bake_normal_2k": 5,
    "bake_normal_4k": 12,
    "bake_albedo": 2,
    "bake_materials": 3,
    "bake_ao_fast": 3,
    "bake_ao_standard": 10,
    "bake_ao_high": 30,
    # Stage 2 — optional
    "inpaint_sdxl": 30,
    # Stage 3
    "auto_rig": 40,
}


def estimate(operation: str) -> int:
    """Get the typical GPU duration for an operation, in seconds.

    Used to:
      - set `@spaces.GPU(duration=N)` correctly
      - show cost warnings in the UI before triggering an operation
    """
    return ESTIMATES.get(operation, 60)


def format_status() -> str:
    """One-line quota summary for the UI status bar."""
    s = get_state()
    used_min = s.used_seconds / 60
    total_min = DAILY_QUOTA_SECONDS / 60
    remaining_min = s.remaining_seconds() / 60

    if s.used_seconds >= DAILY_QUOTA_SECONDS:
        cost = s.overage_cost_usd()
        return f"⚠️ Quota: {used_min:.1f}/{total_min:.0f} min (overage: ${cost:.2f})"
    elif s.usage_fraction() > 0.8:
        return f"⚠️ Quota: {used_min:.1f}/{total_min:.0f} min ({remaining_min:.1f} min left)"
    else:
        return f"Quota: {used_min:.1f}/{total_min:.0f} min H200 today"