File size: 4,877 Bytes
41c0a9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""
agent_trace.py
──────────────
Logs the 3-step GharScan agentic reasoning chain per inference call.
Traces are periodically pushed to HuggingFace Hub as a public dataset.

Qualifies for:  πŸ“‘ Sharing is Caring badge
Dataset target: <hf_username>/gharscan-agent-traces
"""

import json
import uuid
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
from loguru import logger
from huggingface_hub import HfApi, login
import os


HF_TRACE_REPO   = os.getenv("HF_TRACE_REPO", "ritvik360/gharscan-agent-traces")
LOCAL_TRACE_DIR = Path("/tmp/gharscan_traces")
FLUSH_EVERY_N   = 20          # Push to Hub every N traces to avoid rate-limiting


class TraceSession:
    """Represents a single user inference session with 3 reasoning steps."""

    def __init__(self, session_id: Optional[str] = None):
        self.session_id   = session_id or str(uuid.uuid4())[:8]
        self.started_at   = datetime.now(timezone.utc).isoformat()
        self.steps: list  = []
        self.final_report: Optional[dict] = None
        self.duration_ms: Optional[int]   = None
        self._t0 = time.monotonic()

    def log_step(self, step_name: str, step_input: dict, step_output: dict) -> None:
        self.steps.append({
            "step":       step_name,                  # "classify" | "severity" | "cost"
            "input":      step_input,
            "output":     step_output,
            "elapsed_ms": int((time.monotonic() - self._t0) * 1000),
        })

    def finalize(self, final_report: dict) -> None:
        self.final_report = final_report
        self.duration_ms  = int((time.monotonic() - self._t0) * 1000)

    def to_dict(self) -> dict:
        return {
            "session_id":   self.session_id,
            "started_at":   self.started_at,
            "duration_ms":  self.duration_ms,
            "steps":        self.steps,
            "final_report": self.final_report,
        }


class AgentTraceLogger:
    """
    Manages trace collection and periodic HF Hub uploads.
    Keeps traces in local JSONL buffer; flushes to Hub every FLUSH_EVERY_N calls.
    """

    def __init__(self):
        LOCAL_TRACE_DIR.mkdir(parents=True, exist_ok=True)
        self._buffer_path = LOCAL_TRACE_DIR / "traces_buffer.jsonl"
        self._count       = 0
        self._api         = None
        self._hf_ready    = False
        self._init_hf()

    def _init_hf(self):
        """Try to authenticate with HF Hub. Silently skip if no token."""
        hf_token = os.getenv("HF_TOKEN")
        if hf_token:
            try:
                login(token=hf_token)
                self._api     = HfApi()
                self._hf_ready = True
                # Ensure dataset repo exists
                self._api.create_repo(
                    repo_id=HF_TRACE_REPO,
                    repo_type="dataset",
                    exist_ok=True,
                    private=False,
                )
                logger.info(f"AgentTraceLogger ready β†’ {HF_TRACE_REPO}")
            except Exception as e:
                logger.warning(f"HF trace upload disabled: {e}")
        else:
            logger.warning("HF_TOKEN not set β€” traces saved locally only.")

    def start_trace(self) -> TraceSession:
        return TraceSession()

    def save_trace(self, session: TraceSession) -> None:
        """Append trace to local buffer; flush to HF Hub every N traces."""
        record = json.dumps(session.to_dict(), ensure_ascii=False)
        with open(self._buffer_path, "a", encoding="utf-8") as f:
            f.write(record + "\n")

        self._count += 1
        logger.debug(f"Trace saved [{self._count}]: session={session.session_id}")

        if self._count % FLUSH_EVERY_N == 0 and self._hf_ready:
            self._flush_to_hub()

    def _flush_to_hub(self) -> None:
        """Upload local JSONL buffer to HF Hub dataset repo."""
        try:
            timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
            remote_filename = f"traces_{timestamp}.jsonl"

            self._api.upload_file(
                path_or_fileobj=str(self._buffer_path),
                path_in_repo=f"data/{remote_filename}",
                repo_id=HF_TRACE_REPO,
                repo_type="dataset",
                commit_message=f"Auto-flush: {self._count} traces",
            )

            # Reset local buffer after successful upload
            self._buffer_path.write_text("")
            logger.info(f"Traces flushed to Hub β†’ {HF_TRACE_REPO}/data/{remote_filename}")

        except Exception as e:
            logger.error(f"Trace flush failed: {e}")

    def force_flush(self) -> None:
        """Call manually before Space shutdown to push remaining traces."""
        if self._hf_ready and self._buffer_path.exists():
            self._flush_to_hub()