File size: 25,463 Bytes
1cdc0af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e3d126
1cdc0af
 
 
 
 
 
 
 
 
 
 
4dd4ab4
8e3d126
 
4dd4ab4
 
 
 
 
 
8e3d126
 
 
 
 
 
 
 
 
 
 
 
 
1cdc0af
 
 
 
 
 
 
 
8e3d126
 
 
1cdc0af
8e3d126
 
 
 
 
 
1cdc0af
8e3d126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1cdc0af
8e3d126
 
 
 
 
1cdc0af
8e3d126
4dd4ab4
8e3d126
1cdc0af
4dd4ab4
 
 
 
 
 
 
 
 
 
 
 
8e3d126
 
 
 
 
 
 
 
 
 
1cdc0af
8e3d126
1cdc0af
4dd4ab4
1cdc0af
 
 
 
 
 
8e3d126
 
4dd4ab4
8e3d126
4dd4ab4
 
 
 
 
 
1cdc0af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4dd4ab4
1cdc0af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4dd4ab4
 
 
1cdc0af
 
4dd4ab4
1cdc0af
 
 
 
 
 
 
7189688
 
 
 
1cdc0af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7189688
 
 
 
 
 
 
 
 
 
 
 
1cdc0af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4dd4ab4
 
 
 
1cdc0af
 
 
 
 
 
 
 
 
 
 
 
 
 
4dd4ab4
 
 
 
1cdc0af
 
 
 
 
 
 
 
4dd4ab4
1cdc0af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4dd4ab4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1cdc0af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7189688
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4dd4ab4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1cdc0af
 
 
 
 
 
 
 
 
 
 
 
 
 
4dd4ab4
1cdc0af
 
 
 
 
 
 
 
 
4dd4ab4
1cdc0af
 
 
8e3d126
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
import os
import sys
import json
import math
import argparse
from typing import Dict, List, Tuple

from clickhouse_driver import Client as ClickHouseClient

# Add parent to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from models.vocabulary import RETURN_THRESHOLDS

CLICKHOUSE_HOST = os.getenv("CLICKHOUSE_HOST", "localhost")
CLICKHOUSE_PORT = int(os.getenv("CLICKHOUSE_PORT", 9000))
CLICKHOUSE_USER = os.getenv("CLICKHOUSE_USER", "default")
CLICKHOUSE_PASSWORD = os.getenv("CLICKHOUSE_PASSWORD", "")
CLICKHOUSE_DATABASE = os.getenv("CLICKHOUSE_DATABASE", "default")

LAUNCH_PRICE_USD = 0.000004
EPS = 1e-9


def get_client():
    return ClickHouseClient(
        host=CLICKHOUSE_HOST,
        port=CLICKHOUSE_PORT,
        user=CLICKHOUSE_USER,
        password=CLICKHOUSE_PASSWORD,
        database=CLICKHOUSE_DATABASE,
    )


def _midrank_percentiles(items: List[Tuple[str, float]]) -> Dict[str, float]:
    """
    Compute midrank percentiles for a list of (token, value).
    Returns p in (0,1) via (rank - 0.5) / n. Ties get the same midrank.
    """
    if not items:
        return {}
    items_sorted = sorted(items, key=lambda x: x[1])
    n = len(items_sorted)
    out = {}
    i = 0
    while i < n:
        j = i
        v = items_sorted[i][1]
        while j + 1 < n and items_sorted[j + 1][1] == v:
            j += 1
        # midrank is average of ranks i..j (1-based)
        rank_lo = i + 1
        rank_hi = j + 1
        midrank = 0.5 * (rank_lo + rank_hi)
        p = (midrank - 0.5) / n
        for k in range(i, j + 1):
            out[items_sorted[k][0]] = p
        i = j + 1
    return out


def _bucket_id(ret_val: float) -> int:
    for i in range(len(RETURN_THRESHOLDS) - 1):
        lower = RETURN_THRESHOLDS[i]
        upper = RETURN_THRESHOLDS[i + 1]
        if ret_val >= lower and ret_val < upper:
            return i
    return -1


