Upload folder using huggingface_hub
Browse files- QUALITY_SCORE_ARCHITECTURE.md +125 -1
- data/context_targets.py +93 -0
- data/data_collator.py +2 -0
- data/data_loader.py +38 -3
- log.log +2 -2
- models/model.py +15 -0
- pre_cache.sh +1 -1
- scripts/cache_dataset.py +207 -34
- scripts/evaluate_sample.py +606 -120
- train.py +136 -4
QUALITY_SCORE_ARCHITECTURE.md
CHANGED
|
@@ -10,4 +10,128 @@ T_cutoff at trade 700 → returns are -90%
|
|
| 10 |
T_cutoff at trade 900 → returns are -95%
|
| 11 |
So even for class 5 tokens, 80%+ of the cached training samples have negative Ground Truth labels. The model is correctly learning that at any random moment, even a "good" token is most likely going down. The class balancing doesn't change the fact that the actual Y labels are overwhelmingly negative across all classes.
|
| 12 |
|
| 13 |
-
The model isn't broken — it learned exactly what the data showed it. The issue is that the training setup doesn't teach it to recognize the pre-pump moment specifically.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
T_cutoff at trade 900 → returns are -95%
|
| 11 |
So even for class 5 tokens, 80%+ of the cached training samples have negative Ground Truth labels. The model is correctly learning that at any random moment, even a "good" token is most likely going down. The class balancing doesn't change the fact that the actual Y labels are overwhelmingly negative across all classes.
|
| 12 |
|
| 13 |
+
The model isn't broken — it learned exactly what the data showed it. The issue is that the training setup doesn't teach it to recognize the pre-pump moment specifically.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
**Main Issue**
|
| 17 |
+
Your main problem was never just “bad checkpoint choice.” The core issue is training/data misalignment:
|
| 18 |
+
|
| 19 |
+
- token `class_id` is token-level
|
| 20 |
+
- the prediction target is context-level from random `T_cutoff`
|
| 21 |
+
- even a good token produces many bad windows
|
| 22 |
+
- so balanced token classes do not mean balanced future-return labels
|
| 23 |
+
- the model then learns an over-negative prior
|
| 24 |
+
|
| 25 |
+
A second major issue was cache construction:
|
| 26 |
+
- cache was wasting disk/time on overwhelming numbers of garbage-token samples
|
| 27 |
+
- later training weights cannot fix that upstream waste
|
| 28 |
+
|
| 29 |
+
**What We Figured Out**
|
| 30 |
+
- The model is not useless.
|
| 31 |
+
- Wallet signal is real: ablations showed wallet removal hurts predictions materially.
|
| 32 |
+
- OHLC matters, but mostly as a coarse summary, not real chart-pattern intelligence.
|
| 33 |
+
- No obvious future leakage was found in OHLC construction.
|
| 34 |
+
- Social looked basically unused.
|
| 35 |
+
- Graph looked weaker than expected.
|
| 36 |
+
- The movement head idea is valid, but only if labels are placed correctly in the pipeline.
|
| 37 |
+
- Movement labels should come from the data loader, not be derived later in collator/training.
|
| 38 |
+
- Cache balancing should not depend on fragile movement thresholds.
|
| 39 |
+
- A single “movement class” for cache weighting was wrong because:
|
| 40 |
+
- thresholds were unresolved
|
| 41 |
+
- movement differs across horizons inside the same sample
|
| 42 |
+
|
| 43 |
+
**Where You Corrected the Direction**
|
| 44 |
+
You pushed on several important bad assumptions:
|
| 45 |
+
|
| 46 |
+
- `return > 0` is too noisy as a label
|
| 47 |
+
- movement class names should be threshold-agnostic
|
| 48 |
+
- threshold-based movement balancing was premature
|
| 49 |
+
- SQL/global distribution threshold inference was conceptually wrong because labels depend on sampled `T_cutoff`
|
| 50 |
+
- cache should not be filtered by class map in a destructive way
|
| 51 |
+
- cache balancing must happen at cache generation time, not be delegated to train weights
|
| 52 |
+
- positive balancing should not be forced on garbage classes
|
| 53 |
+
- exact class sample counts matter more than approximate expected weighting
|
| 54 |
+
- `T_cutoff` does not need to be deterministic or pre-fixed
|
| 55 |
+
- if cache balancing uses movement-like signals, use threshold-free binary polarity first
|
| 56 |
+
|
| 57 |
+
Those corrections materially improved the design.
|
| 58 |
+
|
| 59 |
+
**Proposed Methods Over the Chat**
|
| 60 |
+
These were the main methods proposed, in order of evolution:
|
| 61 |
+
|
| 62 |
+
1. Forward time validation and token-grouped splits
|
| 63 |
+
- to reduce misleading val results and leakage risk
|
| 64 |
+
|
| 65 |
+
2. Auxiliary head ideas
|
| 66 |
+
- first fixed pump heads
|
| 67 |
+
- then all-horizon direction head
|
| 68 |
+
- then movement-type multiclass head
|
| 69 |
+
- final stable view: one multi-horizon movement head is reasonable, but labels must be created correctly
|
| 70 |
+
|
| 71 |
+
3. Runtime/loader-side label derivation
|
| 72 |
+
- final agreed direction:
|
| 73 |
+
- labels belong in the data loader
|
| 74 |
+
- collator should only stack them
|
| 75 |
+
- model should just consume them
|
| 76 |
+
|
| 77 |
+
4. Cache-time balancing instead of train-time rescue
|
| 78 |
+
- because disk/time waste happens before training starts
|
| 79 |
+
- so train weights alone are too late
|
| 80 |
+
|
| 81 |
+
5. Class-id-based cache expansion
|
| 82 |
+
- proposed because class `0` dominates raw token counts
|
| 83 |
+
- later refined because exact quotas matter more than soft weighting
|
| 84 |
+
|
| 85 |
+
6. Movement-class-based cache balancing
|
| 86 |
+
- proposed, then rejected correctly
|
| 87 |
+
- because it depended on unresolved thresholds and collapsed multi-horizon information incorrectly
|
| 88 |
+
|
| 89 |
+
7. Binary polarity cache balancing
|
| 90 |
+
- final better version:
|
| 91 |
+
- use `positive` if max valid horizon return in a sample is `> 0`
|
| 92 |
+
- else `negative`
|
| 93 |
+
- this is threshold-free and less brittle
|
| 94 |
+
|
| 95 |
+
8. Exact class quotas + class-conditional polarity quotas
|
| 96 |
+
- final strongest cache proposal:
|
| 97 |
+
- exact equal sample budget per class
|
| 98 |
+
- positive/negative balancing only for classes that can realistically produce positive contexts
|
| 99 |
+
- keep class `0` mostly negative
|
| 100 |
+
- keep `T_cutoff` random
|
| 101 |
+
|
| 102 |
+
**Current Best Design**
|
| 103 |
+
The strongest design we converged toward is:
|
| 104 |
+
|
| 105 |
+
- keep `T_cutoff` random
|
| 106 |
+
- make cache generation deterministic at the planning level
|
| 107 |
+
- assign exact sample budgets per token class
|
| 108 |
+
- for higher classes, request positive/negative context ratios
|
| 109 |
+
- for low garbage classes, do not force positive quotas
|
| 110 |
+
- build labels in the data loader
|
| 111 |
+
- keep the model’s main task as future return prediction
|
| 112 |
+
- use the movement head only as auxiliary supervision
|
| 113 |
+
|
| 114 |
+
So:
|
| 115 |
+
- token-type balance is controlled explicitly
|
| 116 |
+
- context-type bias is controlled explicitly
|
| 117 |
+
- disk usage is controlled at cache creation time
|
| 118 |
+
- training does not need to rescue a broken corpus
|
| 119 |
+
|
| 120 |
+
**What We Learned About the Model**
|
| 121 |
+
- backbone has useful signal
|
| 122 |
+
- wallet features matter
|
| 123 |
+
- OHLC is being used, but not in the rich pattern-detection way you hoped
|
| 124 |
+
- current main failure mode is much more about data construction and sampling than about the backbone learning nothing
|
| 125 |
+
|
| 126 |
+
**High-Level Conclusions**
|
| 127 |
+
1. The main issue is data/sample construction, not just optimizer/checkpoint behavior.
|
| 128 |
+
2. Cache balancing is the right place to solve the waste and prior-bias problem.
|
| 129 |
+
3. Threshold-dependent movement balancing was too early and too brittle.
|
| 130 |
+
4. Binary short-horizon max-return polarity is a better first cache-balancing signal.
|
| 131 |
+
5. Exact per-class sample quotas plus class-conditional polarity balancing is the most coherent cache design we arrived at.
|
| 132 |
+
|
| 133 |
+
If you want, I can next turn this into a short final blueprint with only:
|
| 134 |
+
- `Main issue`
|
| 135 |
+
- `Accepted design`
|
| 136 |
+
- `Rejected ideas`
|
| 137 |
+
- `Next implementation order`
|
data/context_targets.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Dict, List, Sequence
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
MOVEMENT_STRONG_DOWN_THRESHOLD = -0.40
|
| 7 |
+
MOVEMENT_DOWN_THRESHOLD = -0.30
|
| 8 |
+
MOVEMENT_PUMP_50_THRESHOLD = 0.50
|
| 9 |
+
MOVEMENT_PUMP_100_THRESHOLD = 1.00
|
| 10 |
+
MOVEMENT_PUMP_300_THRESHOLD = 3.00
|
| 11 |
+
|
| 12 |
+
MOVEMENT_CLASS_NAMES = [
|
| 13 |
+
"strong_down",
|
| 14 |
+
"down",
|
| 15 |
+
"flat",
|
| 16 |
+
"up",
|
| 17 |
+
"strong_up",
|
| 18 |
+
"extreme_up",
|
| 19 |
+
]
|
| 20 |
+
MOVEMENT_CLASS_TO_ID = {name: idx for idx, name in enumerate(MOVEMENT_CLASS_NAMES)}
|
| 21 |
+
MOVEMENT_ID_TO_CLASS = {idx: name for name, idx in MOVEMENT_CLASS_TO_ID.items()}
|
| 22 |
+
|
| 23 |
+
DEFAULT_MOVEMENT_LABEL_CONFIG = {
|
| 24 |
+
"strong_down_threshold": MOVEMENT_STRONG_DOWN_THRESHOLD,
|
| 25 |
+
"down_threshold": MOVEMENT_DOWN_THRESHOLD,
|
| 26 |
+
"pump_50_threshold": MOVEMENT_PUMP_50_THRESHOLD,
|
| 27 |
+
"pump_100_threshold": MOVEMENT_PUMP_100_THRESHOLD,
|
| 28 |
+
"pump_300_threshold": MOVEMENT_PUMP_300_THRESHOLD,
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def classify_movement_return(
|
| 33 |
+
return_value: float,
|
| 34 |
+
movement_label_config: Dict[str, float] | None = None,
|
| 35 |
+
) -> int:
|
| 36 |
+
cfg = dict(DEFAULT_MOVEMENT_LABEL_CONFIG)
|
| 37 |
+
if movement_label_config:
|
| 38 |
+
cfg.update({k: float(v) for k, v in movement_label_config.items() if k in cfg})
|
| 39 |
+
|
| 40 |
+
strong_down_threshold = min(cfg["strong_down_threshold"], cfg["down_threshold"])
|
| 41 |
+
down_threshold = cfg["down_threshold"]
|
| 42 |
+
pump_50_threshold = cfg["pump_50_threshold"]
|
| 43 |
+
pump_100_threshold = cfg["pump_100_threshold"]
|
| 44 |
+
pump_300_threshold = cfg["pump_300_threshold"]
|
| 45 |
+
|
| 46 |
+
if return_value <= strong_down_threshold:
|
| 47 |
+
return MOVEMENT_CLASS_TO_ID["strong_down"]
|
| 48 |
+
if return_value < down_threshold:
|
| 49 |
+
return MOVEMENT_CLASS_TO_ID["down"]
|
| 50 |
+
if return_value < pump_50_threshold:
|
| 51 |
+
return MOVEMENT_CLASS_TO_ID["flat"]
|
| 52 |
+
if return_value < pump_100_threshold:
|
| 53 |
+
return MOVEMENT_CLASS_TO_ID["up"]
|
| 54 |
+
if return_value < pump_300_threshold:
|
| 55 |
+
return MOVEMENT_CLASS_TO_ID["strong_up"]
|
| 56 |
+
return MOVEMENT_CLASS_TO_ID["extreme_up"]
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def derive_movement_targets(
|
| 60 |
+
horizon_returns: Sequence[float],
|
| 61 |
+
horizon_mask: Sequence[float],
|
| 62 |
+
movement_label_config: Dict[str, float] | None = None,
|
| 63 |
+
) -> Dict[str, List[int]]:
|
| 64 |
+
class_targets: List[int] = []
|
| 65 |
+
class_mask: List[int] = []
|
| 66 |
+
class_names: List[str] = []
|
| 67 |
+
|
| 68 |
+
usable = min(len(horizon_returns), len(horizon_mask))
|
| 69 |
+
for idx in range(usable):
|
| 70 |
+
if float(horizon_mask[idx]) <= 0:
|
| 71 |
+
class_targets.append(MOVEMENT_CLASS_TO_ID["flat"])
|
| 72 |
+
class_mask.append(0)
|
| 73 |
+
class_names.append("masked")
|
| 74 |
+
continue
|
| 75 |
+
|
| 76 |
+
class_id = classify_movement_return(
|
| 77 |
+
float(horizon_returns[idx]),
|
| 78 |
+
movement_label_config=movement_label_config,
|
| 79 |
+
)
|
| 80 |
+
class_targets.append(class_id)
|
| 81 |
+
class_mask.append(1)
|
| 82 |
+
class_names.append(MOVEMENT_ID_TO_CLASS[class_id])
|
| 83 |
+
|
| 84 |
+
return {
|
| 85 |
+
"movement_class_targets": class_targets,
|
| 86 |
+
"movement_class_mask": class_mask,
|
| 87 |
+
"movement_class_names": class_names,
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def compute_movement_label_config(valid_returns: Sequence[float]) -> Dict[str, float]:
|
| 92 |
+
del valid_returns
|
| 93 |
+
return dict(DEFAULT_MOVEMENT_LABEL_CONFIG)
|
data/data_collator.py
CHANGED
|
@@ -719,6 +719,8 @@ class MemecoinCollator:
|
|
| 719 |
# Labels
|
| 720 |
'labels': torch.stack([item['labels'] for item in batch]) if batch and 'labels' in batch[0] else None,
|
| 721 |
'labels_mask': torch.stack([item['labels_mask'] for item in batch]) if batch and 'labels_mask' in batch[0] else None,
|
|
|
|
|
|
|
| 722 |
'quality_score': torch.stack([item['quality_score'] if isinstance(item['quality_score'], torch.Tensor) else torch.tensor(item['quality_score'], dtype=torch.float32) for item in batch]) if batch and 'quality_score' in batch[0] else None,
|
| 723 |
'class_id': torch.tensor([item.get('class_id', 0) for item in batch], dtype=torch.long),
|
| 724 |
# Debug info
|
|
|
|
| 719 |
# Labels
|
| 720 |
'labels': torch.stack([item['labels'] for item in batch]) if batch and 'labels' in batch[0] else None,
|
| 721 |
'labels_mask': torch.stack([item['labels_mask'] for item in batch]) if batch and 'labels_mask' in batch[0] else None,
|
| 722 |
+
'movement_class_targets': torch.stack([item['movement_class_targets'] for item in batch]) if batch and 'movement_class_targets' in batch[0] else None,
|
| 723 |
+
'movement_class_mask': torch.stack([item['movement_class_mask'] for item in batch]) if batch and 'movement_class_mask' in batch[0] else None,
|
| 724 |
'quality_score': torch.stack([item['quality_score'] if isinstance(item['quality_score'], torch.Tensor) else torch.tensor(item['quality_score'], dtype=torch.float32) for item in batch]) if batch and 'quality_score' in batch[0] else None,
|
| 725 |
'class_id': torch.tensor([item.get('class_id', 0) for item in batch], dtype=torch.long),
|
| 726 |
# Debug info
|
data/data_loader.py
CHANGED
|
@@ -17,6 +17,7 @@ import json
|
|
| 17 |
import models.vocabulary as vocab
|
| 18 |
from models.multi_modal_processor import MultiModalEncoder
|
| 19 |
from data.data_fetcher import DataFetcher # NEW: Import the DataFetcher
|
|
|
|
| 20 |
from requests.adapters import HTTPAdapter
|
| 21 |
from urllib3.util.retry import Retry
|
| 22 |
|
|
@@ -128,7 +129,8 @@ class OracleDataset(Dataset):
|
|
| 128 |
start_date: Optional[datetime.datetime] = None,
|
| 129 |
min_trade_usd: float = 0.0,
|
| 130 |
max_seq_len: int = 8192,
|
| 131 |
-
p99_clamps: Optional[Dict[str, float]] = None
|
|
|
|
| 132 |
|
| 133 |
self.max_seq_len = max_seq_len
|
| 134 |
|
|
@@ -315,6 +317,7 @@ class OracleDataset(Dataset):
|
|
| 315 |
|
| 316 |
self.min_trade_usd = min_trade_usd
|
| 317 |
self._uri_fail_counts: Dict[str, int] = {}
|
|
|
|
| 318 |
|
| 319 |
def _init_http_session(self) -> None:
|
| 320 |
# Configure robust HTTP session
|
|
@@ -1199,6 +1202,23 @@ class OracleDataset(Dataset):
|
|
| 1199 |
# This is fully deterministic - no runtime sampling or processing
|
| 1200 |
_timings['total'] = _time.perf_counter() - _total_start
|
| 1201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1202 |
if idx % 100 == 0:
|
| 1203 |
print(f"[Sample {idx}] CONTEXT mode | cache_load: {_timings['cache_load']*1000:.1f}ms | "
|
| 1204 |
f"total: {_timings['total']*1000:.1f}ms | events: {len(cached_data.get('event_sequence', []))}")
|
|
@@ -2449,6 +2469,11 @@ class OracleDataset(Dataset):
|
|
| 2449 |
|
| 2450 |
if not all_trades:
|
| 2451 |
# No valid trades for label computation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2452 |
return {
|
| 2453 |
'event_sequence': event_sequence,
|
| 2454 |
'wallets': wallet_data,
|
|
@@ -2457,7 +2482,9 @@ class OracleDataset(Dataset):
|
|
| 2457 |
'embedding_pooler': pooler,
|
| 2458 |
'labels': torch.zeros(len(self.horizons_seconds), dtype=torch.float32),
|
| 2459 |
'labels_mask': torch.zeros(len(self.horizons_seconds), dtype=torch.float32),
|
| 2460 |
-
'quality_score': torch.tensor(quality_score if quality_score is not None else 0.0, dtype=torch.float32)
|
|
|
|
|
|
|
| 2461 |
}
|
| 2462 |
|
| 2463 |
# Ensure sorted
|
|
@@ -2537,6 +2564,12 @@ class OracleDataset(Dataset):
|
|
| 2537 |
|
| 2538 |
# DEBUG: Mask summaries removed after validation
|
| 2539 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2540 |
return {
|
| 2541 |
'sample_idx': sample_idx if sample_idx is not None else -1, # Debug trace
|
| 2542 |
'token_address': token_address, # For debugging
|
|
@@ -2548,7 +2581,9 @@ class OracleDataset(Dataset):
|
|
| 2548 |
'embedding_pooler': pooler,
|
| 2549 |
'labels': torch.tensor(label_values, dtype=torch.float32),
|
| 2550 |
'labels_mask': torch.tensor(mask_values, dtype=torch.float32),
|
| 2551 |
-
'quality_score': torch.tensor(quality_score if quality_score is not None else 0.0, dtype=torch.float32)
|
|
|
|
|
|
|
| 2552 |
}
|
| 2553 |
|
| 2554 |
def _embed_context(self, context: Dict[str, Any], encoder: Any) -> None:
|
|
|
|
| 17 |
import models.vocabulary as vocab
|
| 18 |
from models.multi_modal_processor import MultiModalEncoder
|
| 19 |
from data.data_fetcher import DataFetcher # NEW: Import the DataFetcher
|
| 20 |
+
from data.context_targets import derive_movement_targets
|
| 21 |
from requests.adapters import HTTPAdapter
|
| 22 |
from urllib3.util.retry import Retry
|
| 23 |
|
|
|
|
| 129 |
start_date: Optional[datetime.datetime] = None,
|
| 130 |
min_trade_usd: float = 0.0,
|
| 131 |
max_seq_len: int = 8192,
|
| 132 |
+
p99_clamps: Optional[Dict[str, float]] = None,
|
| 133 |
+
movement_label_config: Optional[Dict[str, float]] = None):
|
| 134 |
|
| 135 |
self.max_seq_len = max_seq_len
|
| 136 |
|
|
|
|
| 317 |
|
| 318 |
self.min_trade_usd = min_trade_usd
|
| 319 |
self._uri_fail_counts: Dict[str, int] = {}
|
| 320 |
+
self.movement_label_config = movement_label_config
|
| 321 |
|
| 322 |
def _init_http_session(self) -> None:
|
| 323 |
# Configure robust HTTP session
|
|
|
|
| 1202 |
# This is fully deterministic - no runtime sampling or processing
|
| 1203 |
_timings['total'] = _time.perf_counter() - _total_start
|
| 1204 |
|
| 1205 |
+
if 'movement_class_targets' not in cached_data and 'labels' in cached_data and 'labels_mask' in cached_data:
|
| 1206 |
+
labels = cached_data['labels']
|
| 1207 |
+
labels_mask = cached_data['labels_mask']
|
| 1208 |
+
movement_targets = derive_movement_targets(
|
| 1209 |
+
labels.tolist() if isinstance(labels, torch.Tensor) else labels,
|
| 1210 |
+
labels_mask.tolist() if isinstance(labels_mask, torch.Tensor) else labels_mask,
|
| 1211 |
+
movement_label_config=self.movement_label_config,
|
| 1212 |
+
)
|
| 1213 |
+
cached_data['movement_class_targets'] = torch.tensor(
|
| 1214 |
+
movement_targets['movement_class_targets'],
|
| 1215 |
+
dtype=torch.long,
|
| 1216 |
+
)
|
| 1217 |
+
cached_data['movement_class_mask'] = torch.tensor(
|
| 1218 |
+
movement_targets['movement_class_mask'],
|
| 1219 |
+
dtype=torch.long,
|
| 1220 |
+
)
|
| 1221 |
+
|
| 1222 |
if idx % 100 == 0:
|
| 1223 |
print(f"[Sample {idx}] CONTEXT mode | cache_load: {_timings['cache_load']*1000:.1f}ms | "
|
| 1224 |
f"total: {_timings['total']*1000:.1f}ms | events: {len(cached_data.get('event_sequence', []))}")
|
|
|
|
| 2469 |
|
| 2470 |
if not all_trades:
|
| 2471 |
# No valid trades for label computation
|
| 2472 |
+
movement_targets = derive_movement_targets(
|
| 2473 |
+
[0.0] * len(self.horizons_seconds),
|
| 2474 |
+
[0.0] * len(self.horizons_seconds),
|
| 2475 |
+
movement_label_config=self.movement_label_config,
|
| 2476 |
+
)
|
| 2477 |
return {
|
| 2478 |
'event_sequence': event_sequence,
|
| 2479 |
'wallets': wallet_data,
|
|
|
|
| 2482 |
'embedding_pooler': pooler,
|
| 2483 |
'labels': torch.zeros(len(self.horizons_seconds), dtype=torch.float32),
|
| 2484 |
'labels_mask': torch.zeros(len(self.horizons_seconds), dtype=torch.float32),
|
| 2485 |
+
'quality_score': torch.tensor(quality_score if quality_score is not None else 0.0, dtype=torch.float32),
|
| 2486 |
+
'movement_class_targets': torch.tensor(movement_targets['movement_class_targets'], dtype=torch.long),
|
| 2487 |
+
'movement_class_mask': torch.tensor(movement_targets['movement_class_mask'], dtype=torch.long),
|
| 2488 |
}
|
| 2489 |
|
| 2490 |
# Ensure sorted
|
|
|
|
| 2564 |
|
| 2565 |
# DEBUG: Mask summaries removed after validation
|
| 2566 |
|
| 2567 |
+
movement_targets = derive_movement_targets(
|
| 2568 |
+
label_values,
|
| 2569 |
+
mask_values,
|
| 2570 |
+
movement_label_config=self.movement_label_config,
|
| 2571 |
+
)
|
| 2572 |
+
|
| 2573 |
return {
|
| 2574 |
'sample_idx': sample_idx if sample_idx is not None else -1, # Debug trace
|
| 2575 |
'token_address': token_address, # For debugging
|
|
|
|
| 2581 |
'embedding_pooler': pooler,
|
| 2582 |
'labels': torch.tensor(label_values, dtype=torch.float32),
|
| 2583 |
'labels_mask': torch.tensor(mask_values, dtype=torch.float32),
|
| 2584 |
+
'quality_score': torch.tensor(quality_score if quality_score is not None else 0.0, dtype=torch.float32),
|
| 2585 |
+
'movement_class_targets': torch.tensor(movement_targets['movement_class_targets'], dtype=torch.long),
|
| 2586 |
+
'movement_class_mask': torch.tensor(movement_targets['movement_class_mask'], dtype=torch.long),
|
| 2587 |
}
|
| 2588 |
|
| 2589 |
def _embed_context(self, context: Dict[str, Any], encoder: Any) -> None:
|
log.log
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:935233e4d7669b2a25173d7ae164317e85f1a5e8b0fc1d8d1832ab0893fca471
|
| 3 |
+
size 19258
|
models/model.py
CHANGED
|
@@ -17,6 +17,7 @@ from models.ohlc_embedder import OHLCEmbedder
|
|
| 17 |
from models.HoldersEncoder import HolderDistributionEncoder # NEW
|
| 18 |
from models.SocialEncoders import SocialEncoder # NEW
|
| 19 |
import models.vocabulary as vocab # For vocab sizes
|
|
|
|
| 20 |
|
| 21 |
class Oracle(nn.Module):
|
| 22 |
"""
|
|
@@ -51,6 +52,7 @@ class Oracle(nn.Module):
|
|
| 51 |
self.quantiles = quantiles
|
| 52 |
self.horizons_seconds = horizons_seconds
|
| 53 |
self.num_outputs = len(quantiles) * len(horizons_seconds)
|
|
|
|
| 54 |
self.dtype = dtype
|
| 55 |
|
| 56 |
# --- 2. Backbone: Llama-style decoder, RANDOM INIT (no pretrained weights) ---
|
|
@@ -103,6 +105,11 @@ class Oracle(nn.Module):
|
|
| 103 |
nn.GELU(),
|
| 104 |
nn.Linear(self.d_model, 1)
|
| 105 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
self.event_type_to_id = event_type_to_id
|
| 108 |
|
|
@@ -1008,9 +1015,11 @@ class Oracle(nn.Module):
|
|
| 1008 |
empty_mask = torch.empty(0, L, device=device, dtype=torch.long)
|
| 1009 |
empty_quantiles = torch.empty(0, self.num_outputs, device=device, dtype=self.dtype)
|
| 1010 |
empty_quality = torch.empty(0, device=device, dtype=self.dtype)
|
|
|
|
| 1011 |
return {
|
| 1012 |
'quantile_logits': empty_quantiles,
|
| 1013 |
'quality_logits': empty_quality,
|
|
|
|
| 1014 |
'pooled_states': torch.empty(0, self.d_model, device=device, dtype=self.dtype),
|
| 1015 |
'hidden_states': empty_hidden,
|
| 1016 |
'attention_mask': empty_mask
|
|
@@ -1131,10 +1140,16 @@ class Oracle(nn.Module):
|
|
| 1131 |
pooled_states = self._pool_hidden_states(sequence_hidden, hf_attention_mask)
|
| 1132 |
quantile_logits = self.quantile_head(pooled_states)
|
| 1133 |
quality_logits = self.quality_head(pooled_states).squeeze(-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1134 |
|
| 1135 |
return {
|
| 1136 |
'quantile_logits': quantile_logits,
|
| 1137 |
'quality_logits': quality_logits,
|
|
|
|
| 1138 |
'pooled_states': pooled_states,
|
| 1139 |
'hidden_states': sequence_hidden,
|
| 1140 |
'attention_mask': hf_attention_mask
|
|
|
|
| 17 |
from models.HoldersEncoder import HolderDistributionEncoder # NEW
|
| 18 |
from models.SocialEncoders import SocialEncoder # NEW
|
| 19 |
import models.vocabulary as vocab # For vocab sizes
|
| 20 |
+
from data.context_targets import MOVEMENT_CLASS_NAMES
|
| 21 |
|
| 22 |
class Oracle(nn.Module):
|
| 23 |
"""
|
|
|
|
| 52 |
self.quantiles = quantiles
|
| 53 |
self.horizons_seconds = horizons_seconds
|
| 54 |
self.num_outputs = len(quantiles) * len(horizons_seconds)
|
| 55 |
+
self.num_movement_classes = len(MOVEMENT_CLASS_NAMES)
|
| 56 |
self.dtype = dtype
|
| 57 |
|
| 58 |
# --- 2. Backbone: Llama-style decoder, RANDOM INIT (no pretrained weights) ---
|
|
|
|
| 105 |
nn.GELU(),
|
| 106 |
nn.Linear(self.d_model, 1)
|
| 107 |
)
|
| 108 |
+
self.movement_head = nn.Sequential(
|
| 109 |
+
nn.Linear(self.d_model, self.d_model),
|
| 110 |
+
nn.GELU(),
|
| 111 |
+
nn.Linear(self.d_model, len(self.horizons_seconds) * self.num_movement_classes)
|
| 112 |
+
)
|
| 113 |
|
| 114 |
self.event_type_to_id = event_type_to_id
|
| 115 |
|
|
|
|
| 1015 |
empty_mask = torch.empty(0, L, device=device, dtype=torch.long)
|
| 1016 |
empty_quantiles = torch.empty(0, self.num_outputs, device=device, dtype=self.dtype)
|
| 1017 |
empty_quality = torch.empty(0, device=device, dtype=self.dtype)
|
| 1018 |
+
empty_movement = torch.empty(0, len(self.horizons_seconds), self.num_movement_classes, device=device, dtype=self.dtype)
|
| 1019 |
return {
|
| 1020 |
'quantile_logits': empty_quantiles,
|
| 1021 |
'quality_logits': empty_quality,
|
| 1022 |
+
'movement_logits': empty_movement,
|
| 1023 |
'pooled_states': torch.empty(0, self.d_model, device=device, dtype=self.dtype),
|
| 1024 |
'hidden_states': empty_hidden,
|
| 1025 |
'attention_mask': empty_mask
|
|
|
|
| 1140 |
pooled_states = self._pool_hidden_states(sequence_hidden, hf_attention_mask)
|
| 1141 |
quantile_logits = self.quantile_head(pooled_states)
|
| 1142 |
quality_logits = self.quality_head(pooled_states).squeeze(-1)
|
| 1143 |
+
movement_logits = self.movement_head(pooled_states).view(
|
| 1144 |
+
pooled_states.shape[0],
|
| 1145 |
+
len(self.horizons_seconds),
|
| 1146 |
+
self.num_movement_classes,
|
| 1147 |
+
)
|
| 1148 |
|
| 1149 |
return {
|
| 1150 |
'quantile_logits': quantile_logits,
|
| 1151 |
'quality_logits': quality_logits,
|
| 1152 |
+
'movement_logits': movement_logits,
|
| 1153 |
'pooled_states': pooled_states,
|
| 1154 |
'hidden_states': sequence_hidden,
|
| 1155 |
'attention_mask': hf_attention_mask
|
pre_cache.sh
CHANGED
|
@@ -39,7 +39,7 @@ python3 scripts/cache_dataset.py \
|
|
| 39 |
--num_workers "$NUM_WORKERS" \
|
| 40 |
--horizons_seconds "${HORIZONS_SECONDS[@]}" \
|
| 41 |
--quantiles "${QUANTILES[@]}" \
|
| 42 |
-
--max_samples
|
| 43 |
"$@"
|
| 44 |
|
| 45 |
echo "Done!"
|
|
|
|
| 39 |
--num_workers "$NUM_WORKERS" \
|
| 40 |
--horizons_seconds "${HORIZONS_SECONDS[@]}" \
|
| 41 |
--quantiles "${QUANTILES[@]}" \
|
| 42 |
+
--max_samples 10 \
|
| 43 |
"$@"
|
| 44 |
|
| 45 |
echo "Done!"
|
scripts/cache_dataset.py
CHANGED
|
@@ -27,13 +27,154 @@ from scripts.compute_quality_score import get_token_quality_scores, fetch_token_
|
|
| 27 |
|
| 28 |
from clickhouse_driver import Client as ClickHouseClient
|
| 29 |
from neo4j import GraphDatabase
|
| 30 |
-
|
| 31 |
_worker_dataset = None
|
| 32 |
_worker_return_class_map = None
|
| 33 |
_worker_quality_scores_map = None
|
| 34 |
_worker_encoder = None
|
| 35 |
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map):
|
| 38 |
global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map
|
| 39 |
from data.data_loader import OracleDataset
|
|
@@ -73,7 +214,7 @@ def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map
|
|
| 73 |
|
| 74 |
|
| 75 |
def _process_single_token_context(args):
|
| 76 |
-
idx, mint_addr, samples_per_token, output_dir = args
|
| 77 |
global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map
|
| 78 |
try:
|
| 79 |
class_id = _worker_return_class_map.get(mint_addr)
|
|
@@ -87,7 +228,17 @@ def _process_single_token_context(args):
|
|
| 87 |
if encoder is None:
|
| 88 |
print(f"ERROR: Worker encoder is None for mint {mint_addr}!", flush=True)
|
| 89 |
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
if not contexts:
|
| 92 |
return {'status': 'skipped', 'reason': 'no valid contexts', 'mint': mint_addr}
|
| 93 |
q_score = _worker_quality_scores_map.get(mint_addr)
|
|
@@ -102,7 +253,16 @@ def _process_single_token_context(args):
|
|
| 102 |
|
| 103 |
torch.save(ctx, output_path)
|
| 104 |
saved_files.append(filename)
|
| 105 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
except Exception as e:
|
| 107 |
import traceback
|
| 108 |
return {'status': 'error', 'mint': mint_addr, 'error': str(e), 'traceback': traceback.format_exc()}
|
|
@@ -132,6 +292,10 @@ def main():
|
|
| 132 |
parser.add_argument("--context_length", type=int, default=8192)
|
| 133 |
parser.add_argument("--min_trades", type=int, default=10)
|
| 134 |
parser.add_argument("--samples_per_token", type=int, default=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
parser.add_argument("--horizons_seconds", type=int, nargs="+", default=[30, 60, 120, 240, 420])
|
| 136 |
parser.add_argument("--quantiles", type=float, nargs="+", default=[0.1, 0.5, 0.9])
|
| 137 |
parser.add_argument("--num_workers", type=int, default=1)
|
|
@@ -223,7 +387,6 @@ def main():
|
|
| 223 |
|
| 224 |
print(f"INFO: Eligible tokens per class: {dict(sorted(eligible_class_counts.items()))}")
|
| 225 |
|
| 226 |
-
# Compute balanced samples_per_token for each class
|
| 227 |
num_classes = len(eligible_class_counts)
|
| 228 |
if args.max_samples:
|
| 229 |
target_total = args.max_samples
|
|
@@ -231,44 +394,46 @@ def main():
|
|
| 231 |
target_total = 15000 # Default target: 15k balanced files
|
| 232 |
target_per_class = target_total // max(num_classes, 1)
|
| 233 |
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
class_token_caps[cid] = target_per_class
|
| 241 |
-
else:
|
| 242 |
-
# Not enough tokens — multi-sample, use all tokens
|
| 243 |
-
class_multipliers[cid] = min(10, max(1, math.ceil(target_per_class / max(count, 1))))
|
| 244 |
-
class_token_caps[cid] = count
|
| 245 |
|
| 246 |
print(f"INFO: Target total: {target_total}, Target per class: {target_per_class}")
|
| 247 |
-
print(f"INFO:
|
| 248 |
-
print(f"INFO: Class
|
| 249 |
|
| 250 |
# Build balanced task list
|
| 251 |
tasks = []
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
|
| 263 |
random.shuffle(tasks) # Shuffle tasks for even load distribution across workers
|
| 264 |
-
expected_files = sum(
|
| 265 |
-
class_multipliers.get(cid, 1) * min(class_token_caps.get(cid, len(ml)), len(ml))
|
| 266 |
-
for cid, ml in mints_by_class.items()
|
| 267 |
-
)
|
| 268 |
print(f"INFO: Total tasks: {len(tasks)} (expected ~{expected_files} output files, target ~{target_total})")
|
| 269 |
|
| 270 |
success_count, skipped_count, error_count = 0, 0, 0
|
| 271 |
class_distribution = {}
|
|
|
|
| 272 |
|
| 273 |
# --- Resume support: skip tokens that already have cached files ---
|
| 274 |
existing_files = set(f.name for f in output_dir.glob("sample_*.pt"))
|
|
@@ -334,6 +499,8 @@ def main():
|
|
| 334 |
if result['status'] == 'success':
|
| 335 |
success_count += 1
|
| 336 |
class_distribution[result['class_id']] = class_distribution.get(result['class_id'], 0) + 1
|
|
|
|
|
|
|
| 337 |
elif result['status'] == 'skipped':
|
| 338 |
skipped_count += 1
|
| 339 |
else:
|
|
@@ -360,6 +527,8 @@ def main():
|
|
| 360 |
if result['status'] == 'success':
|
| 361 |
success_count += 1
|
| 362 |
class_distribution[result['class_id']] = class_distribution.get(result['class_id'], 0) + 1
|
|
|
|
|
|
|
| 363 |
elif result['status'] == 'skipped':
|
| 364 |
skipped_count += 1
|
| 365 |
else:
|
|
@@ -398,10 +567,14 @@ def main():
|
|
| 398 |
'num_workers': args.num_workers,
|
| 399 |
'horizons_seconds': args.horizons_seconds,
|
| 400 |
'quantiles': args.quantiles,
|
| 401 |
-
'class_multipliers': {str(k): v for k, v in class_multipliers.items()},
|
| 402 |
-
'class_token_caps': {str(k): v for k, v in class_token_caps.items()},
|
| 403 |
'target_total': target_total,
|
| 404 |
'target_per_class': target_per_class,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
}, f, indent=2)
|
| 406 |
|
| 407 |
print(f"\n--- Done ---\nSuccess: {success_count}, Skipped: {skipped_count}, Errors: {error_count}\nFiles: {len(file_class_map)}\nLocation: {output_dir.resolve()}")
|
|
|
|
| 27 |
|
| 28 |
from clickhouse_driver import Client as ClickHouseClient
|
| 29 |
from neo4j import GraphDatabase
|
|
|
|
| 30 |
_worker_dataset = None
|
| 31 |
_worker_return_class_map = None
|
| 32 |
_worker_quality_scores_map = None
|
| 33 |
_worker_encoder = None
|
| 34 |
|
| 35 |
|
| 36 |
+
def _to_int_list(values):
|
| 37 |
+
if values is None:
|
| 38 |
+
return []
|
| 39 |
+
if isinstance(values, torch.Tensor):
|
| 40 |
+
return [int(v) for v in values.tolist()]
|
| 41 |
+
return [int(v) for v in values]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _to_float_list(values):
|
| 45 |
+
if values is None:
|
| 46 |
+
return []
|
| 47 |
+
if isinstance(values, torch.Tensor):
|
| 48 |
+
return [float(v) for v in values.tolist()]
|
| 49 |
+
return [float(v) for v in values]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _representative_context_polarity(context):
|
| 53 |
+
labels = _to_float_list(context.get("labels"))
|
| 54 |
+
mask = _to_int_list(context.get("labels_mask"))
|
| 55 |
+
valid_returns = [label for label, keep in zip(labels, mask) if keep > 0]
|
| 56 |
+
if not valid_returns:
|
| 57 |
+
return "negative"
|
| 58 |
+
return "positive" if max(valid_returns) > 0.0 else "negative"
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _select_contexts_by_polarity(contexts, max_keep, desired_positive=None, desired_negative=None):
|
| 62 |
+
if len(contexts) <= max_keep:
|
| 63 |
+
polarity_counts = {}
|
| 64 |
+
for context in contexts:
|
| 65 |
+
polarity = _representative_context_polarity(context)
|
| 66 |
+
polarity_counts[polarity] = polarity_counts.get(polarity, 0) + 1
|
| 67 |
+
context["representative_context_polarity"] = polarity
|
| 68 |
+
return contexts, polarity_counts
|
| 69 |
+
|
| 70 |
+
positive_bucket = []
|
| 71 |
+
negative_bucket = []
|
| 72 |
+
for context in contexts:
|
| 73 |
+
polarity = _representative_context_polarity(context)
|
| 74 |
+
context["representative_context_polarity"] = polarity
|
| 75 |
+
if polarity == "positive":
|
| 76 |
+
positive_bucket.append(context)
|
| 77 |
+
else:
|
| 78 |
+
negative_bucket.append(context)
|
| 79 |
+
|
| 80 |
+
selected = []
|
| 81 |
+
polarity_counts = {"positive": 0, "negative": 0}
|
| 82 |
+
desired_positive = max(0, int(desired_positive)) if desired_positive is not None else None
|
| 83 |
+
desired_negative = max(0, int(desired_negative)) if desired_negative is not None else None
|
| 84 |
+
|
| 85 |
+
if desired_positive is not None or desired_negative is not None:
|
| 86 |
+
target_positive = min(desired_positive or 0, max_keep, len(positive_bucket))
|
| 87 |
+
target_negative = min(desired_negative or 0, max_keep - target_positive, len(negative_bucket))
|
| 88 |
+
|
| 89 |
+
while polarity_counts["positive"] < target_positive and positive_bucket:
|
| 90 |
+
selected.append(positive_bucket.pop())
|
| 91 |
+
polarity_counts["positive"] += 1
|
| 92 |
+
while polarity_counts["negative"] < target_negative and negative_bucket:
|
| 93 |
+
selected.append(negative_bucket.pop())
|
| 94 |
+
polarity_counts["negative"] += 1
|
| 95 |
+
|
| 96 |
+
prefer_positive = len(positive_bucket) >= len(negative_bucket)
|
| 97 |
+
|
| 98 |
+
while len(selected) < max_keep and (positive_bucket or negative_bucket):
|
| 99 |
+
if prefer_positive and positive_bucket:
|
| 100 |
+
selected.append(positive_bucket.pop())
|
| 101 |
+
polarity_counts["positive"] += 1
|
| 102 |
+
elif not prefer_positive and negative_bucket:
|
| 103 |
+
selected.append(negative_bucket.pop())
|
| 104 |
+
polarity_counts["negative"] += 1
|
| 105 |
+
elif positive_bucket:
|
| 106 |
+
selected.append(positive_bucket.pop())
|
| 107 |
+
polarity_counts["positive"] += 1
|
| 108 |
+
elif negative_bucket:
|
| 109 |
+
selected.append(negative_bucket.pop())
|
| 110 |
+
polarity_counts["negative"] += 1
|
| 111 |
+
prefer_positive = not prefer_positive
|
| 112 |
+
|
| 113 |
+
return selected[:max_keep], polarity_counts
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _allocate_class_targets(mints_by_class, target_total, positive_balance_min_class, positive_ratio):
|
| 117 |
+
from collections import defaultdict
|
| 118 |
+
import random
|
| 119 |
+
|
| 120 |
+
class_ids = sorted(mints_by_class.keys())
|
| 121 |
+
if not class_ids:
|
| 122 |
+
return {}, {}, {}
|
| 123 |
+
|
| 124 |
+
target_per_class = target_total // len(class_ids)
|
| 125 |
+
remainder = target_total % len(class_ids)
|
| 126 |
+
|
| 127 |
+
token_plans = {}
|
| 128 |
+
class_targets = {}
|
| 129 |
+
class_polarity_targets = {}
|
| 130 |
+
|
| 131 |
+
for pos, class_id in enumerate(class_ids):
|
| 132 |
+
class_target = target_per_class + (1 if pos < remainder else 0)
|
| 133 |
+
class_targets[class_id] = class_target
|
| 134 |
+
|
| 135 |
+
token_list = list(mints_by_class[class_id])
|
| 136 |
+
random.shuffle(token_list)
|
| 137 |
+
if not token_list or class_target <= 0:
|
| 138 |
+
class_polarity_targets[class_id] = {"positive": 0, "negative": 0}
|
| 139 |
+
continue
|
| 140 |
+
|
| 141 |
+
if class_id >= positive_balance_min_class:
|
| 142 |
+
positive_target = int(round(class_target * positive_ratio))
|
| 143 |
+
positive_target = min(max(positive_target, 0), class_target)
|
| 144 |
+
else:
|
| 145 |
+
positive_target = 0
|
| 146 |
+
negative_target = class_target - positive_target
|
| 147 |
+
class_polarity_targets[class_id] = {
|
| 148 |
+
"positive": positive_target,
|
| 149 |
+
"negative": negative_target,
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
assigned_positive = 0
|
| 153 |
+
assigned_negative = 0
|
| 154 |
+
token_count = len(token_list)
|
| 155 |
+
for sample_num in range(class_target):
|
| 156 |
+
token_idx, mint_record = token_list[sample_num % token_count]
|
| 157 |
+
mint_addr = mint_record["mint_address"]
|
| 158 |
+
plan_key = (token_idx, mint_addr)
|
| 159 |
+
if plan_key not in token_plans:
|
| 160 |
+
token_plans[plan_key] = {
|
| 161 |
+
"samples_to_keep": 0,
|
| 162 |
+
"desired_positive": 0,
|
| 163 |
+
"desired_negative": 0,
|
| 164 |
+
"class_id": class_id,
|
| 165 |
+
}
|
| 166 |
+
token_plans[plan_key]["samples_to_keep"] += 1
|
| 167 |
+
|
| 168 |
+
if assigned_positive < positive_target:
|
| 169 |
+
token_plans[plan_key]["desired_positive"] += 1
|
| 170 |
+
assigned_positive += 1
|
| 171 |
+
else:
|
| 172 |
+
token_plans[plan_key]["desired_negative"] += 1
|
| 173 |
+
assigned_negative += 1
|
| 174 |
+
|
| 175 |
+
return token_plans, class_targets, class_polarity_targets
|
| 176 |
+
|
| 177 |
+
|
| 178 |
def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map):
|
| 179 |
global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map
|
| 180 |
from data.data_loader import OracleDataset
|
|
|
|
| 214 |
|
| 215 |
|
| 216 |
def _process_single_token_context(args):
|
| 217 |
+
idx, mint_addr, samples_per_token, output_dir, oversample_factor, desired_positive, desired_negative = args
|
| 218 |
global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map
|
| 219 |
try:
|
| 220 |
class_id = _worker_return_class_map.get(mint_addr)
|
|
|
|
| 228 |
if encoder is None:
|
| 229 |
print(f"ERROR: Worker encoder is None for mint {mint_addr}!", flush=True)
|
| 230 |
|
| 231 |
+
candidate_contexts = _worker_dataset.__cacheitem_context__(
|
| 232 |
+
idx,
|
| 233 |
+
num_samples_per_token=max(samples_per_token, samples_per_token * max(1, oversample_factor)),
|
| 234 |
+
encoder=encoder,
|
| 235 |
+
)
|
| 236 |
+
contexts, polarity_counts = _select_contexts_by_polarity(
|
| 237 |
+
candidate_contexts,
|
| 238 |
+
samples_per_token,
|
| 239 |
+
desired_positive=desired_positive,
|
| 240 |
+
desired_negative=desired_negative,
|
| 241 |
+
)
|
| 242 |
if not contexts:
|
| 243 |
return {'status': 'skipped', 'reason': 'no valid contexts', 'mint': mint_addr}
|
| 244 |
q_score = _worker_quality_scores_map.get(mint_addr)
|
|
|
|
| 253 |
|
| 254 |
torch.save(ctx, output_path)
|
| 255 |
saved_files.append(filename)
|
| 256 |
+
return {
|
| 257 |
+
'status': 'success',
|
| 258 |
+
'mint': mint_addr,
|
| 259 |
+
'class_id': class_id,
|
| 260 |
+
'q_score': q_score,
|
| 261 |
+
'n_contexts': len(contexts),
|
| 262 |
+
'n_events': len(contexts[0].get('event_sequence', [])) if contexts else 0,
|
| 263 |
+
'files': saved_files,
|
| 264 |
+
'polarity_counts': polarity_counts,
|
| 265 |
+
}
|
| 266 |
except Exception as e:
|
| 267 |
import traceback
|
| 268 |
return {'status': 'error', 'mint': mint_addr, 'error': str(e), 'traceback': traceback.format_exc()}
|
|
|
|
| 292 |
parser.add_argument("--context_length", type=int, default=8192)
|
| 293 |
parser.add_argument("--min_trades", type=int, default=10)
|
| 294 |
parser.add_argument("--samples_per_token", type=int, default=1)
|
| 295 |
+
parser.add_argument("--context_oversample_factor", type=int, default=4)
|
| 296 |
+
parser.add_argument("--cache_balance_mode", type=str, default="hybrid", choices=["class", "uniform", "hybrid"])
|
| 297 |
+
parser.add_argument("--positive_balance_min_class", type=int, default=2)
|
| 298 |
+
parser.add_argument("--positive_context_ratio", type=float, default=0.5)
|
| 299 |
parser.add_argument("--horizons_seconds", type=int, nargs="+", default=[30, 60, 120, 240, 420])
|
| 300 |
parser.add_argument("--quantiles", type=float, nargs="+", default=[0.1, 0.5, 0.9])
|
| 301 |
parser.add_argument("--num_workers", type=int, default=1)
|
|
|
|
| 387 |
|
| 388 |
print(f"INFO: Eligible tokens per class: {dict(sorted(eligible_class_counts.items()))}")
|
| 389 |
|
|
|
|
| 390 |
num_classes = len(eligible_class_counts)
|
| 391 |
if args.max_samples:
|
| 392 |
target_total = args.max_samples
|
|
|
|
| 394 |
target_total = 15000 # Default target: 15k balanced files
|
| 395 |
target_per_class = target_total // max(num_classes, 1)
|
| 396 |
|
| 397 |
+
token_plans, class_targets, class_polarity_targets = _allocate_class_targets(
|
| 398 |
+
mints_by_class=mints_by_class,
|
| 399 |
+
target_total=target_total,
|
| 400 |
+
positive_balance_min_class=args.positive_balance_min_class,
|
| 401 |
+
positive_ratio=args.positive_context_ratio,
|
| 402 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
|
| 404 |
print(f"INFO: Target total: {target_total}, Target per class: {target_per_class}")
|
| 405 |
+
print(f"INFO: Exact class targets: {dict(sorted(class_targets.items()))}")
|
| 406 |
+
print(f"INFO: Class polarity targets: {dict(sorted(class_polarity_targets.items()))}")
|
| 407 |
|
| 408 |
# Build balanced task list
|
| 409 |
tasks = []
|
| 410 |
+
if args.cache_balance_mode == "uniform":
|
| 411 |
+
target_tokens = len(filtered_mints)
|
| 412 |
+
if args.max_samples:
|
| 413 |
+
target_tokens = min(len(filtered_mints), max(1, math.ceil(args.max_samples / max(args.samples_per_token, 1))))
|
| 414 |
+
mint_pool = list(enumerate(filtered_mints))
|
| 415 |
+
random.shuffle(mint_pool)
|
| 416 |
+
for i, m in mint_pool[:target_tokens]:
|
| 417 |
+
tasks.append((i, m['mint_address'], args.samples_per_token, str(output_dir), args.context_oversample_factor, 0, args.samples_per_token))
|
| 418 |
+
else:
|
| 419 |
+
for (token_idx, mint_addr), plan in token_plans.items():
|
| 420 |
+
tasks.append((
|
| 421 |
+
token_idx,
|
| 422 |
+
mint_addr,
|
| 423 |
+
plan["samples_to_keep"],
|
| 424 |
+
str(output_dir),
|
| 425 |
+
args.context_oversample_factor,
|
| 426 |
+
plan["desired_positive"],
|
| 427 |
+
plan["desired_negative"],
|
| 428 |
+
))
|
| 429 |
|
| 430 |
random.shuffle(tasks) # Shuffle tasks for even load distribution across workers
|
| 431 |
+
expected_files = sum(task[2] for task in tasks)
|
|
|
|
|
|
|
|
|
|
| 432 |
print(f"INFO: Total tasks: {len(tasks)} (expected ~{expected_files} output files, target ~{target_total})")
|
| 433 |
|
| 434 |
success_count, skipped_count, error_count = 0, 0, 0
|
| 435 |
class_distribution = {}
|
| 436 |
+
polarity_distribution = {}
|
| 437 |
|
| 438 |
# --- Resume support: skip tokens that already have cached files ---
|
| 439 |
existing_files = set(f.name for f in output_dir.glob("sample_*.pt"))
|
|
|
|
| 499 |
if result['status'] == 'success':
|
| 500 |
success_count += 1
|
| 501 |
class_distribution[result['class_id']] = class_distribution.get(result['class_id'], 0) + 1
|
| 502 |
+
for polarity, count in result.get('polarity_counts', {}).items():
|
| 503 |
+
polarity_distribution[polarity] = polarity_distribution.get(polarity, 0) + count
|
| 504 |
elif result['status'] == 'skipped':
|
| 505 |
skipped_count += 1
|
| 506 |
else:
|
|
|
|
| 527 |
if result['status'] == 'success':
|
| 528 |
success_count += 1
|
| 529 |
class_distribution[result['class_id']] = class_distribution.get(result['class_id'], 0) + 1
|
| 530 |
+
for polarity, count in result.get('polarity_counts', {}).items():
|
| 531 |
+
polarity_distribution[polarity] = polarity_distribution.get(polarity, 0) + count
|
| 532 |
elif result['status'] == 'skipped':
|
| 533 |
skipped_count += 1
|
| 534 |
else:
|
|
|
|
| 567 |
'num_workers': args.num_workers,
|
| 568 |
'horizons_seconds': args.horizons_seconds,
|
| 569 |
'quantiles': args.quantiles,
|
|
|
|
|
|
|
| 570 |
'target_total': target_total,
|
| 571 |
'target_per_class': target_per_class,
|
| 572 |
+
'cache_balance_mode': args.cache_balance_mode,
|
| 573 |
+
'context_polarity_distribution': polarity_distribution,
|
| 574 |
+
'class_targets': {str(k): v for k, v in class_targets.items()},
|
| 575 |
+
'class_polarity_targets': {str(k): v for k, v in class_polarity_targets.items()},
|
| 576 |
+
'positive_balance_min_class': args.positive_balance_min_class,
|
| 577 |
+
'positive_context_ratio': args.positive_context_ratio,
|
| 578 |
}, f, indent=2)
|
| 579 |
|
| 580 |
print(f"\n--- Done ---\nSuccess: {success_count}, Skipped: {skipped_count}, Errors: {error_count}\nFiles: {len(file_class_map)}\nLocation: {output_dir.resolve()}")
|
scripts/evaluate_sample.py
CHANGED
|
@@ -2,6 +2,8 @@ import os
|
|
| 2 |
import sys
|
| 3 |
import argparse
|
| 4 |
import random
|
|
|
|
|
|
|
| 5 |
import torch
|
| 6 |
from pathlib import Path
|
| 7 |
|
|
@@ -14,6 +16,7 @@ from torch.utils.data import DataLoader, Subset
|
|
| 14 |
|
| 15 |
from data.data_loader import OracleDataset
|
| 16 |
from data.data_collator import MemecoinCollator
|
|
|
|
| 17 |
from models.multi_modal_processor import MultiModalEncoder
|
| 18 |
from models.helper_encoders import ContextualTimeEncoder
|
| 19 |
from models.token_encoder import TokenEncoder
|
|
@@ -29,6 +32,25 @@ from neo4j import GraphDatabase
|
|
| 29 |
from data.data_fetcher import DataFetcher
|
| 30 |
from scripts.analyze_distribution import get_return_class_map
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
def unlog_transform(tensor):
|
| 33 |
"""Invert the log1p transform applied during training."""
|
| 34 |
# During training: labels = torch.sign(labels) * torch.log1p(torch.abs(labels))
|
|
@@ -42,10 +64,418 @@ def parse_args():
|
|
| 42 |
parser.add_argument("--horizons_seconds", type=int, nargs="+", default=[300, 900, 1800, 3600, 7200])
|
| 43 |
parser.add_argument("--quantiles", type=float, nargs="+", default=[0.1, 0.5, 0.9])
|
| 44 |
parser.add_argument("--seed", type=int, default=None)
|
| 45 |
-
parser.add_argument("--min_class", type=int, default=
|
| 46 |
-
parser.add_argument("--cutoff_trade_idx", type=int, default=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
return parser.parse_args()
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
def get_latest_checkpoint(checkpoint_dir):
|
| 50 |
ckpt_dir = Path(checkpoint_dir)
|
| 51 |
if ckpt_dir.exists():
|
|
@@ -59,6 +489,10 @@ def get_latest_checkpoint(checkpoint_dir):
|
|
| 59 |
def main():
|
| 60 |
load_dotenv()
|
| 61 |
args = parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
accelerator = Accelerator(mixed_precision=args.mixed_precision)
|
| 64 |
device = accelerator.device
|
|
@@ -186,134 +620,186 @@ def main():
|
|
| 186 |
|
| 187 |
model.eval()
|
| 188 |
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
|
|
|
|
|
|
| 192 |
retries = 0
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
while
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
sample_idx = len(dataset.sampled_mints) - 1
|
| 204 |
-
else:
|
| 205 |
-
sample_idx = found_idx
|
| 206 |
-
else:
|
| 207 |
-
sample_idx = int(args.sample_idx)
|
| 208 |
-
if sample_idx >= len(dataset):
|
| 209 |
-
raise ValueError(f"Sample index {sample_idx} out of range")
|
| 210 |
-
else:
|
| 211 |
-
sample_idx = random.randint(0, len(dataset.sampled_mints) - 1)
|
| 212 |
-
|
| 213 |
sample_mint_addr = dataset.sampled_mints[sample_idx]['mint_address']
|
| 214 |
print(f"Trying Token Address: {sample_mint_addr}")
|
| 215 |
-
|
| 216 |
-
contexts = dataset.__cacheitem_context__(
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
raw_sample = contexts[0]
|
| 227 |
-
|
|
|
|
|
|
|
|
|
|
| 228 |
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
print(f"\nEvaluating precisely on Token Address: {sample_mint_addr}")
|
| 234 |
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
# Yes, train.py does: labels = torch.sign(labels) * torch.log1p(torch.abs(labels))
|
| 268 |
-
real_preds = unlog_transform(preds)
|
| 269 |
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
continue
|
| 302 |
-
|
| 303 |
-
# Ground truth (raw)
|
| 304 |
-
gt_ret = gt_labels[h_idx].item()
|
| 305 |
-
print(f" Ground Truth: {gt_ret * 100:.2f}%")
|
| 306 |
-
|
| 307 |
-
# Predictions
|
| 308 |
-
print(" Predictions:")
|
| 309 |
-
for q_idx, q in enumerate(args.quantiles):
|
| 310 |
-
flat_idx = h_idx * num_quantiles + q_idx
|
| 311 |
-
pred_ret = real_preds[flat_idx].item()
|
| 312 |
-
log_pred = preds[flat_idx].item()
|
| 313 |
-
|
| 314 |
-
print(f" - p{int(q*100):02d}: {pred_ret * 100:>8.2f}% (raw log-val: {log_pred:7.4f})")
|
| 315 |
-
|
| 316 |
-
print("=============================================\n")
|
| 317 |
|
| 318 |
if __name__ == "__main__":
|
| 319 |
main()
|
|
|
|
| 2 |
import sys
|
| 3 |
import argparse
|
| 4 |
import random
|
| 5 |
+
import copy
|
| 6 |
+
import math
|
| 7 |
import torch
|
| 8 |
from pathlib import Path
|
| 9 |
|
|
|
|
| 16 |
|
| 17 |
from data.data_loader import OracleDataset
|
| 18 |
from data.data_collator import MemecoinCollator
|
| 19 |
+
from data.context_targets import MOVEMENT_ID_TO_CLASS
|
| 20 |
from models.multi_modal_processor import MultiModalEncoder
|
| 21 |
from models.helper_encoders import ContextualTimeEncoder
|
| 22 |
from models.token_encoder import TokenEncoder
|
|
|
|
| 32 |
from data.data_fetcher import DataFetcher
|
| 33 |
from scripts.analyze_distribution import get_return_class_map
|
| 34 |
|
| 35 |
+
ABLATION_SWEEP_MODES = [
|
| 36 |
+
"wallet",
|
| 37 |
+
"graph",
|
| 38 |
+
"social",
|
| 39 |
+
"token",
|
| 40 |
+
"holder",
|
| 41 |
+
"ohlc",
|
| 42 |
+
"ohlc_wallet",
|
| 43 |
+
"trade",
|
| 44 |
+
"onchain",
|
| 45 |
+
"wallet_graph",
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
OHLC_PROBE_MODES = [
|
| 49 |
+
"ohlc_reverse",
|
| 50 |
+
"ohlc_shuffle_chunks",
|
| 51 |
+
"ohlc_mask_recent",
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
def unlog_transform(tensor):
|
| 55 |
"""Invert the log1p transform applied during training."""
|
| 56 |
# During training: labels = torch.sign(labels) * torch.log1p(torch.abs(labels))
|
|
|
|
| 64 |
parser.add_argument("--horizons_seconds", type=int, nargs="+", default=[300, 900, 1800, 3600, 7200])
|
| 65 |
parser.add_argument("--quantiles", type=float, nargs="+", default=[0.1, 0.5, 0.9])
|
| 66 |
parser.add_argument("--seed", type=int, default=None)
|
| 67 |
+
parser.add_argument("--min_class", type=int, default=3, help="Filter out tokens with return class beneath this ID (e.g., 1 for >= 3x returns)")
|
| 68 |
+
parser.add_argument("--cutoff_trade_idx", type=int, default=200, help="Force the T_cutoff at this exact trade index (e.g., 10 = right after the 10th trade)")
|
| 69 |
+
parser.add_argument("--num_samples", type=int, default=1, help="Number of valid samples to evaluate and aggregate.")
|
| 70 |
+
parser.add_argument("--max_retries", type=int, default=100, help="Maximum attempts to find valid contexts across samples.")
|
| 71 |
+
parser.add_argument("--show_each", action="store_true", help="Print per-sample details for every evaluated sample.")
|
| 72 |
+
parser.add_argument(
|
| 73 |
+
"--ablation",
|
| 74 |
+
type=str,
|
| 75 |
+
default="none",
|
| 76 |
+
choices=["none", "wallet", "graph", "wallet_graph", "social", "token", "holder", "ohlc", "ohlc_wallet", "trade", "onchain", "all", "sweep", "ohlc_probe"],
|
| 77 |
+
help="Run inference with selected signal families removed, or use 'sweep' to rank multiple families.",
|
| 78 |
+
)
|
| 79 |
return parser.parse_args()
|
| 80 |
|
| 81 |
+
|
| 82 |
+
def clone_batch(batch):
|
| 83 |
+
cloned = {}
|
| 84 |
+
for key, value in batch.items():
|
| 85 |
+
if isinstance(value, torch.Tensor):
|
| 86 |
+
cloned[key] = value.clone()
|
| 87 |
+
else:
|
| 88 |
+
cloned[key] = copy.deepcopy(value)
|
| 89 |
+
return cloned
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _empty_wallet_encoder_inputs(device):
|
| 93 |
+
return {
|
| 94 |
+
'username_embed_indices': torch.tensor([], device=device, dtype=torch.long),
|
| 95 |
+
'profile_rows': [],
|
| 96 |
+
'social_rows': [],
|
| 97 |
+
'holdings_batch': [],
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _empty_token_encoder_inputs(device):
|
| 102 |
+
return {
|
| 103 |
+
'name_embed_indices': torch.tensor([], device=device, dtype=torch.long),
|
| 104 |
+
'symbol_embed_indices': torch.tensor([], device=device, dtype=torch.long),
|
| 105 |
+
'image_embed_indices': torch.tensor([], device=device, dtype=torch.long),
|
| 106 |
+
'protocol_ids': torch.tensor([], device=device, dtype=torch.long),
|
| 107 |
+
'is_vanity_flags': torch.tensor([], device=device, dtype=torch.bool),
|
| 108 |
+
'_addresses_for_lookup': [],
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def apply_ablation(batch, mode, device):
|
| 113 |
+
if mode == "none":
|
| 114 |
+
return batch
|
| 115 |
+
|
| 116 |
+
ablated = clone_batch(batch)
|
| 117 |
+
|
| 118 |
+
if mode in {"wallet", "wallet_graph", "ohlc_wallet", "all"}:
|
| 119 |
+
for key in (
|
| 120 |
+
"wallet_indices",
|
| 121 |
+
"dest_wallet_indices",
|
| 122 |
+
"original_author_indices",
|
| 123 |
+
"holder_snapshot_indices",
|
| 124 |
+
):
|
| 125 |
+
if key in ablated:
|
| 126 |
+
ablated[key].zero_()
|
| 127 |
+
ablated["wallet_encoder_inputs"] = _empty_wallet_encoder_inputs(device)
|
| 128 |
+
ablated["wallet_addr_to_batch_idx"] = {}
|
| 129 |
+
ablated["holder_snapshot_raw_data"] = []
|
| 130 |
+
ablated["graph_updater_links"] = {}
|
| 131 |
+
|
| 132 |
+
if mode in {"graph", "wallet_graph", "all"}:
|
| 133 |
+
ablated["graph_updater_links"] = {}
|
| 134 |
+
|
| 135 |
+
if mode in {"social", "all"}:
|
| 136 |
+
if "textual_event_indices" in ablated:
|
| 137 |
+
ablated["textual_event_indices"].zero_()
|
| 138 |
+
ablated["textual_event_data"] = []
|
| 139 |
+
|
| 140 |
+
if mode in {"token", "all"}:
|
| 141 |
+
for key in (
|
| 142 |
+
"token_indices",
|
| 143 |
+
"quote_token_indices",
|
| 144 |
+
"trending_token_indices",
|
| 145 |
+
"boosted_token_indices",
|
| 146 |
+
):
|
| 147 |
+
if key in ablated:
|
| 148 |
+
ablated[key].zero_()
|
| 149 |
+
ablated["token_encoder_inputs"] = _empty_token_encoder_inputs(device)
|
| 150 |
+
|
| 151 |
+
if mode in {"holder", "all"}:
|
| 152 |
+
if "holder_snapshot_indices" in ablated:
|
| 153 |
+
ablated["holder_snapshot_indices"].zero_()
|
| 154 |
+
ablated["holder_snapshot_raw_data"] = []
|
| 155 |
+
|
| 156 |
+
if mode in {"ohlc", "ohlc_wallet", "all"}:
|
| 157 |
+
if "ohlc_indices" in ablated:
|
| 158 |
+
ablated["ohlc_indices"].zero_()
|
| 159 |
+
if "ohlc_price_tensors" in ablated:
|
| 160 |
+
ablated["ohlc_price_tensors"] = torch.zeros_like(ablated["ohlc_price_tensors"])
|
| 161 |
+
if "ohlc_interval_ids" in ablated:
|
| 162 |
+
ablated["ohlc_interval_ids"] = torch.zeros_like(ablated["ohlc_interval_ids"])
|
| 163 |
+
|
| 164 |
+
if mode in {"trade", "all"}:
|
| 165 |
+
for key in (
|
| 166 |
+
"trade_numerical_features",
|
| 167 |
+
"deployer_trade_numerical_features",
|
| 168 |
+
"smart_wallet_trade_numerical_features",
|
| 169 |
+
"transfer_numerical_features",
|
| 170 |
+
"pool_created_numerical_features",
|
| 171 |
+
"liquidity_change_numerical_features",
|
| 172 |
+
"fee_collected_numerical_features",
|
| 173 |
+
"token_burn_numerical_features",
|
| 174 |
+
"supply_lock_numerical_features",
|
| 175 |
+
"boosted_token_numerical_features",
|
| 176 |
+
"trending_token_numerical_features",
|
| 177 |
+
"dexboost_paid_numerical_features",
|
| 178 |
+
"global_trending_numerical_features",
|
| 179 |
+
"chainsnapshot_numerical_features",
|
| 180 |
+
"lighthousesnapshot_numerical_features",
|
| 181 |
+
"dexprofile_updated_flags",
|
| 182 |
+
):
|
| 183 |
+
if key in ablated:
|
| 184 |
+
ablated[key] = torch.zeros_like(ablated[key])
|
| 185 |
+
for key in (
|
| 186 |
+
"trade_dex_ids",
|
| 187 |
+
"trade_direction_ids",
|
| 188 |
+
"trade_mev_protection_ids",
|
| 189 |
+
"trade_is_bundle_ids",
|
| 190 |
+
"pool_created_protocol_ids",
|
| 191 |
+
"liquidity_change_type_ids",
|
| 192 |
+
"trending_token_source_ids",
|
| 193 |
+
"trending_token_timeframe_ids",
|
| 194 |
+
"lighthousesnapshot_protocol_ids",
|
| 195 |
+
"lighthousesnapshot_timeframe_ids",
|
| 196 |
+
"migrated_protocol_ids",
|
| 197 |
+
"alpha_group_ids",
|
| 198 |
+
"channel_ids",
|
| 199 |
+
"exchange_ids",
|
| 200 |
+
):
|
| 201 |
+
if key in ablated:
|
| 202 |
+
ablated[key] = torch.zeros_like(ablated[key])
|
| 203 |
+
|
| 204 |
+
if mode == "onchain":
|
| 205 |
+
if "onchain_snapshot_numerical_features" in ablated:
|
| 206 |
+
ablated["onchain_snapshot_numerical_features"] = torch.zeros_like(ablated["onchain_snapshot_numerical_features"])
|
| 207 |
+
|
| 208 |
+
return ablated
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def _chunk_permutation_indices(length, chunk_size):
|
| 212 |
+
if length <= 0:
|
| 213 |
+
return []
|
| 214 |
+
chunks = [list(range(i, min(i + chunk_size, length))) for i in range(0, length, chunk_size)]
|
| 215 |
+
if len(chunks) <= 1:
|
| 216 |
+
return list(range(length))
|
| 217 |
+
permuted = list(reversed(chunks))
|
| 218 |
+
out = []
|
| 219 |
+
for chunk in permuted:
|
| 220 |
+
out.extend(chunk)
|
| 221 |
+
return out
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def apply_ohlc_probe(batch, mode):
|
| 225 |
+
probed = clone_batch(batch)
|
| 226 |
+
if "ohlc_price_tensors" not in probed or probed["ohlc_price_tensors"].numel() == 0:
|
| 227 |
+
return probed
|
| 228 |
+
|
| 229 |
+
ohlc = probed["ohlc_price_tensors"].clone()
|
| 230 |
+
seq_len = ohlc.shape[-1]
|
| 231 |
+
|
| 232 |
+
if mode == "ohlc_reverse":
|
| 233 |
+
probed["ohlc_price_tensors"] = torch.flip(ohlc, dims=[-1])
|
| 234 |
+
elif mode == "ohlc_shuffle_chunks":
|
| 235 |
+
perm = _chunk_permutation_indices(seq_len, chunk_size=30)
|
| 236 |
+
idx = torch.tensor(perm, device=ohlc.device, dtype=torch.long)
|
| 237 |
+
probed["ohlc_price_tensors"] = ohlc.index_select(-1, idx)
|
| 238 |
+
elif mode == "ohlc_mask_recent":
|
| 239 |
+
keep = max(seq_len - 60, 0)
|
| 240 |
+
if keep < seq_len and keep > 0:
|
| 241 |
+
fill = ohlc[..., keep - 1:keep].expand_as(ohlc[..., keep:])
|
| 242 |
+
ohlc[..., keep:] = fill
|
| 243 |
+
elif keep == 0:
|
| 244 |
+
ohlc.zero_()
|
| 245 |
+
probed["ohlc_price_tensors"] = ohlc
|
| 246 |
+
|
| 247 |
+
return probed
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def run_inference(model, batch):
|
| 251 |
+
with torch.no_grad():
|
| 252 |
+
outputs = model(batch)
|
| 253 |
+
preds = outputs["quantile_logits"][0].detach().cpu()
|
| 254 |
+
quality_pred = outputs["quality_logits"][0].detach().cpu() if "quality_logits" in outputs else None
|
| 255 |
+
movement_pred = outputs["movement_logits"][0].detach().cpu() if "movement_logits" in outputs else None
|
| 256 |
+
return preds, quality_pred, movement_pred
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def print_results(title, batch, preds, quality_pred, movement_pred, gt_labels, gt_mask, gt_quality, horizons_seconds, quantiles, reference_preds=None, reference_quality=None):
|
| 260 |
+
real_preds = unlog_transform(preds)
|
| 261 |
+
num_quantiles = len(quantiles)
|
| 262 |
+
num_gt_horizons = len(gt_mask)
|
| 263 |
+
|
| 264 |
+
print(f"\n================== {title} ==================")
|
| 265 |
+
print(f"Token Address: {batch.get('token_addresses', ['Unknown'])[0]}")
|
| 266 |
+
if gt_quality is not None:
|
| 267 |
+
quality_line = f"Quality Score: GT = {gt_quality:.4f} | Pred = {quality_pred.item() if quality_pred is not None else 'N/A'}"
|
| 268 |
+
if reference_quality is not None and quality_pred is not None:
|
| 269 |
+
quality_delta = quality_pred.item() - reference_quality.item()
|
| 270 |
+
quality_line += f" | Delta vs Full = {quality_delta:+.6f}"
|
| 271 |
+
print(quality_line)
|
| 272 |
+
if movement_pred is not None:
|
| 273 |
+
movement_targets = batch.get("movement_class_targets")
|
| 274 |
+
movement_mask = batch.get("movement_class_mask")
|
| 275 |
+
print("Movement Classes:")
|
| 276 |
+
for h_idx, horizon in enumerate(horizons_seconds):
|
| 277 |
+
if h_idx >= movement_pred.shape[0]:
|
| 278 |
+
break
|
| 279 |
+
target_txt = "N/A"
|
| 280 |
+
if movement_targets is not None and movement_mask is not None and bool(movement_mask[0, h_idx].item()):
|
| 281 |
+
target_txt = MOVEMENT_ID_TO_CLASS.get(int(movement_targets[0, h_idx].item()), "unknown")
|
| 282 |
+
pred_class = int(movement_pred[h_idx].argmax().item())
|
| 283 |
+
pred_name = MOVEMENT_ID_TO_CLASS.get(pred_class, "unknown")
|
| 284 |
+
pred_prob = float(torch.softmax(movement_pred[h_idx], dim=-1)[pred_class].item())
|
| 285 |
+
print(
|
| 286 |
+
f" {horizon:>4}s GT = {target_txt:<12} | "
|
| 287 |
+
f"Pred = {pred_name:<12} | "
|
| 288 |
+
f"Conf = {pred_prob:.4f}"
|
| 289 |
+
)
|
| 290 |
+
if "context_class_name" in batch:
|
| 291 |
+
print(f"Context Class: {batch['context_class_name'][0]}")
|
| 292 |
+
|
| 293 |
+
print("\nReturns per Horizon:")
|
| 294 |
+
for h_idx, horizon in enumerate(horizons_seconds):
|
| 295 |
+
horizon_min = horizon // 60
|
| 296 |
+
print(f"\n--- Horizon: {horizon}s ({horizon_min}m) ---")
|
| 297 |
+
|
| 298 |
+
if h_idx >= num_gt_horizons:
|
| 299 |
+
print(" [No Ground Truth Available for this Horizon - Not in Dataset]")
|
| 300 |
+
valid = False
|
| 301 |
+
else:
|
| 302 |
+
valid = gt_mask[h_idx].item()
|
| 303 |
+
|
| 304 |
+
if not valid:
|
| 305 |
+
print(" [No Ground Truth Available for this Horizon - Masked]")
|
| 306 |
+
else:
|
| 307 |
+
gt_ret = gt_labels[h_idx].item()
|
| 308 |
+
print(f" Ground Truth: {gt_ret * 100:.2f}%")
|
| 309 |
+
|
| 310 |
+
print(" Predictions:")
|
| 311 |
+
for q_idx, q in enumerate(quantiles):
|
| 312 |
+
flat_idx = h_idx * num_quantiles + q_idx
|
| 313 |
+
pred_ret = real_preds[flat_idx].item()
|
| 314 |
+
log_pred = preds[flat_idx].item()
|
| 315 |
+
line = f" - p{int(q*100):02d}: {pred_ret * 100:>8.2f}% (raw log-val: {log_pred:7.4f})"
|
| 316 |
+
if reference_preds is not None:
|
| 317 |
+
ref_ret = unlog_transform(reference_preds)[flat_idx].item()
|
| 318 |
+
line += f" | Delta vs Full: {(pred_ret - ref_ret) * 100:+7.2f}%"
|
| 319 |
+
print(line)
|
| 320 |
+
|
| 321 |
+
print("=============================================\n")
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def resolve_sample_index(dataset, sample_idx_arg, rng):
|
| 325 |
+
if sample_idx_arg is not None:
|
| 326 |
+
if isinstance(sample_idx_arg, str) and not sample_idx_arg.isdigit():
|
| 327 |
+
found_idx = next((i for i, m in enumerate(dataset.sampled_mints) if m['mint_address'] == sample_idx_arg), None)
|
| 328 |
+
if found_idx is None:
|
| 329 |
+
raise ValueError(f"Mint address {sample_idx_arg} not found in filtered dataset")
|
| 330 |
+
return found_idx
|
| 331 |
+
resolved = int(sample_idx_arg)
|
| 332 |
+
if resolved >= len(dataset):
|
| 333 |
+
raise ValueError(f"Sample index {resolved} out of range")
|
| 334 |
+
return resolved
|
| 335 |
+
return rng.randint(0, len(dataset.sampled_mints) - 1)
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def move_batch_to_device(batch, device):
|
| 339 |
+
for k, v in batch.items():
|
| 340 |
+
if isinstance(v, torch.Tensor):
|
| 341 |
+
batch[k] = v.to(device)
|
| 342 |
+
elif isinstance(v, list) and len(v) > 0 and isinstance(v[0], torch.Tensor):
|
| 343 |
+
batch[k] = [t.to(device) for t in v]
|
| 344 |
+
if 'textual_event_indices' not in batch:
|
| 345 |
+
B, L = batch['event_type_ids'].shape
|
| 346 |
+
batch['textual_event_indices'] = torch.zeros((B, L), dtype=torch.long, device=device)
|
| 347 |
+
if 'textual_event_data' not in batch:
|
| 348 |
+
batch['textual_event_data'] = []
|
| 349 |
+
return batch
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def init_aggregate(horizons_seconds, quantiles):
|
| 353 |
+
return {
|
| 354 |
+
"count": 0,
|
| 355 |
+
"quality_full_sum": 0.0,
|
| 356 |
+
"quality_abl_sum": 0.0,
|
| 357 |
+
"quality_delta_sum": 0.0,
|
| 358 |
+
"gt_quality_sum": 0.0,
|
| 359 |
+
"per_hq": {
|
| 360 |
+
(h, q): {
|
| 361 |
+
"full_sum": 0.0,
|
| 362 |
+
"abl_sum": 0.0,
|
| 363 |
+
"delta_sum": 0.0,
|
| 364 |
+
"abs_delta_sum": 0.0,
|
| 365 |
+
"gt_sum": 0.0,
|
| 366 |
+
"valid_count": 0,
|
| 367 |
+
}
|
| 368 |
+
for h in horizons_seconds for q in quantiles
|
| 369 |
+
},
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def update_aggregate(stats, full_preds, gt_labels, gt_mask, gt_quality, horizons_seconds, quantiles, ablated_preds=None, full_quality=None, ablated_quality=None):
|
| 374 |
+
stats["count"] += 1
|
| 375 |
+
if gt_quality is not None:
|
| 376 |
+
stats["gt_quality_sum"] += float(gt_quality)
|
| 377 |
+
if full_quality is not None:
|
| 378 |
+
stats["quality_full_sum"] += float(full_quality.item())
|
| 379 |
+
if ablated_quality is not None:
|
| 380 |
+
stats["quality_abl_sum"] += float(ablated_quality.item())
|
| 381 |
+
if full_quality is not None and ablated_quality is not None:
|
| 382 |
+
stats["quality_delta_sum"] += float(ablated_quality.item() - full_quality.item())
|
| 383 |
+
|
| 384 |
+
full_real = unlog_transform(full_preds)
|
| 385 |
+
ablated_real = unlog_transform(ablated_preds) if ablated_preds is not None else None
|
| 386 |
+
num_quantiles = len(quantiles)
|
| 387 |
+
|
| 388 |
+
for h_idx, horizon in enumerate(horizons_seconds):
|
| 389 |
+
valid = h_idx < len(gt_mask) and bool(gt_mask[h_idx].item())
|
| 390 |
+
gt_ret = float(gt_labels[h_idx].item()) if valid else math.nan
|
| 391 |
+
for q_idx, q in enumerate(quantiles):
|
| 392 |
+
flat_idx = h_idx * num_quantiles + q_idx
|
| 393 |
+
bucket = stats["per_hq"][(horizon, q)]
|
| 394 |
+
full_val = float(full_real[flat_idx].item())
|
| 395 |
+
bucket["full_sum"] += full_val
|
| 396 |
+
if ablated_real is not None:
|
| 397 |
+
abl_val = float(ablated_real[flat_idx].item())
|
| 398 |
+
delta = abl_val - full_val
|
| 399 |
+
bucket["abl_sum"] += abl_val
|
| 400 |
+
bucket["delta_sum"] += delta
|
| 401 |
+
bucket["abs_delta_sum"] += abs(delta)
|
| 402 |
+
if valid:
|
| 403 |
+
bucket["gt_sum"] += gt_ret
|
| 404 |
+
bucket["valid_count"] += 1
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def print_aggregate_summary(stats, horizons_seconds, quantiles, ablation_mode):
|
| 408 |
+
n = stats["count"]
|
| 409 |
+
print("\n================== Aggregate Summary ==================")
|
| 410 |
+
print(f"Evaluated Samples: {n}")
|
| 411 |
+
if n == 0:
|
| 412 |
+
print("No valid samples collected.")
|
| 413 |
+
print("=======================================================\n")
|
| 414 |
+
return
|
| 415 |
+
|
| 416 |
+
if ablation_mode != "none":
|
| 417 |
+
print(
|
| 418 |
+
f"Quality Mean: full={stats['quality_full_sum'] / n:.6f} | "
|
| 419 |
+
f"ablated={stats['quality_abl_sum'] / n:.6f} | "
|
| 420 |
+
f"delta={stats['quality_delta_sum'] / n:+.6f}"
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
for horizon in horizons_seconds:
|
| 424 |
+
horizon_min = horizon // 60
|
| 425 |
+
print(f"\n--- Horizon: {horizon}s ({horizon_min}m) ---")
|
| 426 |
+
valid_counts = [stats["per_hq"][(horizon, q)]["valid_count"] for q in quantiles]
|
| 427 |
+
valid_count = max(valid_counts) if valid_counts else 0
|
| 428 |
+
if valid_count > 0:
|
| 429 |
+
gt_mean = stats["per_hq"][(horizon, quantiles[0])]["gt_sum"] / valid_count
|
| 430 |
+
print(f" Mean Ground Truth over valid labels: {gt_mean * 100:.2f}% (n={valid_count})")
|
| 431 |
+
else:
|
| 432 |
+
print(" Mean Ground Truth over valid labels: N/A")
|
| 433 |
+
|
| 434 |
+
for q in quantiles:
|
| 435 |
+
bucket = stats["per_hq"][(horizon, q)]
|
| 436 |
+
full_mean = bucket["full_sum"] / n
|
| 437 |
+
line = f" p{int(q*100):02d} mean full: {full_mean * 100:>8.2f}%"
|
| 438 |
+
if ablation_mode != "none":
|
| 439 |
+
abl_mean = bucket["abl_sum"] / n
|
| 440 |
+
delta_mean = bucket["delta_sum"] / n
|
| 441 |
+
abs_delta_mean = bucket["abs_delta_sum"] / n
|
| 442 |
+
line += (
|
| 443 |
+
f" | ablated: {abl_mean * 100:>8.2f}%"
|
| 444 |
+
f" | delta: {delta_mean * 100:+8.2f}%"
|
| 445 |
+
f" | mean|delta|: {abs_delta_mean * 100:>8.2f}%"
|
| 446 |
+
)
|
| 447 |
+
print(line)
|
| 448 |
+
print("=======================================================\n")
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def summarize_influence_score(stats, horizons_seconds, quantiles):
|
| 452 |
+
n = stats["count"]
|
| 453 |
+
if n == 0:
|
| 454 |
+
return 0.0
|
| 455 |
+
total = 0.0
|
| 456 |
+
denom = 0
|
| 457 |
+
for horizon in horizons_seconds:
|
| 458 |
+
for q in quantiles:
|
| 459 |
+
total += stats["per_hq"][(horizon, q)]["abs_delta_sum"] / n
|
| 460 |
+
denom += 1
|
| 461 |
+
return total / max(denom, 1)
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
def print_probe_summary(mode_to_stats, horizons_seconds, quantiles):
|
| 465 |
+
rankings = []
|
| 466 |
+
for mode in OHLC_PROBE_MODES:
|
| 467 |
+
score = summarize_influence_score(mode_to_stats[mode], horizons_seconds, quantiles)
|
| 468 |
+
rankings.append((mode, score))
|
| 469 |
+
rankings.sort(key=lambda x: x[1], reverse=True)
|
| 470 |
+
|
| 471 |
+
print("\n================== OHLC Probe Ranking ==================")
|
| 472 |
+
for rank, (mode, score) in enumerate(rankings, start=1):
|
| 473 |
+
print(f"{rank:>2}. {mode:<20} mean|delta| = {score * 100:8.2f}%")
|
| 474 |
+
print("========================================================\n")
|
| 475 |
+
|
| 476 |
+
for mode, _ in rankings:
|
| 477 |
+
print_aggregate_summary(mode_to_stats[mode], horizons_seconds, quantiles, mode)
|
| 478 |
+
|
| 479 |
def get_latest_checkpoint(checkpoint_dir):
|
| 480 |
ckpt_dir = Path(checkpoint_dir)
|
| 481 |
if ckpt_dir.exists():
|
|
|
|
| 489 |
def main():
|
| 490 |
load_dotenv()
|
| 491 |
args = parse_args()
|
| 492 |
+
rng = random.Random(args.seed)
|
| 493 |
+
if args.seed is not None:
|
| 494 |
+
random.seed(args.seed)
|
| 495 |
+
torch.manual_seed(args.seed)
|
| 496 |
|
| 497 |
accelerator = Accelerator(mixed_precision=args.mixed_precision)
|
| 498 |
device = accelerator.device
|
|
|
|
| 620 |
|
| 621 |
model.eval()
|
| 622 |
|
| 623 |
+
stats = init_aggregate(args.horizons_seconds, args.quantiles)
|
| 624 |
+
selected_modes = [] if args.ablation == "none" else (ABLATION_SWEEP_MODES if args.ablation == "sweep" else ([] if args.ablation == "ohlc_probe" else [args.ablation]))
|
| 625 |
+
mode_to_stats = {mode: init_aggregate(args.horizons_seconds, args.quantiles) for mode in selected_modes}
|
| 626 |
+
probe_to_stats = {mode: init_aggregate(args.horizons_seconds, args.quantiles) for mode in OHLC_PROBE_MODES} if args.ablation == "ohlc_probe" else {}
|
| 627 |
+
max_target_samples = max(1, args.num_samples)
|
| 628 |
retries = 0
|
| 629 |
+
collected = 0
|
| 630 |
+
seen_indices = set()
|
| 631 |
+
|
| 632 |
+
while collected < max_target_samples and retries < args.max_retries:
|
| 633 |
+
sample_idx = resolve_sample_index(dataset, args.sample_idx, rng)
|
| 634 |
+
if args.sample_idx is None and sample_idx in seen_indices and len(seen_indices) < len(dataset.sampled_mints):
|
| 635 |
+
retries += 1
|
| 636 |
+
continue
|
| 637 |
+
seen_indices.add(sample_idx)
|
| 638 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 639 |
sample_mint_addr = dataset.sampled_mints[sample_idx]['mint_address']
|
| 640 |
print(f"Trying Token Address: {sample_mint_addr}")
|
| 641 |
+
|
| 642 |
+
contexts = dataset.__cacheitem_context__(
|
| 643 |
+
sample_idx,
|
| 644 |
+
num_samples_per_token=1,
|
| 645 |
+
encoder=multi_modal_encoder,
|
| 646 |
+
forced_cutoff_trade_idx=args.cutoff_trade_idx,
|
| 647 |
+
)
|
| 648 |
+
|
| 649 |
+
if not contexts or contexts[0] is None:
|
| 650 |
+
print(" [Failed to generate valid context pattern, skipping...]")
|
| 651 |
+
retries += 1
|
| 652 |
+
if args.sample_idx is not None:
|
| 653 |
+
print("Specific sample requested but failed to generate context. Exiting.")
|
| 654 |
+
return
|
| 655 |
+
continue
|
| 656 |
+
|
| 657 |
raw_sample = contexts[0]
|
| 658 |
+
batch = move_batch_to_device(collator([raw_sample]), device)
|
| 659 |
+
gt_labels = batch["labels"][0].cpu()
|
| 660 |
+
gt_mask = batch["labels_mask"][0].cpu().bool()
|
| 661 |
+
gt_quality = batch["quality_score"][0].item() if "quality_score" in batch else None
|
| 662 |
|
| 663 |
+
if collected == 0 or args.show_each:
|
| 664 |
+
print(f"\nEvaluating sample {collected + 1}/{max_target_samples} on Token Address: {sample_mint_addr}")
|
| 665 |
+
print("\n--- Running Inference ---")
|
|
|
|
|
|
|
| 666 |
|
| 667 |
+
full_preds, full_quality, full_direction = run_inference(model, batch)
|
| 668 |
+
ablation_outputs = {}
|
| 669 |
+
for mode in selected_modes:
|
| 670 |
+
ablated_batch = apply_ablation(batch, mode, device)
|
| 671 |
+
ablated_preds, ablated_quality, ablated_direction = run_inference(model, ablated_batch)
|
| 672 |
+
ablation_outputs[mode] = (ablated_batch, ablated_preds, ablated_quality, ablated_direction)
|
| 673 |
+
probe_outputs = {}
|
| 674 |
+
if args.ablation == "ohlc_probe":
|
| 675 |
+
for mode in OHLC_PROBE_MODES:
|
| 676 |
+
probe_batch = apply_ohlc_probe(batch, mode)
|
| 677 |
+
probe_preds, probe_quality, probe_direction = run_inference(model, probe_batch)
|
| 678 |
+
probe_outputs[mode] = (probe_batch, probe_preds, probe_quality, probe_direction)
|
| 679 |
|
| 680 |
+
if collected == 0 or args.show_each:
|
| 681 |
+
print_results(
|
| 682 |
+
title="Full Results",
|
| 683 |
+
batch=batch,
|
| 684 |
+
preds=full_preds,
|
| 685 |
+
quality_pred=full_quality,
|
| 686 |
+
direction_pred=full_direction,
|
| 687 |
+
gt_labels=gt_labels,
|
| 688 |
+
gt_mask=gt_mask,
|
| 689 |
+
gt_quality=gt_quality,
|
| 690 |
+
horizons_seconds=args.horizons_seconds,
|
| 691 |
+
quantiles=args.quantiles,
|
| 692 |
+
)
|
| 693 |
+
if args.ablation != "none":
|
| 694 |
+
if args.ablation == "sweep":
|
| 695 |
+
print(f"Collected full predictions for {len(selected_modes)} ablation families on this sample. Aggregate ranking will be printed at the end.")
|
| 696 |
+
elif args.ablation == "ohlc_probe":
|
| 697 |
+
for mode in OHLC_PROBE_MODES:
|
| 698 |
+
probe_batch, probe_preds, probe_quality, probe_direction = probe_outputs[mode]
|
| 699 |
+
print_results(
|
| 700 |
+
title=f"OHLC Probe ({mode})",
|
| 701 |
+
batch=probe_batch,
|
| 702 |
+
preds=probe_preds,
|
| 703 |
+
quality_pred=probe_quality,
|
| 704 |
+
direction_pred=probe_direction,
|
| 705 |
+
gt_labels=gt_labels,
|
| 706 |
+
gt_mask=gt_mask,
|
| 707 |
+
gt_quality=gt_quality,
|
| 708 |
+
horizons_seconds=args.horizons_seconds,
|
| 709 |
+
quantiles=args.quantiles,
|
| 710 |
+
reference_preds=full_preds,
|
| 711 |
+
reference_quality=full_quality,
|
| 712 |
+
)
|
| 713 |
+
else:
|
| 714 |
+
ablated_batch, ablated_preds, ablated_quality, ablated_direction = ablation_outputs[args.ablation]
|
| 715 |
+
print_results(
|
| 716 |
+
title=f"Ablation Results ({args.ablation})",
|
| 717 |
+
batch=ablated_batch,
|
| 718 |
+
preds=ablated_preds,
|
| 719 |
+
quality_pred=ablated_quality,
|
| 720 |
+
direction_pred=ablated_direction,
|
| 721 |
+
gt_labels=gt_labels,
|
| 722 |
+
gt_mask=gt_mask,
|
| 723 |
+
gt_quality=gt_quality,
|
| 724 |
+
horizons_seconds=args.horizons_seconds,
|
| 725 |
+
quantiles=args.quantiles,
|
| 726 |
+
reference_preds=full_preds,
|
| 727 |
+
reference_quality=full_quality,
|
| 728 |
+
)
|
| 729 |
|
| 730 |
+
update_aggregate(
|
| 731 |
+
stats=stats,
|
| 732 |
+
full_preds=full_preds,
|
| 733 |
+
gt_labels=gt_labels,
|
| 734 |
+
gt_mask=gt_mask,
|
| 735 |
+
gt_quality=gt_quality,
|
| 736 |
+
horizons_seconds=args.horizons_seconds,
|
| 737 |
+
quantiles=args.quantiles,
|
| 738 |
+
full_quality=full_quality,
|
| 739 |
+
)
|
| 740 |
+
for mode, (_, ablated_preds, ablated_quality, _) in ablation_outputs.items():
|
| 741 |
+
update_aggregate(
|
| 742 |
+
stats=mode_to_stats[mode],
|
| 743 |
+
full_preds=full_preds,
|
| 744 |
+
gt_labels=gt_labels,
|
| 745 |
+
gt_mask=gt_mask,
|
| 746 |
+
gt_quality=gt_quality,
|
| 747 |
+
horizons_seconds=args.horizons_seconds,
|
| 748 |
+
quantiles=args.quantiles,
|
| 749 |
+
ablated_preds=ablated_preds,
|
| 750 |
+
full_quality=full_quality,
|
| 751 |
+
ablated_quality=ablated_quality,
|
| 752 |
+
)
|
| 753 |
+
for mode, (_, probe_preds, probe_quality, _) in probe_outputs.items():
|
| 754 |
+
update_aggregate(
|
| 755 |
+
stats=probe_to_stats[mode],
|
| 756 |
+
full_preds=full_preds,
|
| 757 |
+
gt_labels=gt_labels,
|
| 758 |
+
gt_mask=gt_mask,
|
| 759 |
+
gt_quality=gt_quality,
|
| 760 |
+
horizons_seconds=args.horizons_seconds,
|
| 761 |
+
quantiles=args.quantiles,
|
| 762 |
+
ablated_preds=probe_preds,
|
| 763 |
+
full_quality=full_quality,
|
| 764 |
+
ablated_quality=probe_quality,
|
| 765 |
+
)
|
| 766 |
+
collected += 1
|
| 767 |
+
retries += 1
|
| 768 |
|
| 769 |
+
if args.sample_idx is not None:
|
| 770 |
+
break
|
|
|
|
|
|
|
| 771 |
|
| 772 |
+
if collected == 0:
|
| 773 |
+
print(f"Could not find a valid context after {args.max_retries} attempts.")
|
| 774 |
+
return
|
| 775 |
+
|
| 776 |
+
if collected < max_target_samples:
|
| 777 |
+
print(f"WARNING: Requested {max_target_samples} samples but only evaluated {collected}.")
|
| 778 |
+
|
| 779 |
+
if args.ablation == "none":
|
| 780 |
+
print_aggregate_summary(stats, args.horizons_seconds, args.quantiles, args.ablation)
|
| 781 |
+
return
|
| 782 |
+
|
| 783 |
+
if args.ablation == "ohlc_probe":
|
| 784 |
+
print_probe_summary(probe_to_stats, args.horizons_seconds, args.quantiles)
|
| 785 |
+
return
|
| 786 |
+
|
| 787 |
+
if args.ablation == "sweep":
|
| 788 |
+
rankings = []
|
| 789 |
+
for mode in selected_modes:
|
| 790 |
+
score = summarize_influence_score(mode_to_stats[mode], args.horizons_seconds, args.quantiles)
|
| 791 |
+
rankings.append((mode, score))
|
| 792 |
+
rankings.sort(key=lambda x: x[1], reverse=True)
|
| 793 |
+
|
| 794 |
+
print("\n================== Influence Ranking ==================")
|
| 795 |
+
for rank, (mode, score) in enumerate(rankings, start=1):
|
| 796 |
+
print(f"{rank:>2}. {mode:<12} mean|delta| = {score * 100:8.2f}%")
|
| 797 |
+
print("=======================================================\n")
|
| 798 |
+
|
| 799 |
+
for mode, _ in rankings:
|
| 800 |
+
print_aggregate_summary(mode_to_stats[mode], args.horizons_seconds, args.quantiles, mode)
|
| 801 |
+
else:
|
| 802 |
+
print_aggregate_summary(mode_to_stats[args.ablation], args.horizons_seconds, args.quantiles, args.ablation)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 803 |
|
| 804 |
if __name__ == "__main__":
|
| 805 |
main()
|
train.py
CHANGED
|
@@ -52,6 +52,7 @@ from neo4j import GraphDatabase
|
|
| 52 |
from data.data_fetcher import DataFetcher
|
| 53 |
from data.data_loader import OracleDataset
|
| 54 |
from data.data_collator import MemecoinCollator
|
|
|
|
| 55 |
from models.multi_modal_processor import MultiModalEncoder
|
| 56 |
from models.helper_encoders import ContextualTimeEncoder
|
| 57 |
from models.token_encoder import TokenEncoder
|
|
@@ -148,6 +149,89 @@ def quantile_pinball_loss_per_sample(
|
|
| 148 |
return per_sample_num / per_sample_den
|
| 149 |
|
| 150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
def create_balanced_split(dataset, n_val_per_class: int = 1, seed: int = 42):
|
| 152 |
"""
|
| 153 |
Create train/val split with balanced classes in validation set.
|
|
@@ -207,6 +291,8 @@ def run_validation(model, val_dataloader, accelerator, quantiles, quality_loss_f
|
|
| 207 |
total_loss = 0.0
|
| 208 |
total_return_loss = 0.0
|
| 209 |
total_quality_loss = 0.0
|
|
|
|
|
|
|
| 210 |
n_batches = 0
|
| 211 |
|
| 212 |
# Per-class metrics
|
|
@@ -228,9 +314,12 @@ def run_validation(model, val_dataloader, accelerator, quantiles, quality_loss_f
|
|
| 228 |
|
| 229 |
preds = outputs["quantile_logits"]
|
| 230 |
quality_preds = outputs["quality_logits"]
|
|
|
|
| 231 |
labels = batch["labels"]
|
| 232 |
labels_mask = batch["labels_mask"]
|
| 233 |
quality_targets = batch["quality_score"].to(accelerator.device, dtype=quality_preds.dtype)
|
|
|
|
|
|
|
| 234 |
|
| 235 |
if labels_mask.sum() == 0:
|
| 236 |
return_loss = torch.tensor(0.0, device=accelerator.device)
|
|
@@ -240,11 +329,22 @@ def run_validation(model, val_dataloader, accelerator, quantiles, quality_loss_f
|
|
| 240 |
return_loss = quantile_pinball_loss(preds, labels, labels_mask, quantiles)
|
| 241 |
|
| 242 |
quality_loss = quality_loss_fn(quality_preds, quality_targets)
|
| 243 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
total_loss += loss.item()
|
| 246 |
total_return_loss += return_loss.item()
|
| 247 |
total_quality_loss += quality_loss.item()
|
|
|
|
|
|
|
| 248 |
n_batches += 1
|
| 249 |
|
| 250 |
# Track per-class losses
|
|
@@ -264,6 +364,8 @@ def run_validation(model, val_dataloader, accelerator, quantiles, quality_loss_f
|
|
| 264 |
'val/loss': total_loss / n_batches,
|
| 265 |
'val/return_loss': total_return_loss / n_batches,
|
| 266 |
'val/quality_loss': total_quality_loss / n_batches,
|
|
|
|
|
|
|
| 267 |
'val/n_batches': n_batches,
|
| 268 |
'class_losses': {k: v['loss'] / max(v['count'], 1) for k, v in class_losses.items()}
|
| 269 |
}
|
|
@@ -446,6 +548,7 @@ def parse_args() -> argparse.Namespace:
|
|
| 446 |
parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Path to checkpoint or 'latest'")
|
| 447 |
parser.add_argument("--val_samples_per_class", type=int, default=1, help="Number of validation samples per class (default 1)")
|
| 448 |
parser.add_argument("--val_every", type=int, default=1000, help="Run validation every N steps (default 1000)")
|
|
|
|
| 449 |
return parser.parse_args()
|
| 450 |
|
| 451 |
|
|
@@ -815,6 +918,10 @@ def main() -> None:
|
|
| 815 |
|
| 816 |
# --- 7. Training Loop ---
|
| 817 |
quality_loss_fn = nn.MSELoss()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 818 |
|
| 819 |
logger.info("***** Running training *****")
|
| 820 |
logger.info(f" Num examples = {len(dataset)}")
|
|
@@ -865,6 +972,7 @@ def main() -> None:
|
|
| 865 |
|
| 866 |
preds = outputs["quantile_logits"]
|
| 867 |
quality_preds = outputs["quality_logits"]
|
|
|
|
| 868 |
labels = batch["labels"]
|
| 869 |
labels_mask = batch["labels_mask"]
|
| 870 |
if "quality_score" not in batch:
|
|
@@ -920,6 +1028,21 @@ def main() -> None:
|
|
| 920 |
per_sample_return = quantile_pinball_loss_per_sample(preds, labels, labels_mask, quantiles)
|
| 921 |
|
| 922 |
quality_loss = quality_loss_fn(quality_preds, quality_targets)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 923 |
per_sample_quality = (quality_preds - quality_targets).pow(2)
|
| 924 |
|
| 925 |
# Apply per-sample class weighting to return loss
|
|
@@ -928,10 +1051,10 @@ def main() -> None:
|
|
| 928 |
sample_weights = class_loss_weights[batch_class_ids] # [B]
|
| 929 |
# Scale return loss by mean class weight for this batch
|
| 930 |
class_weight_factor = sample_weights.mean()
|
| 931 |
-
loss = return_loss * class_weight_factor + quality_loss
|
| 932 |
else:
|
| 933 |
class_weight_factor = torch.tensor(1.0, device=accelerator.device, dtype=return_loss.dtype)
|
| 934 |
-
loss = return_loss + quality_loss
|
| 935 |
per_sample_total = per_sample_return * class_weight_factor + per_sample_quality
|
| 936 |
|
| 937 |
if not torch.isfinite(loss).all().item():
|
|
@@ -957,8 +1080,10 @@ def main() -> None:
|
|
| 957 |
"loss": loss.unsqueeze(0),
|
| 958 |
"return_loss": return_loss.unsqueeze(0),
|
| 959 |
"quality_loss": quality_loss.unsqueeze(0),
|
|
|
|
| 960 |
"preds": preds,
|
| 961 |
"quality_preds": quality_preds,
|
|
|
|
| 962 |
"labels_raw": batch.get("labels"),
|
| 963 |
"labels_log": labels,
|
| 964 |
"labels_mask": labels_mask,
|
|
@@ -1065,6 +1190,7 @@ def main() -> None:
|
|
| 1065 |
current_loss = loss.item()
|
| 1066 |
current_return_loss = return_loss.item()
|
| 1067 |
current_quality_loss = quality_loss.item()
|
|
|
|
| 1068 |
current_class_weight_factor = float(class_weight_factor.item()) if isinstance(class_weight_factor, torch.Tensor) else float(class_weight_factor)
|
| 1069 |
current_mask_coverage = float(labels_mask.float().mean().item()) if labels_mask is not None else 0.0
|
| 1070 |
epoch_loss += current_loss
|
|
@@ -1080,6 +1206,8 @@ def main() -> None:
|
|
| 1080 |
"train/loss": current_loss,
|
| 1081 |
"train/return_loss": current_return_loss,
|
| 1082 |
"train/quality_loss": current_quality_loss,
|
|
|
|
|
|
|
| 1083 |
"train/class_weight_factor": current_class_weight_factor,
|
| 1084 |
"train/mask_coverage": current_mask_coverage,
|
| 1085 |
"train/loss_ema": loss_ema if loss_ema is not None else current_loss,
|
|
@@ -1149,7 +1277,9 @@ def main() -> None:
|
|
| 1149 |
logger.info(
|
| 1150 |
f"Validation - Loss: {val_metrics['val/loss']:.4f} | "
|
| 1151 |
f"Return: {val_metrics['val/return_loss']:.4f} | "
|
| 1152 |
-
f"Quality: {val_metrics['val/quality_loss']:.4f}"
|
|
|
|
|
|
|
| 1153 |
)
|
| 1154 |
# Log per-class losses
|
| 1155 |
class_loss_str = " | ".join(
|
|
@@ -1162,6 +1292,8 @@ def main() -> None:
|
|
| 1162 |
"val/loss": val_metrics['val/loss'],
|
| 1163 |
"val/return_loss": val_metrics['val/return_loss'],
|
| 1164 |
"val/quality_loss": val_metrics['val/quality_loss'],
|
|
|
|
|
|
|
| 1165 |
}, step=total_steps)
|
| 1166 |
# Log per-class losses to tensorboard
|
| 1167 |
for class_id, class_loss in val_metrics['class_losses'].items():
|
|
|
|
| 52 |
from data.data_fetcher import DataFetcher
|
| 53 |
from data.data_loader import OracleDataset
|
| 54 |
from data.data_collator import MemecoinCollator
|
| 55 |
+
from data.context_targets import MOVEMENT_CLASS_NAMES
|
| 56 |
from models.multi_modal_processor import MultiModalEncoder
|
| 57 |
from models.helper_encoders import ContextualTimeEncoder
|
| 58 |
from models.token_encoder import TokenEncoder
|
|
|
|
| 149 |
return per_sample_num / per_sample_den
|
| 150 |
|
| 151 |
|
| 152 |
+
def masked_movement_cross_entropy(
|
| 153 |
+
logits: torch.Tensor,
|
| 154 |
+
targets: torch.Tensor,
|
| 155 |
+
mask: torch.Tensor,
|
| 156 |
+
class_weights: Optional[torch.Tensor] = None,
|
| 157 |
+
) -> torch.Tensor:
|
| 158 |
+
if mask.sum() == 0:
|
| 159 |
+
return torch.tensor(0.0, device=logits.device, dtype=logits.dtype)
|
| 160 |
+
|
| 161 |
+
flat_logits = logits.reshape(-1, logits.shape[-1])
|
| 162 |
+
flat_targets = targets.reshape(-1)
|
| 163 |
+
flat_mask = mask.reshape(-1).bool()
|
| 164 |
+
|
| 165 |
+
valid_logits = flat_logits[flat_mask]
|
| 166 |
+
valid_targets = flat_targets[flat_mask]
|
| 167 |
+
if valid_logits.numel() == 0:
|
| 168 |
+
return torch.tensor(0.0, device=logits.device, dtype=logits.dtype)
|
| 169 |
+
|
| 170 |
+
return nn.functional.cross_entropy(valid_logits, valid_targets, weight=class_weights)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def movement_accuracy(
|
| 174 |
+
logits: torch.Tensor,
|
| 175 |
+
targets: torch.Tensor,
|
| 176 |
+
mask: torch.Tensor,
|
| 177 |
+
) -> float:
|
| 178 |
+
if mask.sum().item() == 0:
|
| 179 |
+
return 0.0
|
| 180 |
+
preds = logits.argmax(dim=-1)
|
| 181 |
+
valid = mask.bool()
|
| 182 |
+
correct = (preds[valid] == targets[valid]).float()
|
| 183 |
+
if correct.numel() == 0:
|
| 184 |
+
return 0.0
|
| 185 |
+
return float(correct.mean().item())
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def estimate_movement_class_weights(
|
| 189 |
+
dataset,
|
| 190 |
+
indices,
|
| 191 |
+
movement_label_config: Optional[Dict[str, float]] = None,
|
| 192 |
+
sample_cap: int = 4096,
|
| 193 |
+
) -> torch.Tensor:
|
| 194 |
+
del movement_label_config
|
| 195 |
+
counts = torch.ones(len(MOVEMENT_CLASS_NAMES), dtype=torch.float32)
|
| 196 |
+
if not indices:
|
| 197 |
+
return counts
|
| 198 |
+
|
| 199 |
+
for idx in indices[: min(len(indices), sample_cap)]:
|
| 200 |
+
try:
|
| 201 |
+
item = dataset[idx]
|
| 202 |
+
except Exception:
|
| 203 |
+
continue
|
| 204 |
+
if not item:
|
| 205 |
+
continue
|
| 206 |
+
labels = item.get("labels")
|
| 207 |
+
labels_mask = item.get("labels_mask")
|
| 208 |
+
movement_targets = item.get("movement_class_targets")
|
| 209 |
+
movement_mask = item.get("movement_class_mask")
|
| 210 |
+
if movement_targets is None or movement_mask is None:
|
| 211 |
+
if labels is None or labels_mask is None:
|
| 212 |
+
continue
|
| 213 |
+
batch_targets = collator_like_targets(labels, labels_mask)
|
| 214 |
+
targets = batch_targets["movement_class_targets"]
|
| 215 |
+
mask = batch_targets["movement_class_mask"]
|
| 216 |
+
else:
|
| 217 |
+
targets = movement_targets.tolist() if isinstance(movement_targets, torch.Tensor) else movement_targets
|
| 218 |
+
mask = movement_mask.tolist() if isinstance(movement_mask, torch.Tensor) else movement_mask
|
| 219 |
+
for target, target_mask in zip(targets, mask):
|
| 220 |
+
if int(target_mask) > 0:
|
| 221 |
+
counts[int(target)] += 1.0
|
| 222 |
+
|
| 223 |
+
weights = counts.sum() / counts.clamp_min(1.0)
|
| 224 |
+
return weights / weights.mean().clamp_min(1e-6)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def collator_like_targets(labels, labels_mask, movement_label_config: Optional[Dict[str, float]] = None):
|
| 228 |
+
from data.context_targets import derive_movement_targets
|
| 229 |
+
|
| 230 |
+
labels_list = labels.tolist() if isinstance(labels, torch.Tensor) else labels
|
| 231 |
+
mask_list = labels_mask.tolist() if isinstance(labels_mask, torch.Tensor) else labels_mask
|
| 232 |
+
return derive_movement_targets(labels_list, mask_list, movement_label_config=movement_label_config)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
def create_balanced_split(dataset, n_val_per_class: int = 1, seed: int = 42):
|
| 236 |
"""
|
| 237 |
Create train/val split with balanced classes in validation set.
|
|
|
|
| 291 |
total_loss = 0.0
|
| 292 |
total_return_loss = 0.0
|
| 293 |
total_quality_loss = 0.0
|
| 294 |
+
total_movement_loss = 0.0
|
| 295 |
+
total_movement_acc = 0.0
|
| 296 |
n_batches = 0
|
| 297 |
|
| 298 |
# Per-class metrics
|
|
|
|
| 314 |
|
| 315 |
preds = outputs["quantile_logits"]
|
| 316 |
quality_preds = outputs["quality_logits"]
|
| 317 |
+
movement_logits = outputs.get("movement_logits")
|
| 318 |
labels = batch["labels"]
|
| 319 |
labels_mask = batch["labels_mask"]
|
| 320 |
quality_targets = batch["quality_score"].to(accelerator.device, dtype=quality_preds.dtype)
|
| 321 |
+
movement_targets = batch.get("movement_class_targets")
|
| 322 |
+
movement_mask = batch.get("movement_class_mask")
|
| 323 |
|
| 324 |
if labels_mask.sum() == 0:
|
| 325 |
return_loss = torch.tensor(0.0, device=accelerator.device)
|
|
|
|
| 329 |
return_loss = quantile_pinball_loss(preds, labels, labels_mask, quantiles)
|
| 330 |
|
| 331 |
quality_loss = quality_loss_fn(quality_preds, quality_targets)
|
| 332 |
+
movement_loss = torch.tensor(0.0, device=accelerator.device)
|
| 333 |
+
movement_acc = 0.0
|
| 334 |
+
if movement_logits is not None and movement_targets is not None and movement_mask is not None:
|
| 335 |
+
movement_targets = movement_targets.to(accelerator.device)
|
| 336 |
+
movement_mask = movement_mask.to(accelerator.device)
|
| 337 |
+
movement_loss = masked_movement_cross_entropy(
|
| 338 |
+
movement_logits, movement_targets, movement_mask
|
| 339 |
+
)
|
| 340 |
+
movement_acc = movement_accuracy(movement_logits, movement_targets, movement_mask)
|
| 341 |
+
loss = return_loss + quality_loss + movement_loss
|
| 342 |
|
| 343 |
total_loss += loss.item()
|
| 344 |
total_return_loss += return_loss.item()
|
| 345 |
total_quality_loss += quality_loss.item()
|
| 346 |
+
total_movement_loss += movement_loss.item()
|
| 347 |
+
total_movement_acc += movement_acc
|
| 348 |
n_batches += 1
|
| 349 |
|
| 350 |
# Track per-class losses
|
|
|
|
| 364 |
'val/loss': total_loss / n_batches,
|
| 365 |
'val/return_loss': total_return_loss / n_batches,
|
| 366 |
'val/quality_loss': total_quality_loss / n_batches,
|
| 367 |
+
'val/movement_loss': total_movement_loss / n_batches,
|
| 368 |
+
'val/movement_acc': total_movement_acc / n_batches,
|
| 369 |
'val/n_batches': n_batches,
|
| 370 |
'class_losses': {k: v['loss'] / max(v['count'], 1) for k, v in class_losses.items()}
|
| 371 |
}
|
|
|
|
| 548 |
parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Path to checkpoint or 'latest'")
|
| 549 |
parser.add_argument("--val_samples_per_class", type=int, default=1, help="Number of validation samples per class (default 1)")
|
| 550 |
parser.add_argument("--val_every", type=int, default=1000, help="Run validation every N steps (default 1000)")
|
| 551 |
+
parser.add_argument("--movement_loss_weight", type=float, default=1.0, help="Auxiliary loss weight for movement classification head.")
|
| 552 |
return parser.parse_args()
|
| 553 |
|
| 554 |
|
|
|
|
| 918 |
|
| 919 |
# --- 7. Training Loop ---
|
| 920 |
quality_loss_fn = nn.MSELoss()
|
| 921 |
+
movement_class_weights = estimate_movement_class_weights(
|
| 922 |
+
dataset,
|
| 923 |
+
train_indices,
|
| 924 |
+
).to(accelerator.device)
|
| 925 |
|
| 926 |
logger.info("***** Running training *****")
|
| 927 |
logger.info(f" Num examples = {len(dataset)}")
|
|
|
|
| 972 |
|
| 973 |
preds = outputs["quantile_logits"]
|
| 974 |
quality_preds = outputs["quality_logits"]
|
| 975 |
+
movement_logits = outputs.get("movement_logits")
|
| 976 |
labels = batch["labels"]
|
| 977 |
labels_mask = batch["labels_mask"]
|
| 978 |
if "quality_score" not in batch:
|
|
|
|
| 1028 |
per_sample_return = quantile_pinball_loss_per_sample(preds, labels, labels_mask, quantiles)
|
| 1029 |
|
| 1030 |
quality_loss = quality_loss_fn(quality_preds, quality_targets)
|
| 1031 |
+
movement_targets = batch.get("movement_class_targets")
|
| 1032 |
+
movement_mask = batch.get("movement_class_mask")
|
| 1033 |
+
if movement_logits is not None and movement_targets is not None and movement_mask is not None:
|
| 1034 |
+
movement_targets = movement_targets.to(accelerator.device)
|
| 1035 |
+
movement_mask = movement_mask.to(accelerator.device)
|
| 1036 |
+
movement_loss = masked_movement_cross_entropy(
|
| 1037 |
+
movement_logits,
|
| 1038 |
+
movement_targets,
|
| 1039 |
+
movement_mask,
|
| 1040 |
+
class_weights=movement_class_weights,
|
| 1041 |
+
)
|
| 1042 |
+
current_movement_acc = movement_accuracy(movement_logits, movement_targets, movement_mask)
|
| 1043 |
+
else:
|
| 1044 |
+
movement_loss = torch.tensor(0.0, device=accelerator.device, dtype=quality_loss.dtype)
|
| 1045 |
+
current_movement_acc = 0.0
|
| 1046 |
per_sample_quality = (quality_preds - quality_targets).pow(2)
|
| 1047 |
|
| 1048 |
# Apply per-sample class weighting to return loss
|
|
|
|
| 1051 |
sample_weights = class_loss_weights[batch_class_ids] # [B]
|
| 1052 |
# Scale return loss by mean class weight for this batch
|
| 1053 |
class_weight_factor = sample_weights.mean()
|
| 1054 |
+
loss = return_loss * class_weight_factor + quality_loss + (args.movement_loss_weight * movement_loss)
|
| 1055 |
else:
|
| 1056 |
class_weight_factor = torch.tensor(1.0, device=accelerator.device, dtype=return_loss.dtype)
|
| 1057 |
+
loss = return_loss + quality_loss + (args.movement_loss_weight * movement_loss)
|
| 1058 |
per_sample_total = per_sample_return * class_weight_factor + per_sample_quality
|
| 1059 |
|
| 1060 |
if not torch.isfinite(loss).all().item():
|
|
|
|
| 1080 |
"loss": loss.unsqueeze(0),
|
| 1081 |
"return_loss": return_loss.unsqueeze(0),
|
| 1082 |
"quality_loss": quality_loss.unsqueeze(0),
|
| 1083 |
+
"movement_loss": movement_loss.unsqueeze(0),
|
| 1084 |
"preds": preds,
|
| 1085 |
"quality_preds": quality_preds,
|
| 1086 |
+
"movement_logits": movement_logits,
|
| 1087 |
"labels_raw": batch.get("labels"),
|
| 1088 |
"labels_log": labels,
|
| 1089 |
"labels_mask": labels_mask,
|
|
|
|
| 1190 |
current_loss = loss.item()
|
| 1191 |
current_return_loss = return_loss.item()
|
| 1192 |
current_quality_loss = quality_loss.item()
|
| 1193 |
+
current_movement_loss = movement_loss.item()
|
| 1194 |
current_class_weight_factor = float(class_weight_factor.item()) if isinstance(class_weight_factor, torch.Tensor) else float(class_weight_factor)
|
| 1195 |
current_mask_coverage = float(labels_mask.float().mean().item()) if labels_mask is not None else 0.0
|
| 1196 |
epoch_loss += current_loss
|
|
|
|
| 1206 |
"train/loss": current_loss,
|
| 1207 |
"train/return_loss": current_return_loss,
|
| 1208 |
"train/quality_loss": current_quality_loss,
|
| 1209 |
+
"train/movement_loss": current_movement_loss,
|
| 1210 |
+
"train/movement_acc": current_movement_acc,
|
| 1211 |
"train/class_weight_factor": current_class_weight_factor,
|
| 1212 |
"train/mask_coverage": current_mask_coverage,
|
| 1213 |
"train/loss_ema": loss_ema if loss_ema is not None else current_loss,
|
|
|
|
| 1277 |
logger.info(
|
| 1278 |
f"Validation - Loss: {val_metrics['val/loss']:.4f} | "
|
| 1279 |
f"Return: {val_metrics['val/return_loss']:.4f} | "
|
| 1280 |
+
f"Quality: {val_metrics['val/quality_loss']:.4f} | "
|
| 1281 |
+
f"Movement: {val_metrics['val/movement_loss']:.4f} | "
|
| 1282 |
+
f"MoveAcc: {val_metrics['val/movement_acc']:.4f}"
|
| 1283 |
)
|
| 1284 |
# Log per-class losses
|
| 1285 |
class_loss_str = " | ".join(
|
|
|
|
| 1292 |
"val/loss": val_metrics['val/loss'],
|
| 1293 |
"val/return_loss": val_metrics['val/return_loss'],
|
| 1294 |
"val/quality_loss": val_metrics['val/quality_loss'],
|
| 1295 |
+
"val/movement_loss": val_metrics['val/movement_loss'],
|
| 1296 |
+
"val/movement_acc": val_metrics['val/movement_acc'],
|
| 1297 |
}, step=total_steps)
|
| 1298 |
# Log per-class losses to tensorboard
|
| 1299 |
for class_id, class_loss in val_metrics['class_losses'].items():
|