File size: 6,082 Bytes
6d9770a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""Lightweight logging and per-turn profiling for the advisor runtime.

The numbers here are debug/operations signal only — they are written to logs, never to the
UI. Stage timings are measured by *observing the turn event stream from the main process*, so
they stay correct even when the model itself runs inside a ZeroGPU fork (where a module-global
counter would reset on every call).
"""

from __future__ import annotations

from dataclasses import dataclass, field
import logging
import os
import platform
import sys
import threading
import time
from typing import Any

logger = logging.getLogger("hackathon_advisor")

_counter_lock = threading.Lock()
_messages_processed = 0


def configure_logging() -> None:
    """Attach a stream handler once, honoring ADVISOR_LOG_LEVEL (default INFO)."""
    level_name = os.environ.get("ADVISOR_LOG_LEVEL", "INFO").strip().upper()
    logger.setLevel(getattr(logging, level_name, logging.INFO))
    if not logger.handlers:
        handler = logging.StreamHandler()
        handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s %(message)s"))
        logger.addHandler(handler)
    logger.propagate = False


def next_message_index() -> int:
    """Increment and return the lifetime count of processed advisor messages (main process)."""
    global _messages_processed
    with _counter_lock:
        _messages_processed += 1
        return _messages_processed


def messages_processed() -> int:
    return _messages_processed


def _ms(seconds: float) -> float:
    return round(seconds * 1000.0, 1)


def resource_snapshot() -> dict[str, Any]:
    """Best-effort process resource usage via the stdlib plus torch device memory if torch is
    already imported. Returns whatever could be sampled; never raises."""
    snapshot: dict[str, Any] = {}
    try:
        import resource

        usage = resource.getrusage(resource.RUSAGE_SELF)
        # ru_maxrss is bytes on macOS, kilobytes on Linux.
        divisor = 1024 * 1024 if platform.system() == "Darwin" else 1024
        snapshot["rss_mb"] = round(usage.ru_maxrss / divisor, 1)
        snapshot["cpu_user_s"] = round(usage.ru_utime, 3)
        snapshot["cpu_sys_s"] = round(usage.ru_stime, 3)
    except Exception:  # pragma: no cover - platform dependent
        pass
    snapshot.update(_torch_memory_snapshot())
    return snapshot


def _torch_memory_snapshot() -> dict[str, Any]:
    out: dict[str, Any] = {}
    torch = sys.modules.get("torch")  # do not import torch just to profile
    if torch is None:
        return out
    try:
        if torch.cuda.is_available():
            out["cuda_alloc_mb"] = round(torch.cuda.memory_allocated() / 1e6, 1)
            out["cuda_peak_mb"] = round(torch.cuda.max_memory_allocated() / 1e6, 1)
    except Exception:  # pragma: no cover - device dependent
        pass
    try:
        mps = getattr(torch, "mps", None)
        current = getattr(mps, "current_allocated_memory", None)
        if current is not None:
            out["mps_alloc_mb"] = round(current() / 1e6, 1)
    except Exception:  # pragma: no cover - device dependent
        pass
    return out


@dataclass
class TurnProfiler:
    """Times a single advisor turn by observing its event stream. Drive it by calling
    ``observe(event)`` for every emitted event dict, then ``log_summary()`` when the turn
    ends (in a finally block, so partial turns still get logged)."""

    message_index: int
    compute: str
    backend: str
    device: str = ""
    message_chars: int = 0
    started: float = field(default_factory=time.perf_counter)
    stage_at: dict[str, float] = field(default_factory=dict)
    ended: float | None = None
    tokens: int = 0
    tool_count: int = 0
    fell_back: bool = False
    logged: bool = False

    def log_start(self) -> None:
        logger.info(
            "turn #%d start | compute=%s backend=%s message_chars=%d",
            self.message_index,
            self.compute,
            self.backend,
            self.message_chars,
        )

    def observe(self, event: dict[str, Any]) -> None:
        now = time.perf_counter()
        event_type = event.get("type")
        if event_type == "stage":
            self.stage_at.setdefault(str(event.get("stage")), now)
        elif event_type == "model_progress":
            self.tokens = max(self.tokens, int(event.get("tokens") or 0))
        elif event_type == "tool_event":
            self.tool_count += 1
        elif event_type == "fallback":
            self.fell_back = True
        elif event_type == "done":
            self.ended = now

    def durations(self) -> dict[str, float]:
        end = self.ended if self.ended is not None else time.perf_counter()
        out: dict[str, float] = {"total_ms": _ms(end - self.started)}
        planning = self.stage_at.get("planning")
        running = self.stage_at.get("running_tool")
        writing = self.stage_at.get("writing")
        if planning is not None and running is not None:
            out["decode_ms"] = _ms(running - planning)
        if running is not None and writing is not None:
            out["tools_ms"] = _ms(writing - running)
        if writing is not None:
            out["write_ms"] = _ms(end - writing)
        return out

    def log_summary(self, error: BaseException | None = None) -> None:
        if self.logged:
            return
        self.logged = True
        durations = self.durations()
        timing = " ".join(f"{key}={value}" for key, value in durations.items())
        resources = " ".join(f"{key}={value}" for key, value in resource_snapshot().items())
        status = "error" if error is not None else "done"
        message = (
            f"turn #{self.message_index} {status} | {timing} | "
            f"tokens={self.tokens} tools={self.tool_count} compute={self.compute} "
            f"device={self.device or '?'} backend={self.backend} fallback={self.fell_back} | {resources}"
        )
        if error is not None:
            logger.warning("%s | exception=%r", message, error)
        else:
            logger.info(message)