def fetch_token_metrics(client) -> List[dict]:
    """
    Fetches lifetime metrics needed for quality scoring.
    Returns a list of dicts keyed by token_address.
    """
    query = f"""
    WITH
        -- 1. Aggregated trade stats (unchanged)
        trade_agg AS (
            SELECT
                base_address,
                sum(priority_fee + coin_creator_fee) AS fees_sol,
                sum(total_usd) AS volume_usd,
                count() AS n_trades,
                min(timestamp) AS t0,
                argMax(timestamp, price_usd) AS t_ath
            FROM trades
            GROUP BY base_address
        ),
        -- 2. "Token list derived MINTS.
        token_meta_raw AS (
            SELECT
                mint_address AS token_address,
                argMax(creator_address, timestamp) AS creator_address,
                argMax(total_supply, timestamp) AS total_supply,
                argMax(token_decimals, timestamp) AS decimals
            FROM mints
            GROUP BY mint_address
        ),
        token_meta AS (
            SELECT
                token_address,
                creator_address,
                total_supply,
                decimals,
                -- Derived adjusted supply for percentage calcs
                (total_supply / pow(10, decimals)) AS adj_supply
            FROM token_meta_raw
            WHERE adj_supply > 0
        ),
        -- 3. Token lifetimes metrics (returns, holders)
        ret_agg AS (
            SELECT
                token_address,
                (argMax(ath_price_usd, updated_at) / {LAUNCH_PRICE_USD}) AS ret,
                argMax(unique_holders, updated_at) AS unique_holders
            FROM token_metrics
            GROUP BY token_address
        ),
        -- 4. WALLET PEAKS: Pre-calculate the Peak Balance (max current_balance) for every wallet
        --    This handles the "User Sold" case by taking their highest ever balance.
        wallet_peaks AS (
            SELECT
                mint_address,
                wallet_address,
                max(current_balance) AS peak_balance,
                max(history_transfer_in) AS max_transfer_in
            FROM wallet_holdings
            GROUP BY mint_address, wallet_address
        ),
        
        -- 5. SNIPERS: Identify sniper addresses (rank <= 70), then sum their PEAK balances
        snipers_list AS (
             SELECT
                 base_address,
                 maker
             FROM (
                 SELECT
                     base_address,
                     maker,
                     dense_rank() OVER (PARTITION BY base_address ORDER BY min_slot, min_idx) AS buyer_rank
                 FROM (
                     SELECT
                         base_address,
                         maker,
                         min(slot) AS min_slot,
                         min(transaction_index) AS min_idx
                     FROM trades
                     WHERE trade_type = 0 -- buy
                     GROUP BY base_address, maker
                 )
             )
             WHERE buyer_rank <= 70
        ),
        snipers_agg AS (
            SELECT
                s.base_address AS token_address,
                sum(wp.peak_balance) AS snipers_total_peak
            FROM snipers_list s
            JOIN wallet_peaks wp ON s.base_address = wp.mint_address AND s.maker = wp.wallet_address
            GROUP BY s.base_address
        ),

        -- 6. BUNDLED: Sum the base_amount of ALL trades that happened in a slot with multiple buys
        bundled_agg AS (
            SELECT
                t.base_address AS token_address,
                sum(t.base_amount) AS bundled_total_peak
            FROM trades t
            WHERE (t.base_address, t.slot) IN (
                 SELECT base_address, slot
                 FROM trades
                 WHERE trade_type = 0 -- buy
                 GROUP BY base_address, slot
                 HAVING count() > 1
            )
            AND t.trade_type = 0 -- buy
            GROUP BY t.base_address
        ),

        -- 7. DEV HOLD: Creator's Peak Balance
        dev_hold_agg AS (
            SELECT
                t.token_address,
                max(wp.peak_balance) AS dev_peak -- max in case of dupe, but should be 1:1
            FROM token_meta t
            JOIN wallet_peaks wp ON t.token_address = wp.mint_address AND t.creator_address = wp.wallet_address
            GROUP BY t.token_address
        )

    SELECT
        t.token_address,
        r.ret,
        r.unique_holders,
        f.fees_sol,
        f.volume_usd,
        f.n_trades,
        (f.t_ath - f.t0) AS time_to_ath_sec,
        -- Calculate Percentages using Peak Sums / Total Supply
        (COALESCE(s.snipers_total_peak, 0) / t.adj_supply * 100) AS snipers_pct,
        (COALESCE(b.bundled_total_peak, 0) / t.total_supply * 100) AS bundled_pct,
        (COALESCE(d.dev_peak, 0)           / t.adj_supply * 100) AS dev_hold_pct
    FROM token_meta t
    LEFT JOIN ret_agg r ON t.token_address = r.token_address
    LEFT JOIN trade_agg f ON t.token_address = f.base_address
    LEFT JOIN snipers_agg s ON t.token_address = s.token_address
    LEFT JOIN bundled_agg b ON t.token_address = b.token_address
    LEFT JOIN dev_hold_agg d ON t.token_address = d.token_address
    """
    rows = client.execute(query)
    cols = [
        "token_address",
        "ret",
        "unique_holders",
        "fees_sol",
        "volume_usd",
        "n_trades",
        "time_to_ath_sec",
        "snipers_pct",
        "bundled_pct",
        "dev_hold_pct",
    ]
    out = []
    for r in rows:
        out.append(dict(zip(cols, r)))
    return out


