Spaces:
Running on Zero
Running on Zero
File size: 6,082 Bytes
6d9770a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 | """Lightweight logging and per-turn profiling for the advisor runtime.
The numbers here are debug/operations signal only — they are written to logs, never to the
UI. Stage timings are measured by *observing the turn event stream from the main process*, so
they stay correct even when the model itself runs inside a ZeroGPU fork (where a module-global
counter would reset on every call).
"""
from __future__ import annotations
from dataclasses import dataclass, field
import logging
import os
import platform
import sys
import threading
import time
from typing import Any
logger = logging.getLogger("hackathon_advisor")
_counter_lock = threading.Lock()
_messages_processed = 0
def configure_logging() -> None:
"""Attach a stream handler once, honoring ADVISOR_LOG_LEVEL (default INFO)."""
level_name = os.environ.get("ADVISOR_LOG_LEVEL", "INFO").strip().upper()
logger.setLevel(getattr(logging, level_name, logging.INFO))
if not logger.handlers:
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s %(message)s"))
logger.addHandler(handler)
logger.propagate = False
def next_message_index() -> int:
"""Increment and return the lifetime count of processed advisor messages (main process)."""
global _messages_processed
with _counter_lock:
_messages_processed += 1
return _messages_processed
def messages_processed() -> int:
return _messages_processed
def _ms(seconds: float) -> float:
return round(seconds * 1000.0, 1)
def resource_snapshot() -> dict[str, Any]:
"""Best-effort process resource usage via the stdlib plus torch device memory if torch is
already imported. Returns whatever could be sampled; never raises."""
snapshot: dict[str, Any] = {}
try:
import resource
usage = resource.getrusage(resource.RUSAGE_SELF)
# ru_maxrss is bytes on macOS, kilobytes on Linux.
divisor = 1024 * 1024 if platform.system() == "Darwin" else 1024
snapshot["rss_mb"] = round(usage.ru_maxrss / divisor, 1)
snapshot["cpu_user_s"] = round(usage.ru_utime, 3)
snapshot["cpu_sys_s"] = round(usage.ru_stime, 3)
except Exception: # pragma: no cover - platform dependent
pass
snapshot.update(_torch_memory_snapshot())
return snapshot
def _torch_memory_snapshot() -> dict[str, Any]:
out: dict[str, Any] = {}
torch = sys.modules.get("torch") # do not import torch just to profile
if torch is None:
return out
try:
if torch.cuda.is_available():
out["cuda_alloc_mb"] = round(torch.cuda.memory_allocated() / 1e6, 1)
out["cuda_peak_mb"] = round(torch.cuda.max_memory_allocated() / 1e6, 1)
except Exception: # pragma: no cover - device dependent
pass
try:
mps = getattr(torch, "mps", None)
current = getattr(mps, "current_allocated_memory", None)
if current is not None:
out["mps_alloc_mb"] = round(current() / 1e6, 1)
except Exception: # pragma: no cover - device dependent
pass
return out
@dataclass
class TurnProfiler:
"""Times a single advisor turn by observing its event stream. Drive it by calling
``observe(event)`` for every emitted event dict, then ``log_summary()`` when the turn
ends (in a finally block, so partial turns still get logged)."""
message_index: int
compute: str
backend: str
device: str = ""
message_chars: int = 0
started: float = field(default_factory=time.perf_counter)
stage_at: dict[str, float] = field(default_factory=dict)
ended: float | None = None
tokens: int = 0
tool_count: int = 0
fell_back: bool = False
logged: bool = False
def log_start(self) -> None:
logger.info(
"turn #%d start | compute=%s backend=%s message_chars=%d",
self.message_index,
self.compute,
self.backend,
self.message_chars,
)
def observe(self, event: dict[str, Any]) -> None:
now = time.perf_counter()
event_type = event.get("type")
if event_type == "stage":
self.stage_at.setdefault(str(event.get("stage")), now)
elif event_type == "model_progress":
self.tokens = max(self.tokens, int(event.get("tokens") or 0))
elif event_type == "tool_event":
self.tool_count += 1
elif event_type == "fallback":
self.fell_back = True
elif event_type == "done":
self.ended = now
def durations(self) -> dict[str, float]:
end = self.ended if self.ended is not None else time.perf_counter()
out: dict[str, float] = {"total_ms": _ms(end - self.started)}
planning = self.stage_at.get("planning")
running = self.stage_at.get("running_tool")
writing = self.stage_at.get("writing")
if planning is not None and running is not None:
out["decode_ms"] = _ms(running - planning)
if running is not None and writing is not None:
out["tools_ms"] = _ms(writing - running)
if writing is not None:
out["write_ms"] = _ms(end - writing)
return out
def log_summary(self, error: BaseException | None = None) -> None:
if self.logged:
return
self.logged = True
durations = self.durations()
timing = " ".join(f"{key}={value}" for key, value in durations.items())
resources = " ".join(f"{key}={value}" for key, value in resource_snapshot().items())
status = "error" if error is not None else "done"
message = (
f"turn #{self.message_index} {status} | {timing} | "
f"tokens={self.tokens} tools={self.tool_count} compute={self.compute} "
f"device={self.device or '?'} backend={self.backend} fallback={self.fell_back} | {resources}"
)
if error is not None:
logger.warning("%s | exception=%r", message, error)
else:
logger.info(message)
|