zirobtc commited on
Commit
7064310
·
verified ·
1 Parent(s): 2c39730

Upload folder using huggingface_hub

Browse files
QUALITY_SCORE_ARCHITECTURE.md CHANGED
@@ -10,4 +10,128 @@ T_cutoff at trade 700 → returns are -90%
10
  T_cutoff at trade 900 → returns are -95%
11
  So even for class 5 tokens, 80%+ of the cached training samples have negative Ground Truth labels. The model is correctly learning that at any random moment, even a "good" token is most likely going down. The class balancing doesn't change the fact that the actual Y labels are overwhelmingly negative across all classes.
12
 
13
- The model isn't broken — it learned exactly what the data showed it. The issue is that the training setup doesn't teach it to recognize the pre-pump moment specifically.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  T_cutoff at trade 900 → returns are -95%
11
  So even for class 5 tokens, 80%+ of the cached training samples have negative Ground Truth labels. The model is correctly learning that at any random moment, even a "good" token is most likely going down. The class balancing doesn't change the fact that the actual Y labels are overwhelmingly negative across all classes.
12
 
13
+ The model isn't broken — it learned exactly what the data showed it. The issue is that the training setup doesn't teach it to recognize the pre-pump moment specifically.
14
+
15
+
16
+ **Main Issue**
17
+ Your main problem was never just “bad checkpoint choice.” The core issue is training/data misalignment:
18
+
19
+ - token `class_id` is token-level
20
+ - the prediction target is context-level from random `T_cutoff`
21
+ - even a good token produces many bad windows
22
+ - so balanced token classes do not mean balanced future-return labels
23
+ - the model then learns an over-negative prior
24
+
25
+ A second major issue was cache construction:
26
+ - cache was wasting disk/time on overwhelming numbers of garbage-token samples
27
+ - later training weights cannot fix that upstream waste
28
+
29
+ **What We Figured Out**
30
+ - The model is not useless.
31
+ - Wallet signal is real: ablations showed wallet removal hurts predictions materially.
32
+ - OHLC matters, but mostly as a coarse summary, not real chart-pattern intelligence.
33
+ - No obvious future leakage was found in OHLC construction.
34
+ - Social looked basically unused.
35
+ - Graph looked weaker than expected.
36
+ - The movement head idea is valid, but only if labels are placed correctly in the pipeline.
37
+ - Movement labels should come from the data loader, not be derived later in collator/training.
38
+ - Cache balancing should not depend on fragile movement thresholds.
39
+ - A single “movement class” for cache weighting was wrong because:
40
+ - thresholds were unresolved
41
+ - movement differs across horizons inside the same sample
42
+
43
+ **Where You Corrected the Direction**
44
+ You pushed on several important bad assumptions:
45
+
46
+ - `return > 0` is too noisy as a label
47
+ - movement class names should be threshold-agnostic
48
+ - threshold-based movement balancing was premature
49
+ - SQL/global distribution threshold inference was conceptually wrong because labels depend on sampled `T_cutoff`
50
+ - cache should not be filtered by class map in a destructive way
51
+ - cache balancing must happen at cache generation time, not be delegated to train weights
52
+ - positive balancing should not be forced on garbage classes
53
+ - exact class sample counts matter more than approximate expected weighting
54
+ - `T_cutoff` does not need to be deterministic or pre-fixed
55
+ - if cache balancing uses movement-like signals, use threshold-free binary polarity first
56
+
57
+ Those corrections materially improved the design.
58
+
59
+ **Proposed Methods Over the Chat**
60
+ These were the main methods proposed, in order of evolution:
61
+
62
+ 1. Forward time validation and token-grouped splits
63
+ - to reduce misleading val results and leakage risk
64
+
65
+ 2. Auxiliary head ideas
66
+ - first fixed pump heads
67
+ - then all-horizon direction head
68
+ - then movement-type multiclass head
69
+ - final stable view: one multi-horizon movement head is reasonable, but labels must be created correctly
70
+
71
+ 3. Runtime/loader-side label derivation
72
+ - final agreed direction:
73
+ - labels belong in the data loader
74
+ - collator should only stack them
75
+ - model should just consume them
76
+
77
+ 4. Cache-time balancing instead of train-time rescue
78
+ - because disk/time waste happens before training starts
79
+ - so train weights alone are too late
80
+
81
+ 5. Class-id-based cache expansion
82
+ - proposed because class `0` dominates raw token counts
83
+ - later refined because exact quotas matter more than soft weighting
84
+
85
+ 6. Movement-class-based cache balancing
86
+ - proposed, then rejected correctly
87
+ - because it depended on unresolved thresholds and collapsed multi-horizon information incorrectly
88
+
89
+ 7. Binary polarity cache balancing
90
+ - final better version:
91
+ - use `positive` if max valid horizon return in a sample is `> 0`
92
+ - else `negative`
93
+ - this is threshold-free and less brittle
94
+
95
+ 8. Exact class quotas + class-conditional polarity quotas
96
+ - final strongest cache proposal:
97
+ - exact equal sample budget per class
98
+ - positive/negative balancing only for classes that can realistically produce positive contexts
99
+ - keep class `0` mostly negative
100
+ - keep `T_cutoff` random
101
+
102
+ **Current Best Design**
103
+ The strongest design we converged toward is:
104
+
105
+ - keep `T_cutoff` random
106
+ - make cache generation deterministic at the planning level
107
+ - assign exact sample budgets per token class
108
+ - for higher classes, request positive/negative context ratios
109
+ - for low garbage classes, do not force positive quotas
110
+ - build labels in the data loader
111
+ - keep the model’s main task as future return prediction
112
+ - use the movement head only as auxiliary supervision
113
+
114
+ So:
115
+ - token-type balance is controlled explicitly
116
+ - context-type bias is controlled explicitly
117
+ - disk usage is controlled at cache creation time
118
+ - training does not need to rescue a broken corpus
119
+
120
+ **What We Learned About the Model**
121
+ - backbone has useful signal
122
+ - wallet features matter
123
+ - OHLC is being used, but not in the rich pattern-detection way you hoped
124
+ - current main failure mode is much more about data construction and sampling than about the backbone learning nothing
125
+
126
+ **High-Level Conclusions**
127
+ 1. The main issue is data/sample construction, not just optimizer/checkpoint behavior.
128
+ 2. Cache balancing is the right place to solve the waste and prior-bias problem.
129
+ 3. Threshold-dependent movement balancing was too early and too brittle.
130
+ 4. Binary short-horizon max-return polarity is a better first cache-balancing signal.
131
+ 5. Exact per-class sample quotas plus class-conditional polarity balancing is the most coherent cache design we arrived at.
132
+
133
+ If you want, I can next turn this into a short final blueprint with only:
134
+ - `Main issue`
135
+ - `Accepted design`
136
+ - `Rejected ideas`
137
+ - `Next implementation order`
data/context_targets.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Dict, List, Sequence
4
+
5
+
6
+ MOVEMENT_STRONG_DOWN_THRESHOLD = -0.40
7
+ MOVEMENT_DOWN_THRESHOLD = -0.30
8
+ MOVEMENT_PUMP_50_THRESHOLD = 0.50
9
+ MOVEMENT_PUMP_100_THRESHOLD = 1.00
10
+ MOVEMENT_PUMP_300_THRESHOLD = 3.00
11
+
12
+ MOVEMENT_CLASS_NAMES = [
13
+ "strong_down",
14
+ "down",
15
+ "flat",
16
+ "up",
17
+ "strong_up",
18
+ "extreme_up",
19
+ ]
20
+ MOVEMENT_CLASS_TO_ID = {name: idx for idx, name in enumerate(MOVEMENT_CLASS_NAMES)}
21
+ MOVEMENT_ID_TO_CLASS = {idx: name for name, idx in MOVEMENT_CLASS_TO_ID.items()}
22
+
23
+ DEFAULT_MOVEMENT_LABEL_CONFIG = {
24
+ "strong_down_threshold": MOVEMENT_STRONG_DOWN_THRESHOLD,
25
+ "down_threshold": MOVEMENT_DOWN_THRESHOLD,
26
+ "pump_50_threshold": MOVEMENT_PUMP_50_THRESHOLD,
27
+ "pump_100_threshold": MOVEMENT_PUMP_100_THRESHOLD,
28
+ "pump_300_threshold": MOVEMENT_PUMP_300_THRESHOLD,
29
+ }
30
+
31
+
32
+ def classify_movement_return(
33
+ return_value: float,
34
+ movement_label_config: Dict[str, float] | None = None,
35
+ ) -> int:
36
+ cfg = dict(DEFAULT_MOVEMENT_LABEL_CONFIG)
37
+ if movement_label_config:
38
+ cfg.update({k: float(v) for k, v in movement_label_config.items() if k in cfg})
39
+
40
+ strong_down_threshold = min(cfg["strong_down_threshold"], cfg["down_threshold"])
41
+ down_threshold = cfg["down_threshold"]
42
+ pump_50_threshold = cfg["pump_50_threshold"]
43
+ pump_100_threshold = cfg["pump_100_threshold"]
44
+ pump_300_threshold = cfg["pump_300_threshold"]
45
+
46
+ if return_value <= strong_down_threshold:
47
+ return MOVEMENT_CLASS_TO_ID["strong_down"]
48
+ if return_value < down_threshold:
49
+ return MOVEMENT_CLASS_TO_ID["down"]
50
+ if return_value < pump_50_threshold:
51
+ return MOVEMENT_CLASS_TO_ID["flat"]
52
+ if return_value < pump_100_threshold:
53
+ return MOVEMENT_CLASS_TO_ID["up"]
54
+ if return_value < pump_300_threshold:
55
+ return MOVEMENT_CLASS_TO_ID["strong_up"]
56
+ return MOVEMENT_CLASS_TO_ID["extreme_up"]
57
+
58
+
59
+ def derive_movement_targets(
60
+ horizon_returns: Sequence[float],
61
+ horizon_mask: Sequence[float],
62
+ movement_label_config: Dict[str, float] | None = None,
63
+ ) -> Dict[str, List[int]]:
64
+ class_targets: List[int] = []
65
+ class_mask: List[int] = []
66
+ class_names: List[str] = []
67
+
68
+ usable = min(len(horizon_returns), len(horizon_mask))
69
+ for idx in range(usable):
70
+ if float(horizon_mask[idx]) <= 0:
71
+ class_targets.append(MOVEMENT_CLASS_TO_ID["flat"])
72
+ class_mask.append(0)
73
+ class_names.append("masked")
74
+ continue
75
+
76
+ class_id = classify_movement_return(
77
+ float(horizon_returns[idx]),
78
+ movement_label_config=movement_label_config,
79
+ )
80
+ class_targets.append(class_id)
81
+ class_mask.append(1)
82
+ class_names.append(MOVEMENT_ID_TO_CLASS[class_id])
83
+
84
+ return {
85
+ "movement_class_targets": class_targets,
86
+ "movement_class_mask": class_mask,
87
+ "movement_class_names": class_names,
88
+ }
89
+
90
+
91
+ def compute_movement_label_config(valid_returns: Sequence[float]) -> Dict[str, float]:
92
+ del valid_returns
93
+ return dict(DEFAULT_MOVEMENT_LABEL_CONFIG)
data/data_collator.py CHANGED
@@ -719,6 +719,8 @@ class MemecoinCollator:
719
  # Labels
720
  'labels': torch.stack([item['labels'] for item in batch]) if batch and 'labels' in batch[0] else None,
721
  'labels_mask': torch.stack([item['labels_mask'] for item in batch]) if batch and 'labels_mask' in batch[0] else None,
 
 
722
  'quality_score': torch.stack([item['quality_score'] if isinstance(item['quality_score'], torch.Tensor) else torch.tensor(item['quality_score'], dtype=torch.float32) for item in batch]) if batch and 'quality_score' in batch[0] else None,
723
  'class_id': torch.tensor([item.get('class_id', 0) for item in batch], dtype=torch.long),
724
  # Debug info
 
719
  # Labels
720
  'labels': torch.stack([item['labels'] for item in batch]) if batch and 'labels' in batch[0] else None,
721
  'labels_mask': torch.stack([item['labels_mask'] for item in batch]) if batch and 'labels_mask' in batch[0] else None,
722
+ 'movement_class_targets': torch.stack([item['movement_class_targets'] for item in batch]) if batch and 'movement_class_targets' in batch[0] else None,
723
+ 'movement_class_mask': torch.stack([item['movement_class_mask'] for item in batch]) if batch and 'movement_class_mask' in batch[0] else None,
724
  'quality_score': torch.stack([item['quality_score'] if isinstance(item['quality_score'], torch.Tensor) else torch.tensor(item['quality_score'], dtype=torch.float32) for item in batch]) if batch and 'quality_score' in batch[0] else None,
725
  'class_id': torch.tensor([item.get('class_id', 0) for item in batch], dtype=torch.long),
726
  # Debug info
data/data_loader.py CHANGED
@@ -17,6 +17,7 @@ import json
17
  import models.vocabulary as vocab
18
  from models.multi_modal_processor import MultiModalEncoder