def compute_quality_scores(
    client,
    max_ret: float = 10000.0,
    rerank: bool = True,
    with_debug: bool = False,
):
    data = fetch_token_metrics(client)

    # feature spec: (name, getter, positive_when_high)
    feature_defs = [
        ("fees_log", lambda d: math.log1p(d["fees_sol"]) if d["fees_sol"] is not None else None, True),
        ("volume_log", lambda d: math.log1p(d["volume_usd"]) if d["volume_usd"] is not None else None, True),
        ("holders_log", lambda d: math.log1p(d["unique_holders"]) if d["unique_holders"] is not None else None, True),
        ("time_to_ath_log", lambda d: math.log1p(d["time_to_ath_sec"]) if d["time_to_ath_sec"] is not None else None, True),
        ("fees_per_volume", lambda d: (d["fees_sol"] / (d["volume_usd"] + EPS)) if d["fees_sol"] is not None and d["volume_usd"] is not None else None, True),
        ("fees_per_trade", lambda d: (d["fees_sol"] / (d["n_trades"] + EPS)) if d["fees_sol"] is not None and d["n_trades"] is not None else None, True),
        ("holders_per_trade", lambda d: (d["unique_holders"] / (d["n_trades"] + EPS)) if d["unique_holders"] is not None and d["n_trades"] is not None else None, True),
        ("holders_per_volume", lambda d: (d["unique_holders"] / (d["volume_usd"] + EPS)) if d["unique_holders"] is not None and d["volume_usd"] is not None else None, True),
        ("snipers_pct", lambda d: d["snipers_pct"], True),
        ("bundled_pct", lambda d: d["bundled_pct"], True),
        ("dev_hold_pct", lambda d: d["dev_hold_pct"], True),
    ]

    raw_metrics = ["snipers_pct", "bundled_pct", "dev_hold_pct", "fees_sol"] # Added fees_sol for diagnostic logging

    debug = None
    if with_debug:
        debug = {
            "q_raw": [],
            "feature_pairs": {f[0]: [] for f in feature_defs},
            "raw_pairs": {m: [] for m in raw_metrics},
            # For checking assumptions like "higher return buckets have lower bundled_pct".
            # Store raw metric distributions per return bucket and (ret, metric) pairs overall.
            "bucket_raw": {},  # bucket_id -> metric -> [raw vals]
            "ret_pairs": {m: [] for m in raw_metrics},  # metric -> [(ret, raw_val)]
        }

    # Build bucket mapping
    buckets: Dict[int, List[dict]] = {}
    for d in data:
        ret_val = d.get("ret")
        if ret_val is None or ret_val <= 0 or ret_val > max_ret:
            continue
        b = _bucket_id(ret_val)
        if b == -1:
            continue
        d["bucket_id"] = b
        buckets.setdefault(b, []).append(d)

    # Compute percentiles per bucket + feature
    token_scores = []
    for b, items in buckets.items():
        if with_debug:
            debug["bucket_raw"].setdefault(b, {m: [] for m in raw_metrics})
            for d in items:
                ret_val = d.get("ret")
                for metric in raw_metrics:
                    raw_val = d.get(metric)
                    if raw_val is None:
                        continue
                    debug["bucket_raw"][b][metric].append(raw_val)
                    if ret_val is not None:
                        debug["ret_pairs"][metric].append((ret_val, raw_val))

        # Precompute percentiles per feature
        feature_percentiles: Dict[str, Dict[str, float]] = {}
        for fname, fget, _pos in feature_defs:
            vals = []
            for d in items:
                v = fget(d)
                if v is None or (isinstance(v, float) and (math.isnan(v) or math.isinf(v))):
                    continue
                vals.append((d["token_address"], v))
            feature_percentiles[fname] = _midrank_percentiles(vals)

        # Compute q_raw for each token
        q_raw_map = {}
        for d in items:
            s_vals = []
            s_map = {}
            for fname, _fget, pos in feature_defs:
                p = feature_percentiles[fname].get(d["token_address"])
                if p is None:
                    continue
                s = 2.0 * p - 1.0
                if not pos:
                    s = -s
                # clip
                if s > 0.99:
                    s = 0.99
                elif s < -0.99:
                    s = -0.99
                s_vals.append(s)
                s_map[fname] = s
            if not s_vals:
                continue
            q_raw = sum(s_vals) / len(s_vals)
            q_raw_map[d["token_address"]] = q_raw
            if with_debug:
                debug["q_raw"].append(q_raw)
                for fname, s in s_map.items():
                    debug["feature_pairs"][fname].append((q_raw, s))
                for metric in raw_metrics:
                    raw_val = d.get(metric)
                    if raw_val is None:
                        continue
                    debug["raw_pairs"][metric].append((q_raw, raw_val))

        # Optional re-rank within bucket
        if rerank:
            q_items = [(t, q) for t, q in q_raw_map.items()]
            q_p = _midrank_percentiles(q_items)
            for d in items:
                t = d["token_address"]
                if t not in q_raw_map:
                    continue
                q_final = 2.0 * q_p[t] - 1.0
                token_scores.append(
                    {
                        "token_address": t,
                        "bucket_id": b,
                        "ret": d["ret"],
                        "q_raw": q_raw_map[t],
                        "q": q_final,
                        # Pass through raw metrics for analysis
                        "bundled_pct": d.get("bundled_pct"),
                        "snipers_pct": d.get("snipers_pct"),
                        "fees_sol": d.get("fees_sol"),
                    }
                )
        else:
            for d in items:
                t = d["token_address"]
                if t not in q_raw_map:
                    continue
                token_scores.append(
                    {
                        "token_address": t,
                        "bucket_id": b,
                        "ret": d["ret"],
                        "q_raw": q_raw_map[t],
                        "q": q_raw_map[t],
                        # Pass through raw metrics for analysis
                        "bundled_pct": d.get("bundled_pct"),
                        "snipers_pct": d.get("snipers_pct"),
                        "fees_sol": d.get("fees_sol"),
                    }
                )

    if with_debug:
        return token_scores, debug
    return token_scores





