File size: 3,996 Bytes
201cf4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""Main orchestrator: read mmap ticks, score with KAN, detect signals."""
from __future__ import annotations

import argparse
import logging
import time
from typing import Dict, List

from prediction_engine.python_bridge.kan_scorer import KANFastScorer, MarketFeatureExtractor
from prediction_engine.python_bridge.mmap_reader import CompactTick, MmapReader

logger = logging.getLogger(__name__)

ARBITRAGE_THRESHOLD = 0.98


class SignalDetector:
    """Detect trading signals from market ticks."""

    def __init__(self):
        self.extractor = MarketFeatureExtractor()
        self.scorer = KANFastScorer(in_features=6)
        self.market_state: Dict[int, Dict[str, CompactTick]] = {}

    def update(self, tick: CompactTick):
        key = tick.market_id_hash
        if key not in self.market_state:
            self.market_state[key] = {}
        venue_key = f"{tick.venue}_{tick.side}"
        self.market_state[key][venue_key] = tick

    def check_arbitrage(self, market_hash: int) -> dict | None:
        state = self.market_state.get(market_hash, {})
        yes_prices = {k: v.price for k, v in state.items() if "yes" in k}
        no_prices = {k: v.price for k, v in state.items() if "no" in k}

        for yes_venue, yes_price in yes_prices.items():
            for no_venue, no_price in no_prices.items():
                if yes_venue.split("_")[0] != no_venue.split("_")[0]:
                    total = yes_price + no_price
                    if total < ARBITRAGE_THRESHOLD:
                        return {
                            "type": "cross_venue_arbitrage",
                            "yes_venue": yes_venue,
                            "no_venue": no_venue,
                            "yes_price": yes_price,
                            "no_price": no_price,
                            "total_cost": total,
                            "guaranteed_profit": 1.0 - total,
                        }
        return None

    def process_ticks(self, ticks: List[CompactTick]) -> List[dict]:
        signals = []
        for tick in ticks:
            self.update(tick)
            arb = self.check_arbitrage(tick.market_id_hash)
            if arb:
                features = self.extractor.extract({
                    "yes_price": arb["yes_price"],
                    "no_price": arb["no_price"],
                    "spread": arb["guaranteed_profit"],
                    "volume_ratio": 1.0,
                    "time_to_event_hours": 24.0,
                    "venue_count": 2,
                })
                score = self.scorer.score(features)
                arb["kan_score"] = score
                signals.append(arb)
                logger.info(
                    "ARBITRAGE: %s vs %s, profit=$%.4f, confidence=%.3f, edge=%s",
                    arb["yes_venue"], arb["no_venue"],
                    arb["guaranteed_profit"], score["confidence"],
                    score["edge_shape"],
                )
        return signals


def main():
    parser = argparse.ArgumentParser(description="Prediction Engine Orchestrator")
    parser.add_argument("--mmap-path", default="/tmp/prediction_ticks.mmap")
    parser.add_argument("--poll-ms", type=float, default=10.0)
    parser.add_argument("--test", action="store_true")
    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")

    if args.test:
        logger.info("Smoke test mode")
        detector = SignalDetector()
        logger.info("SignalDetector initialized with KAN scorer")
        logger.info("Orchestrator smoke test PASSED")
        return

    logger.info("Starting orchestrator, reading from %s", args.mmap_path)
    detector = SignalDetector()

    with MmapReader(args.mmap_path) as reader:
        for ticks in reader.poll(interval_ms=args.poll_ms):
            signals = detector.process_ticks(ticks)
            for sig in signals:
                logger.info("SIGNAL: %s", sig)


if __name__ == "__main__":
    main()