19
  from data.data_fetcher import DataFetcher # NEW: Import the DataFetcher
 
20
  from requests.adapters import HTTPAdapter
21
  from urllib3.util.retry import Retry
22
 
@@ -128,7 +129,8 @@ class OracleDataset(Dataset):
128
  start_date: Optional[datetime.datetime] = None,
129
  min_trade_usd: float = 0.0,
130
  max_seq_len: int = 8192,
131
- p99_clamps: Optional[Dict[str, float]] = None):
 
132
 
133
  self.max_seq_len = max_seq_len
134
 
@@ -315,6 +317,7 @@ class OracleDataset(Dataset):
315
 
316
  self.min_trade_usd = min_trade_usd
317
  self._uri_fail_counts: Dict[str, int] = {}
 
318
 
319
  def _init_http_session(self) -> None:
320
  # Configure robust HTTP session
@@ -1199,6 +1202,23 @@ class OracleDataset(Dataset):
1199
  # This is fully deterministic - no runtime sampling or processing
1200
  _timings['total'] = _time.perf_counter() - _total_start
1201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1202
  if idx % 100 == 0:
1203
  print(f"[Sample {idx}] CONTEXT mode | cache_load: {_timings['cache_load']*1000:.1f}ms | "
1204
  f"total: {_timings['total']*1000:.1f}ms | events: {len(cached_data.get('event_sequence', []))}")
@@ -2449,6 +2469,11 @@ class OracleDataset(Dataset):
2449
 
2450
  if not all_trades:
2451
  # No valid trades for label computation
 
 
 
 
 
2452
  return {
2453
  'event_sequence': event_sequence,
2454
  'wallets': wallet_data,
@@ -2457,7 +2482,9 @@ class OracleDataset(Dataset):
2457
  'embedding_pooler': pooler,
2458
  'labels': torch.zeros(len(self.horizons_seconds), dtype=torch.float32),
2459
  'labels_mask': torch.zeros(len(self.horizons_seconds), dtype=torch.float32),
2460
- 'quality_score': torch.tensor(quality_score if quality_score is not None else 0.0, dtype=torch.float32)
 
 
2461
  }
2462
 
2463
  # Ensure sorted
@@ -2537,6 +2564,12 @@ class OracleDataset(Dataset):
2537
 
2538
  # DEBUG: Mask summaries removed after validation
2539
 
 
 
 
 
 
 
2540
  return {
2541
  'sample_idx': sample_idx if sample_idx is not None else -1, # Debug trace
2542
  'token_address': token_address, # For debugging
@@ -2548,7 +2581,9 @@ class OracleDataset(Dataset):
2548
  'embedding_pooler': pooler,
2549
  'labels': torch.tensor(label_values, dtype=torch.float32),
2550
  'labels_mask': torch.tensor(mask_values, dtype=torch.float32),
2551
- 'quality_score': torch.tensor(quality_score if quality_score is not None else 0.0, dtype=torch.float32)
 
 
2552
  }
2553
 
2554
  def _embed_context(self, context: Dict[str, Any], encoder: Any) -> None:
 
17
  import models.vocabulary as vocab
18
  from models.multi_modal_processor import MultiModalEncoder
19
  from data.data_fetcher import DataFetcher # NEW: Import the DataFetcher
20
+ from data.context_targets import derive_movement_targets
21
  from requests.adapters import HTTPAdapter
22
  from urllib3.util.retry import Retry
23
 
 
129
  start_date: Optional[datetime.datetime] = None,
130
  min_trade_usd: float = 0.0,
131
  max_seq_len: int = 8192,
132
+ p99_clamps: Optional[Dict[str, float]] = None,
133
+ movement_label_config: Optional[Dict[str, float]] = None):
134
 
135
  self.max_seq_len = max_seq_len
136
 
 
317
 
318
  self.min_trade_usd = min_trade_usd
319
  self._uri_fail_counts: Dict[str, int] = {}
320
+ self.movement_label_config = movement_label_config
321
 
322
  def _init_http_session(self) -> None:
323
  # Configure robust HTTP session
 
1202
  # This is fully deterministic - no runtime sampling or processing
1203
  _timings['total'] = _time.perf_counter() - _total_start
1204
 
1205
+ if 'movement_class_targets' not in cached_data and 'labels' in cached_data and 'labels_mask' in cached_data:
1206
+ labels = cached_data['labels']
1207
+ labels_mask = cached_data['labels_mask']
1208
+ movement_targets = derive_movement_targets(
1209
+ labels.tolist() if isinstance(labels, torch.Tensor) else labels,
1210
+ labels_mask.tolist() if isinstance(labels_mask, torch.Tensor) else labels_mask,
1211
+ movement_label_config=self.movement_label_config,
1212
+ )
1213
+ cached_data['movement_class_targets'] = torch.tensor(
1214
+ movement_targets['movement_class_targets'],
1215
+ dtype=torch.long,
1216
+ )
1217
+ cached_data['movement_class_mask'] = torch.tensor(
1218
+ movement_targets['movement_class_mask'],
1219
+ dtype=torch.long,
1220
+ )
1221
+
1222
  if idx % 100 == 0:
1223
  print(f"[Sample {idx}] CONTEXT mode | cache_load: {_timings['cache_load']*1000:.1f}ms | "
1224
  f"total: {_timings['total']*1000:.1f}ms | events: {len(cached_data.get('event_sequence', []))}")
 
2469
 
2470
  if not all_trades:
2471
  # No valid trades for label computation
2472
+ movement_targets = derive_movement_targets(
2473
+ [0.0] * len(self.horizons_seconds),
2474
+ [0.0] * len(self.horizons_seconds),
2475
+ movement_label_config=self.movement_label_config,
2476
+ )
2477
  return {
2478
  'event_sequence': event_sequence,
2479
  'wallets': wallet_data,
 
2482
  'embedding_pooler': pooler,
2483
  'labels': torch.zeros(len(self.horizons_seconds), dtype=torch.float32),
2484
  'labels_mask': torch.zeros(len(self.horizons_seconds), dtype=torch.float32),
2485
+ 'quality_score': torch.tensor(quality_score if quality_score is not None else 0.0, dtype=torch.float32),
2486
+ 'movement_class_targets': torch.tensor(movement_targets['movement_class_targets'], dtype=torch.long),
2487
+ 'movement_class_mask': torch.tensor(movement_targets['movement_class_mask'], dtype=torch.long),
2488
  }
2489
 
2490
  # Ensure sorted
 
2564
 
2565
  # DEBUG: Mask summaries removed after validation
2566
 
2567
+ movement_targets = derive_movement_targets(
2568
+ label_values,
2569
+ mask_values,
2570
+ movement_label_config=self.movement_label_config,
2571
+ )
2572
+
2573
  return {
2574
  'sample_idx': sample_idx if sample_idx is not None else -1, # Debug trace
2575
  'token_address': token_address, # For debugging
 
2581
  'embedding_pooler': pooler,
2582
  'labels': torch.tensor(label_values, dtype=torch.float32),
2583
  'labels_mask': torch.tensor(mask_values, dtype=torch.float32),
2584
+ 'quality_score': torch.tensor(quality_score if quality_score is not None else 0.0, dtype=torch.float32),
2585
+ 'movement_class_targets': torch.tensor(movement_targets['movement_class_targets'], dtype=torch.long),
2586
+ 'movement_class_mask': torch.tensor(movement_targets['movement_class_mask'], dtype=torch.long),
2587
  }
2588
 
2589
  def _embed_context(self, context: Dict[str, Any], encoder: Any) -> None:
log.log CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:11ac013f8e91ad65475b8106a5a072dc42f67e0773ddc4a50825e316c578e0d4
3
- size 3472
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:935233e4d7669b2a25173d7ae164317e85f1a5e8b0fc1d8d1832ab0893fca471
3
+ size 19258
models/model.py CHANGED
@@ -17,6 +17,7 @@ from models.ohlc_embedder import OHLCEmbedder
17
  from models.HoldersEncoder import HolderDistributionEncoder # NEW
18
  from models.SocialEncoders import SocialEncoder # NEW
19
  import models.vocabulary as vocab # For vocab sizes
 
20
 
21
  class Oracle(nn.Module):
22
  """
@@ -51,6 +52,7 @@ class Oracle(nn.Module):
51
  self.quantiles = quantiles
52
  self.horizons_seconds = horizons_seconds
53
  self.num_outputs = len(quantiles) * len(horizons_seconds)
 
54
  self.dtype = dtype
55
 
56
  # --- 2. Backbone: Llama-style decoder, RANDOM INIT (no pretrained weights) ---
@@ -103,6 +105,11 @@ class Oracle(nn.Module):
103
  nn.GELU(),
104
  nn.Linear(self.d_model, 1)
105
  )
 
 
 
 
 
106
 
107
  self.event_type_to_id = event_type_to_id
108
 
@@ -1008,9 +1015,11 @@ class Oracle(nn.Module):
1008
  empty_mask = torch.empty(0, L, device=device, dtype=torch.long)
1009
  empty_quantiles = torch.empty(0, self.num_outputs, device=device, dtype=self.dtype)
1010
  empty_quality = torch.empty(0, device=device, dtype=self.dtype)
 
1011
  return {
1012
  'quantile_logits': empty_quantiles,
1013
  'quality_logits': empty_quality,
 
1014
  'pooled_states': torch.empty(0, self.d_model, device=device, dtype=self.dtype),
1015
  'hidden_states': empty_hidden,
1016
  'attention_mask': empty_mask
@@ -1131,10 +1140,16 @@ class Oracle(nn.Module):
1131
  pooled_states = self._pool_hidden_states(sequence_hidden, hf_attention_mask)
1132
  quantile_logits = self.quantile_head(pooled_states)
1133
  quality_logits = self.quality_head(pooled_states).squeeze(-1)
 
 
 
 
 
1134
 
1135
  return {
1136
  'quantile_logits': quantile_logits,
1137
  'quality_logits': quality_logits,
 
1138
  'pooled_states': pooled_states,
1139
  'hidden_states': sequence_hidden,
1140
  'attention_mask': hf_attention_mask
 
17
  from models.HoldersEncoder import HolderDistributionEncoder # NEW
18
  from models.SocialEncoders import SocialEncoder # NEW
19
  import models.vocabulary as vocab # For vocab sizes
20
+ from data.context_targets import MOVEMENT_CLASS_NAMES
21
 
22
  class Oracle(nn.Module):
23
  """
 
52
  self.quantiles = quantiles
53
  self.horizons_seconds = horizons_seconds
54
  self.num_outputs = len(quantiles) * len(horizons_seconds)
55
+ self.num_movement_classes = len(MOVEMENT_CLASS_NAMES)
56
  self.dtype = dtype
57
 
58
  # --- 2. Backbone: Llama-style decoder, RANDOM INIT (no pretrained weights) ---
 
105
  nn.GELU(),
106
  nn.Linear(self.d_model, 1)
107
  )
108
+ self.movement_head = nn.Sequential(
109
+ nn.Linear(self.d_model, self.d_model),
110
+ nn.GELU(),
111
+ nn.Linear(self.d_model, len(self.horizons_seconds) * self.num_movement_classes)
112
+ )
113
 
114
  self.event_type_to_id = event_type_to_id
115
 
 
1015
  empty_mask = torch.empty(0, L, device=device, dtype=torch.long)
1016
  empty_quantiles = torch.empty(0, self.num_outputs, device=device, dtype=self.dtype)
1017
  empty_quality = torch.empty(0, device=device, dtype=self.dtype)