def write_jsonl(path: str, rows: List[dict]) -> None:
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        for r in rows:
            f.write(json.dumps(r) + "\n")


def _percentile(sorted_vals: List[float], p: float) -> float:
    if not sorted_vals:
        return float("nan")
    n = len(sorted_vals)
    if n == 1:
        return sorted_vals[0]
    pos = p * (n - 1)
    lo = int(math.floor(pos))
    hi = int(math.ceil(pos))
    if lo == hi:
        return sorted_vals[lo]
    frac = pos - lo
    return sorted_vals[lo] * (1 - frac) + sorted_vals[hi] * frac


def _summary_stats(vals: List[float]) -> Dict[str, float]:
    if not vals:
        return {}
    vals_sorted = sorted(vals)
    return {
        "mean": sum(vals_sorted) / len(vals_sorted),
        "min": vals_sorted[0],
        "max": vals_sorted[-1],
        "p10": _percentile(vals_sorted, 0.10),
        "p50": _percentile(vals_sorted, 0.50),
        "p90": _percentile(vals_sorted, 0.90),
        "p99": _percentile(vals_sorted, 0.99),
    }


def _pearson_corr(xs: List[float], ys: List[float]) -> float:
    if not xs or not ys or len(xs) != len(ys) or len(xs) < 2:
        return float("nan")
    n = len(xs)
    mean_x = sum(xs) / n
    mean_y = sum(ys) / n
    num = 0.0
    den_x = 0.0
    den_y = 0.0
    for i in range(n):
        dx = xs[i] - mean_x
        dy = ys[i] - mean_y
        num += dx * dy
        den_x += dx * dx
        den_y += dy * dy
    denom = math.sqrt(den_x * den_y)
    if denom == 0.0:
        return float("nan")
    return num / denom


