#!/usr/bin/env python3
"""
Minimal web viewer for cached context samples.
Loads one cache sample with the same torch.load path used by dump_cache_sample.py
and renders:
- top-level debug metadata
- cached Chart_Segment line view
- quant window boundaries
- per-window level/pattern summaries
- full per-window feature maps
Usage:
/venv/main/bin/python scripts/cache_debug_web.py
/venv/main/bin/python scripts/cache_debug_web.py --cache_dir data/cache --port 8765
/venv/main/bin/python scripts/cache_debug_web.py --file data/cache/sample_ABC.pt
"""
import argparse
import json
import os
import random
import sys
from collections import Counter
from datetime import datetime
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from pathlib import Path
from typing import Any, Dict, List, Optional
from urllib.parse import parse_qs, urlparse
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import numpy as np
import torch
import pandas as pd
from data.quant_ohlc_feature_schema import FEATURE_NAMES
from signals.support_resistance import compute_support_resistance_debug, compute_support_resistance_features
from signals.trendlines import _fit_with_trendln
from ta.trend import ema_indicator, sma_indicator
def _load_cache_sample(path: Path) -> Dict[str, Any]:
return torch.load(path, map_location="cpu", weights_only=False)
def _safe_float(value: Any) -> float:
try:
return float(value)
except Exception:
return 0.0
def _feature_map(window: Dict[str, Any]) -> Dict[str, float]:
vector = window.get("feature_vector", [])
if not isinstance(vector, list):
return {}
return {
name: _safe_float(vector[idx]) if idx < len(vector) else 0.0
for idx, name in enumerate(FEATURE_NAMES)
}
def _chart_points(chart_event: Dict[str, Any]) -> List[Dict[str, float]]:
raw_opens = chart_event.get("raw_opens")
raw_closes = chart_event.get("raw_closes")
if isinstance(raw_opens, list) and isinstance(raw_closes, list) and raw_closes:
opens = raw_opens
closes = raw_closes
else:
opens_logged = chart_event.get("opens", []) or []
closes_logged = chart_event.get("closes", []) or []
opens = [float(np.exp(v)) for v in opens_logged]
closes = [float(np.exp(v)) for v in closes_logged]
end_ts = int(chart_event.get("timestamp", 0) or 0)
if not closes:
return []
interval_str = str(chart_event.get("i", "1s"))
try:
interval_seconds = max(1, int(interval_str.rstrip("s")))
except Exception:
interval_seconds = 1
start_ts = end_ts - interval_seconds * (len(closes) - 1)
points: List[Dict[str, float]] = []
for idx, (open_value, close_value) in enumerate(zip(opens, closes)):
ts = start_ts + idx * interval_seconds
high_value = max(open_value, close_value)
low_value = min(open_value, close_value)
points.append({
"time": int(ts),
"open": _safe_float(open_value),
"high": _safe_float(high_value),
"low": _safe_float(low_value),
"close": _safe_float(close_value),
"index": idx,
})
return points
def _compute_level_overlays(points: List[Dict[str, float]], windows: List[Dict[str, Any]]) -> Dict[str, Any]:
del windows
if not points:
return {"support_levels": [], "resistance_levels": []}
closes = [_safe_float(p["close"]) for p in points]
highs = [_safe_float(p["high"]) for p in points]
lows = [_safe_float(p["low"]) for p in points]
timestamps = [int(p["time"]) for p in points]
debug = compute_support_resistance_debug(
closes=closes,
highs=highs,
lows=lows,
timestamps=timestamps,
)
support_levels = debug.get("support_levels", []) or debug.get("all_support_levels", [])
resistance_levels = debug.get("resistance_levels", []) or debug.get("all_resistance_levels", [])
return {
"support_levels": support_levels,
"resistance_levels": resistance_levels,
"all_support_levels": debug.get("all_support_levels", []),
"all_resistance_levels": debug.get("all_resistance_levels", []),
"sr_available": debug.get("sr_available", 0.0),
}
def _compute_trendline_overlays(points: List[Dict[str, float]]) -> List[Dict[str, Any]]:
if len(points) < 5:
return []
closes = np.asarray([p["close"] for p in points], dtype=np.float64)
highs = np.asarray([p["high"] for p in points], dtype=np.float64)
lows = np.asarray([p["low"] for p in points], dtype=np.float64)
out: List[Dict[str, Any]] = []
try:
lower_line, upper_line = _fit_with_trendln(closes)
except Exception:
lower_line, upper_line = None, None
if lower_line is None:
try:
lower_line, _ = _fit_with_trendln(lows)
except Exception:
lower_line = None
if upper_line is None:
try:
_, upper_line = _fit_with_trendln(highs)
except Exception:
upper_line = None
def _line_payload(name: str, line: Any, color: str) -> Optional[Dict[str, Any]]:
if line is None:
return None
slope, intercept = line
x0, x1 = 0, len(points) - 1
y0 = slope * x0 + intercept
y1 = slope * x1 + intercept
return {
"name": name,
"color": color,
"points": [
{"time": points[x0]["time"], "value": _safe_float(y0)},
{"time": points[x1]["time"], "value": _safe_float(y1)},
],
}
lower_payload = _line_payload("lower_trendline", lower_line, "#0f766e")
upper_payload = _line_payload("upper_trendline", upper_line, "#b91c1c")
if lower_payload:
out.append(lower_payload)
if upper_payload:
out.append(upper_payload)
return out
def _compute_window_boundaries(points: List[Dict[str, float]], windows: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
if not points:
return []
min_value = min(point["low"] for point in points)
max_value = max(point["high"] for point in points)
out: List[Dict[str, Any]] = []
for idx, window in enumerate(windows):
end_ts = int(window.get("end_ts", 0) or 0)
breakout_active = any(
_safe_float((window.get("keylevel_flags", {}) or {}).get(flag, 0.0)) > 0.0
for flag in ("breakout_up", "breakout_down", "flip_to_support", "flip_to_resistance")
)
out.append({
"name": f"window_{idx}",
"window_idx": idx,
"color": "#8b1e3f" if breakout_active else "#c84c2d",
"points": [
{"time": end_ts, "value": _safe_float(min_value)},
{"time": end_ts, "value": _safe_float(max_value)},
],
})
return out
def _compute_indicator_overlays(points: List[Dict[str, float]]) -> List[Dict[str, Any]]:
if not points:
return []
closes = pd.Series([_safe_float(point["close"]) for point in points], dtype="float64")
ema_fast = ema_indicator(closes, window=8, fillna=True)
ema_medium = ema_indicator(closes, window=21, fillna=True)
sma_fast = sma_indicator(closes, window=8, fillna=True)
sma_medium = sma_indicator(closes, window=21, fillna=True)
def _series_payload(name: str, series: pd.Series, color: str, dash: str = "solid") -> Dict[str, Any]:
return {
"name": name,
"color": color,
"dash": dash,
"points": [
{"time": points[idx]["time"], "value": _safe_float(value)}
for idx, value in enumerate(series.tolist())
],
}
return [
_series_payload("ema_fast_8", ema_fast, "#2563eb"),
_series_payload("ema_medium_21", ema_medium, "#7c3aed"),
_series_payload("sma_fast_8", sma_fast, "#ea580c", "dot"),
_series_payload("sma_medium_21", sma_medium, "#0891b2", "dot"),
]
def _recompute_window_keylevel_flags(
points: List[Dict[str, float]],
windows: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
if not points or not windows:
return windows
closes = [_safe_float(point["close"]) for point in points]
highs = [_safe_float(point["high"]) for point in points]
lows = [_safe_float(point["low"]) for point in points]
timestamps = [int(point["time"]) for point in points]
time_to_idx = {timestamp: idx for idx, timestamp in enumerate(timestamps)}
updated: List[Dict[str, Any]] = []
for window in windows:
start_ts = int(window.get("start_ts", 0) or 0)
end_ts = int(window.get("end_ts", 0) or 0)
if end_ts not in time_to_idx:
updated.append(window)
continue
end_idx = time_to_idx[end_ts]
start_idx = time_to_idx.get(start_ts, max(0, end_idx))
if start_idx > end_idx:
start_idx = end_idx
sr_features = compute_support_resistance_features(
closes=closes,
highs=highs,
lows=lows,
end_idx=end_idx,
window_start=start_idx,
window_end=end_idx + 1,
timestamps=timestamps,
)
keylevel_flags = {
"breakout_up": sr_features.get("keylevel_breakout_up", 0.0),
"breakout_down": sr_features.get("keylevel_breakout_down", 0.0),
"hold_above": sr_features.get("keylevel_hold_above", 0.0),
"hold_below": sr_features.get("keylevel_hold_below", 0.0),
"failed_breakout_up": sr_features.get("keylevel_failed_breakout_up", 0.0),
"failed_breakout_down": sr_features.get("keylevel_failed_breakout_down", 0.0),
"flip_to_support": sr_features.get("keylevel_flip_to_support", 0.0),
"flip_to_resistance": sr_features.get("keylevel_flip_to_resistance", 0.0),
}
top_signal_name = "none"
for signal_name in (
"breakout_up",
"breakout_down",
"flip_to_support",
"flip_to_resistance",
"failed_breakout_up",
"failed_breakout_down",
):
if _safe_float(keylevel_flags.get(signal_name, 0.0)) > 0.0:
top_signal_name = signal_name
break
updated_window = dict(window)
updated_window["keylevel_flags"] = keylevel_flags
updated_window["top_signal_name"] = top_signal_name
updated.append(updated_window)
return updated
def _compute_keylevel_signal_overlays(
points: List[Dict[str, float]],
windows: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
if not points or not windows:
return []
time_to_point = {int(point["time"]): point for point in points}
signal_specs = {
"breakout_up": {"color": "#15803d", "symbol": "triangle-up", "y_key": "high"},
"breakout_down": {"color": "#b91c1c", "symbol": "triangle-down", "y_key": "low"},
"flip_to_support": {"color": "#1d4ed8", "symbol": "diamond", "y_key": "close"},
"flip_to_resistance": {"color": "#7c2d12", "symbol": "diamond", "y_key": "close"},
"failed_breakout_up": {"color": "#ea580c", "symbol": "x", "y_key": "high"},
"failed_breakout_down": {"color": "#9333ea", "symbol": "x", "y_key": "low"},
}
overlays: List[Dict[str, Any]] = []
for window in windows:
end_ts = int(window.get("end_ts", 0) or 0)
point = time_to_point.get(end_ts)
if point is None:
continue
flags = window.get("keylevel_flags", {}) or {}
for signal_name, spec in signal_specs.items():
if _safe_float(flags.get(signal_name, 0.0)) <= 0.0:
continue
y_value = _safe_float(point.get(spec["y_key"], point["close"]))
if spec["y_key"] == "high":
y_value *= 1.003
elif spec["y_key"] == "low":
y_value *= 0.997
overlays.append({
"name": signal_name,
"time": end_ts,
"value": y_value,
"color": spec["color"],
"symbol": spec["symbol"],
"window_idx": int(window.get("idx", -1)),
})
return overlays
def _sample_to_payload(sample: Dict[str, Any], source_file: Path) -> Dict[str, Any]:
event_sequence = sample.get("event_sequence", [])
event_counts = Counter(event.get("event_type", "Unknown") for event in event_sequence)
chart_events = [event for event in event_sequence if event.get("event_type") == "Chart_Segment"]
chart_event = chart_events[0] if chart_events else {}
quant_windows = chart_event.get("quant_ohlc_features", []) or []
windows_payload: List[Dict[str, Any]] = []
for idx, window in enumerate(quant_windows):
feature_map = _feature_map(window)
windows_payload.append({
"idx": idx,
"start_ts": window.get("start_ts"),
"end_ts": window.get("end_ts"),
"window_seconds": window.get("window_seconds"),
"level_snapshot": window.get("level_snapshot", {}) or {},
"keylevel_flags": window.get("keylevel_flags", {}) or {},
"top_signal_name": "none",
"feature_map": feature_map,
"sr_available": _safe_float(feature_map.get("sr_available", 0.0)),
"trendline_available": _safe_float(feature_map.get("trendline_available", 0.0)),
})
chart_points = _chart_points(chart_event) if chart_event else []
windows_payload = _recompute_window_keylevel_flags(chart_points, windows_payload)
level_overlays = _compute_level_overlays(chart_points, windows_payload)
trendline_overlays = _compute_trendline_overlays(chart_points)
boundary_overlays = _compute_window_boundaries(chart_points, windows_payload)
indicator_overlays = _compute_indicator_overlays(chart_points)
signal_overlays = _compute_keylevel_signal_overlays(chart_points, windows_payload)
return {
"source_file": str(source_file),
"sample": {
"token_address": sample.get("token_address"),
"source_token": sample.get("source_token"),
"sample_idx": sample.get("sample_idx"),
"class_id": sample.get("class_id"),
"context_bucket": sample.get("context_bucket"),
"context_score": sample.get("context_score"),
"quality_score": _safe_float(sample.get("quality_score", 0.0)),
"t_cutoff": sample.get("t_cutoff"),
"labels": sample.get("labels").tolist() if hasattr(sample.get("labels"), "tolist") else sample.get("labels"),
"labels_mask": sample.get("labels_mask").tolist() if hasattr(sample.get("labels_mask"), "tolist") else sample.get("labels_mask"),
"event_counts": dict(event_counts),
"n_events": len(event_sequence),
"n_wallets": len(sample.get("wallets", {})),
"n_tokens": len(sample.get("tokens", {})),
"n_graph_link_types": len(sample.get("graph_links", {})),
},
"chart": {
"present": bool(chart_event),
"timestamp": chart_event.get("timestamp"),
"relative_ts": chart_event.get("relative_ts"),
"interval": chart_event.get("i"),
"opens": chart_event.get("opens", []) or [],
"closes": chart_event.get("closes", []) or [],
"windows": windows_payload,
"points": chart_points,
"overlays": {
"levels": level_overlays,
"trendlines": trendline_overlays,
"boundaries": boundary_overlays,
"indicators": indicator_overlays,
"signals": signal_overlays,
},
},
}
HTML = """
Cache Debug
Quant Windows
| # |
Range |
SR |
Trend |
Top Signal |
Support |
Resistance |
Window Detail
Select a quant window.
"""
class CacheDebugHandler(BaseHTTPRequestHandler):
cache_dir: Path = Path("data/cache")
fixed_file: Optional[Path] = None
def _json(self, payload: Dict[str, Any], code: int = 200) -> None:
encoded = json.dumps(payload).encode("utf-8")
self.send_response(code)
self.send_header("Content-Type", "application/json; charset=utf-8")
self.send_header("Content-Length", str(len(encoded)))
self.end_headers()
self.wfile.write(encoded)
def _html(self, body: str, code: int = 200) -> None:
encoded = body.encode("utf-8")
self.send_response(code)
self.send_header("Content-Type", "text/html; charset=utf-8")
self.send_header("Content-Length", str(len(encoded)))
self.end_headers()
self.wfile.write(encoded)
def _pick_sample(self, qs: Dict[str, List[str]]) -> Path:
if self.fixed_file is not None:
return self.fixed_file
if "file" in qs and qs["file"] and qs["file"][0]:
candidate = Path(qs["file"][0]).expanduser()
if not candidate.is_absolute():
candidate = (Path.cwd() / candidate).resolve()
if not candidate.exists():
raise FileNotFoundError(candidate)
return candidate
files = sorted(self.cache_dir.glob("sample_*.pt"))
if not files:
raise FileNotFoundError(f"No sample_*.pt files found in {self.cache_dir}")
if "index" in qs and qs["index"]:
idx = max(0, min(int(qs["index"][0]), len(files) - 1))
return files[idx]
return random.choice(files)
def do_GET(self) -> None:
parsed = urlparse(self.path)
if parsed.path == "/":
self._html(HTML)
return
if parsed.path == "/api/sample":
try:
source = self._pick_sample(parse_qs(parsed.query))
sample = _load_cache_sample(source)
self._json(_sample_to_payload(sample, source))
except Exception as exc:
self._json({"error": str(exc)}, code=500)
return
self.send_response(404)
self.end_headers()
def main() -> int:
parser = argparse.ArgumentParser(description="View cached samples in a simple web UI.")
parser.add_argument("--cache_dir", type=str, default="data/cache", help="Cache directory containing sample_*.pt")
parser.add_argument("--file", type=str, default=None, help="Optional fixed sample file to always render")
parser.add_argument("--host", type=str, default="127.0.0.1", help="Bind host")
parser.add_argument("--port", type=int, default=8765, help="Bind port")
args = parser.parse_args()
CacheDebugHandler.cache_dir = Path(args.cache_dir)
CacheDebugHandler.fixed_file = Path(args.file).resolve() if args.file else None
server = ThreadingHTTPServer((args.host, args.port), CacheDebugHandler)
print(f"Cache debug viewer running at http://{args.host}:{args.port}")
try:
server.serve_forever()
except KeyboardInterrupt:
pass
finally:
server.server_close()
return 0
if __name__ == "__main__":
raise SystemExit(main())