1018
+ empty_movement = torch.empty(0, len(self.horizons_seconds), self.num_movement_classes, device=device, dtype=self.dtype)
1019
  return {
1020
  'quantile_logits': empty_quantiles,
1021
  'quality_logits': empty_quality,
1022
+ 'movement_logits': empty_movement,
1023
  'pooled_states': torch.empty(0, self.d_model, device=device, dtype=self.dtype),
1024
  'hidden_states': empty_hidden,
1025
  'attention_mask': empty_mask
 
1140
  pooled_states = self._pool_hidden_states(sequence_hidden, hf_attention_mask)
1141
  quantile_logits = self.quantile_head(pooled_states)
1142
  quality_logits = self.quality_head(pooled_states).squeeze(-1)
1143
+ movement_logits = self.movement_head(pooled_states).view(
1144
+ pooled_states.shape[0],
1145
+ len(self.horizons_seconds),
1146
+ self.num_movement_classes,
1147
+ )
1148
 
1149
  return {
1150
  'quantile_logits': quantile_logits,
1151
  'quality_logits': quality_logits,
1152
+ 'movement_logits': movement_logits,
1153
  'pooled_states': pooled_states,
1154
  'hidden_states': sequence_hidden,
1155
  'attention_mask': hf_attention_mask
pre_cache.sh CHANGED
@@ -39,7 +39,7 @@ python3 scripts/cache_dataset.py \
39
  --num_workers "$NUM_WORKERS" \
40
  --horizons_seconds "${HORIZONS_SECONDS[@]}" \
41
  --quantiles "${QUANTILES[@]}" \
42
- --max_samples 300000 \
43
  "$@"
44
 
45
  echo "Done!"
 
39
  --num_workers "$NUM_WORKERS" \
40
  --horizons_seconds "${HORIZONS_SECONDS[@]}" \
41
  --quantiles "${QUANTILES[@]}" \
42
+ --max_samples 10 \
43
  "$@"
44
 
45
  echo "Done!"
scripts/cache_dataset.py CHANGED
@@ -27,13 +27,154 @@ from scripts.compute_quality_score import get_token_quality_scores, fetch_token_
27
 
28
  from clickhouse_driver import Client as ClickHouseClient
29
  from neo4j import GraphDatabase
30
-
31
  _worker_dataset = None
32
  _worker_return_class_map = None
33
  _worker_quality_scores_map = None
34
  _worker_encoder = None
35
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map):
38
  global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map
39
  from data.data_loader import OracleDataset
@@ -73,7 +214,7 @@ def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map
73
 
74
 
75
  def _process_single_token_context(args):
76
- idx, mint_addr, samples_per_token, output_dir = args
77
  global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map
78
  try:
79
  class_id = _worker_return_class_map.get(mint_addr)
@@ -87,7 +228,17 @@ def _process_single_token_context(args):
87
  if encoder is None:
88
  print(f"ERROR: Worker encoder is None for mint {mint_addr}!", flush=True)
89
 
90
- contexts = _worker_dataset.__cacheitem_context__(idx, num_samples_per_token=samples_per_token, encoder=encoder)
 
 
 
 
 
 
 
 
 
 
91
  if not contexts:
92
  return {'status': 'skipped', 'reason': 'no valid contexts', 'mint': mint_addr}
93
  q_score = _worker_quality_scores_map.get(mint_addr)
@@ -102,7 +253,16 @@ def _process_single_token_context(args):
102
 
103
  torch.save(ctx, output_path)
104
  saved_files.append(filename)
105
- return {'status': 'success', 'mint': mint_addr, 'class_id': class_id, 'q_score': q_score, 'n_contexts': len(contexts), 'n_events': len(contexts[0].get('event_sequence', [])) if contexts else 0, 'files': saved_files}
 
 
 
 
 
 
 
 
 
106
  except Exception as e:
107
  import traceback
108
  return {'status': 'error', 'mint': mint_addr, 'error': str(e), 'traceback': traceback.format_exc()}
@@ -132,6 +292,10 @@ def main():
132
  parser.add_argument("--context_length", type=int, default=8192)
133
  parser.add_argument("--min_trades", type=int, default=10)
134
  parser.add_argument("--samples_per_token", type=int, default=1)
 
 
 
 
135
  parser.add_argument("--horizons_seconds", type=int, nargs="+", default=[30, 60, 120, 240, 420])
136
  parser.add_argument("--quantiles", type=float, nargs="+", default=[0.1, 0.5, 0.9])
137
  parser.add_argument("--num_workers", type=int, default=1)
@@ -223,7 +387,6 @@ def main():
223
 
224
  print(f"INFO: Eligible tokens per class: {dict(sorted(eligible_class_counts.items()))}")
225
 
226
- # Compute balanced samples_per_token for each class
227
  num_classes = len(eligible_class_counts)
228
  if args.max_samples:
229
  target_total = args.max_samples
@@ -231,44 +394,46 @@ def main():
231
  target_total = 15000 # Default target: 15k balanced files
232
  target_per_class = target_total // max(num_classes, 1)
233
 
234
- class_multipliers = {}
235
- class_token_caps = {}
236
- for cid, count in eligible_class_counts.items():
237
- if count >= target_per_class:
238
- # Enough tokens — 1 sample each, cap token count
239
- class_multipliers[cid] = 1
240
- class_token_caps[cid] = target_per_class
241
- else:
242
- # Not enough tokens — multi-sample, use all tokens
243
- class_multipliers[cid] = min(10, max(1, math.ceil(target_per_class / max(count, 1))))
244
- class_token_caps[cid] = count
245
 
246
  print(f"INFO: Target total: {target_total}, Target per class: {target_per_class}")
247
- print(f"INFO: Class multipliers: {dict(sorted(class_multipliers.items()))}")
248
- print(f"INFO: Class token caps: {dict(sorted(class_token_caps.items()))}")
249
 
250
  # Build balanced task list
251
  tasks = []
252
- for cid, mint_list in mints_by_class.items():
253
- random.shuffle(mint_list)
254
- cap = class_token_caps.get(cid, len(mint_list))
255
- spt = class_multipliers.get(cid, 1)
256
- # Override with CLI --samples_per_token if explicitly set > 1
257
- if args.samples_per_token > 1:
258
- spt = args.samples_per_token
259
- for i, m in mint_list[:cap]:
260
- mint_addr = m['mint_address']
261
- tasks.append((i, mint_addr, spt, str(output_dir)))
 
 
 
 
 
 
 
 
 
262
 
263
  random.shuffle(tasks) # Shuffle tasks for even load distribution across workers
264
- expected_files = sum(
265
- class_multipliers.get(cid, 1) * min(class_token_caps.get(cid, len(ml)), len(ml))
266
- for cid, ml in mints_by_class.items()
267
- )
268
  print(f"INFO: Total tasks: {len(tasks)} (expected ~{expected_files} output files, target ~{target_total})")
269
 
270
  success_count, skipped_count, error_count = 0, 0, 0
271
  class_distribution = {}
 
272
 
273
  # --- Resume support: skip tokens that already have cached files ---
274
  existing_files = set(f.name for f in output_dir.glob("sample_*.pt"))
@@ -334,6 +499,8 @@ def main():
334
  if result['status'] == 'success':
335
  success_count += 1
336
  class_distribution[result['class_id']] = class_distribution.get(result['class_id'], 0) + 1
 
 
337
  elif result['status'] == 'skipped':
338
  skipped_count += 1
339
  else:
@@ -360,6 +527,8 @@ def main():
360
  if result['status'] == 'success':
361
  success_count += 1
362
  class_distribution[result['class_id']] = class_distribution.get(result['class_id'], 0) + 1
 
 
363
  elif result['status'] == 'skipped':
364
  skipped_count += 1
365
  else:
@@ -398,10 +567,14 @@ def main():
398
  'num_workers': args.num_workers,
399
  'horizons_seconds': args.horizons_seconds,
400
  'quantiles': args.quantiles,
401
- 'class_multipliers': {str(k): v for k, v in class_multipliers.items()},
402
- 'class_token_caps': {str(k): v for k, v in class_token_caps.items()},
403
  'target_total': target_total,
404
  'target_per_class': target_per_class,
 
 
 
 
 
 
405
  }, f, indent=2)
406
 
407
  print(f"\n--- Done ---\nSuccess: {success_count}, Skipped: {skipped_count}, Errors: {error_count}\nFiles: {len(file_class_map)}\nLocation: {output_dir.resolve()}")
 
27
 
28
  from clickhouse_driver import Client as ClickHouseClient
29
  from neo4j import GraphDatabase
 
30
  _worker_dataset = None
31
  _worker_return_class_map = None
32
  _worker_quality_scores_map = None
33
  _worker_encoder = None
34
 
35
 
36
+ def _to_int_list(values):
37
+ if values is None:
38
+ return []
39
+ if isinstance(values, torch.Tensor):
40
+ return [int(v) for v in values.tolist()]
41
+ return [int(v) for v in values]
42
+
43
+
44
+ def _to_float_list(values):
45
+ if values is None:
46
+ return []
47
+ if isinstance(values, torch.Tensor):
48
+ return [float(v) for v in values.tolist()]
49
+ return [float(v) for v in values]
50
+
51
+
52
+ def _representative_context_polarity(context):
53
+ labels = _to_float_list(context.get("labels"))
54
+ mask = _to_int_list(context.get("labels_mask"))
55
+ valid_returns = [label for label, keep in zip(labels, mask) if keep > 0]
56
+ if not valid_returns:
57
+ return "negative"
58
+ return "positive" if max(valid_returns) > 0.0 else "negative"
59
+
60
+
61
+ def _select_contexts_by_polarity(contexts, max_keep, desired_positive=None, desired_negative=None):
62
+ if len(contexts) <= max_keep:
63
+ polarity_counts = {}
64
+ for context in contexts:
65
+ polarity = _representative_context_polarity(context)
66
+ polarity_counts[polarity] = polarity_counts.get(polarity, 0) + 1
67
+ context["representative_context_polarity"] = polarity
68
+ return contexts, polarity_counts
69
+
70
+ positive_bucket = []
71
+ negative_bucket = []
72
+ for context in contexts:
73
+ polarity = _representative_context_polarity(context)
74
+ context["representative_context_polarity"] = polarity
75
+ if polarity == "positive":
76
+ positive_bucket.append(context)
77
+ else:
78
+ negative_bucket.append(context)
79
+
80
+ selected = []
81
+ polarity_counts = {"positive": 0, "negative": 0}
82
+ desired_positive = max(0, int(desired_positive)) if desired_positive is not None else None
83
+ desired_negative = max(0, int(desired_negative)) if desired_negative is not None else None
84
+
85
+ if desired_positive is not None or desired_negative is not None:
86
+ target_positive = min(desired_positive or 0, max_keep, len(positive_bucket))
87
+ target_negative = min(desired_negative or 0, max_keep - target_positive, len(negative_bucket))
88
+
89
+ while polarity_counts["positive"] < target_positive and positive_bucket:
90
+ selected.append(positive_bucket.pop())
91
+ polarity_counts["positive"] += 1
92
+ while polarity_counts["negative"] < target_negative and negative_bucket:
93
+ selected.append(negative_bucket.pop())
94
+ polarity_counts["negative"] += 1
95
+
96
+ prefer_positive = len(positive_bucket) >= len(negative_bucket)
97
+
98
+ while len(selected) < max_keep and (positive_bucket or negative_bucket):
99
+ if prefer_positive and positive_bucket:
100
+ selected.append(positive_bucket.pop())
101
+ polarity_counts["positive"] += 1
102
+ elif not prefer_positive and negative_bucket:
103
+ selected.append(negative_bucket.pop())
104
+ polarity_counts["negative"] += 1
105
+ elif positive_bucket:
106
+ selected.append(positive_bucket.pop())
107
+ polarity_counts["positive"] += 1
108
+ elif negative_bucket:
109
+ selected.append(negative_bucket.pop())
110
+ polarity_counts["negative"] += 1
111
+ prefer_positive = not prefer_positive
112
+
113
+ return selected[:max_keep], polarity_counts
114
+
115
+
116
+ def _allocate_class_targets(mints_by_class, target_total, positive_balance_min_class, positive_ratio):
117
+ from collections import defaultdict
118
+ import random
119
+
120
+ class_ids = sorted(mints_by_class.keys())
121
+ if not class_ids:
122
+ return {}, {}, {}
123
+
124
+ target_per_class = target_total // len(class_ids)
125
+ remainder = target_total % len(class_ids)
126
+
127
+ token_plans = {}
128
+ class_targets = {}
129
+ class_polarity_targets = {}
130
+
131
+ for pos, class_id in enumerate(class_ids):
132
+ class_target = target_per_class + (1 if pos < remainder else 0)
133
+ class_targets[class_id] = class_target
134
+
135
+ token_list = list(mints_by_class[class_id])
136
+ random.shuffle(token_list)
137
+ if not token_list or class_target <= 0:
138
+ class_polarity_targets[class_id] = {"positive": 0, "negative": 0}
139
+ continue
140
+
141
+ if class_id >= positive_balance_min_class:
142
+ positive_target = int(round(class_target * positive_ratio))
143
+ positive_target = min(max(positive_target, 0), class_target)
144
+ else:
145
+ positive_target = 0
146
+ negative_target = class_target - positive_target
147
+ class_polarity_targets[class_id] = {
148
+ "positive": positive_target,
149
+ "negative": negative_target,
150
+ }
151
+
152
+ assigned_positive = 0
153
+ assigned_negative = 0
154
+ token_count = len(token_list)
155
+ for sample_num in range(class_target):
156
+ token_idx, mint_record = token_list[sample_num % token_count]
157
+ mint_addr = mint_record["mint_address"]
158
+ plan_key = (token_idx, mint_addr)
159
+ if plan_key not in token_plans:
160
+ token_plans[plan_key] = {
161
+ "samples_to_keep": 0,
162
+ "desired_positive": 0,
163
+ "desired_negative": 0,
164
+ "class_id": class_id,
165
+ }
166
+ token_plans[plan_key]["samples_to_keep"] += 1
167
+
168
+ if assigned_positive < positive_target:
169
+ token_plans[plan_key]["desired_positive"] += 1
170
+ assigned_positive += 1
171
+ else:
172
+ token_plans[plan_key]["desired_negative"] += 1
173
+ assigned_negative += 1
174
+
175
+ return token_plans, class_targets, class_polarity_targets
176
+
177
+
178
  def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map):
179
  global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map
180
  from data.data_loader import OracleDataset
 
214
 
215
 
216
  def _process_single_token_context(args):