def _bucket_label(b: int) -> str:
    lower = RETURN_THRESHOLDS[b]
    upper = RETURN_THRESHOLDS[b + 1] if b + 1 < len(RETURN_THRESHOLDS) else None
    if upper is None:
        return f">= {lower}x"
    return f"{lower}x - {upper}x"


def print_summary(scores: List[dict]) -> None:
    print("=== QUALITY SCORE SUMMARY ===")
    print(f"Total tokens scored: {len(scores)}")
    if not scores:
        return

    overall_q = [s["q"] for s in scores if "q" in s]
    overall_q_raw = [s["q_raw"] for s in scores if "q_raw" in s]
    for name, series in [("q", overall_q), ("q_raw", overall_q_raw)]:
        stats = _summary_stats(series)
        if not stats:
            continue
        print(f"\nOverall {name}:")
        print(f"  Mean: {stats['mean']:.4f} | Min: {stats['min']:.4f} | Max: {stats['max']:.4f}")
        print(f"  Q: p10={stats['p10']:.2f} p50={stats['p50']:.2f} p90={stats['p90']:.2f} p99={stats['p99']:.2f}")

    # Per-bucket summaries
    buckets: Dict[int, List[dict]] = {}
    for s in scores:
        buckets.setdefault(s["bucket_id"], []).append(s)

    for b in sorted(buckets.keys()):
        items = buckets[b]
        q_vals = [i["q"] for i in items if "q" in i]
        q_raw_vals = [i["q_raw"] for i in items if "q_raw" in i]
        print(f"\nSEGMENT: {b}. {_bucket_label(b)}")
        print(f"Tokens in segment: {len(items)}")
        stats_q = _summary_stats(q_vals)
        stats_q_raw = _summary_stats(q_raw_vals)
        if stats_q:
            print("  q:")
            print(f"    Mean: {stats_q['mean']:.4f} | Min: {stats_q['min']:.4f} | Max: {stats_q['max']:.4f}")
            print(f"    Q: p10={stats_q['p10']:.2f} p50={stats_q['p50']:.2f} p90={stats_q['p90']:.2f} p99={stats_q['p99']:.2f}")
        if stats_q_raw:
            print("  q_raw:")
            print(f"    Mean: {stats_q_raw['mean']:.4f} | Min: {stats_q_raw['min']:.4f} | Max: {stats_q_raw['max']:.4f}")
            print(f"    Q: p10={stats_q_raw['p10']:.2f} p50={stats_q_raw['p50']:.2f} p90={stats_q_raw['p90']:.2f} p99={stats_q_raw['p99']:.2f}")

        # --- NEW: Print 3 Examples (Min, Mid, Max) ---
        if items:
            # Sort items by 'q' to find min/mid/max easily
            items_sorted = sorted(items, key=lambda x: x.get("q", 0))
            
            ex_min = items_sorted[0]
            ex_max = items_sorted[-1]
            
            # Find mid (closest to 0.0, or just median index? Request said "mean quality" which is 0.0)
            # finding item with q closest to 0.0
            ex_mid = min(items_sorted, key=lambda x: abs(x.get("q", 0) - 0.0))
            
            print("  Examples:")
            print(f"    Low  (-1.0): {ex_min['token_address']} (q={ex_min.get('q',0):.4f}, ret={ex_min.get('ret',0):.2f}x)")
            print(f"    Mid  (~0.0): {ex_mid['token_address']} (q={ex_mid.get('q',0):.4f}, ret={ex_mid.get('ret',0):.2f}x)")
            print(f"    High ( 1.0): {ex_max['token_address']} (q={ex_max.get('q',0):.4f}, ret={ex_max.get('ret',0):.2f}x)")


def print_diagnostics(debug: dict) -> None:
    if not debug:
        return
    q_raw_vals = debug.get("q_raw", [])
    if not q_raw_vals:
        return
    print("\n=== QUALITY SCORE DIAGNOSTICS ===")

    feature_pairs = debug.get("feature_pairs", {})
    if feature_pairs:
        print("Correlation with q_raw (signed features):")
        for fname in sorted(feature_pairs.keys()):
            pairs = feature_pairs[fname]
            xs = [p[0] for p in pairs]
            ys = [p[1] for p in pairs]
            corr = _pearson_corr(xs, ys)
            print(f"  {fname}: {corr:.4f} (n={len(pairs)})")

    raw_pairs = debug.get("raw_pairs", {})
    if raw_pairs:
        q_sorted = sorted(q_raw_vals)
        p10 = _percentile(q_sorted, 0.10)
        p90 = _percentile(q_sorted, 0.90)
        print("\nTop/bottom decile raw means (by q_raw):")
        for metric in sorted(raw_pairs.keys()):
            pairs = raw_pairs[metric]
            lows = [v for q, v in pairs if q <= p10]
            highs = [v for q, v in pairs if q >= p90]
            if not lows or not highs:
                continue
            low_mean = sum(lows) / len(lows)
            high_mean = sum(highs) / len(highs)
            print(f"  {metric}: bottom_mean={low_mean:.4f} top_mean={high_mean:.4f} (n_low={len(lows)}, n_high={len(highs)})")

    # Return bucket -> raw metric distributions (answers questions like "do higher-return tokens bundle less?")
    bucket_raw = debug.get("bucket_raw", {})
    if bucket_raw:
        print("\n=== RETURN BUCKET RAW METRICS ===")
        for b in sorted(bucket_raw.keys()):
            print(f"\nSEGMENT: {b}. {_bucket_label(b)}")
            for metric in sorted(bucket_raw[b].keys()):
                vals = [v for v in bucket_raw[b][metric] if v is not None]
                if not vals:
                    continue
                stats = _summary_stats(vals)
                # Also report how often the metric is > 0 (useful since many pct metrics are 0).
                nz = sum(1 for v in vals if v > 0)
                nz_rate = nz / len(vals)
                print(
                    f"  {metric}: mean={stats['mean']:.4f} p50={stats['p50']:.4f} "
                    f"p90={stats['p90']:.4f} p99={stats['p99']:.4f} nonzero_rate={nz_rate:.3f} (n={len(vals)})"
                )

    # Overall return-vs-metric correlation (not bucketed). Use log(ret) to reduce tail leverage.
    ret_pairs = debug.get("ret_pairs", {})
    if ret_pairs:
        print("\n=== RETURN VS RAW METRICS (GLOBAL) ===")
        for metric in sorted(ret_pairs.keys()):
            pairs = ret_pairs[metric]
            xs = []
            ys = []
            for r, v in pairs:
                if r is None or r <= 0:
                    continue
                xs.append(math.log(r))
                ys.append(v)
            if len(xs) < 3:
                continue
            corr = _pearson_corr(xs, ys)
            print(f"  log(ret) vs {metric}: {corr:.4f} (n={len(xs)})")

    # Removed placeholder
    pass