217
+ idx, mint_addr, samples_per_token, output_dir, oversample_factor, desired_positive, desired_negative = args
218
  global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map
219
  try:
220
  class_id = _worker_return_class_map.get(mint_addr)
 
228
  if encoder is None:
229
  print(f"ERROR: Worker encoder is None for mint {mint_addr}!", flush=True)
230
 
231
+ candidate_contexts = _worker_dataset.__cacheitem_context__(
232
+ idx,
233
+ num_samples_per_token=max(samples_per_token, samples_per_token * max(1, oversample_factor)),
234
+ encoder=encoder,
235
+ )
236
+ contexts, polarity_counts = _select_contexts_by_polarity(
237
+ candidate_contexts,
238
+ samples_per_token,
239
+ desired_positive=desired_positive,
240
+ desired_negative=desired_negative,
241
+ )
242
  if not contexts:
243
  return {'status': 'skipped', 'reason': 'no valid contexts', 'mint': mint_addr}
244
  q_score = _worker_quality_scores_map.get(mint_addr)
 
253
 
254
  torch.save(ctx, output_path)
255
  saved_files.append(filename)
256
+ return {
257
+ 'status': 'success',
258
+ 'mint': mint_addr,
259
+ 'class_id': class_id,
260
+ 'q_score': q_score,
261
+ 'n_contexts': len(contexts),
262
+ 'n_events': len(contexts[0].get('event_sequence', [])) if contexts else 0,
263
+ 'files': saved_files,
264
+ 'polarity_counts': polarity_counts,
265
+ }
266
  except Exception as e:
267
  import traceback
268
  return {'status': 'error', 'mint': mint_addr, 'error': str(e), 'traceback': traceback.format_exc()}
 
292
  parser.add_argument("--context_length", type=int, default=8192)
293
  parser.add_argument("--min_trades", type=int, default=10)
294
  parser.add_argument("--samples_per_token", type=int, default=1)
295
+ parser.add_argument("--context_oversample_factor", type=int, default=4)
296
+ parser.add_argument("--cache_balance_mode", type=str, default="hybrid", choices=["class", "uniform", "hybrid"])
297
+ parser.add_argument("--positive_balance_min_class", type=int, default=2)
298
+ parser.add_argument("--positive_context_ratio", type=float, default=0.5)
299
  parser.add_argument("--horizons_seconds", type=int, nargs="+", default=[30, 60, 120, 240, 420])
300
  parser.add_argument("--quantiles", type=float, nargs="+", default=[0.1, 0.5, 0.9])
301
  parser.add_argument("--num_workers", type=int, default=1)
 
387
 
388
  print(f"INFO: Eligible tokens per class: {dict(sorted(eligible_class_counts.items()))}")
389
 
 
390
  num_classes = len(eligible_class_counts)
391
  if args.max_samples:
392
  target_total = args.max_samples
 
394
  target_total = 15000 # Default target: 15k balanced files
395
  target_per_class = target_total // max(num_classes, 1)
396
 
397
+ token_plans, class_targets, class_polarity_targets = _allocate_class_targets(
398
+ mints_by_class=mints_by_class,
399
+ target_total=target_total,
400
+ positive_balance_min_class=args.positive_balance_min_class,
401
+ positive_ratio=args.positive_context_ratio,
402
+ )
 
 
 
 
 
403
 
404
  print(f"INFO: Target total: {target_total}, Target per class: {target_per_class}")
405
+ print(f"INFO: Exact class targets: {dict(sorted(class_targets.items()))}")
406
+ print(f"INFO: Class polarity targets: {dict(sorted(class_polarity_targets.items()))}")
407
 
408
  # Build balanced task list
409
  tasks = []
410
+ if args.cache_balance_mode == "uniform":
411
+ target_tokens = len(filtered_mints)
412
+ if args.max_samples:
413
+ target_tokens = min(len(filtered_mints), max(1, math.ceil(args.max_samples / max(args.samples_per_token, 1))))
414
+ mint_pool = list(enumerate(filtered_mints))
415
+ random.shuffle(mint_pool)
416
+ for i, m in mint_pool[:target_tokens]:
417
+ tasks.append((i, m['mint_address'], args.samples_per_token, str(output_dir), args.context_oversample_factor, 0, args.samples_per_token))
418
+ else:
419
+ for (token_idx, mint_addr), plan in token_plans.items():
420
+ tasks.append((
421
+ token_idx,
422
+ mint_addr,
423
+ plan["samples_to_keep"],
424
+ str(output_dir),
425
+ args.context_oversample_factor,
426
+ plan["desired_positive"],
427
+ plan["desired_negative"],
428
+ ))
429
 
430
  random.shuffle(tasks) # Shuffle tasks for even load distribution across workers
431
+ expected_files = sum(task[2] for task in tasks)
 
 
 
432
  print(f"INFO: Total tasks: {len(tasks)} (expected ~{expected_files} output files, target ~{target_total})")
433
 
434
  success_count, skipped_count, error_count = 0, 0, 0
435
  class_distribution = {}
436
+ polarity_distribution = {}
437
 
438
  # --- Resume support: skip tokens that already have cached files ---
439
  existing_files = set(f.name for f in output_dir.glob("sample_*.pt"))
 
499
  if result['status'] == 'success':
500
  success_count += 1
501
  class_distribution[result['class_id']] = class_distribution.get(result['class_id'], 0) + 1
502
+ for polarity, count in result.get('polarity_counts', {}).items():
503
+ polarity_distribution[polarity] = polarity_distribution.get(polarity, 0) + count
504
  elif result['status'] == 'skipped':
505
  skipped_count += 1
506
  else:
 
527
  if result['status'] == 'success':
528
  success_count += 1
529
  class_distribution[result['class_id']] = class_distribution.get(result['class_id'], 0) + 1
530
+ for polarity, count in result.get('polarity_counts', {}).items():
531
+ polarity_distribution[polarity] = polarity_distribution.get(polarity, 0) + count
532
  elif result['status'] == 'skipped':
533
  skipped_count += 1
534
  else:
 
567
  'num_workers': args.num_workers,
568
  'horizons_seconds': args.horizons_seconds,
569
  'quantiles': args.quantiles,
 
 
570
  'target_total': target_total,
571
  'target_per_class': target_per_class,
572
+ 'cache_balance_mode': args.cache_balance_mode,
573
+ 'context_polarity_distribution': polarity_distribution,
574
+ 'class_targets': {str(k): v for k, v in class_targets.items()},
575
+ 'class_polarity_targets': {str(k): v for k, v in class_polarity_targets.items()},
576
+ 'positive_balance_min_class': args.positive_balance_min_class,
577
+ 'positive_context_ratio': args.positive_context_ratio,
578
  }, f, indent=2)
579
 
580
  print(f"\n--- Done ---\nSuccess: {success_count}, Skipped: {skipped_count}, Errors: {error_count}\nFiles: {len(file_class_map)}\nLocation: {output_dir.resolve()}")
scripts/evaluate_sample.py CHANGED
@@ -2,6 +2,8 @@ import os
2
  import sys
3
  import argparse
4
  import random
 
 
5
  import torch
6
  from pathlib import Path
7
 
@@ -14,6 +16,7 @@ from torch.utils.data import DataLoader, Subset
14
 
15
  from data.data_loader import OracleDataset
16
  from data.data_collator import MemecoinCollator
 
17
  from models.multi_modal_processor import MultiModalEncoder
18
  from models.helper_encoders import ContextualTimeEncoder
19
  from models.token_encoder import TokenEncoder
@@ -29,6 +32,25 @@ from neo4j import GraphDatabase
29
  from data.data_fetcher import DataFetcher
30
  from scripts.analyze_distribution import get_return_class_map
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def unlog_transform(tensor):
33
  """Invert the log1p transform applied during training."""
34
  # During training: labels = torch.sign(labels) * torch.log1p(torch.abs(labels))
@@ -42,10 +64,418 @@ def parse_args():
42
  parser.add_argument("--horizons_seconds", type=int, nargs="+", default=[300, 900, 1800, 3600, 7200])
43
  parser.add_argument("--quantiles", type=float, nargs="+", default=[0.1, 0.5, 0.9])
44
  parser.add_argument("--seed", type=int, default=None)
45
- parser.add_argument("--min_class", type=int, default=5, help="Filter out tokens with return class beneath this ID (e.g., 1 for >= 3x returns)")
46
- parser.add_argument("--cutoff_trade_idx", type=int, default=600, help="Force the T_cutoff at this exact trade index (e.g., 10 = right after the 10th trade)")
 
 
 
 
 
 
 
 
 
 
47
  return parser.parse_args()
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  def get_latest_checkpoint(checkpoint_dir):
50
  ckpt_dir = Path(checkpoint_dir)
51
  if ckpt_dir.exists():
@@ -59,6 +489,10 @@ def get_latest_checkpoint(checkpoint_dir):
59
  def main():
60
  load_dotenv()
61
  args = parse_args()
 
 
 
 
62
 
63
  accelerator = Accelerator(mixed_precision=args.mixed_precision)
64
  device = accelerator.device
@@ -186,134 +620,186 @@ def main():
186
 
187
  model.eval()
188
 
189
- # Find a valid sample
190
- valid_context_found = False
191
- max_retries = 20
 
 
192
  retries = 0
193
- raw_sample = None
194
- sample_mint_addr = None
195
-
196
- while not valid_context_found and retries < max_retries:
197
- if args.sample_idx is not None:
198
- if isinstance(args.sample_idx, str) and not args.sample_idx.isdigit():
199
- found_idx = next((i for i, m in enumerate(dataset.sampled_mints) if m['mint_address'] == args.sample_idx), None)
200
- if found_idx is None:
201
- import datetime
202
- dataset.sampled_mints.append({'mint_address': args.sample_idx, 'creator_address': '', 'timestamp': datetime.datetime.now(datetime.timezone.utc)})
203
- sample_idx = len(dataset.sampled_mints) - 1
204
- else:
205
- sample_idx = found_idx
206
- else:
207
- sample_idx = int(args.sample_idx)
208
- if sample_idx >= len(dataset):
209
- raise ValueError(f"Sample index {sample_idx} out of range")
210
- else:
211
- sample_idx = random.randint(0, len(dataset.sampled_mints) - 1)
212
-
213
  sample_mint_addr = dataset.sampled_mints[sample_idx]['mint_address']
214
  print(f"Trying Token Address: {sample_mint_addr}")
215
-
216
- contexts = dataset.__cacheitem_context__(sample_idx, num_samples_per_token=1, encoder=multi_modal_encoder, forced_cutoff_trade_idx=args.cutoff_trade_idx)
217
-
218
- if not contexts or len(contexts) == 0 or contexts[0] is None:
219
- print(" [Failed to generate valid context pattern, skipping...]")
220
- retries += 1
221
- if args.sample_idx is not None:
222
- print("Specific sample requested but failed to generate context. Exiting.")
223
- return
224
- continue
225
-
 
 
 
 
 
226
  raw_sample = contexts[0]
227
- valid_context_found = True
 
 
 
228
 
229
- if not valid_context_found:
230
- print(f"Could not find a valid context after {max_retries} attempts.")
231
- return
232
-
233
- print(f"\nEvaluating precisely on Token Address: {sample_mint_addr}")
234
 
235
- batch = collator([raw_sample])
236
-
237
- # Move batch to device
238
- for k, v in batch.items():
239
- if isinstance(v, torch.Tensor):
240
- batch[k] = v.to(device)
241
- elif isinstance(v, list) and len(v) > 0 and isinstance(v[0], torch.Tensor):
242
- batch[k] = [t.to(device) for t in v]
 
 
 
 
243
 
244
- # Add missing keys needed by model safety checks
245
- if 'textual_event_indices' not in batch:
246
- B, L = batch['event_type_ids'].shape
247
- batch['textual_event_indices'] = torch.zeros((B, L), dtype=torch.long, device=device)
248
- if 'textual_event_data' not in batch:
249
- batch['textual_event_data'] = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
- print("\n--- Running Inference ---")
252
- with torch.no_grad():
253
- outputs = model(batch)
254
-
255
- preds = outputs["quantile_logits"][0].cpu() # shape [Horizons * Quantiles]
256
- quality_preds = outputs["quality_logits"][0].cpu() if "quality_logits" in outputs else None
257
-
258
- # Raw labels from dataset (these are NOT log-transformed yet)
259
- gt_labels = batch["labels"][0].cpu()
260
- gt_mask = batch["labels_mask"][0].cpu().bool()
261
-
262
- # Quality target if available
263
- gt_quality = batch["quality_score"][0].item() if "quality_score" in batch else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
 
265
- # Un-log the predictions since model was trained on log-transformed returns
266
- # But wait, did the user train with log transformed returns?
267
- # Yes, train.py does: labels = torch.sign(labels) * torch.log1p(torch.abs(labels))
268
- real_preds = unlog_transform(preds)
269
 
270
- print("\n================== Results ==================")
271
- print(f"Token Address: {batch.get('token_addresses', ['Unknown'])[0]}")
272
- if gt_quality is not None:
273
- print(f"Quality Score: GT = {gt_quality:.4f} | Pred = {quality_preds.item() if quality_preds is not None else 'N/A'}")
274
-
275
- print("\nReturns per Horizon:")
276
- num_quantiles = len(args.quantiles)
277
- # The models outputs all defined horizons, but the dataset labels might be truncated
278
- # if it was generated with fewer horizons.
279
- num_gt_horizons = len(gt_mask) # Shape is [H]
280
-
281
- for h_idx, horizon in enumerate(args.horizons_seconds):
282
- horizon_min = horizon // 60
283
- print(f"\n--- Horizon: {horizon}s ({horizon_min}m) ---")
284
-
285
- if h_idx >= num_gt_horizons:
286
- print(" [No Ground Truth Available for this Horizon - Not in Dataset]")
287
- valid = False
288
- else:
289
- # Mask format is [H]
290
- valid = gt_mask[h_idx].item()
291
-
292
- if not valid:
293
- print(" [No Ground Truth Available for this Horizon - Masked]")
294
- # We still print predictions even if GT is masked/missing
295
- print(" Predictions:")
296
- for q_idx, q in enumerate(args.quantiles):
297
- flat_idx = h_idx * num_quantiles + q_idx
298
- pred_ret = real_preds[flat_idx].item()
299
- log_pred = preds[flat_idx].item()
300
- print(f" - p{int(q*100):02d}: {pred_ret * 100:>8.2f}% (raw log-val: {log_pred:7.4f})")
301
- continue
302
-
303
- # Ground truth (raw)
304
- gt_ret = gt_labels[h_idx].item()
305
- print(f" Ground Truth: {gt_ret * 100:.2f}%")
306
-
307
- # Predictions
308
- print(" Predictions:")
309
- for q_idx, q in enumerate(args.quantiles):
310
- flat_idx = h_idx * num_quantiles + q_idx
311
- pred_ret = real_preds[flat_idx].item()
312
- log_pred = preds[flat_idx].item()
313
-
314
- print(f" - p{int(q*100):02d}: {pred_ret * 100:>8.2f}% (raw log-val: {log_pred:7.4f})")
315
-
316
- print("=============================================\n")
317
 
318
  if __name__ == "__main__":
319
  main()
 
2
  import sys
3
  import argparse
4
  import random
5
+ import copy
6
+ import math
7
  import torch
8
  from pathlib import Path
9
 
 
16
 
17
  from data.data_loader import OracleDataset
18
  from data.data_collator import MemecoinCollator
19
+ from data.context_targets import MOVEMENT_ID_TO_CLASS
20
  from models.multi_modal_processor import MultiModalEncoder
21
  from models.helper_encoders import ContextualTimeEncoder
22
  from models.token_encoder import TokenEncoder
 
32
  from data.data_fetcher import DataFetcher
33
  from scripts.analyze_distribution import get_return_class_map
34
 
35
+ ABLATION_SWEEP_MODES = [
36
+ "wallet",
37
+ "graph",
38
+ "social",
39
+ "token",
40
+ "holder",
41
+ "ohlc",
42
+ "ohlc_wallet",
43
+ "trade",
44
+ "onchain",
45
+ "wallet_graph",
46
+ ]
47
+
48
+ OHLC_PROBE_MODES = [
49
+ "ohlc_reverse",
50
+ "ohlc_shuffle_chunks",
51
+ "ohlc_mask_recent",
52
+ ]
53
+
54
  def unlog_transform(tensor):
55
  """Invert the log1p transform applied during training."""
56
  # During training: labels = torch.sign(labels) * torch.log1p(torch.abs(labels))
 
64
  parser.add_argument("--horizons_seconds", type=int, nargs="+", default=[300, 900, 1800, 3600, 7200])
65
  parser.add_argument("--quantiles", type=float, nargs="+", default=[0.1, 0.5, 0.9])
66
  parser.add_argument("--seed", type=int, default=None)
67
+ parser.add_argument("--min_class", type=int, default=3, help="Filter out tokens with return class beneath this ID (e.g., 1 for >= 3x returns)")
68
+ parser.add_argument("--cutoff_trade_idx", type=int, default=200, help="Force the T_cutoff at this exact trade index (e.g., 10 = right after the 10th trade)")
69
+ parser.add_argument("--num_samples", type=int, default=1, help="Number of valid samples to evaluate and aggregate.")
70
+ parser.add_argument("--max_retries", type=int, default=100, help="Maximum attempts to find valid contexts across samples.")
71
+ parser.add_argument("--show_each", action="store_true", help="Print per-sample details for every evaluated sample.")
72
+ parser.add_argument(
73
+ "--ablation",
74
+ type=str,
75
+ default="none",
76
+ choices=["none", "wallet", "graph", "wallet_graph", "social", "token", "holder", "ohlc", "ohlc_wallet", "trade", "onchain", "all", "sweep", "ohlc_probe"],
77
+ help="Run inference with selected signal families removed, or use 'sweep' to rank multiple families.",
78
+ )
79
  return parser.parse_args()
80
 