def print_high_ret_analysis(scores: List[dict]) -> None:
    print("\n=== MID-HIGH RETURN SPLIT ANALYSIS (10x - 20x) ===")
    
    # 1. Filter for Mid-High Return Cohort (10x - 20x)
    cohort = [s for s in scores if s.get("ret") is not None and s["ret"] >= 10.0 and s["ret"] < 20.0]
    if not cohort:
        print("No tokens 10x-20x found.")
        return

    print(f"Total tokens 10x-20x: {len(cohort)}")

    # 2. Extract Bundled Pct
    bundled_vals = [s.get("bundled_pct", 0) for s in cohort if s.get("bundled_pct") is not None]
    if not bundled_vals:
        print("No bundled_pct data found.")
        return
    
    median_bundled = _percentile(sorted(bundled_vals), 0.50)
    print(f"Median Bundled% for Cohort: {median_bundled:.2f}%")

    # 3. Split
    low_group = [s for s in cohort if (s.get("bundled_pct") or 0) <= median_bundled]
    high_group = [s for s in cohort if (s.get("bundled_pct") or 0) > median_bundled]

    # 4. Analyze Fees
    def get_mean_fees(group):
        fees = [s.get("fees_sol", 0) for s in group if s.get("fees_sol") is not None]
        if not fees: return 0.0
        return sum(fees) / len(fees)

    mean_fees_low = get_mean_fees(low_group)
    mean_fees_high = get_mean_fees(high_group)

    print(f"\nGroup 1: LOW Bundled (<= {median_bundled:.2f}%)")
    print(f"  Count: {len(low_group)}")
    print(f"  Mean Fees: {mean_fees_low:.4f} SOL")

    print(f"\nGroup 2: HIGH Bundled (> {median_bundled:.2f}%)")
    print(f"  Count: {len(high_group)}")
    print(f"  Mean Fees: {mean_fees_high:.4f} SOL")
    
    # Extra: Check returns too
    def get_mean_ret(group):
        rets = [s["ret"] for s in group]
        if not rets: return 0.0
        return sum(rets) / len(rets)
        
    print(f"  Mean Ret:  {get_mean_ret(high_group):.2f}x (vs Low: {get_mean_ret(low_group):.2f}x)")


def get_token_quality_scores(client):
    """
    Returns a dictionary mapping token_address -> q (quality score)
    """
    # Force rerank=True to get final scores
    results = compute_quality_scores(client, max_ret=1e9, rerank=True)
    
    # Return mapping
    # If compute_quality_scores returns (scores, debug) tuple (when with_debug=True), handle it.
    # Default call rerank=True returns 'scores' list if with_debug=False?
    # No, looking at main, it returns 'scores' if no_diagnostics.
    # But get_token_quality_scores uses default args. 
    # Let's check compute_quality_score signature... it has with_debug=False default.
    # So it returns 'scores'.
    
    return {r["token_address"]: r.get("q", 0.0) for r in results}


def main():
    parser = argparse.ArgumentParser(description="Compute token quality/health score.")
    parser.add_argument("--max-ret", type=float, default=10000.0, help="Max return to include")
    parser.add_argument("--no-rerank", action="store_true", help="Disable final rerank within bucket")
    parser.add_argument("--no-summary", action="store_true", help="Disable summary logging")
    parser.add_argument("--no-diagnostics", action="store_true", help="Disable diagnostics logging")
    args = parser.parse_args()

    client = get_client()
    if args.no_diagnostics:
        scores = compute_quality_scores(client, max_ret=args.max_ret, rerank=not args.no_rerank)
        debug = None
    else:
        scores, debug = compute_quality_scores(
            client,
            max_ret=args.max_ret,
            rerank=not args.no_rerank,
            with_debug=True,
        )
    if not args.no_summary:
        print_summary(scores)
    if not args.no_diagnostics:
        print_diagnostics(debug)
        print_high_ret_analysis(scores) # Call the new analysis


if __name__ == "__main__":
    main()