81
+
82
+ def clone_batch(batch):
83
+ cloned = {}
84
+ for key, value in batch.items():
85
+ if isinstance(value, torch.Tensor):
86
+ cloned[key] = value.clone()
87
+ else:
88
+ cloned[key] = copy.deepcopy(value)
89
+ return cloned
90
+
91
+
92
+ def _empty_wallet_encoder_inputs(device):
93
+ return {
94
+ 'username_embed_indices': torch.tensor([], device=device, dtype=torch.long),
95
+ 'profile_rows': [],
96
+ 'social_rows': [],
97
+ 'holdings_batch': [],
98
+ }
99
+
100
+
101
+ def _empty_token_encoder_inputs(device):
102
+ return {
103
+ 'name_embed_indices': torch.tensor([], device=device, dtype=torch.long),
104
+ 'symbol_embed_indices': torch.tensor([], device=device, dtype=torch.long),
105
+ 'image_embed_indices': torch.tensor([], device=device, dtype=torch.long),
106
+ 'protocol_ids': torch.tensor([], device=device, dtype=torch.long),
107
+ 'is_vanity_flags': torch.tensor([], device=device, dtype=torch.bool),
108
+ '_addresses_for_lookup': [],
109
+ }
110
+
111
+
112
+ def apply_ablation(batch, mode, device):
113
+ if mode == "none":
114
+ return batch
115
+
116
+ ablated = clone_batch(batch)
117
+
118
+ if mode in {"wallet", "wallet_graph", "ohlc_wallet", "all"}:
119
+ for key in (
120
+ "wallet_indices",
121
+ "dest_wallet_indices",
122
+ "original_author_indices",
123
+ "holder_snapshot_indices",
124
+ ):
125
+ if key in ablated:
126
+ ablated[key].zero_()
127
+ ablated["wallet_encoder_inputs"] = _empty_wallet_encoder_inputs(device)
128
+ ablated["wallet_addr_to_batch_idx"] = {}
129
+ ablated["holder_snapshot_raw_data"] = []
130
+ ablated["graph_updater_links"] = {}
131
+
132
+ if mode in {"graph", "wallet_graph", "all"}:
133
+ ablated["graph_updater_links"] = {}
134
+
135
+ if mode in {"social", "all"}:
136
+ if "textual_event_indices" in ablated:
137
+ ablated["textual_event_indices"].zero_()
138
+ ablated["textual_event_data"] = []
139
+
140
+ if mode in {"token", "all"}:
141
+ for key in (
142
+ "token_indices",
143
+ "quote_token_indices",
144
+ "trending_token_indices",
145
+ "boosted_token_indices",
146
+ ):
147
+ if key in ablated:
148
+ ablated[key].zero_()
149
+ ablated["token_encoder_inputs"] = _empty_token_encoder_inputs(device)
150
+
151
+ if mode in {"holder", "all"}:
152
+ if "holder_snapshot_indices" in ablated:
153
+ ablated["holder_snapshot_indices"].zero_()
154
+ ablated["holder_snapshot_raw_data"] = []
155
+
156
+ if mode in {"ohlc", "ohlc_wallet", "all"}:
157
+ if "ohlc_indices" in ablated:
158
+ ablated["ohlc_indices"].zero_()
159
+ if "ohlc_price_tensors" in ablated:
160
+ ablated["ohlc_price_tensors"] = torch.zeros_like(ablated["ohlc_price_tensors"])
161
+ if "ohlc_interval_ids" in ablated:
162
+ ablated["ohlc_interval_ids"] = torch.zeros_like(ablated["ohlc_interval_ids"])
163
+
164
+ if mode in {"trade", "all"}:
165
+ for key in (
166
+ "trade_numerical_features",
167
+ "deployer_trade_numerical_features",
168
+ "smart_wallet_trade_numerical_features",
169
+ "transfer_numerical_features",
170
+ "pool_created_numerical_features",
171
+ "liquidity_change_numerical_features",
172
+ "fee_collected_numerical_features",
173
+ "token_burn_numerical_features",
174
+ "supply_lock_numerical_features",
175
+ "boosted_token_numerical_features",
176
+ "trending_token_numerical_features",
177
+ "dexboost_paid_numerical_features",
178
+ "global_trending_numerical_features",
179
+ "chainsnapshot_numerical_features",
180
+ "lighthousesnapshot_numerical_features",
181
+ "dexprofile_updated_flags",
182
+ ):
183
+ if key in ablated:
184
+ ablated[key] = torch.zeros_like(ablated[key])
185
+ for key in (
186
+ "trade_dex_ids",
187
+ "trade_direction_ids",
188
+ "trade_mev_protection_ids",
189
+ "trade_is_bundle_ids",
190
+ "pool_created_protocol_ids",
191
+ "liquidity_change_type_ids",
192
+ "trending_token_source_ids",
193
+ "trending_token_timeframe_ids",
194
+ "lighthousesnapshot_protocol_ids",
195
+ "lighthousesnapshot_timeframe_ids",
196
+ "migrated_protocol_ids",
197
+ "alpha_group_ids",
198
+ "channel_ids",
199
+ "exchange_ids",
200
+ ):
201
+ if key in ablated:
202
+ ablated[key] = torch.zeros_like(ablated[key])
203
+
204
+ if mode == "onchain":
205
+ if "onchain_snapshot_numerical_features" in ablated:
206
+ ablated["onchain_snapshot_numerical_features"] = torch.zeros_like(ablated["onchain_snapshot_numerical_features"])
207
+
208
+ return ablated
209
+
210
+
211
+ def _chunk_permutation_indices(length, chunk_size):
212
+ if length <= 0:
213
+ return []
214
+ chunks = [list(range(i, min(i + chunk_size, length))) for i in range(0, length, chunk_size)]
215
+ if len(chunks) <= 1:
216
+ return list(range(length))
217
+ permuted = list(reversed(chunks))
218
+ out = []
219
+ for chunk in permuted:
220
+ out.extend(chunk)
221
+ return out
222
+
223
+
224
+ def apply_ohlc_probe(batch, mode):
225
+ probed = clone_batch(batch)
226
+ if "ohlc_price_tensors" not in probed or probed["ohlc_price_tensors"].numel() == 0:
227
+ return probed
228
+
229
+ ohlc = probed["ohlc_price_tensors"].clone()
230
+ seq_len = ohlc.shape[-1]
231
+
232
+ if mode == "ohlc_reverse":
233
+ probed["ohlc_price_tensors"] = torch.flip(ohlc, dims=[-1])
234
+ elif mode == "ohlc_shuffle_chunks":
235
+ perm = _chunk_permutation_indices(seq_len, chunk_size=30)
236
+ idx = torch.tensor(perm, device=ohlc.device, dtype=torch.long)
237
+ probed["ohlc_price_tensors"] = ohlc.index_select(-1, idx)
238
+ elif mode == "ohlc_mask_recent":
239
+ keep = max(seq_len - 60, 0)
240
+ if keep < seq_len and keep > 0:
241
+ fill = ohlc[..., keep - 1:keep].expand_as(ohlc[..., keep:])
242
+ ohlc[..., keep:] = fill
243
+ elif keep == 0:
244
+ ohlc.zero_()
245
+ probed["ohlc_price_tensors"] = ohlc
246
+
247
+ return probed
248
+
249
+
250
+ def run_inference(model, batch):
251
+ with torch.no_grad():
252
+ outputs = model(batch)
253
+ preds = outputs["quantile_logits"][0].detach().cpu()
254
+ quality_pred = outputs["quality_logits"][0].detach().cpu() if "quality_logits" in outputs else None
255
+ movement_pred = outputs["movement_logits"][0].detach().cpu() if "movement_logits" in outputs else None
256
+ return preds, quality_pred, movement_pred
257
+
258
+
259
+ def print_results(title, batch, preds, quality_pred, movement_pred, gt_labels, gt_mask, gt_quality, horizons_seconds, quantiles, reference_preds=None, reference_quality=None):
260
+ real_preds = unlog_transform(preds)
261
+ num_quantiles = len(quantiles)
262
+ num_gt_horizons = len(gt_mask)
263
+
264
+ print(f"\n================== {title} ==================")
265
+ print(f"Token Address: {batch.get('token_addresses', ['Unknown'])[0]}")
266
+ if gt_quality is not None:
267
+ quality_line = f"Quality Score: GT = {gt_quality:.4f} | Pred = {quality_pred.item() if quality_pred is not None else 'N/A'}"
268
+ if reference_quality is not None and quality_pred is not None:
269
+ quality_delta = quality_pred.item() - reference_quality.item()
270
+ quality_line += f" | Delta vs Full = {quality_delta:+.6f}"
271
+ print(quality_line)
272
+ if movement_pred is not None:
273
+ movement_targets = batch.get("movement_class_targets")
274
+ movement_mask = batch.get("movement_class_mask")
275
+ print("Movement Classes:")
276
+ for h_idx, horizon in enumerate(horizons_seconds):
277
+ if h_idx >= movement_pred.shape[0]:
278
+ break
279
+ target_txt = "N/A"
280
+ if movement_targets is not None and movement_mask is not None and bool(movement_mask[0, h_idx].item()):
281
+ target_txt = MOVEMENT_ID_TO_CLASS.get(int(movement_targets[0, h_idx].item()), "unknown")
282
+ pred_class = int(movement_pred[h_idx].argmax().item())
283
+ pred_name = MOVEMENT_ID_TO_CLASS.get(pred_class, "unknown")
284
+ pred_prob = float(torch.softmax(movement_pred[h_idx], dim=-1)[pred_class].item())
285
+ print(
286
+ f" {horizon:>4}s GT = {target_txt:<12} | "
287
+ f"Pred = {pred_name:<12} | "
288
+ f"Conf = {pred_prob:.4f}"
289
+ )
290
+ if "context_class_name" in batch:
291
+ print(f"Context Class: {batch['context_class_name'][0]}")
292
+
293
+ print("\nReturns per Horizon:")
294
+ for h_idx, horizon in enumerate(horizons_seconds):
295
+ horizon_min = horizon // 60
296
+ print(f"\n--- Horizon: {horizon}s ({horizon_min}m) ---")
297
+
298
+ if h_idx >= num_gt_horizons:
299
+ print(" [No Ground Truth Available for this Horizon - Not in Dataset]")
300
+ valid = False
301
+ else:
302
+ valid = gt_mask[h_idx].item()
303
+
304
+ if not valid:
305
+ print(" [No Ground Truth Available for this Horizon - Masked]")
306
+ else:
307
+ gt_ret = gt_labels[h_idx].item()
308
+ print(f" Ground Truth: {gt_ret * 100:.2f}%")
309
+
310
+ print(" Predictions:")
311
+ for q_idx, q in enumerate(quantiles):
312
+ flat_idx = h_idx * num_quantiles + q_idx
313
+ pred_ret = real_preds[flat_idx].item()
314
+ log_pred = preds[flat_idx].item()
315
+ line = f" - p{int(q*100):02d}: {pred_ret * 100:>8.2f}% (raw log-val: {log_pred:7.4f})"
316
+ if reference_preds is not None:
317
+ ref_ret = unlog_transform(reference_preds)[flat_idx].item()
318
+ line += f" | Delta vs Full: {(pred_ret - ref_ret) * 100:+7.2f}%"
319
+ print(line)
320
+
321
+ print("=============================================\n")
322
+
323
+
324
+ def resolve_sample_index(dataset, sample_idx_arg, rng):
325
+ if sample_idx_arg is not None:
326
+ if isinstance(sample_idx_arg, str) and not sample_idx_arg.isdigit():
327
+ found_idx = next((i for i, m in enumerate(dataset.sampled_mints) if m['mint_address'] == sample_idx_arg), None)
328
+ if found_idx is None:
329
+ raise ValueError(f"Mint address {sample_idx_arg} not found in filtered dataset")
330
+ return found_idx
331
+ resolved = int(sample_idx_arg)
332
+ if resolved >= len(dataset):
333
+ raise ValueError(f"Sample index {resolved} out of range")
334
+ return resolved
335
+ return rng.randint(0, len(dataset.sampled_mints) - 1)
336
+
337
+
338
+ def move_batch_to_device(batch, device):
339
+ for k, v in batch.items():
340
+ if isinstance(v, torch.Tensor):
341
+ batch[k] = v.to(device)
342
+ elif isinstance(v, list) and len(v) > 0 and isinstance(v[0], torch.Tensor):
343
+ batch[k] = [t.to(device) for t in v]
344
+ if 'textual_event_indices' not in batch:
345
+ B, L = batch['event_type_ids'].shape
346
+ batch['textual_event_indices'] = torch.zeros((B, L), dtype=torch.long, device=device)
347
+ if 'textual_event_data' not in batch:
348
+ batch['textual_event_data'] = []
349
+ return batch
350
+
351
+
352
+ def init_aggregate(horizons_seconds, quantiles):
353
+ return {
354
+ "count": 0,
355
+ "quality_full_sum": 0.0,
356
+ "quality_abl_sum": 0.0,
357
+ "quality_delta_sum": 0.0,
358
+ "gt_quality_sum": 0.0,
359
+ "per_hq": {
360
+ (h, q): {
361
+ "full_sum": 0.0,
362
+ "abl_sum": 0.0,
363
+ "delta_sum": 0.0,
364
+ "abs_delta_sum": 0.0,
365
+ "gt_sum": 0.0,
366
+ "valid_count": 0,
367
+ }
368
+ for h in horizons_seconds for q in quantiles
369
+ },
370
+ }
371
+
372
+
373
+ def update_aggregate(stats, full_preds, gt_labels, gt_mask, gt_quality, horizons_seconds, quantiles, ablated_preds=None, full_quality=None, ablated_quality=None):
374
+ stats["count"] += 1
375
+ if gt_quality is not None:
376
+ stats["gt_quality_sum"] += float(gt_quality)
377
+ if full_quality is not None:
378
+ stats["quality_full_sum"] += float(full_quality.item())
379
+ if ablated_quality is not None:
380
+ stats["quality_abl_sum"] += float(ablated_quality.item())
381
+ if full_quality is not None and ablated_quality is not None:
382
+ stats["quality_delta_sum"] += float(ablated_quality.item() - full_quality.item())
383
+
384
+ full_real = unlog_transform(full_preds)
385
+ ablated_real = unlog_transform(ablated_preds) if ablated_preds is not None else None
386
+ num_quantiles = len(quantiles)
387
+
388
+ for h_idx, horizon in enumerate(horizons_seconds):
389
+ valid = h_idx < len(gt_mask) and bool(gt_mask[h_idx].item())
390
+ gt_ret = float(gt_labels[h_idx].item()) if valid else math.nan
391
+ for q_idx, q in enumerate(quantiles):
392
+ flat_idx = h_idx * num_quantiles + q_idx
393
+ bucket = stats["per_hq"][(horizon, q)]
394
+ full_val = float(full_real[flat_idx].item())
395
+ bucket["full_sum"] += full_val
396
+ if ablated_real is not None:
397
+ abl_val = float(ablated_real[flat_idx].item())
398
+ delta = abl_val - full_val
399
+ bucket["abl_sum"] += abl_val
400
+ bucket["delta_sum"] += delta
401
+ bucket["abs_delta_sum"] += abs(delta)
402
+ if valid:
403
+ bucket["gt_sum"] += gt_ret
404
+ bucket["valid_count"] += 1
405
+
406
+
407
+ def print_aggregate_summary(stats, horizons_seconds, quantiles, ablation_mode):
408
+ n = stats["count"]
409
+ print("\n================== Aggregate Summary ==================")
410
+ print(f"Evaluated Samples: {n}")
411
+ if n == 0:
412
+ print("No valid samples collected.")
413
+ print("=======================================================\n")
414
+ return
415
+
416
+ if ablation_mode != "none":
417
+ print(
418
+ f"Quality Mean: full={stats['quality_full_sum'] / n:.6f} | "
419
+ f"ablated={stats['quality_abl_sum'] / n:.6f} | "
420
+ f"delta={stats['quality_delta_sum'] / n:+.6f}"
421
+ )
422
+
423
+ for horizon in horizons_seconds:
424
+ horizon_min = horizon // 60
425
+ print(f"\n--- Horizon: {horizon}s ({horizon_min}m) ---")
426
+ valid_counts = [stats["per_hq"][(horizon, q)]["valid_count"] for q in quantiles]
427
+ valid_count = max(valid_counts) if valid_counts else 0
428
+ if valid_count > 0:
429
+ gt_mean = stats["per_hq"][(horizon, quantiles[0])]["gt_sum"] / valid_count
430
+ print(f" Mean Ground Truth over valid labels: {gt_mean * 100:.2f}% (n={valid_count})")
431
+ else:
432
+ print(" Mean Ground Truth over valid labels: N/A")
433
+
434
+ for q in quantiles:
435
+ bucket = stats["per_hq"][(horizon, q)]
436
+ full_mean = bucket["full_sum"] / n
437
+ line = f" p{int(q*100):02d} mean full: {full_mean * 100:>8.2f}%"
438
+ if ablation_mode != "none":
439
+ abl_mean = bucket["abl_sum"] / n
440
+ delta_mean = bucket["delta_sum"] / n
441
+ abs_delta_mean = bucket["abs_delta_sum"] / n
442
+ line += (
443
+ f" | ablated: {abl_mean * 100:>8.2f}%"
444
+ f" | delta: {delta_mean * 100:+8.2f}%"
445
+ f" | mean|delta|: {abs_delta_mean * 100:>8.2f}%"
446
+ )
447
+ print(line)
448
+ print("=======================================================\n")
449
+
450
+
451
+ def summarize_influence_score(stats, horizons_seconds, quantiles):
452
+ n = stats["count"]
453
+ if n == 0:
454
+ return 0.0
455
+ total = 0.0
456
+ denom = 0
457
+ for horizon in horizons_seconds:
458
+ for q in quantiles:
459
+ total += stats["per_hq"][(horizon, q)]["abs_delta_sum"] / n
460
+ denom += 1
461
+ return total / max(denom, 1)
462
+
463
+
464
+ def print_probe_summary(mode_to_stats, horizons_seconds, quantiles):
465
+ rankings = []
466
+ for mode in OHLC_PROBE_MODES:
467
+ score = summarize_influence_score(mode_to_stats[mode], horizons_seconds, quantiles)
468
+ rankings.append((mode, score))
469
+ rankings.sort(key=lambda x: x[1], reverse=True)
470
+
471
+ print("\n================== OHLC Probe Ranking ==================")
472
+ for rank, (mode, score) in enumerate(rankings, start=1):
473
+ print(f"{rank:>2}. {mode:<20} mean|delta| = {score * 100:8.2f}%")
474
+ print("========================================================\n")
475
+
476
+ for mode, _ in rankings:
477
+ print_aggregate_summary(mode_to_stats[mode], horizons_seconds, quantiles, mode)
478
+
479
  def get_latest_checkpoint(checkpoint_dir):
480
  ckpt_dir = Path(checkpoint_dir)
481
  if ckpt_dir.exists():
 
489
  def main():
490
  load_dotenv()
491
  args = parse_args()
492
+ rng = random.Random(args.seed)
493
+ if args.seed is not None:
494
+ random.seed(args.seed)
495
+ torch.manual_seed(args.seed)
496
 
497
  accelerator = Accelerator(mixed_precision=args.mixed_precision)
498
  device = accelerator.device
 
620
 
621
  model.eval()
622
 
623
+ stats = init_aggregate(args.horizons_seconds, args.quantiles)
624
+ selected_modes = [] if args.ablation == "none" else (ABLATION_SWEEP_MODES if args.ablation == "sweep" else ([] if args.ablation == "ohlc_probe" else [args.ablation]))
625
+ mode_to_stats = {mode: init_aggregate(args.horizons_seconds, args.quantiles) for mode in selected_modes}
626
+ probe_to_stats = {mode: init_aggregate(args.horizons_seconds, args.quantiles) for mode in OHLC_PROBE_MODES} if args.ablation == "ohlc_probe" else {}
627
+ max_target_samples = max(1, args.num_samples)
628
  retries = 0
629
+ collected = 0
630
+ seen_indices = set()
631
+
632
+ while collected < max_target_samples and retries < args.max_retries:
633
+ sample_idx = resolve_sample_index(dataset, args.sample_idx, rng)
634
+ if args.sample_idx is None and sample_idx in seen_indices and len(seen_indices) < len(dataset.sampled_mints):
635
+ retries += 1
636
+ continue
637
+ seen_indices.add(sample_idx)
638
+
 
 
 
 
 
 
 
 
 
 
639
  sample_mint_addr = dataset.sampled_mints[sample_idx]['mint_address']
640
  print(f"Trying Token Address: {sample_mint_addr}")
641
+
642
+ contexts = dataset.__cacheitem_context__(
643
+ sample_idx,
644
+ num_samples_per_token=1,
645
+ encoder=multi_modal_encoder,
646
+ forced_cutoff_trade_idx=args.cutoff_trade_idx,
647
+ )
648
+
649
+ if not contexts or contexts[0] is None:
650
+ print(" [Failed to generate valid context pattern, skipping...]")
651
+ retries += 1
652
+ if args.sample_idx is not None:
653
+ print("Specific sample requested but failed to generate context. Exiting.")
654
+ return
655
+ continue
656
+
657
  raw_sample = contexts[0]
658
+ batch = move_batch_to_device(collator([raw_sample]), device)
659
+ gt_labels = batch["labels"][0].cpu()
660
+ gt_mask = batch["labels_mask"][0].cpu().bool()
661
+ gt_quality = batch["quality_score"][0].item() if "quality_score" in batch else None
662
 
663
+ if collected == 0 or args.show_each:
664
+ print(f"\nEvaluating sample {collected + 1}/{max_target_samples} on Token Address: {sample_mint_addr}")
665
+ print("\n--- Running Inference ---")
 
 
666
 
667
+ full_preds, full_quality, full_direction = run_inference(model, batch)
668
+ ablation_outputs = {}
669
+ for mode in selected_modes:
670
+ ablated_batch = apply_ablation(batch, mode, device)
671
+ ablated_preds, ablated_quality, ablated_direction = run_inference(model, ablated_batch)
672
+ ablation_outputs[mode] = (ablated_batch, ablated_preds, ablated_quality, ablated_direction)
673
+ probe_outputs = {}
674
+ if args.ablation == "ohlc_probe":
675
+ for mode in OHLC_PROBE_MODES:
676
+ probe_batch = apply_ohlc_probe(batch, mode)
677
+ probe_preds, probe_quality, probe_direction = run_inference(model, probe_batch)
678
+ probe_outputs[mode] = (probe_batch, probe_preds, probe_quality, probe_direction)
679
 
680
+ if collected == 0 or args.show_each:
681
+ print_results(
682
+ title="Full Results",
683
+ batch=batch,
684
+ preds=full_preds,
685
+ quality_pred=full_quality,
686
+ direction_pred=full_direction,
687
+ gt_labels=gt_labels,
688
+ gt_mask=gt_mask,
689
+ gt_quality=gt_quality,
690
+ horizons_seconds=args.horizons_seconds,
691
+ quantiles=args.quantiles,
692
+ )
693
+ if args.ablation != "none":
694
+ if args.ablation == "sweep":
695
+ print(f"Collected full predictions for {len(selected_modes)} ablation families on this sample. Aggregate ranking will be printed at the end.")
696
+ elif args.ablation == "ohlc_probe":
697
+ for mode in OHLC_PROBE_MODES:
698
+ probe_batch, probe_preds, probe_quality, probe_direction = probe_outputs[mode]
699
+ print_results(
700
+ title=f"OHLC Probe ({mode})",
701
+ batch=probe_batch,
702
+ preds=probe_preds,
703
+ quality_pred=probe_quality,
704
+ direction_pred=probe_direction,
705
+ gt_labels=gt_labels,
706
+ gt_mask=gt_mask,
707
+ gt_quality=gt_quality,
708
+ horizons_seconds=args.horizons_seconds,
709
+ quantiles=args.quantiles,
710
+ reference_preds=full_preds,
711
+ reference_quality=full_quality,
712
+ )
713
+ else:
714
+ ablated_batch, ablated_preds, ablated_quality, ablated_direction = ablation_outputs[args.ablation]
715
+ print_results(
716
+ title=f"Ablation Results ({args.ablation})",
717
+ batch=ablated_batch,
718
+ preds=ablated_preds,
719
+ quality_pred=ablated_quality,
720
+ direction_pred=ablated_direction,
721
+ gt_labels=gt_labels,
722
+ gt_mask=gt_mask,
723
+ gt_quality=gt_quality,
724
+ horizons_seconds=args.horizons_seconds,
725
+ quantiles=args.quantiles,
726
+ reference_preds=full_preds,
727
+ reference_quality=full_quality,
728
+ )
729
 
730
+ update_aggregate(
731
+ stats=stats,
732
+ full_preds=full_preds,
733
+ gt_labels=gt_labels,
734
+ gt_mask=gt_mask,
735
+ gt_quality=gt_quality,
736
+ horizons_seconds=args.horizons_seconds,
737
+ quantiles=args.quantiles,
738
+ full_quality=full_quality,
739
+ )
740
+ for mode, (_, ablated_preds, ablated_quality, _) in ablation_outputs.items():
741
+ update_aggregate(
742
+ stats=mode_to_stats[mode],
743
+ full_preds=full_preds,
744
+ gt_labels=gt_labels,
745
+ gt_mask=gt_mask,
746
+ gt_quality=gt_quality,
747
+ horizons_seconds=args.horizons_seconds,
748
+ quantiles=args.quantiles,
749
+ ablated_preds=ablated_preds,
750
+ full_quality=full_quality,
751
+ ablated_quality=ablated_quality,
752
+ )
753
+ for mode, (_, probe_preds, probe_quality, _) in probe_outputs.items():
754
+ update_aggregate(
755
+ stats=probe_to_stats[mode],
756
+ full_preds=full_preds,
757
+ gt_labels=gt_labels,
758
+ gt_mask=gt_mask,
759
+ gt_quality=gt_quality,
760
+ horizons_seconds=args.horizons_seconds,
761
+ quantiles=args.quantiles,
762
+ ablated_preds=probe_preds,
763
+ full_quality=full_quality,
764
+ ablated_quality=probe_quality,
765
+ )
766
+ collected += 1
767
+ retries += 1
768
 
769
+ if args.sample_idx is not None:
770
+ break
 
 
771
 
772
+ if collected == 0:
773
+ print(f"Could not find a valid context after {args.max_retries} attempts.")
774
+ return
775
+
776
+ if collected < max_target_samples:
777
+ print(f"WARNING: Requested {max_target_samples} samples but only evaluated {collected}.")
778
+
779
+ if args.ablation == "none":
780
+ print_aggregate_summary(stats, args.horizons_seconds, args.quantiles, args.ablation)
781
+ return
782
+
783
+ if args.ablation == "ohlc_probe":
784
+ print_probe_summary(probe_to_stats, args.horizons_seconds, args.quantiles)
785
+ return
786
+
787
+ if args.ablation == "sweep":
788
+ rankings = []
789
+ for mode in selected_modes:
790
+ score = summarize_influence_score(mode_to_stats[mode], args.horizons_seconds, args.quantiles)
791
+ rankings.append((mode, score))
792
+ rankings.sort(key=lambda x: x[1], reverse=True)
793
+
794
+ print("\n================== Influence Ranking ==================")
795
+ for rank, (mode, score) in enumerate(rankings, start=1):
796
+ print(f"{rank:>2}. {mode:<12} mean|delta| = {score * 100:8.2f}%")
797
+ print("=======================================================\n")
798
+
799
+ for mode, _ in rankings:
800
+ print_aggregate_summary(mode_to_stats[mode], args.horizons_seconds, args.quantiles, mode)
801
+ else:
802
+ print_aggregate_summary(mode_to_stats[args.ablation], args.horizons_seconds, args.quantiles, args.ablation)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
803
 
804
  if __name__ == "__main__":
805
  main()
train.py CHANGED
@@ -52,6 +52,7 @@ from neo4j import GraphDatabase
52
  from data.data_fetcher import DataFetcher
53
  from data.data_loader import OracleDataset
54
  from data.data_collator import MemecoinCollator
 
55
  from models.multi_modal_processor import MultiModalEncoder
56
  from models.helper_encoders import ContextualTimeEncoder
57
  from models.token_encoder import TokenEncoder
@@ -148,6 +149,89 @@ def quantile_pinball_loss_per_sample(
148
  return per_sample_num / per_sample_den
149
 
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  def create_balanced_split(dataset, n_val_per_class: int = 1, seed: int = 42):
152
  """
153
  Create train/val split with balanced classes in validation set.
@@ -207,6 +291,8 @@ def run_validation(model, val_dataloader, accelerator, quantiles, quality_loss_f
207
  total_loss = 0.0
208
  total_return_loss = 0.0
209
  total_quality_loss = 0.0
 
 
210
  n_batches = 0
211
 
212
  # Per-class metrics
@@ -228,9 +314,12 @@ def run_validation(model, val_dataloader, accelerator, quantiles, quality_loss_f
228
 
229
  preds = outputs["quantile_logits"]
230
  quality_preds = outputs["quality_logits"]
 
231
  labels = batch["labels"]
232
  labels_mask = batch["labels_mask"]
233
  quality_targets = batch["quality_score"].to(accelerator.device, dtype=quality_preds.dtype)
 
 
234
 
235
  if labels_mask.sum() == 0:
236
  return_loss = torch.tensor(0.0, device=accelerator.device)
@@ -240,11 +329,22 @@ def run_validation(model, val_dataloader, accelerator, quantiles, quality_loss_f
240
  return_loss = quantile_pinball_loss(preds, labels, labels_mask, quantiles)
241
 
242
  quality_loss = quality_loss_fn(quality_preds, quality_targets)
243
- loss = return_loss + quality_loss
 
 
 
 
 
 
 
 
 
244
 
245
  total_loss += loss.item()
246
  total_return_loss += return_loss.item()
247
  total_quality_loss += quality_loss.item()
 
 
248
  n_batches += 1
249
 
250
  # Track per-class losses
@@ -264,6 +364,8 @@ def run_validation(model, val_dataloader, accelerator, quantiles, quality_loss_f
264
  'val/loss': total_loss / n_batches,
265
  'val/return_loss': total_return_loss / n_batches,
266
  'val/quality_loss': total_quality_loss / n_batches,
 
 
267
  'val/n_batches': n_batches,
268
  'class_losses': {k: v['loss'] / max(v['count'], 1) for k, v in class_losses.items()}
269
  }
@@ -446,6 +548,7 @@ def parse_args() -> argparse.Namespace:
446
  parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Path to checkpoint or 'latest'")
447
  parser.add_argument("--val_samples_per_class", type=int, default=1, help="Number of validation samples per class (default 1)")
448
  parser.add_argument("--val_every", type=int, default=1000, help="Run validation every N steps (default 1000)")
 
449
  return parser.parse_args()
450
 
451
 
@@ -815,6 +918,10 @@ def main() -> None:
815
 
816
  # --- 7. Training Loop ---
817
  quality_loss_fn = nn.MSELoss()
 
 
 
 
818
 
819
  logger.info("***** Running training *****")
820
  logger.info(f" Num examples = {len(dataset)}")
@@ -865,6 +972,7 @@ def main() -> None:
865
 
866
  preds = outputs["quantile_logits"]
867
  quality_preds = outputs["quality_logits"]
 
868
  labels = batch["labels"]
869
  labels_mask = batch["labels_mask"]
870
  if "quality_score" not in batch:
@@ -920,6 +1028,21 @@ def main() -> None:
920
  per_sample_return = quantile_pinball_loss_per_sample(preds, labels, labels_mask, quantiles)
921
 
922
  quality_loss = quality_loss_fn(quality_preds, quality_targets)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
923
  per_sample_quality = (quality_preds - quality_targets).pow(2)
924
 
925
  # Apply per-sample class weighting to return loss
@@ -928,10 +1051,10 @@ def main() -> None:
928
  sample_weights = class_loss_weights[batch_class_ids] # [B]
929
  # Scale return loss by mean class weight for this batch
930
  class_weight_factor = sample_weights.mean()
931
- loss = return_loss * class_weight_factor + quality_loss
932
  else:
933
  class_weight_factor = torch.tensor(1.0, device=accelerator.device, dtype=return_loss.dtype)
934
- loss = return_loss + quality_loss
935
  per_sample_total = per_sample_return * class_weight_factor + per_sample_quality
936
 
937
  if not torch.isfinite(loss).all().item():
@@ -957,8 +1080,10 @@ def main() -> None:
957
  "loss": loss.unsqueeze(0),
958
  "return_loss": return_loss.unsqueeze(0),
959
  "quality_loss": quality_loss.unsqueeze(0),
 
960
  "preds": preds,
961
  "quality_preds": quality_preds,
 
962
  "labels_raw": batch.get("labels"),
963
  "labels_log": labels,
964
  "labels_mask": labels_mask,
@@ -1065,6 +1190,7 @@ def main() -> None:
1065
  current_loss = loss.item()
1066
  current_return_loss = return_loss.item()
1067
  current_quality_loss = quality_loss.item()
 
1068
  current_class_weight_factor = float(class_weight_factor.item()) if isinstance(class_weight_factor, torch.Tensor) else float(class_weight_factor)
1069
  current_mask_coverage = float(labels_mask.float().mean().item()) if labels_mask is not None else 0.0
1070
  epoch_loss += current_loss
@@ -1080,6 +1206,8 @@ def main() -> None:
1080
  "train/loss": current_loss,
1081
  "train/return_loss": current_return_loss,
1082
  "train/quality_loss": current_quality_loss,
 
 
1083
  "train/class_weight_factor": current_class_weight_factor,
1084
  "train/mask_coverage": current_mask_coverage,
1085
  "train/loss_ema": loss_ema if loss_ema is not None else current_loss,
@@ -1149,7 +1277,9 @@ def main() -> None:
1149
  logger.info(
1150
  f"Validation - Loss: {val_metrics['val/loss']:.4f} | "
1151
  f"Return: {val_metrics['val/return_loss']:.4f} | "
1152
- f"Quality: {val_metrics['val/quality_loss']:.4f}"
 
 
1153
  )
1154
  # Log per-class losses
1155
  class_loss_str = " | ".join(
@@ -1162,6 +1292,8 @@ def main() -> None:
1162
  "val/loss": val_metrics['val/loss'],
1163
  "val/return_loss": val_metrics['val/return_loss'],
1164
  "val/quality_loss": val_metrics['val/quality_loss'],
 
 
1165
  }, step=total_steps)
1166
  # Log per-class losses to tensorboard
1167
  for class_id, class_loss in val_metrics['class_losses'].items():
 
52
  from data.data_fetcher import DataFetcher
53
  from data.data_loader import OracleDataset
54
  from data.data_collator import MemecoinCollator
55
+ from data.context_targets import MOVEMENT_CLASS_NAMES
56
  from models.multi_modal_processor import MultiModalEncoder
57
  from models.helper_encoders import ContextualTimeEncoder
58
  from models.token_encoder import TokenEncoder
 
149
  return per_sample_num / per_sample_den
150
 
151
 
152
+ def masked_movement_cross_entropy(
153
+ logits: torch.Tensor,
154
+ targets: torch.Tensor,
155
+ mask: torch.Tensor,
156
+ class_weights: Optional[torch.Tensor] = None,
157
+ ) -> torch.Tensor:
158
+ if mask.sum() == 0:
159
+ return torch.tensor(0.0, device=logits.device, dtype=logits.dtype)
160
+
161
+ flat_logits = logits.reshape(-1, logits.shape[-1])
162
+ flat_targets = targets.reshape(-1)
163
+ flat_mask = mask.reshape(-1).bool()
164
+
165
+ valid_logits = flat_logits[flat_mask]
166
+ valid_targets = flat_targets[flat_mask]
167
+ if valid_logits.numel() == 0:
168
+ return torch.tensor(0.0, device=logits.device, dtype=logits.dtype)
169
+
170
+ return nn.functional.cross_entropy(valid_logits, valid_targets, weight=class_weights)
171
+
172
+
173
+ def movement_accuracy(
174
+ logits: torch.Tensor,
175
+ targets: torch.Tensor,
176
+ mask: torch.Tensor,
177
+ ) -> float:
178
+ if mask.sum().item() == 0:
179
+ return 0.0
180
+ preds = logits.argmax(dim=-1)
181
+ valid = mask.bool()
182
+ correct = (preds[valid] == targets[valid]).float()
183
+ if correct.numel() == 0:
184
+ return 0.0
185
+ return float(correct.mean().item())
186
+
187
+
188
+ def estimate_movement_class_weights(
189
+ dataset,
190
+ indices,
191
+ movement_label_config: Optional[Dict[str, float]] = None,
192
+ sample_cap: int = 4096,
193
+ ) -> torch.Tensor:
194
+ del movement_label_config
195
+ counts = torch.ones(len(MOVEMENT_CLASS_NAMES), dtype=torch.float32)
196
+ if not indices:
197
+ return counts
198
+
199
+ for idx in indices[: min(len(indices), sample_cap)]:
200
+ try:
201
+ item = dataset[idx]
202
+ except Exception:
203
+ continue
204
+ if not item:
205
+ continue
206
+ labels = item.get("labels")
207
+ labels_mask = item.get("labels_mask")
208
+ movement_targets = item.get("movement_class_targets")
209
+ movement_mask = item.get("movement_class_mask")
210
+ if movement_targets is None or movement_mask is None:
211
+ if labels is None or labels_mask is None:
212
+ continue
213
+ batch_targets = collator_like_targets(labels, labels_mask)
214
+ targets = batch_targets["movement_class_targets"]
215
+ mask = batch_targets["movement_class_mask"]
216
+ else:
217
+ targets = movement_targets.tolist() if isinstance(movement_targets, torch.Tensor) else movement_targets
218
+ mask = movement_mask.tolist() if isinstance(movement_mask, torch.Tensor) else movement_mask
219
+ for target, target_mask in zip(targets, mask):
220
+ if int(target_mask) > 0:
221
+ counts[int(target)] += 1.0
222
+
223
+ weights = counts.sum() / counts.clamp_min(1.0)
224
+ return weights / weights.mean().clamp_min(1e-6)
225
+
226
+
227
+ def collator_like_targets(labels, labels_mask, movement_label_config: Optional[Dict[str, float]] = None):
228
+ from data.context_targets import derive_movement_targets
229
+
230
+ labels_list = labels.tolist() if isinstance(labels, torch.Tensor) else labels
231
+ mask_list = labels_mask.tolist() if isinstance(labels_mask, torch.Tensor) else labels_mask
232
+ return derive_movement_targets(labels_list, mask_list, movement_label_config=movement_label_config)
233
+
234
+
235
  def create_balanced_split(dataset, n_val_per_class: int = 1, seed: int = 42):
236
  """
237
  Create train/val split with balanced classes in validation set.
 
291
  total_loss = 0.0
292
  total_return_loss = 0.0
293
  total_quality_loss = 0.0
294
+ total_movement_loss = 0.0
295
+ total_movement_acc = 0.0
296
  n_batches = 0
297
 
298
  # Per-class metrics
 
314
 
315
  preds = outputs["quantile_logits"]
316
  quality_preds = outputs["quality_logits"]
317
+ movement_logits = outputs.get("movement_logits")
318
  labels = batch["labels"]
319
  labels_mask = batch["labels_mask"]
320
  quality_targets = batch["quality_score"].to(accelerator.device, dtype=quality_preds.dtype)
321
+ movement_targets = batch.get("movement_class_targets")
322
+ movement_mask = batch.get("movement_class_mask")
323
 
324
  if labels_mask.sum() == 0:
325
  return_loss = torch.tensor(0.0, device=accelerator.device)
 
329
  return_loss = quantile_pinball_loss(preds, labels, labels_mask, quantiles)
330
 
331
  quality_loss = quality_loss_fn(quality_preds, quality_targets)
332
+ movement_loss = torch.tensor(0.0, device=accelerator.device)
333
+ movement_acc = 0.0
334
+ if movement_logits is not None and movement_targets is not None and movement_mask is not None:
335
+ movement_targets = movement_targets.to(accelerator.device)
336
+ movement_mask = movement_mask.to(accelerator.device)
337
+ movement_loss = masked_movement_cross_entropy(
338
+ movement_logits, movement_targets, movement_mask
339
+ )
340
+ movement_acc = movement_accuracy(movement_logits, movement_targets, movement_mask)
341
+ loss = return_loss + quality_loss + movement_loss
342
 
343
  total_loss += loss.item()
344
  total_return_loss += return_loss.item()
345
  total_quality_loss += quality_loss.item()
346
+ total_movement_loss += movement_loss.item()
347
+ total_movement_acc += movement_acc
348
  n_batches += 1
349
 
350
  # Track per-class losses
 
364
  'val/loss': total_loss / n_batches,
365
  'val/return_loss': total_return_loss / n_batches,
366
  'val/quality_loss': total_quality_loss / n_batches,
367
+ 'val/movement_loss': total_movement_loss / n_batches,
368
+ 'val/movement_acc': total_movement_acc / n_batches,
369
  'val/n_batches': n_batches,
370
  'class_losses': {k: v['loss'] / max(v['count'], 1) for k, v in class_losses.items()}
371
  }
 
548
  parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Path to checkpoint or 'latest'")
549
  parser.add_argument("--val_samples_per_class", type=int, default=1, help="Number of validation samples per class (default 1)")
550
  parser.add_argument("--val_every", type=int, default=1000, help="Run validation every N steps (default 1000)")
551
+ parser.add_argument("--movement_loss_weight", type=float, default=1.0, help="Auxiliary loss weight for movement classification head.")
552
  return parser.parse_args()
553
 
554
 
 
918
 
919
  # --- 7. Training Loop ---
920
  quality_loss_fn = nn.MSELoss()
921
+ movement_class_weights = estimate_movement_class_weights(
922
+ dataset,
923
+ train_indices,
924
+ ).to(accelerator.device)
925
 
926
  logger.info("***** Running training *****")
927
  logger.info(f" Num examples = {len(dataset)}")
 
972
 
973
  preds = outputs["quantile_logits"]
974
  quality_preds = outputs["quality_logits"]
975
+ movement_logits = outputs.get("movement_logits")
976
  labels = batch["labels"]
977
  labels_mask = batch["labels_mask"]
978
  if "quality_score" not in batch:
 
1028
  per_sample_return = quantile_pinball_loss_per_sample(preds, labels, labels_mask, quantiles)
1029
 
1030
  quality_loss = quality_loss_fn(quality_preds, quality_targets)
1031
+ movement_targets = batch.get("movement_class_targets")
1032
+ movement_mask = batch.get("movement_class_mask")
1033
+ if movement_logits is not None and movement_targets is not None and movement_mask is not None:
1034
+ movement_targets = movement_targets.to(accelerator.device)
1035
+ movement_mask = movement_mask.to(accelerator.device)
1036
+ movement_loss = masked_movement_cross_entropy(
1037
+ movement_logits,
1038
+ movement_targets,
1039
+ movement_mask,
1040
+ class_weights=movement_class_weights,
1041
+ )
1042
+ current_movement_acc = movement_accuracy(movement_logits, movement_targets, movement_mask)
1043
+ else:
1044
+ movement_loss = torch.tensor(0.0, device=accelerator.device, dtype=quality_loss.dtype)
1045
+ current_movement_acc = 0.0
1046
  per_sample_quality = (quality_preds - quality_targets).pow(2)
1047
 
1048
  # Apply per-sample class weighting to return loss
 
1051
  sample_weights = class_loss_weights[batch_class_ids] # [B]
1052
  # Scale return loss by mean class weight for this batch
1053
  class_weight_factor = sample_weights.mean()
1054
+ loss = return_loss * class_weight_factor + quality_loss + (args.movement_loss_weight * movement_loss)
1055
  else:
1056
  class_weight_factor = torch.tensor(1.0, device=accelerator.device, dtype=return_loss.dtype)
1057
+ loss = return_loss + quality_loss + (args.movement_loss_weight * movement_loss)
1058
  per_sample_total = per_sample_return * class_weight_factor + per_sample_quality
1059
 
1060
  if not torch.isfinite(loss).all().item():
 
1080
  "loss": loss.unsqueeze(0),
1081
  "return_loss": return_loss.unsqueeze(0),
1082
  "quality_loss": quality_loss.unsqueeze(0),
1083
+ "movement_loss": movement_loss.unsqueeze(0),
1084
  "preds": preds,
1085
  "quality_preds": quality_preds,
1086
+ "movement_logits": movement_logits,
1087
  "labels_raw": batch.get("labels"),
1088
  "labels_log": labels,
1089
  "labels_mask": labels_mask,
 
1190
  current_loss = loss.item()
1191
  current_return_loss = return_loss.item()
1192
  current_quality_loss = quality_loss.item()
1193
+ current_movement_loss = movement_loss.item()
1194
  current_class_weight_factor = float(class_weight_factor.item()) if isinstance(class_weight_factor, torch.Tensor) else float(class_weight_factor)
1195
  current_mask_coverage = float(labels_mask.float().mean().item()) if labels_mask is not None else 0.0
1196
  epoch_loss += current_loss
 
1206
  "train/loss": current_loss,
1207
  "train/return_loss": current_return_loss,
1208
  "train/quality_loss": current_quality_loss,
1209
+ "train/movement_loss": current_movement_loss,
1210
+ "train/movement_acc": current_movement_acc,
1211
  "train/class_weight_factor": current_class_weight_factor,
1212
  "train/mask_coverage": current_mask_coverage,
1213
  "train/loss_ema": loss_ema if loss_ema is not None else current_loss,
 
1277
  logger.info(
1278
  f"Validation - Loss: {val_metrics['val/loss']:.4f} | "
1279
  f"Return: {val_metrics['val/return_loss']:.4f} | "
1280
+ f"Quality: {val_metrics['val/quality_loss']:.4f} | "
1281
+ f"Movement: {val_metrics['val/movement_loss']:.4f} | "
1282
+ f"MoveAcc: {val_metrics['val/movement_acc']:.4f}"
1283
  )
1284
  # Log per-class losses
1285
  class_loss_str = " | ".join(
 
1292
  "val/loss": val_metrics['val/loss'],
1293
  "val/return_loss": val_metrics['val/return_loss'],
1294
  "val/quality_loss": val_metrics['val/quality_loss'],
1295
+ "val/movement_loss": val_metrics['val/movement_loss'],
1296
+ "val/movement_acc": val_metrics['val/movement_acc'],
1297
  }, step=total_steps)
1298
  # Log per-class losses to tensorboard
1299
  for class_id, class_loss in val_metrics['class_losses'].items():