zirobtc commited on
Commit
4dd4ab4
·
1 Parent(s): 8e3d126

Upload folder using huggingface_hub

Browse files
data/data_collator.py CHANGED
@@ -711,11 +711,15 @@ class MemecoinCollator:
711
  # Labels
712
  'labels': torch.stack([item['labels'] for item in batch]) if batch and 'labels' in batch[0] else None,
713
  'labels_mask': torch.stack([item['labels_mask'] for item in batch]) if batch and 'labels_mask' in batch[0] else None,
 
714
  # Debug info
715
  'token_addresses': [item.get('token_address', 'unknown') for item in batch],
716
  't_cutoffs': [item.get('t_cutoff', 'unknown') for item in batch],
717
  'sample_indices': [item.get('sample_idx', -1) for item in batch]
718
  }
719
 
 
 
 
720
  # Filter out None values (e.g., if no labels provided)
721
  return {k: v for k, v in collated_batch.items() if v is not None}
 
711
  # Labels
712
  'labels': torch.stack([item['labels'] for item in batch]) if batch and 'labels' in batch[0] else None,
713
  'labels_mask': torch.stack([item['labels_mask'] for item in batch]) if batch and 'labels_mask' in batch[0] else None,
714
+ 'quality_score': torch.stack([item['quality_score'] for item in batch]) if batch and 'quality_score' in batch[0] else None,
715
  # Debug info
716
  'token_addresses': [item.get('token_address', 'unknown') for item in batch],
717
  't_cutoffs': [item.get('t_cutoff', 'unknown') for item in batch],
718
  'sample_indices': [item.get('sample_idx', -1) for item in batch]
719
  }
720
 
721
+ if collated_batch['quality_score'] is None:
722
+ raise RuntimeError("FATAL: Missing quality_score in batch items. Rebuild cache with quality_score enabled.")
723
+
724
  # Filter out None values (e.g., if no labels provided)
725
  return {k: v for k, v in collated_batch.items() if v is not None}
data/data_loader.py CHANGED
@@ -156,43 +156,41 @@ class OracleDataset(Dataset):
156
  if not self.cached_files:
157
  raise RuntimeError(f"Cache directory '{self.cache_dir}' provided but contains no 'sample_*.pt' files.")
158
 
159
- # --- NEW: Strict Metadata & Weighting ---
160
- metadata_path = self.cache_dir / "metadata.jsonl"
161
- if not metadata_path.exists():
162
- raise RuntimeError(f"FATAL: metadata.jsonl not found in {self.cache_dir}. Cannot train without class-balanced sampling.")
163
-
164
- print(f"INFO: Loading metadata from {metadata_path}...")
165
  file_class_map = {}
166
  class_counts = defaultdict(int)
167
-
168
- with open(metadata_path, 'r') as f:
169
- for line in f:
 
 
170
  try:
171
- entry = json.loads(line)
172
- fname = entry['file']
173
- cid = entry['class_id']
174
- file_class_map[fname] = cid
175
- class_counts[cid] += 1
176
- except Exception as e:
177
- print(f"WARN: Failed to parse metadata line: {e}")
 
 
 
 
178
 
179
  print(f"INFO: Class Distribution: {dict(class_counts)}")
180
-
181
  # Compute Weights
182
  self.weights_list = []
183
  valid_files = []
184
-
185
  # We iterate properly sorted cached files to align with __getitem__ index
186
  for p in self.cached_files:
187
  fname = p.name
188
  if fname not in file_class_map:
189
- # Should be fatal if strict, but maybe some files were skipped?
190
- # If file exists but no metadata, we can't weight it properly.
191
- # Current pipeline writes metadata only for successful caches.
192
- # So if it's in cached_files but not metadata, it might be a stale file.
193
- print(f"WARN: File {fname} found in cache but missing metadata. Skipping.")
194
  continue
195
-
196
  cid = file_class_map[fname]
197
  count = class_counts[cid]
198
  weight = 1.0 / count if count > 0 else 0.0
@@ -976,7 +974,8 @@ class OracleDataset(Dataset):
976
  "fee_collections",
977
  "burns",
978
  "supply_locks",
979
- "migrations"
 
980
  ]
981
  missing_keys = [key for key in required_keys if key not in raw_data]
982
  if missing_keys:
@@ -1683,7 +1682,8 @@ class OracleDataset(Dataset):
1683
  'graph_links': graph_links,
1684
  'embedding_pooler': pooler,
1685
  'labels': torch.zeros(len(self.horizons_seconds), dtype=torch.float32),
1686
- 'labels_mask': torch.zeros(len(self.horizons_seconds), dtype=torch.float32)
 
1687
  }
1688
 
1689
  # Ensure sorted
@@ -1758,5 +1758,6 @@ class OracleDataset(Dataset):
1758
  'graph_links': graph_links,
1759
  'embedding_pooler': pooler,
1760
  'labels': torch.tensor(label_values, dtype=torch.float32),
1761
- 'labels_mask': torch.tensor(mask_values, dtype=torch.float32)
 
1762
  }
 
156
  if not self.cached_files:
157
  raise RuntimeError(f"Cache directory '{self.cache_dir}' provided but contains no 'sample_*.pt' files.")
158
 
159
+ # --- NEW: Strict Metadata & Weighting (from cached samples) ---
 
 
 
 
 
160
  file_class_map = {}
161
  class_counts = defaultdict(int)
162
+
163
+ # Read class_id directly from each cached sample
164
+ for p in self.cached_files:
165
+ try:
166
+ # Cached samples are trusted local artifacts; allow full load.
167
  try:
168
+ cached_item = torch.load(p, map_location="cpu", weights_only=False)
169
+ except TypeError:
170
+ cached_item = torch.load(p, map_location="cpu")
171
+ cid = cached_item.get("class_id")
172
+ if cid is None:
173
+ print(f"WARN: File {p.name} missing class_id. Skipping.")
174
+ continue
175
+ file_class_map[p.name] = cid
176
+ class_counts[cid] += 1
177
+ except Exception as e:
178
+ print(f"WARN: Failed to read cached sample {p.name}: {e}")
179
 
180
  print(f"INFO: Class Distribution: {dict(class_counts)}")
181
+
182
  # Compute Weights
183
  self.weights_list = []
184
  valid_files = []
185
+
186
  # We iterate properly sorted cached files to align with __getitem__ index
187
  for p in self.cached_files:
188
  fname = p.name
189
  if fname not in file_class_map:
190
+ # If file exists but missing class_id, it might be stale or from an older cache.
191
+ print(f"WARN: File {fname} found in cache but missing class_id. Skipping.")
 
 
 
192
  continue
193
+
194
  cid = file_class_map[fname]
195
  count = class_counts[cid]
196
  weight = 1.0 / count if count > 0 else 0.0
 
974
  "fee_collections",
975
  "burns",
976
  "supply_locks",
977
+ "migrations",
978
+ "quality_score"
979
  ]
980
  missing_keys = [key for key in required_keys if key not in raw_data]
981
  if missing_keys:
 
1682
  'graph_links': graph_links,
1683
  'embedding_pooler': pooler,
1684
  'labels': torch.zeros(len(self.horizons_seconds), dtype=torch.float32),
1685
+ 'labels_mask': torch.zeros(len(self.horizons_seconds), dtype=torch.float32),
1686
+ 'quality_score': torch.tensor(raw_data['quality_score'], dtype=torch.float32)
1687
  }
1688
 
1689
  # Ensure sorted
 
1758
  'graph_links': graph_links,
1759
  'embedding_pooler': pooler,
1760
  'labels': torch.tensor(label_values, dtype=torch.float32),
1761
+ 'labels_mask': torch.tensor(mask_values, dtype=torch.float32),
1762
+ 'quality_score': torch.tensor(raw_data['quality_score'], dtype=torch.float32)
1763
  }
data/ohlc_stats.npz CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e8366dfe6785692219a4d4bcbe5c3b111b5b9acd3df38fba7edd5d29bea20e15
3
  size 1660
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f2c86bf03e5761e7fb319a54274e032f7aa1d01dd5873f2f44a52c9e0be5244
3
  size 1660
log.log CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b4dd7b51859975e9b53550cdda3099bd1fd899d8b335ff3b90ab5ae7d9a1e4ff
3
- size 4414
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:461e55d31752fd72f09aa30c5bcc3a619654ae86ddf1e759c9c57b0dc5db53f6
3
+ size 21794
models/model.py CHANGED
@@ -54,7 +54,9 @@ class Oracle(nn.Module):
54
  self.dtype = dtype
55
 
56
  # --- 2. Load Qwen3 Configuration (architecture only; training from scratch) ---
57
- model_config = AutoConfig.from_pretrained(model_config_name, trust_remote_code=True)
 
 
58
  self.d_model = model_config.hidden_size
59
  self.model = AutoModel.from_config(model_config, trust_remote_code=True)
60
  self.model.to(self.device, dtype=self.dtype)
@@ -65,6 +67,11 @@ class Oracle(nn.Module):
65
  nn.GELU(),
66
  nn.Linear(self.d_model, self.num_outputs)
67
  )
 
 
 
 
 
68
 
69
  self.event_type_to_id = event_type_to_id
70
 
@@ -947,8 +954,10 @@ class Oracle(nn.Module):
947
  empty_hidden = torch.empty(0, L, self.d_model, device=device, dtype=self.dtype)
948
  empty_mask = torch.empty(0, L, device=device, dtype=torch.long)
949
  empty_quantiles = torch.empty(0, self.num_outputs, device=device, dtype=self.dtype)
 
950
  return {
951
  'quantile_logits': empty_quantiles,
 
952
  'pooled_states': torch.empty(0, self.d_model, device=device, dtype=self.dtype),
953
  'hidden_states': empty_hidden,
954
  'attention_mask': empty_mask
@@ -1068,9 +1077,11 @@ class Oracle(nn.Module):
1068
  sequence_hidden = outputs.last_hidden_state
1069
  pooled_states = self._pool_hidden_states(sequence_hidden, hf_attention_mask)
1070
  quantile_logits = self.quantile_head(pooled_states)
 
1071
 
1072
  return {
1073
  'quantile_logits': quantile_logits,
 
1074
  'pooled_states': pooled_states,
1075
  'hidden_states': sequence_hidden,
1076
  'attention_mask': hf_attention_mask
 
54
  self.dtype = dtype
55
 
56
  # --- 2. Load Qwen3 Configuration (architecture only; training from scratch) ---
57
+ hf_token = os.getenv("Hf_TOKEN") or os.getenv("HF_TOKEN")
58
+ hf_kwargs = {"token": hf_token} if hf_token else {}
59
+ model_config = AutoConfig.from_pretrained(model_config_name, trust_remote_code=True, **hf_kwargs)
60
  self.d_model = model_config.hidden_size
61
  self.model = AutoModel.from_config(model_config, trust_remote_code=True)
62
  self.model.to(self.device, dtype=self.dtype)
 
67
  nn.GELU(),
68
  nn.Linear(self.d_model, self.num_outputs)
69
  )
70
+ self.quality_head = nn.Sequential(
71
+ nn.Linear(self.d_model, self.d_model),
72
+ nn.GELU(),
73
+ nn.Linear(self.d_model, 1)
74
+ )
75
 
76
  self.event_type_to_id = event_type_to_id
77
 
 
954
  empty_hidden = torch.empty(0, L, self.d_model, device=device, dtype=self.dtype)
955
  empty_mask = torch.empty(0, L, device=device, dtype=torch.long)
956
  empty_quantiles = torch.empty(0, self.num_outputs, device=device, dtype=self.dtype)
957
+ empty_quality = torch.empty(0, device=device, dtype=self.dtype)
958
  return {
959
  'quantile_logits': empty_quantiles,
960
+ 'quality_logits': empty_quality,
961
  'pooled_states': torch.empty(0, self.d_model, device=device, dtype=self.dtype),
962
  'hidden_states': empty_hidden,
963
  'attention_mask': empty_mask
 
1077
  sequence_hidden = outputs.last_hidden_state
1078
  pooled_states = self._pool_hidden_states(sequence_hidden, hf_attention_mask)
1079
  quantile_logits = self.quantile_head(pooled_states)
1080
+ quality_logits = self.quality_head(pooled_states).squeeze(-1)
1081
 
1082
  return {
1083
  'quantile_logits': quantile_logits,
1084
+ 'quality_logits': quality_logits,
1085
  'pooled_states': pooled_states,
1086
  'hidden_states': sequence_hidden,
1087
  'attention_mask': hf_attention_mask
models/multi_modal_processor.py CHANGED
@@ -38,13 +38,16 @@ class MultiModalEncoder:
38
 
39
 
40
  try:
 
 
41
  # --- SigLIP Loading with Config Fix ---
42
  self.processor = AutoProcessor.from_pretrained(
43
  self.model_id,
44
- use_fast=True
 
45
  )
46
 
47
- config = AutoConfig.from_pretrained(self.model_id)
48
 
49
  if not hasattr(config, 'projection_dim'):
50
  # print("❗ Config missing projection_dim, patching...")
@@ -54,7 +57,8 @@ class MultiModalEncoder:
54
  self.model_id,
55
  config=config,
56
  dtype=self.dtype, # Use torch_dtype for from_pretrained
57
- trust_remote_code=False
 
58
  ).to(self.device).eval()
59
  # -----------------------------------------------
60
 
 
38
 
39
 
40
  try:
41
+ hf_token = os.getenv("Hf_TOKEN") or os.getenv("HF_TOKEN")
42
+ hf_kwargs = {"token": hf_token} if hf_token else {}
43
  # --- SigLIP Loading with Config Fix ---
44
  self.processor = AutoProcessor.from_pretrained(
45
  self.model_id,
46
+ use_fast=True,
47
+ **hf_kwargs
48
  )
49
 
50
+ config = AutoConfig.from_pretrained(self.model_id, **hf_kwargs)
51
 
52
  if not hasattr(config, 'projection_dim'):
53
  # print("❗ Config missing projection_dim, patching...")
 
57
  self.model_id,
58
  config=config,
59
  dtype=self.dtype, # Use torch_dtype for from_pretrained
60
+ trust_remote_code=False,
61
+ **hf_kwargs
62
  ).to(self.device).eval()
63
  # -----------------------------------------------
64
 
pre_cache.sh CHANGED
@@ -4,6 +4,6 @@
4
  echo "Starting dataset caching..."
5
  python3 scripts/cache_dataset.py \
6
  --ohlc_stats_path "/workspace/apollo/data/ohlc_stats.npz" \
7
- --max_samples 1000
8
 
9
  echo "Done!"
 
4
  echo "Starting dataset caching..."
5
  python3 scripts/cache_dataset.py \
6
  --ohlc_stats_path "/workspace/apollo/data/ohlc_stats.npz" \
7
+ --max_samples 50
8
 
9
  echo "Done!"
scripts/analyze_distribution.py CHANGED
@@ -1,21 +1,22 @@
1
-
2
  import os
3
  import sys
4
  import datetime
 
 
5
  from clickhouse_driver import Client as ClickHouseClient
6
 
7
  # Add parent to path
8
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
9
 
10
- # removed dotenv
11
- # load_dotenv()
12
 
13
  CLICKHOUSE_HOST = os.getenv("CLICKHOUSE_HOST", "localhost")
14
  CLICKHOUSE_PORT = int(os.getenv("CLICKHOUSE_PORT", 9000))
15
- # .env shows empty user/pass, which implies 'default' user and empty password for ClickHouse
16
  CLICKHOUSE_USER = os.getenv("CLICKHOUSE_USER", "default")
17
  CLICKHOUSE_PASSWORD = os.getenv("CLICKHOUSE_PASSWORD", "")
18
  CLICKHOUSE_DATABASE = os.getenv("CLICKHOUSE_DATABASE", "default")
 
 
19
 
20
  def get_client():
21
  return ClickHouseClient(
@@ -26,484 +27,331 @@ def get_client():
26
  database=CLICKHOUSE_DATABASE
27
  )
28
 
29
- def print_distribution_stats(client, metric_name, subquery, bucket_case_sql):
30
- print(f"\n -> {metric_name}")
31
-
32
- # 1. Print Basic Stats (Mean, Quantiles)
33
- stats_query = f"""
34
- SELECT
35
- avg(val),
36
- quantiles(0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99)(val),
37
- min(val),
38
- max(val),
39
- count()
40
- FROM (
41
- {subquery}
42
- )
43
  """
44
- try:
45
- stats = client.execute(stats_query)[0]
46
- avg_val = stats[0]
47
- qs = stats[1]
48
- min_val = stats[2]
49
- max_val = stats[3]
50
- count_val = stats[4]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- if count_val == 0:
53
- print(" No data for this segment.")
54
- return
 
 
 
 
 
 
55
 
56
- print(f" Mean: {avg_val:.4f} | Min: {min_val:.4f} | Max: {max_val:.4f}")
57
- print(f" Q: p10={qs[0]:.2f} p50={qs[2]:.2f} p90={qs[4]:.2f} p99={qs[6]:.2f}")
 
 
 
 
 
 
 
58
 
59
- except Exception as e:
60
- print(f" Error calculating stats: {e}")
61
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- # 2. Print Buckets
64
- query = f"""
65
- SELECT
66
- {bucket_case_sql} as bucket,
67
- count() as cnt
68
- FROM (
69
- {subquery}
70
- )
71
- GROUP BY bucket
72
- ORDER BY bucket
73
- """
74
- try:
75
- rows = client.execute(query)
76
- # total_count used for pct is the count_val from stats
77
- print(" Buckets:")
78
- for r in rows:
79
- pct = (r[1] / count_val * 100) if count_val > 0 else 0
80
- print(f" {r[0]}: {r[1]} ({pct:.1f}%)")
81
- except Exception as e:
82
- print(f" Error calculating buckets: {e}")
83
 
84
- def get_filtered_metric_query(inner_query, cohort_sql):
85
- """
86
- Wraps the inner metric query to only include tokens in the cohort.
87
- Assumes inner_query returns 'base_address' (or aliased) and 'val'.
88
- If the inner query returns 'token_address', it should be handled.
89
- Most of our queries return 'base_address' (from trades) or 'token_address' (from token_metrics).
90
- We will normalize to use 'base_address' via subquery alias if needed, but simplest is
91
- to filter on the outer Select.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  """
93
- # We need to know if the inner query produces 'base_address' or 'token_address'
94
- # Currently our queries produce 'base_address' mostly, except token_metrics ones.
95
- # Let's standardize inner queries in the main loop to alias the key column to 'join_key'
96
 
97
- return f"""
98
- SELECT * FROM (
99
- {inner_query}
100
- ) WHERE join_key IN ({cohort_sql})
101
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
- import numpy as np
104
- from models.vocabulary import RETURN_THRESHOLDS, MANIPULATED_CLASS_ID
 
 
 
 
 
 
 
 
 
105
 
106
- def get_return_class_map(client):
107
- """
108
- Returns a dictionary mapping token_address -> class_id (int)
109
- Filters out tokens with > 10,000x return.
110
- Implements Dynamic Outlier Detection:
111
- - Calculates Median Fees, Volume, Holders for each Class (1-4).
112
- - Downgrades tokens with metrics < 10% of their class median to Class 5 (Manipulated).
113
  """
114
- print(" -> Fetching metrics for classification...")
115
- # improved query to get fees/vol/holders
116
- # aggregating trades for fees/vol to appear more robust than token_metrics snapshots
117
- print(" -> Fetching metrics for classification...")
118
- # SQL OPTIMIZATION:
119
- # 1. Use token_metrics for Volume/Holders (Pre-computed).
120
- # 2. Pre-aggregate trades for Fees in a subquery to avoid massive JOIN explosion.
121
- query = """
122
- SELECT
123
- tm.token_address,
124
- (argMax(tm.ath_price_usd, tm.updated_at) / 0.000004) as ret,
125
- any(tr.fees) as fees,
126
- argMax(tm.total_volume_usd, tm.updated_at) as vol,
127
- argMax(tm.unique_holders, tm.updated_at) as holders
128
- FROM token_metrics tm
129
- LEFT JOIN (
130
- SELECT
131
- base_address,
132
- sum(priority_fee + coin_creator_fee) as fees
133
- FROM trades
134
- GROUP BY base_address
135
- ) tr ON tm.token_address = tr.base_address
136
- GROUP BY tm.token_address
137
- HAVING ret <= 10000
138
  """
139
- rows = client.execute(query)
140
-
141
  # 1. Initial Classification
142
- temp_map = {} # token -> {class_id, fees, vol, holders}
143
-
144
- # Storage for stats calculation
145
- class_stats = {i: {'fees': [], 'vol': [], 'holders': []} for i in range(len(RETURN_THRESHOLDS)-1)}
146
 
147
- print(f" -> Initial classification of {len(rows)} tokens...")
148
- for r in rows:
149
- token_addr = r[0]
150
- ret_val = r[1]
151
- fees = r[2] or 0.0
152
- vol = r[3] or 0.0
153
- holders = r[4] or 0
154
 
155
- class_id = -1
 
156
  for i in range(len(RETURN_THRESHOLDS) - 1):
157
  lower = RETURN_THRESHOLDS[i]
158
  upper = RETURN_THRESHOLDS[i+1]
159
- if ret_val >= lower and ret_val < upper:
160
- class_id = i
 
161
  break
162
 
163
- if class_id != -1:
164
- temp_map[token_addr] = {'id': class_id, 'fees': fees, 'vol': vol, 'holders': holders}
165
- class_stats[class_id]['fees'].append(fees)
166
- class_stats[class_id]['vol'].append(vol)
167
- class_stats[class_id]['holders'].append(holders)
 
 
168
 
169
- # 2. Calculate Medians & Thresholds
 
170
  thresholds = {}
171
- print(" -> Calculating Class Medians & Thresholds (< 10% of Median)...")
172
- for i in range(1, 5): # Check classes 1, 2, 3, 4 (Profitable to PVE)
173
- # Class 0 (Garbage) is not checked/filtered
174
- if len(class_stats[i]['fees']) > 0:
175
- med_fees = np.median(class_stats[i]['fees'])
176
- med_vol = np.median(class_stats[i]['vol'])
177
- med_holders = np.median(class_stats[i]['holders'])
 
 
 
 
178
 
179
  thresholds[i] = {
180
  'fees': med_fees * 0.5,
181
  'vol': med_vol * 0.5,
182
  'holders': med_holders * 0.5
183
  }
184
- print(f" [Class {i}] Median Fees: {med_fees:.4f} (Thresh: {thresholds[i]['fees']:.4f}) | Median Vol: ${med_vol:.0f} (Thresh: ${thresholds[i]['vol']:.0f}) | Median Holders: {med_holders:.0f} (Thresh: {thresholds[i]['holders']:.0f})")
185
  else:
186
- thresholds[i] = {'fees': 0, 'vol': 0, 'holders': 0}
187
 
188
  # 3. Reclassification
189
- print(" -> Detecting Manipulated Outliers...")
190
- final_map = {}
191
- manipulated_count = 0
 
192
 
193
- for token, data in temp_map.items():
194
- cid = data['id']
195
- # Only check if it's a "successful" class (ID > 0)
196
- if cid > 0 and cid in thresholds:
197
- t = thresholds[cid]
198
- # Condition: If ANY metric is suspiciously low
199
- is_manipulated = (data['fees'] < t['fees']) or (data['vol'] < t['vol']) or (data['holders'] < t['holders'])
 
200
 
201
- if is_manipulated:
202
- final_map[token] = MANIPULATED_CLASS_ID
203
- manipulated_count += 1
204
- else:
205
- final_map[token] = cid
206
- else:
207
- final_map[token] = cid
208
 
209
- print(f" -> Reclassification Complete. identified {manipulated_count} manipulated tokens.")
210
- return final_map, thresholds
211
-
212
- def analyze():
213
- client = get_client()
214
 
215
- print("=== SEGMENTED DISTRIBUTION ANALYSIS ===")
216
-
217
- # 1. Get Classified Map AND Thresholds
218
- class_map, thresholds = get_return_class_map(client)
 
 
 
219
 
220
- # 2. Invert Map for easy lookups (still useful for counts or smaller segments)
221
- segments_tokens = {}
222
- for t, c in class_map.items():
223
- if c not in segments_tokens:
224
- segments_tokens[c] = []
225
- segments_tokens[c].append(t)
 
 
 
 
 
 
 
 
 
226
 
227
- # Define Labels from thresholds so bucket changes don't silently desync output.
228
- labels = {}
229
- for i in range(len(RETURN_THRESHOLDS) - 1):
230
- lower = RETURN_THRESHOLDS[i]
231
- upper = RETURN_THRESHOLDS[i + 1]
232
- labels[i] = f"{i}. {lower}x - {upper}x"
233
- labels[MANIPULATED_CLASS_ID] = f"{MANIPULATED_CLASS_ID}. MANIPULATED (Fake Metrics)"
234
-
235
- # Common SQL parts
236
- # We need a robust base for the WHERE clause variables (fees, vol, holders)
237
- # Since we can't easily alias in the WHERE clause of a subquery filter without re-joining,
238
- # we will rely on a standardized CTE-like structure or just simpler subqueries in the condition.
239
 
240
- # Efficient Token Metrics View
241
- # We need to filter based on: ret, fees, vol, holders
242
- # fees come from trades (sum), vol/holders/ret from token_metrics (argMax)
 
 
 
243
 
244
- # To keep query size small, we define the criteria logic in SQL.
245
- # But we need 'fees' which is an aggregate.
246
- # So we define a base cohort query that computes these 4 values for EVERY token,
247
- # and then wrap it with the WHERE clause.
248
 
249
- base_cohort_source = """
250
- SELECT
251
- tm.token_address as join_key,
252
- (argMax(tm.ath_price_usd, tm.updated_at) / 0.000004) as ret,
253
- any(tr.fees) as fees,
254
- argMax(tm.total_volume_usd, tm.updated_at) as vol,
255
- argMax(tm.unique_holders, tm.updated_at) as holders
256
- FROM token_metrics tm
257
- LEFT JOIN (
258
- SELECT base_address, sum(priority_fee + coin_creator_fee) as fees
259
- FROM trades
260
- GROUP BY base_address
261
- ) tr ON tm.token_address = tr.base_address
262
- GROUP BY tm.token_address
263
- """
264
 
265
- # Iterate through known classes
266
- for cid in sorted(labels.keys()):
267
- label = labels[cid]
268
- tokens = segments_tokens.get(cid, [])
269
- count = len(tokens)
270
-
271
- print(f"\n\n==================================================")
272
- print(f"SEGMENT: {label}")
273
- print(f"==================================================")
274
- print(f"Tokens in segment: {count}")
275
-
276
- if count == 0:
277
- continue
278
-
279
- # Construct SQL Condition based on ID
280
- condition = "1=0" # Default fail
281
 
282
- if cid == 0:
283
- # Garbage: Just Return < 3.
284
- # Note: Technically it also includes tokens that might have been >3x but <10000x...
285
- # BUT our Python/Map logic says Garbage is class 0.
286
- # The only way to be class 0 in the map is if ret < 3.
287
- # Downgraded tokens go to Class 5.
288
- condition = "ret < 3"
289
-
290
- elif cid == MANIPULATED_CLASS_ID:
291
- # Manipulated:
292
- # It's the collection of (Class K logic AND is_outlier)
293
- sub_conds = []
294
- for k in range(1, 5):
295
- if k in thresholds:
296
- t = thresholds[k]
297
- # Range for Class K
298
- lower = RETURN_THRESHOLDS[k]
299
- upper = RETURN_THRESHOLDS[k+1]
300
- # Outlier logic
301
- sub_conds.append(f"(ret >= {lower} AND ret < {upper} AND (fees < {t['fees']} OR vol < {t['vol']} OR holders < {t['holders']}))")
302
-
303
- if sub_conds:
304
- condition = " OR ".join(sub_conds)
305
-
306
  else:
307
- # Normal Classes 1-4
308
- if cid in thresholds:
309
- t = thresholds[cid]
310
- lower = RETURN_THRESHOLDS[cid]
311
- upper = RETURN_THRESHOLDS[cid+1]
312
- # Valid logic: In Range AND NOT Outlier
313
- condition = f"(ret >= {lower} AND ret < {upper} AND fees >= {t['fees']} AND vol >= {t['vol']} AND holders >= {t['holders']})"
314
 
315
- # Final Cohort SQL: Select keys satisfying the condition
316
- # We wrap the base source
317
- cohort_sql = f"""
318
- SELECT join_key FROM (
319
- {base_cohort_source}
320
- ) WHERE {condition}
321
- """
322
-
323
- # Helper to construct the full condition "join_key IN (...)"
324
- # NOW we use the subquery instead of a literal list
325
- def make_query(inner, cohort_subquery):
326
- return f"""
327
- SELECT * FROM (
328
- {inner}
329
- ) WHERE join_key IN (
330
- {cohort_subquery}
331
- )
332
- """
333
-
334
- # --- Metrics Definitions ---
335
 
336
- # 1. Fees (SOL)
337
- fees_inner = """
338
- SELECT base_address as join_key, sum(priority_fee + coin_creator_fee) as val
339
- FROM trades
340
- GROUP BY base_address
341
- """
342
- fees_buckets = """
343
- case
344
- when val < 0.001 then '1. < 0.001 SOL'
345
- when val >= 0.001 AND val < 0.01 then '2. 0.001 - 0.01'
346
- when val >= 0.01 AND val < 0.1 then '3. 0.01 - 0.1'
347
- when val >= 0.1 AND val < 1 then '4. 0.1 - 1'
348
- when val >= 1 then '5. > 1 SOL'
349
- else 'Unknown'
350
- end
351
- """
352
- print_distribution_stats(client, "Total Fees (SOL)", make_query(fees_inner, cohort_sql), fees_buckets)
353
-
354
- # 2. Volume (USD)
355
- vol_inner = """
356
- SELECT base_address as join_key, sum(total_usd) as val
357
- FROM trades
358
- GROUP BY base_address
359
- """
360
- vol_buckets = """
361
- case
362
- when val < 1000 then '1. < $1k'
363
- when val >= 1000 AND val < 10000 then '2. $1k - $10k'
364
- when val >= 10000 AND val < 100000 then '3. $10k - $100k'
365
- when val >= 100000 AND val < 1000000 then '4. $100k - $1M'
366
- when val >= 1000000 then '5. > $1M'
367
- else 'Unknown'
368
- end
369
- """
370
- print_distribution_stats(client, "Total Volume (USD)", make_query(vol_inner, cohort_sql), vol_buckets)
371
-
372
- # 3. Unique Holders
373
- holders_inner = """
374
- SELECT token_address as join_key, argMax(unique_holders, updated_at) as val
375
- FROM token_metrics
376
- GROUP BY token_address
377
- """
378
- holders_buckets = """
379
- case
380
- when val < 10 then '1. < 10'
381
- when val >= 10 AND val < 50 then '2. 10 - 50'
382
- when val >= 50 AND val < 100 then '3. 50 - 100'
383
- when val >= 100 AND val < 500 then '4. 100 - 500'
384
- when val >= 500 then '5. > 500'
385
- else 'Unknown'
386
- end
387
- """
388
- print_distribution_stats(client, "Unique Holders", make_query(holders_inner, cohort_sql), holders_buckets)
389
-
390
- # 4. Snipers % Supply
391
- snipers_inner = """
392
- SELECT
393
- m.base_address as join_key,
394
- (m.val / t.total_supply * 100) as val
395
- FROM (
396
- SELECT
397
- base_address,
398
- sumIf(base_amount, buyer_rank <= 70) as val
399
- FROM (
400
- SELECT
401
- base_address,
402
- base_amount,
403
- dense_rank() OVER (PARTITION BY base_address ORDER BY min_slot, min_idx) as buyer_rank
404
- FROM (
405
- SELECT
406
- base_address,
407
- maker,
408
- min(slot) as min_slot,
409
- min(transaction_index) as min_idx,
410
- sum(base_amount) as base_amount
411
- FROM trades
412
- WHERE trade_type = 0
413
- GROUP BY base_address, maker
414
- )
415
- )
416
- GROUP BY base_address
417
- ) m
418
- JOIN (
419
- SELECT token_address, argMax(total_supply, updated_at) as total_supply
420
- FROM tokens
421
- GROUP BY token_address
422
- ) t ON m.base_address = t.token_address
423
- WHERE t.total_supply > 0
424
- """
425
- pct_buckets = """
426
- case
427
- when val < 1 then '1. < 1%'
428
- when val >= 1 AND val < 5 then '2. 1% - 5%'
429
- when val >= 5 AND val < 10 then '3. 5% - 10%'
430
- when val >= 10 AND val < 20 then '4. 10% - 20%'
431
- when val >= 20 AND val < 50 then '5. 20% - 50%'
432
- when val >= 50 then '6. > 50%'
433
- else 'Unknown'
434
- end
435
- """
436
- print_distribution_stats(client, "Snipers % Supply (Top 70)", make_query(snipers_inner, cohort_sql), pct_buckets)
437
-
438
- # 5. Bundled % Supply
439
- bundled_inner = """
440
- SELECT
441
- m.base_address as join_key,
442
- (m.val / t.total_supply * 100) as val
443
- FROM (
444
- SELECT
445
- t.base_address,
446
- sum(t.base_amount) as val
447
- FROM trades t
448
- JOIN (
449
- SELECT base_address, min(slot) as min_slot
450
- FROM trades
451
- GROUP BY base_address
452
- ) m ON t.base_address = m.base_address AND t.slot = m.min_slot
453
- WHERE t.trade_type = 0
454
- GROUP BY t.base_address
455
- ) m
456
- JOIN (
457
- SELECT token_address, argMax(total_supply, updated_at) as total_supply
458
- FROM tokens
459
- GROUP BY token_address
460
- ) t ON m.base_address = t.token_address
461
- WHERE t.total_supply > 0
462
- """
463
- print_distribution_stats(client, "Bundled % Supply", make_query(bundled_inner, cohort_sql), pct_buckets)
464
-
465
- # 6. Dev Holding % Supply
466
- dev_inner = """
467
- SELECT
468
- t.token_address as join_key,
469
- (wh.current_balance / (t.total_supply / pow(10, t.decimals)) * 100) as val
470
- FROM (
471
- SELECT token_address, argMax(creator_address, updated_at) as creator_address, argMax(total_supply, updated_at) as total_supply, argMax(decimals, updated_at) as decimals
472
- FROM tokens
473
- GROUP BY token_address
474
- ) t
475
- JOIN (
476
- SELECT mint_address, wallet_address, argMax(current_balance, updated_at) as current_balance
477
- FROM wallet_holdings
478
- GROUP BY mint_address, wallet_address
479
- ) wh ON t.token_address = wh.mint_address AND t.creator_address = wh.wallet_address
480
- WHERE t.total_supply > 0
481
- """
482
- print_distribution_stats(client, "Dev Holding % Supply", make_query(dev_inner, cohort_sql), pct_buckets)
483
-
484
-
485
-
486
- # 8. Time to ATH (Seconds)
487
- time_ath_inner = """
488
- SELECT
489
- base_address as join_key,
490
- (argMax(timestamp, price_usd) - min(timestamp)) as val
491
- FROM trades
492
- GROUP BY base_address
493
- """
494
- time_ath_buckets = """
495
- case
496
- when val < 5 then '1. < 5s'
497
- when val >= 5 AND val < 30 then '2. 5s - 30s'
498
- when val >= 30 AND val < 60 then '3. 30s - 1m'
499
- when val >= 60 AND val < 300 then '4. 1m - 5m'
500
- when val >= 300 AND val < 3600 then '5. 5m - 1h'
501
- when val >= 3600 then '6. > 1h'
502
- else 'Unknown'
503
- end
504
- """
505
- print_distribution_stats(client, "Time to ATH (Seconds)", make_query(time_ath_inner, cohort_sql), time_ath_buckets)
506
-
507
 
508
  if __name__ == "__main__":
509
  analyze()
 
 
1
  import os
2
  import sys
3
  import datetime
4
+ import numpy as np
5
+ import math
6
  from clickhouse_driver import Client as ClickHouseClient
7
 
8
  # Add parent to path
9
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
10
 
11
+ from models.vocabulary import RETURN_THRESHOLDS, MANIPULATED_CLASS_ID
 
12
 
13
  CLICKHOUSE_HOST = os.getenv("CLICKHOUSE_HOST", "localhost")
14
  CLICKHOUSE_PORT = int(os.getenv("CLICKHOUSE_PORT", 9000))
 
15
  CLICKHOUSE_USER = os.getenv("CLICKHOUSE_USER", "default")
16
  CLICKHOUSE_PASSWORD = os.getenv("CLICKHOUSE_PASSWORD", "")
17
  CLICKHOUSE_DATABASE = os.getenv("CLICKHOUSE_DATABASE", "default")
18
+ LAUNCH_PRICE_USD = 0.000004
19
+ EPS = 1e-9
20
 
21
  def get_client():
22
  return ClickHouseClient(
 
27
  database=CLICKHOUSE_DATABASE
28
  )
29
 
30
+ def fetch_all_metrics(client):
31
+ """
32
+ Fetches all needed metrics for all tokens in a single query.
33
+ Base Table: MINTS (to ensure we cover all ~50k tokens).
34
+ Definitions:
35
+ - Snipers: Peak Balance Sum of top 70 buyers
36
+ - Bundles: Base Amount Sum of trades in multi-buy slots
37
+ - Dev Hold: Max Peak Balance of Creator
 
 
 
 
 
 
38
  """
39
+ print(" -> Fetching all token metrics (Unified Query)...")
40
+
41
+ query = f"""
42
+ WITH
43
+ -- 1. Aggregated trade stats (Fees, Volume, ATH Time)
44
+ trade_agg AS (
45
+ SELECT
46
+ base_address,
47
+ sum(priority_fee + coin_creator_fee) AS fees_sol,
48
+ sum(total_usd) AS volume_usd,
49
+ count() AS n_trades,
50
+ argMax(timestamp, price_usd) AS t_ath,
51
+ min(timestamp) AS t0
52
+ FROM trades
53
+ GROUP BY base_address
54
+ ),
55
+
56
+ -- 2. Token Metadata from MINTS (Base Source of Truth)
57
+ token_meta AS (
58
+ SELECT
59
+ mint_address AS token_address,
60
+ argMax(creator_address, timestamp) AS creator_address,
61
+ argMax(total_supply, timestamp) AS total_supply,
62
+ argMax(token_decimals, timestamp) AS decimals
63
+ FROM mints
64
+ GROUP BY mint_address
65
+ ),
66
 
67
+ -- 3. Returns & Holders (from Token Metrics or manual calc)
68
+ metrics AS (
69
+ SELECT
70
+ token_address,
71
+ argMax(ath_price_usd, updated_at) as ath_price_usd,
72
+ argMax(unique_holders, updated_at) as unique_holders
73
+ FROM token_metrics
74
+ GROUP BY token_address
75
+ ),
76
 
77
+ -- 4. WALLET PEAKS (normalized balance likely)
78
+ wallet_peaks AS (
79
+ SELECT
80
+ mint_address,
81
+ wallet_address,
82
+ max(current_balance) AS peak_balance
83
+ FROM wallet_holdings
84
+ GROUP BY mint_address, wallet_address
85
+ ),
86
 
87
+ -- 5. SNIPERS: Identify sniper addresses (rank <= 70)
88
+ snipers_list AS (
89
+ SELECT
90
+ base_address,
91
+ maker
92
+ FROM (
93
+ SELECT
94
+ base_address,
95
+ maker,
96
+ dense_rank() OVER (PARTITION BY base_address ORDER BY min_slot, min_idx) AS buyer_rank
97
+ FROM (
98
+ SELECT
99
+ base_address,
100
+ maker,
101
+ min(slot) AS min_slot,
102
+ min(transaction_index) AS min_idx
103
+ FROM trades
104
+ WHERE trade_type = 0 -- buy
105
+ GROUP BY base_address, maker
106
+ )
107
+ )
108
+ WHERE buyer_rank <= 70
109
+ ),
110
+ snipers_agg AS (
111
+ SELECT
112
+ s.base_address AS token_address,
113
+ sum(wp.peak_balance) AS snipers_total_peak
114
+ FROM snipers_list s
115
+ JOIN wallet_peaks wp ON s.base_address = wp.mint_address AND s.maker = wp.wallet_address
116
+ GROUP BY s.base_address
117
+ ),
118
 
119
+ -- 6. BUNDLED: Sum the base_amount of ALL trades that happened in a slot with multiple buys
120
+ bundled_agg AS (
121
+ SELECT
122
+ t.base_address AS token_address,
123
+ sum(t.base_amount) AS bundled_total_peak
124
+ FROM trades t
125
+ WHERE (t.base_address, t.slot) IN (
126
+ SELECT base_address, slot
127
+ FROM trades
128
+ WHERE trade_type = 0 -- buy
129
+ GROUP BY base_address, slot
130
+ HAVING count() > 1
131
+ )
132
+ AND t.trade_type = 0 -- buy
133
+ GROUP BY t.base_address
134
+ ),
 
 
 
 
135
 
136
+ -- 7. DEV HOLD: Creator's Peak Balance
137
+ dev_hold_agg AS (
138
+ SELECT
139
+ t.token_address,
140
+ max(wp.peak_balance) AS dev_peak
141
+ FROM token_meta t
142
+ JOIN wallet_peaks wp ON t.token_address = wp.mint_address AND t.creator_address = wp.wallet_address
143
+ GROUP BY t.token_address
144
+ )
145
+
146
+ SELECT
147
+ t.token_address,
148
+ (COALESCE(m.ath_price_usd, ta.t_ath, 0) / {LAUNCH_PRICE_USD}) AS ret,
149
+
150
+ COALESCE(ta.fees_sol, 0) AS fees_sol,
151
+ COALESCE(ta.volume_usd, 0) AS volume_usd,
152
+ COALESCE(m.unique_holders, 0) AS unique_holders,
153
+ (ta.t_ath - ta.t0) AS time_to_ath_sec,
154
+
155
+ COALESCE(s.snipers_total_peak, 0) AS snipers_val,
156
+ COALESCE(b.bundled_total_peak, 0) AS bundled_val,
157
+ COALESCE(d.dev_peak, 0) AS dev_val,
158
+
159
+ t.total_supply AS total_supply,
160
+ t.decimals AS decimals
161
+
162
+ FROM token_meta t
163
+ LEFT JOIN trade_agg ta ON t.token_address = ta.base_address
164
+ LEFT JOIN metrics m ON t.token_address = m.token_address
165
+ LEFT JOIN snipers_agg s ON t.token_address = s.token_address
166
+ LEFT JOIN bundled_agg b ON t.token_address = b.token_address
167
+ LEFT JOIN dev_hold_agg d ON t.token_address = d.token_address
168
  """
 
 
 
169
 
170
+ rows = client.execute(query)
171
+ # Convert to list of dicts
172
+ cols = [
173
+ "token_address", "ret", "fees_sol", "volume_usd", "unique_holders", "time_to_ath_sec",
174
+ "snipers_val", "bundled_val", "dev_val", "total_supply", "decimals"
175
+ ]
176
+ results = []
177
+
178
+ print(f" -> Fetched {len(rows)} tokens.")
179
+
180
+ for r in rows:
181
+ d = dict(zip(cols, r))
182
+
183
+ supply = d["total_supply"]
184
+ decimals = d["decimals"]
185
+
186
+ try:
187
+ adj_supply = supply / (10 ** decimals) if (supply and decimals is not None) else supply
188
+ except:
189
+ adj_supply = supply
190
 
191
+ if adj_supply and adj_supply > 0:
192
+ d["snipers_pct"] = (d["snipers_val"] / adj_supply) * 100
193
+ d["dev_hold_pct"] = (d["dev_val"] / adj_supply) * 100
194
+ else:
195
+ d["snipers_pct"] = 0.0
196
+ d["dev_hold_pct"] = 0.0
197
+
198
+ if supply and supply > 0:
199
+ d["bundled_pct"] = (d["bundled_val"] / supply) * 100
200
+ else:
201
+ d["bundled_pct"] = 0.0
202
 
203
+ results.append(d)
204
+
205
+ return results
206
+
207
+ def _classify_tokens(data):
 
 
208
  """
209
+ Internal logic: returns (buckets_dict, thresholds_dict, count_manipulated)
210
+ buckets_dict: {class_id: [list of tokens]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  """
 
 
212
  # 1. Initial Classification
213
+ temp_buckets = {i: [] for i in range(len(RETURN_THRESHOLDS))}
 
 
 
214
 
215
+ for d in data:
216
+ ret = d["ret"]
217
+ if ret > 10000: continue
 
 
 
 
218
 
219
+ cid = 0
220
+ found = False
221
  for i in range(len(RETURN_THRESHOLDS) - 1):
222
  lower = RETURN_THRESHOLDS[i]
223
  upper = RETURN_THRESHOLDS[i+1]
224
+ if ret >= lower and ret < upper:
225
+ cid = i
226
+ found = True
227
  break
228
 
229
+ if found:
230
+ d["class_id_initial"] = cid
231
+ temp_buckets[cid].append(d)
232
+ else:
233
+ if ret >= 10000: continue
234
+ d["class_id_initial"] = 0
235
+ temp_buckets[0].append(d)
236
 
237
+ # 2. Calculate Thresholds (50% of Median)
238
+ print("\n -> Calculating Class Medians & Thresholds (Dynamic Outlier Detection)...")
239
  thresholds = {}
240
+
241
+ for i in range(1, len(RETURN_THRESHOLDS)-1):
242
+ items = temp_buckets.get(i, [])
243
+ if len(items) > 5:
244
+ fees = [x["fees_sol"] for x in items]
245
+ vols = [x["volume_usd"] for x in items]
246
+ holders = [x["unique_holders"] for x in items]
247
+
248
+ med_fees = np.median(fees)
249
+ med_vol = np.median(vols)
250
+ med_holders = np.median(holders)
251
 
252
  thresholds[i] = {
253
  'fees': med_fees * 0.5,
254
  'vol': med_vol * 0.5,
255
  'holders': med_holders * 0.5
256
  }
 
257
  else:
258
+ thresholds[i] = {'fees': 0, 'vol': 0, 'holders': 0}
259
 
260
  # 3. Reclassification
261
+ final_buckets = {i: [] for i in range(len(RETURN_THRESHOLDS))}
262
+ final_buckets[MANIPULATED_CLASS_ID] = []
263
+
264
+ count_manipulated = 0
265
 
266
+ for cid, items in temp_buckets.items():
267
+ for d in items:
268
+ final_cid = cid
269
+ if cid > 0 and cid in thresholds:
270
+ t = thresholds[cid]
271
+ if (d["fees_sol"] < t['fees']) or (d["volume_usd"] < t['vol']) or (d["unique_holders"] < t['holders']):
272
+ final_cid = MANIPULATED_CLASS_ID
273
+ count_manipulated += 1
274
 
275
+ d["class_id_final"] = final_cid
276
+ if final_cid not in final_buckets:
277
+ final_buckets[final_cid] = []
278
+ final_buckets[final_cid].append(d)
 
 
 
279
 
280
+ return final_buckets, thresholds, count_manipulated
 
 
 
 
281
 
282
+ def get_return_class_map(client):
283
+ """
284
+ Returns (map {token_addr: class_id}, thresholds)
285
+ Used by cache_dataset.py
286
+ """
287
+ data = fetch_all_metrics(client)
288
+ buckets, thresholds, _ = _classify_tokens(data)
289
 
290
+ # Flatten buckets to map
291
+ ret_map = {}
292
+ for cid, items in buckets.items():
293
+ for d in items:
294
+ ret_map[d["token_address"]] = cid
295
+
296
+ return ret_map, thresholds
297
+
298
+ def print_stats(name, values):
299
+ """
300
+ prints compact stats: mean, p50, p90, p99
301
+ """
302
+ if not values:
303
+ print(f" {name}: No data")
304
+ return
305
 
306
+ vals = np.array(values)
307
+ mean = np.mean(vals)
308
+ p50 = np.percentile(vals, 50)
309
+ p90 = np.percentile(vals, 90)
310
+ p99 = np.percentile(vals, 99)
311
+ nonzero = np.count_nonzero(vals)
312
+ nonzero_rate = nonzero / len(vals)
 
 
 
 
 
313
 
314
+ print(f" {name}: mean={mean:.4f} p50={p50:.4f} p90={p90:.4f} p99={p99:.4f} nonzero_rate={nonzero_rate:.3f} (n={len(vals)})")
315
+
316
+ def analyze():
317
+ client = get_client()
318
+ data = fetch_all_metrics(client)
319
+ final_buckets, thresholds, count_manipulated = _classify_tokens(data)
320
 
321
+ print(f" -> Reclassification Complete. Identified {count_manipulated} manipulated tokens.")
322
+ print("\n=== SEGMENTED DISTRIBUTION ANALYSIS ===")
 
 
323
 
324
+ # Print Thresholds debug
325
+ for k, t in thresholds.items():
326
+ if t['fees'] > 0:
327
+ print(f" [Class {k}] Thresh: Fees>{t['fees']:.3f} Vol>${t['vol']:.0f} Holders>{t['holders']:.0f}")
 
 
 
 
 
 
 
 
 
 
 
328
 
329
+ sorted_classes = sorted([k for k in final_buckets.keys() if k != MANIPULATED_CLASS_ID]) + [MANIPULATED_CLASS_ID]
330
+
331
+ for cid in sorted_classes:
332
+ items = final_buckets.get(cid, [])
333
+ if not items: continue
 
 
 
 
 
 
 
 
 
 
 
334
 
335
+ if cid == MANIPULATED_CLASS_ID:
336
+ label = f"{cid}. MANIPULATED / FAKE (Outliers from {1}~{4})"
337
+ elif cid < len(RETURN_THRESHOLDS)-1:
338
+ label = f"{cid}. {RETURN_THRESHOLDS[cid]}x - {RETURN_THRESHOLDS[cid+1]}x"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
  else:
340
+ label = f"{cid}. Unknown"
341
+
342
+ print(f"\nSEGMENT: {label}")
343
+ print("="*50)
344
+ print(f"Tokens in segment: {len(items)}")
 
 
345
 
346
+ bundled = [x["bundled_pct"] for x in items]
347
+ dev_hold = [x["dev_hold_pct"] for x in items]
348
+ fees = [x["fees_sol"] for x in items]
349
+ snipers = [x["snipers_pct"] for x in items]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
 
351
+ print_stats("bundled_pct", bundled)
352
+ print_stats("dev_hold_pct", dev_hold)
353
+ print_stats("fees_sol", fees)
354
+ print_stats("snipers_pct", snipers)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
 
356
  if __name__ == "__main__":
357
  analyze()
scripts/cache_dataset.py CHANGED
@@ -6,6 +6,7 @@ import numpy as np
6
  import datetime
7
  import torch
8
  import json
 
9
  from pathlib import Path
10
  from tqdm import tqdm
11
  from dotenv import load_dotenv
@@ -23,6 +24,8 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
23
  from data.data_loader import OracleDataset
24
  from data.data_fetcher import DataFetcher
25
  from scripts.analyze_distribution import get_return_class_map
 
 
26
 
27
  from clickhouse_driver import Client as ClickHouseClient
28
  from neo4j import GraphDatabase
@@ -94,6 +97,86 @@ def compute_save_ohlc_stats(client: ClickHouseClient, output_path: str):
94
  print(f"ERROR: Failed to compute OHLC stats: {e}")
95
  # Don't crash, let it try to proceed (though dataset might complain if file missing)
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  def main():
98
  load_dotenv()
99
 
@@ -140,10 +223,15 @@ def main():
140
  data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
141
 
142
  # Pre-fetch the Return Class Map
143
- # tokens not in this map (e.g. >10k x) are INVALID and will be skipped
144
  print("INFO: Fetching Return Classification Map...")
145
  return_class_map, thresholds = get_return_class_map(clickhouse_client)
146
  print(f"INFO: Loaded {len(return_class_map)} valid classified tokens.")
 
 
 
 
 
 
147
 
148
  dataset = OracleDataset(
149
  data_fetcher=data_fetcher,
@@ -158,67 +246,103 @@ def main():
158
  if len(dataset) == 0:
159
  print("WARNING: Dataset initialization resulted in 0 samples. Nothing to cache.")
160
  return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
  # --- 3. Iterate and cache each item ---
163
  print(f"INFO: Starting to generate and cache {len(dataset)} samples...")
164
 
165
- metadata_path = output_dir / "metadata.jsonl"
166
- print(f"INFO: Writing metadata to {metadata_path}")
167
-
168
  skipped_count = 0
169
- filtered_count = 0
170
  cached_count = 0
171
 
172
- # Open metadata file in append mode
173
- with open(metadata_path, 'a') as meta_f:
174
- for i in tqdm(range(len(dataset)), desc="Caching samples"):
175
- mint_addr = dataset.sampled_mints[i]['mint_address']
176
-
177
- # 1. Filter Check
178
- if mint_addr not in return_class_map:
179
- # Token is effectively "filtered out" (e.g. > 10,000x return or missing metrics)
180
- filtered_count += 1
181
- continue
182
-
183
- class_id = return_class_map[mint_addr]
184
-
185
- try:
186
- item = dataset.__cacheitem__(i)
187
- if item is None:
188
- skipped_count += 1
189
- continue
190
-
191
- filename = f"sample_{i}.pt"
192
- output_path = output_dir / filename
193
- torch.save(item, output_path)
194
-
195
- # Write metadata entry
196
- # Minimizing IO overhead by keeping line short
197
- meta_entry = {"file": filename, "class_id": class_id}
198
- meta_f.write(json.dumps(meta_entry) + "\n")
199
-
200
- cached_count += 1
201
-
202
- except Exception as e:
203
- error_msg = str(e)
204
- # If a FATAL error occurs (e.g. persistent DB auth failure), stop the script immediately.
205
- if "FATAL" in error_msg or "AuthenticationRateLimit" in error_msg:
206
- print(f"\nCRITICAL: Fatal error encountered processing sample {i}. Stopping execution.\nError: {e}", file=sys.stderr)
207
- sys.exit(1)
208
-
209
- print(f"\nERROR: Failed to generate or save sample {i} for mint '{mint_addr}'. Error: {e}", file=sys.stderr)
210
- # print trackback
211
- import traceback
212
- traceback.print_exc()
213
  skipped_count += 1
214
  continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
  print(f"\n--- Caching Complete ---")
217
  print(f"Successfully cached: {cached_count} items.")
218
  print(f"Filtered (Invalid/High Return): {filtered_count} items.")
219
  print(f"Skipped (Errors/Empty): {skipped_count} items.")
220
  print(f"Cache location: {output_dir.resolve()}")
221
- print(f"Metadata location: {metadata_path.resolve()}")
222
 
223
  finally:
224
  # --- 4. Close connections ---
 
6
  import datetime
7
  import torch
8
  import json
9
+ import math
10
  from pathlib import Path
11
  from tqdm import tqdm
12
  from dotenv import load_dotenv
 
24
  from data.data_loader import OracleDataset
25
  from data.data_fetcher import DataFetcher
26
  from scripts.analyze_distribution import get_return_class_map
27
+ # Import quality score calculator
28
+ from scripts.compute_quality_score import get_token_quality_scores, fetch_token_metrics, _bucket_id, _midrank_percentiles, EPS
29
 
30
  from clickhouse_driver import Client as ClickHouseClient
31
  from neo4j import GraphDatabase
 
97
  print(f"ERROR: Failed to compute OHLC stats: {e}")
98
  # Don't crash, let it try to proceed (though dataset might complain if file missing)
99
 
100
+ def build_quality_missing_reason_map(client: ClickHouseClient, max_ret: float = 1e9):
101
+ """
102
+ Build a map: token_address -> reason string for why a quality score is missing.
103
+ This mirrors compute_quality_scores filtering and feature availability.
104
+ """
105
+ data = fetch_token_metrics(client)
106
+ metrics_by_token = {d.get("token_address"): d for d in data if d.get("token_address")}
107
+
108
+ # Build buckets with the same return filtering as compute_quality_scores
109
+ buckets = {}
110
+ for d in data:
111
+ ret_val = d.get("ret")
112
+ if ret_val is None or ret_val <= 0 or ret_val > max_ret:
113
+ continue
114
+ b = _bucket_id(ret_val)
115
+ if b == -1:
116
+ continue
117
+ d["bucket_id"] = b
118
+ buckets.setdefault(b, []).append(d)
119
+
120
+ # Same feature definitions as compute_quality_scores
121
+ feature_defs = [
122
+ ("fees_log", lambda d: math.log1p(d["fees_sol"]) if d.get("fees_sol") is not None else None, True),
123
+ ("volume_log", lambda d: math.log1p(d["volume_usd"]) if d.get("volume_usd") is not None else None, True),
124
+ ("holders_log", lambda d: math.log1p(d["unique_holders"]) if d.get("unique_holders") is not None else None, True),
125
+ ("time_to_ath_log", lambda d: math.log1p(d["time_to_ath_sec"]) if d.get("time_to_ath_sec") is not None else None, True),
126
+ ("fees_per_volume", lambda d: (d["fees_sol"] / (d["volume_usd"] + EPS)) if d.get("fees_sol") is not None and d.get("volume_usd") is not None else None, True),
127
+ ("fees_per_trade", lambda d: (d["fees_sol"] / (d["n_trades"] + EPS)) if d.get("fees_sol") is not None and d.get("n_trades") is not None else None, True),
128
+ ("holders_per_trade", lambda d: (d["unique_holders"] / (d["n_trades"] + EPS)) if d.get("unique_holders") is not None and d.get("n_trades") is not None else None, True),
129
+ ("holders_per_volume", lambda d: (d["unique_holders"] / (d["volume_usd"] + EPS)) if d.get("unique_holders") is not None and d.get("volume_usd") is not None else None, True),
130
+ ("snipers_pct", lambda d: d.get("snipers_pct"), True),
131
+ ("bundled_pct", lambda d: d.get("bundled_pct"), True),
132
+ ("dev_hold_pct", lambda d: d.get("dev_hold_pct"), True),
133
+ ]
134
+
135
+ # Precompute percentiles per bucket + feature
136
+ bucket_feature_percentiles = {}
137
+ for b, items in buckets.items():
138
+ feature_percentiles = {}
139
+ for fname, fget, _pos in feature_defs:
140
+ vals = []
141
+ for d in items:
142
+ v = fget(d)
143
+ if v is None or (isinstance(v, float) and (math.isnan(v) or math.isinf(v))):
144
+ continue
145
+ vals.append((d["token_address"], v))
146
+ feature_percentiles[fname] = _midrank_percentiles(vals)
147
+ bucket_feature_percentiles[b] = feature_percentiles
148
+
149
+ def _reason_for(token_address: str) -> str:
150
+ d = metrics_by_token.get(token_address)
151
+ if not d:
152
+ return "no metrics found (missing from token_metrics/trades/mints joins)"
153
+ ret_val = d.get("ret")
154
+ if ret_val is None:
155
+ return "ret is None (missing ATH/launch metrics)"
156
+ if ret_val <= 0:
157
+ return f"ret <= 0 ({ret_val})"
158
+ if ret_val > max_ret:
159
+ return f"ret > max_ret ({ret_val} > {max_ret})"
160
+ b = _bucket_id(ret_val)
161
+ if b == -1:
162
+ return f"ret {ret_val} not in RETURN_THRESHOLDS"
163
+ items = buckets.get(b, [])
164
+ if not items:
165
+ return f"bucket {b} empty after filtering"
166
+ feature_percentiles = bucket_feature_percentiles.get(b, {})
167
+ has_any = False
168
+ missing_features = []
169
+ for fname, _fget, _pos in feature_defs:
170
+ if feature_percentiles.get(fname, {}).get(token_address) is None:
171
+ missing_features.append(fname)
172
+ else:
173
+ has_any = True
174
+ if not has_any:
175
+ return "no valid feature percentiles for token (all features missing/invalid)"
176
+ return f"unexpected: has feature percentiles but no score; missing features={','.join(missing_features)}"
177
+
178
+ return _reason_for
179
+
180
  def main():
181
  load_dotenv()
182
 
 
223
  data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
224
 
225
  # Pre-fetch the Return Class Map
 
226
  print("INFO: Fetching Return Classification Map...")
227
  return_class_map, thresholds = get_return_class_map(clickhouse_client)
228
  print(f"INFO: Loaded {len(return_class_map)} valid classified tokens.")
229
+
230
+ # Pre-fetch Quality Scores
231
+ print("INFO: Fetching Token Quality Scores...")
232
+ quality_scores_map = get_token_quality_scores(clickhouse_client)
233
+ quality_missing_reason = build_quality_missing_reason_map(clickhouse_client, max_ret=1e9)
234
+ print(f"INFO: Loaded {len(quality_scores_map)} quality scores.")
235
 
236
  dataset = OracleDataset(
237
  data_fetcher=data_fetcher,
 
246
  if len(dataset) == 0:
247
  print("WARNING: Dataset initialization resulted in 0 samples. Nothing to cache.")
248
  return
249
+
250
+ # --- FILTER DATASET BY CLASS MAP ---
251
+ # Only keep mints that are classified (valid return, sufficient data)
252
+ original_size = len(dataset)
253
+ print(f"INFO: Filtering dataset... Original size: {original_size}")
254
+ dataset.sampled_mints = [
255
+ m for m in dataset.sampled_mints
256
+ if m['mint_address'] in return_class_map
257
+ ]
258
+ filtered_size = len(dataset)
259
+ filtered_count = original_size - filtered_size
260
+ print(f"INFO: Filtered size: {filtered_size}")
261
+
262
+ if len(dataset) == 0:
263
+ print("WARNING: No tokens remain after filtering by return_class_map.")
264
+ return
265
 
266
  # --- 3. Iterate and cache each item ---
267
  print(f"INFO: Starting to generate and cache {len(dataset)} samples...")
268
 
 
 
 
269
  skipped_count = 0
 
270
  cached_count = 0
271
 
272
+ for i in tqdm(range(len(dataset)), desc="Caching samples"):
273
+ mint_addr = dataset.sampled_mints[i]['mint_address']
274
+
275
+ # (No need to check if in return_class_map anymore, we filtered)
276
+ class_id = return_class_map[mint_addr]
277
+
278
+ try:
279
+ item = dataset.__cacheitem__(i)
280
+ if item is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  skipped_count += 1
282
  continue
283
+
284
+ # Require quality score only for samples that will be cached
285
+ if mint_addr not in quality_scores_map:
286
+ reason = quality_missing_reason(mint_addr)
287
+ raise RuntimeError(
288
+ f"Missing quality score for mint {mint_addr}. Reason: {reason}. "
289
+ "Refusing to cache without quality_score."
290
+ )
291
+ q_score = quality_scores_map[mint_addr]
292
+
293
+ # INJECT QUALITY SCORE INTO TENSOR DICT
294
+ item["quality_score"] = q_score
295
+ item["class_id"] = class_id
296
+
297
+ filename = f"sample_{i}.pt"
298
+ output_path = output_dir / filename
299
+ torch.save(item, output_path)
300
+
301
+ cached_count += 1
302
+
303
+ # Log progress details (reflect all cached event lists)
304
+ n_trades = len(item.get("trades", []))
305
+ n_transfers = len(item.get("transfers", []))
306
+ n_pool_creations = len(item.get("pool_creations", []))
307
+ n_liquidity_changes = len(item.get("liquidity_changes", []))
308
+ n_fee_collections = len(item.get("fee_collections", []))
309
+ n_burns = len(item.get("burns", []))
310
+ n_supply_locks = len(item.get("supply_locks", []))
311
+ n_migrations = len(item.get("migrations", []))
312
+ n_ohlc = len(item.get("ohlc_1s", [])) if item.get("ohlc_1s") is not None else 0
313
+ n_snapshots_5m = len(item.get("snapshots_5m", []))
314
+ n_holders = len(item.get("holder_snapshots_list", []))
315
+
316
+ tqdm.write(f" + Cached: {mint_addr} | Class: {class_id} | Q: {q_score:.4f}")
317
+ tqdm.write(
318
+ " Events | "
319
+ f"Trades: {n_trades} | Transfers: {n_transfers} | Pool Creations: {n_pool_creations} | "
320
+ f"Liquidity Changes: {n_liquidity_changes} | Fee Collections: {n_fee_collections} | "
321
+ f"Burns: {n_burns} | Supply Locks: {n_supply_locks} | Migrations: {n_migrations}"
322
+ )
323
+ tqdm.write(
324
+ f" Derived | Mint: 1 | Ohlc 1s: {n_ohlc} | Snapshots 5m: {n_snapshots_5m} | Holder Snapshots: {n_holders}"
325
+ )
326
+
327
+ except Exception as e:
328
+ error_msg = str(e)
329
+ # If a FATAL error occurs (e.g. persistent DB auth failure), stop the script immediately.
330
+ if "FATAL" in error_msg or "AuthenticationRateLimit" in error_msg:
331
+ print(f"\nCRITICAL: Fatal error encountered processing sample {i}. Stopping execution.\nError: {e}", file=sys.stderr)
332
+ sys.exit(1)
333
+
334
+ print(f"\nERROR: Failed to generate or save sample {i} for mint '{mint_addr}'. Error: {e}", file=sys.stderr)
335
+ # print trackback
336
+ import traceback
337
+ traceback.print_exc()
338
+ skipped_count += 1
339
+ continue
340
 
341
  print(f"\n--- Caching Complete ---")
342
  print(f"Successfully cached: {cached_count} items.")
343
  print(f"Filtered (Invalid/High Return): {filtered_count} items.")
344
  print(f"Skipped (Errors/Empty): {skipped_count} items.")
345
  print(f"Cache location: {output_dir.resolve()}")
 
346
 
347
  finally:
348
  # --- 4. Close connections ---
scripts/compute_quality_score.py CHANGED
@@ -87,15 +87,15 @@ def fetch_token_metrics(client) -> List[dict]:
87
  FROM trades
88
  GROUP BY base_address
89
  ),
90
- -- 2. Token metadata (supply, decimals, creator)
91
  token_meta_raw AS (
92
  SELECT
93
- token_address,
94
- argMax(creator_address, updated_at) AS creator_address,
95
- argMax(total_supply, updated_at) AS total_supply,
96
- argMax(decimals, updated_at) AS decimals
97
- FROM tokens
98
- GROUP BY token_address
99
  ),
100
  token_meta AS (
101
  SELECT
@@ -161,28 +161,21 @@ def fetch_token_metrics(client) -> List[dict]:
161
  GROUP BY s.base_address
162
  ),
163
 
164
- -- 6. BUNDLED: Identify bundled addresses, sum their PEAK balances
165
- -- Bundled definition: Bought in the same slot as the very first buy slot for that token.
166
- bundled_list AS (
167
- SELECT
168
- t.base_address,
169
- t.maker
170
- FROM trades t
171
- JOIN (
172
- SELECT base_address, min(slot) AS min_slot
173
- FROM trades
174
- GROUP BY base_address
175
- ) m ON t.base_address = m.base_address AND t.slot = m.min_slot
176
- WHERE t.trade_type = 0 -- buy
177
- GROUP BY t.base_address, t.maker
178
- ),
179
  bundled_agg AS (
180
  SELECT
181
- b.base_address AS token_address,
182
- sum(wp.peak_balance) AS bundled_total_peak
183
- FROM bundled_list b
184
- JOIN wallet_peaks wp ON b.base_address = wp.mint_address AND b.maker = wp.wallet_address
185
- GROUP BY b.base_address
 
 
 
 
 
 
 
186
  ),
187
 
188
  -- 7. DEV HOLD: Creator's Peak Balance
@@ -196,7 +189,7 @@ def fetch_token_metrics(client) -> List[dict]:
196
  )
197
 
198
  SELECT
199
- r.token_address,
200
  r.ret,
201
  r.unique_holders,
202
  f.fees_sol,
@@ -205,14 +198,14 @@ def fetch_token_metrics(client) -> List[dict]:
205
  (f.t_ath - f.t0) AS time_to_ath_sec,
206
  -- Calculate Percentages using Peak Sums / Total Supply
207
  (COALESCE(s.snipers_total_peak, 0) / t.adj_supply * 100) AS snipers_pct,
208
- (COALESCE(b.bundled_total_peak, 0) / t.adj_supply * 100) AS bundled_pct,
209
  (COALESCE(d.dev_peak, 0) / t.adj_supply * 100) AS dev_hold_pct
210
- FROM ret_agg r
211
- JOIN token_meta t ON r.token_address = t.token_address
212
- LEFT JOIN trade_agg f ON r.token_address = f.base_address
213
- LEFT JOIN snipers_agg s ON r.token_address = s.token_address
214
- LEFT JOIN bundled_agg b ON r.token_address = b.token_address
215
- LEFT JOIN dev_hold_agg d ON r.token_address = d.token_address
216
  """
217
  rows = client.execute(query)
218
  cols = [
@@ -233,7 +226,7 @@ def fetch_token_metrics(client) -> List[dict]:
233
  return out
234
 
235
 
236
- def _compute_quality_scores(
237
  client,
238
  max_ret: float = 10000.0,
239
  rerank: bool = True,
@@ -251,12 +244,12 @@ def _compute_quality_scores(
251
  ("fees_per_trade", lambda d: (d["fees_sol"] / (d["n_trades"] + EPS)) if d["fees_sol"] is not None and d["n_trades"] is not None else None, True),
252
  ("holders_per_trade", lambda d: (d["unique_holders"] / (d["n_trades"] + EPS)) if d["unique_holders"] is not None and d["n_trades"] is not None else None, True),
253
  ("holders_per_volume", lambda d: (d["unique_holders"] / (d["volume_usd"] + EPS)) if d["unique_holders"] is not None and d["volume_usd"] is not None else None, True),
254
- ("snipers_pct", lambda d: d["snipers_pct"], False),
255
- ("bundled_pct", lambda d: d["bundled_pct"], False),
256
- ("dev_hold_pct", lambda d: d["dev_hold_pct"], False),
257
  ]
258
 
259
- raw_metrics = ["snipers_pct", "bundled_pct", "dev_hold_pct"]
260
 
261
  debug = None
262
  if with_debug:
@@ -357,6 +350,10 @@ def _compute_quality_scores(
357
  "ret": d["ret"],
358
  "q_raw": q_raw_map[t],
359
  "q": q_final,
 
 
 
 
360
  }
361
  )
362
  else:
@@ -371,6 +368,10 @@ def _compute_quality_scores(
371
  "ret": d["ret"],
372
  "q_raw": q_raw_map[t],
373
  "q": q_raw_map[t],
 
 
 
 
374
  }
375
  )
376
 
@@ -379,12 +380,7 @@ def _compute_quality_scores(
379
  return token_scores
380
 
381
 
382
- def compute_quality_scores(
383
- client,
384
- max_ret: float = 10000.0,
385
- rerank: bool = True,
386
- ) -> List[dict]:
387
- return _compute_quality_scores(client, max_ret=max_ret, rerank=rerank, with_debug=False)
388
 
389
 
390
  def write_jsonl(path: str, rows: List[dict]) -> None:
@@ -491,6 +487,23 @@ def print_summary(scores: List[dict]) -> None:
491
  print(f" Mean: {stats_q_raw['mean']:.4f} | Min: {stats_q_raw['min']:.4f} | Max: {stats_q_raw['max']:.4f}")
492
  print(f" Q: p10={stats_q_raw['p10']:.2f} p50={stats_q_raw['p50']:.2f} p90={stats_q_raw['p90']:.2f} p99={stats_q_raw['p99']:.2f}")
493
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
494
 
495
  def print_diagnostics(debug: dict) -> None:
496
  if not debug:
@@ -563,6 +576,77 @@ def print_diagnostics(debug: dict) -> None:
563
  corr = _pearson_corr(xs, ys)
564
  print(f" log(ret) vs {metric}: {corr:.4f} (n={len(xs)})")
565
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
566
 
567
  def main():
568
  parser = argparse.ArgumentParser(description="Compute token quality/health score.")
@@ -577,7 +661,7 @@ def main():
577
  scores = compute_quality_scores(client, max_ret=args.max_ret, rerank=not args.no_rerank)
578
  debug = None
579
  else:
580
- scores, debug = _compute_quality_scores(
581
  client,
582
  max_ret=args.max_ret,
583
  rerank=not args.no_rerank,
@@ -587,6 +671,7 @@ def main():
587
  print_summary(scores)
588
  if not args.no_diagnostics:
589
  print_diagnostics(debug)
 
590
 
591
 
592
  if __name__ == "__main__":
 
87
  FROM trades
88
  GROUP BY base_address
89
  ),
90
+ -- 2. "Token list derived MINTS.
91
  token_meta_raw AS (
92
  SELECT
93
+ mint_address AS token_address,
94
+ argMax(creator_address, timestamp) AS creator_address,
95
+ argMax(total_supply, timestamp) AS total_supply,
96
+ argMax(token_decimals, timestamp) AS decimals
97
+ FROM mints
98
+ GROUP BY mint_address
99
  ),
100
  token_meta AS (
101
  SELECT
 
161
  GROUP BY s.base_address
162
  ),
163
 
164
+ -- 6. BUNDLED: Sum the base_amount of ALL trades that happened in a slot with multiple buys
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  bundled_agg AS (
166
  SELECT
167
+ t.base_address AS token_address,
168
+ sum(t.base_amount) AS bundled_total_peak
169
+ FROM trades t
170
+ WHERE (t.base_address, t.slot) IN (
171
+ SELECT base_address, slot
172
+ FROM trades
173
+ WHERE trade_type = 0 -- buy
174
+ GROUP BY base_address, slot
175
+ HAVING count() > 1
176
+ )
177
+ AND t.trade_type = 0 -- buy
178
+ GROUP BY t.base_address
179
  ),
180
 
181
  -- 7. DEV HOLD: Creator's Peak Balance
 
189
  )
190
 
191
  SELECT
192
+ t.token_address,
193
  r.ret,
194
  r.unique_holders,
195
  f.fees_sol,
 
198
  (f.t_ath - f.t0) AS time_to_ath_sec,
199
  -- Calculate Percentages using Peak Sums / Total Supply
200
  (COALESCE(s.snipers_total_peak, 0) / t.adj_supply * 100) AS snipers_pct,
201
+ (COALESCE(b.bundled_total_peak, 0) / t.total_supply * 100) AS bundled_pct,
202
  (COALESCE(d.dev_peak, 0) / t.adj_supply * 100) AS dev_hold_pct
203
+ FROM token_meta t
204
+ LEFT JOIN ret_agg r ON t.token_address = r.token_address
205
+ LEFT JOIN trade_agg f ON t.token_address = f.base_address
206
+ LEFT JOIN snipers_agg s ON t.token_address = s.token_address
207
+ LEFT JOIN bundled_agg b ON t.token_address = b.token_address
208
+ LEFT JOIN dev_hold_agg d ON t.token_address = d.token_address
209
  """
210
  rows = client.execute(query)
211
  cols = [
 
226
  return out
227
 
228
 
229
+ def compute_quality_scores(
230
  client,
231
  max_ret: float = 10000.0,
232
  rerank: bool = True,
 
244
  ("fees_per_trade", lambda d: (d["fees_sol"] / (d["n_trades"] + EPS)) if d["fees_sol"] is not None and d["n_trades"] is not None else None, True),
245
  ("holders_per_trade", lambda d: (d["unique_holders"] / (d["n_trades"] + EPS)) if d["unique_holders"] is not None and d["n_trades"] is not None else None, True),
246
  ("holders_per_volume", lambda d: (d["unique_holders"] / (d["volume_usd"] + EPS)) if d["unique_holders"] is not None and d["volume_usd"] is not None else None, True),
247
+ ("snipers_pct", lambda d: d["snipers_pct"], True),
248
+ ("bundled_pct", lambda d: d["bundled_pct"], True),
249
+ ("dev_hold_pct", lambda d: d["dev_hold_pct"], True),
250
  ]
251
 
252
+ raw_metrics = ["snipers_pct", "bundled_pct", "dev_hold_pct", "fees_sol"] # Added fees_sol for diagnostic logging
253
 
254
  debug = None
255
  if with_debug:
 
350
  "ret": d["ret"],
351
  "q_raw": q_raw_map[t],
352
  "q": q_final,
353
+ # Pass through raw metrics for analysis
354
+ "bundled_pct": d.get("bundled_pct"),
355
+ "snipers_pct": d.get("snipers_pct"),
356
+ "fees_sol": d.get("fees_sol"),
357
  }
358
  )
359
  else:
 
368
  "ret": d["ret"],
369
  "q_raw": q_raw_map[t],
370
  "q": q_raw_map[t],
371
+ # Pass through raw metrics for analysis
372
+ "bundled_pct": d.get("bundled_pct"),
373
+ "snipers_pct": d.get("snipers_pct"),
374
+ "fees_sol": d.get("fees_sol"),
375
  }
376
  )
377
 
 
380
  return token_scores
381
 
382
 
383
+
 
 
 
 
 
384
 
385
 
386
  def write_jsonl(path: str, rows: List[dict]) -> None:
 
487
  print(f" Mean: {stats_q_raw['mean']:.4f} | Min: {stats_q_raw['min']:.4f} | Max: {stats_q_raw['max']:.4f}")
488
  print(f" Q: p10={stats_q_raw['p10']:.2f} p50={stats_q_raw['p50']:.2f} p90={stats_q_raw['p90']:.2f} p99={stats_q_raw['p99']:.2f}")
489
 
490
+ # --- NEW: Print 3 Examples (Min, Mid, Max) ---
491
+ if items:
492
+ # Sort items by 'q' to find min/mid/max easily
493
+ items_sorted = sorted(items, key=lambda x: x.get("q", 0))
494
+
495
+ ex_min = items_sorted[0]
496
+ ex_max = items_sorted[-1]
497
+
498
+ # Find mid (closest to 0.0, or just median index? Request said "mean quality" which is 0.0)
499
+ # finding item with q closest to 0.0
500
+ ex_mid = min(items_sorted, key=lambda x: abs(x.get("q", 0) - 0.0))
501
+
502
+ print(" Examples:")
503
+ print(f" Low (-1.0): {ex_min['token_address']} (q={ex_min.get('q',0):.4f}, ret={ex_min.get('ret',0):.2f}x)")
504
+ print(f" Mid (~0.0): {ex_mid['token_address']} (q={ex_mid.get('q',0):.4f}, ret={ex_mid.get('ret',0):.2f}x)")
505
+ print(f" High ( 1.0): {ex_max['token_address']} (q={ex_max.get('q',0):.4f}, ret={ex_max.get('ret',0):.2f}x)")
506
+
507
 
508
  def print_diagnostics(debug: dict) -> None:
509
  if not debug:
 
576
  corr = _pearson_corr(xs, ys)
577
  print(f" log(ret) vs {metric}: {corr:.4f} (n={len(xs)})")
578
 
579
+ # Removed placeholder
580
+ pass
581
+
582
+
583
+ def print_high_ret_analysis(scores: List[dict]) -> None:
584
+ print("\n=== MID-HIGH RETURN SPLIT ANALYSIS (10x - 20x) ===")
585
+
586
+ # 1. Filter for Mid-High Return Cohort (10x - 20x)
587
+ cohort = [s for s in scores if s.get("ret") is not None and s["ret"] >= 10.0 and s["ret"] < 20.0]
588
+ if not cohort:
589
+ print("No tokens 10x-20x found.")
590
+ return
591
+
592
+ print(f"Total tokens 10x-20x: {len(cohort)}")
593
+
594
+ # 2. Extract Bundled Pct
595
+ bundled_vals = [s.get("bundled_pct", 0) for s in cohort if s.get("bundled_pct") is not None]
596
+ if not bundled_vals:
597
+ print("No bundled_pct data found.")
598
+ return
599
+
600
+ median_bundled = _percentile(sorted(bundled_vals), 0.50)
601
+ print(f"Median Bundled% for Cohort: {median_bundled:.2f}%")
602
+
603
+ # 3. Split
604
+ low_group = [s for s in cohort if (s.get("bundled_pct") or 0) <= median_bundled]
605
+ high_group = [s for s in cohort if (s.get("bundled_pct") or 0) > median_bundled]
606
+
607
+ # 4. Analyze Fees
608
+ def get_mean_fees(group):
609
+ fees = [s.get("fees_sol", 0) for s in group if s.get("fees_sol") is not None]
610
+ if not fees: return 0.0
611
+ return sum(fees) / len(fees)
612
+
613
+ mean_fees_low = get_mean_fees(low_group)
614
+ mean_fees_high = get_mean_fees(high_group)
615
+
616
+ print(f"\nGroup 1: LOW Bundled (<= {median_bundled:.2f}%)")
617
+ print(f" Count: {len(low_group)}")
618
+ print(f" Mean Fees: {mean_fees_low:.4f} SOL")
619
+
620
+ print(f"\nGroup 2: HIGH Bundled (> {median_bundled:.2f}%)")
621
+ print(f" Count: {len(high_group)}")
622
+ print(f" Mean Fees: {mean_fees_high:.4f} SOL")
623
+
624
+ # Extra: Check returns too
625
+ def get_mean_ret(group):
626
+ rets = [s["ret"] for s in group]
627
+ if not rets: return 0.0
628
+ return sum(rets) / len(rets)
629
+
630
+ print(f" Mean Ret: {get_mean_ret(high_group):.2f}x (vs Low: {get_mean_ret(low_group):.2f}x)")
631
+
632
+
633
+ def get_token_quality_scores(client):
634
+ """
635
+ Returns a dictionary mapping token_address -> q (quality score)
636
+ """
637
+ # Force rerank=True to get final scores
638
+ results = compute_quality_scores(client, max_ret=1e9, rerank=True)
639
+
640
+ # Return mapping
641
+ # If compute_quality_scores returns (scores, debug) tuple (when with_debug=True), handle it.
642
+ # Default call rerank=True returns 'scores' list if with_debug=False?
643
+ # No, looking at main, it returns 'scores' if no_diagnostics.
644
+ # But get_token_quality_scores uses default args.
645
+ # Let's check compute_quality_score signature... it has with_debug=False default.
646
+ # So it returns 'scores'.
647
+
648
+ return {r["token_address"]: r.get("q", 0.0) for r in results}
649
+
650
 
651
  def main():
652
  parser = argparse.ArgumentParser(description="Compute token quality/health score.")
 
661
  scores = compute_quality_scores(client, max_ret=args.max_ret, rerank=not args.no_rerank)
662
  debug = None
663
  else:
664
+ scores, debug = compute_quality_scores(
665
  client,
666
  max_ret=args.max_ret,
667
  rerank=not args.no_rerank,
 
671
  print_summary(scores)
672
  if not args.no_diagnostics:
673
  print_diagnostics(debug)
674
+ print_high_ret_analysis(scores) # Call the new analysis
675
 
676
 
677
  if __name__ == "__main__":
token_stats.rs ADDED
@@ -0,0 +1,857 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use crate::database::insert_rows;
2
+ use crate::services::price_service::PriceService;
3
+ use crate::types::{
4
+ EventPayload, EventType, MigrationRow, MintRow, TokenMetricsRow, TokenStaticRow, TradeRow,
5
+ };
6
+ use anyhow::{Context, Result, anyhow};
7
+ use borsh::BorshDeserialize;
8
+ use clickhouse::Client;
9
+ use futures_util::future;
10
+ use mpl_token_metadata::accounts::Metadata;
11
+ use once_cell::sync::Lazy;
12
+ use redis::aio::MultiplexedConnection;
13
+ use redis::streams::{StreamReadOptions, StreamReadReply};
14
+ use redis::{AsyncCommands, Client as RedisClient, FromRedisValue};
15
+ use solana_client::nonblocking::rpc_client::RpcClient;
16
+ use solana_program::program_pack::Pack;
17
+ use solana_sdk::pubkey::Pubkey;
18
+ use spl_token::state::Mint;
19
+ use std::collections::{HashMap, HashSet};
20
+ use std::env;
21
+ use std::str::FromStr;
22
+ use std::sync::Arc;
23
+ use std::time::Duration;
24
+ use tokio::sync::RwLock;
25
+
26
+ type TokenCache = HashMap<String, TokenEntry>;
27
+
28
+ fn env_parse<T: FromStr>(key: &str, default: T) -> T {
29
+ env::var(key)
30
+ .ok()
31
+ .and_then(|v| v.parse::<T>().ok())
32
+ .unwrap_or(default)
33
+ }
34
+
35
+ static TOKEN_STATS_CHUNK_SIZE: Lazy<usize> =
36
+ Lazy::new(|| env_parse("TOKEN_STATS_CHUNK_SIZE", 1000usize));
37
+
38
+ #[derive(Debug, Clone)]
39
+ struct TokenEntry {
40
+ token: TokenStaticRow,
41
+ metrics: TokenMetricsRow,
42
+ }
43
+
44
+ impl TokenEntry {
45
+ fn new(token: TokenStaticRow, metrics: Option<TokenMetricsRow>) -> Self {
46
+ let metrics = metrics
47
+ .unwrap_or_else(|| TokenMetricsRow::new(token.token_address.clone(), token.updated_at));
48
+ Self { token, metrics }
49
+ }
50
+ }
51
+
52
+ #[derive(Clone, Debug)]
53
+ struct TokenContext {
54
+ timestamp: u32,
55
+ protocol: Option<u8>,
56
+ pool_address: Option<String>,
57
+ decimals: Option<u8>,
58
+ }
59
+
60
+ impl TokenContext {
61
+ fn new(
62
+ timestamp: u32,
63
+ protocol: Option<u8>,
64
+ pool_address: Option<String>,
65
+ decimals: Option<u8>,
66
+ ) -> Self {
67
+ Self {
68
+ timestamp,
69
+ protocol,
70
+ pool_address,
71
+ decimals,
72
+ }
73
+ }
74
+ }
75
+
76
+ fn record_token_context(
77
+ contexts: &mut HashMap<String, TokenContext>,
78
+ token_address: &str,
79
+ timestamp: u32,
80
+ protocol: Option<u8>,
81
+ pool_address: Option<String>,
82
+ decimals: Option<u8>,
83
+ ) {
84
+ if token_address.is_empty() {
85
+ return;
86
+ }
87
+
88
+ let mut pool_for_insert = pool_address.clone();
89
+ let entry = contexts
90
+ .entry(token_address.to_string())
91
+ .or_insert_with(|| {
92
+ TokenContext::new(timestamp, protocol, pool_for_insert.take(), decimals)
93
+ });
94
+
95
+ if timestamp < entry.timestamp {
96
+ entry.timestamp = timestamp;
97
+ }
98
+
99
+ if entry.protocol.is_none() {
100
+ entry.protocol = protocol;
101
+ }
102
+
103
+ let should_update_pool = entry
104
+ .pool_address
105
+ .as_ref()
106
+ .map(|p| p.is_empty())
107
+ .unwrap_or(true);
108
+ if should_update_pool {
109
+ if let Some(pool) = pool_address {
110
+ if !pool.is_empty() {
111
+ entry.pool_address = Some(pool);
112
+ }
113
+ }
114
+ }
115
+
116
+ if let Some(dec) = decimals {
117
+ entry.decimals = Some(dec);
118
+ }
119
+ }
120
+
121
+ fn pool_addresses_from_context(context: &TokenContext) -> Vec<String> {
122
+ context
123
+ .pool_address
124
+ .as_ref()
125
+ .filter(|addr| !addr.is_empty())
126
+ .map(|addr| vec![addr.clone()])
127
+ .unwrap_or_default()
128
+ }
129
+
130
+ fn event_success(event: &EventType) -> bool {
131
+ match event {
132
+ EventType::Trade(row) => row.success,
133
+ EventType::Mint(row) => row.success,
134
+ EventType::Migration(row) => row.success,
135
+ EventType::FeeCollection(row) => row.success,
136
+ EventType::Liquidity(row) => row.success,
137
+ EventType::PoolCreation(row) => row.success,
138
+ EventType::Transfer(row) => row.success,
139
+ EventType::SupplyLock(row) => row.success,
140
+ EventType::SupplyLockAction(row) => row.success,
141
+ EventType::Burn(row) => row.success,
142
+ }
143
+ }
144
+
145
+ pub struct TokenAggregator {
146
+ db_client: Client,
147
+ redis_conn: MultiplexedConnection,
148
+ rpc_client: Arc<RpcClient>,
149
+ price_service: PriceService,
150
+ backfill_mode: bool,
151
+ }
152
+
153
+ impl TokenAggregator {
154
+ pub async fn new(
155
+ db_client: Client,
156
+ redis_client: RedisClient,
157
+ rpc_client: Arc<RpcClient>,
158
+ price_service: PriceService,
159
+ ) -> Result<Self> {
160
+ let redis_conn = redis_client.get_multiplexed_async_connection().await?;
161
+ println!("[TokenAggregator] ✔️ Connected to ClickHouse, Redis, and Solana RPC.");
162
+
163
+ let backfill_mode =
164
+ env::var("BACKFILL_MODE").unwrap_or_else(|_| "false".to_string()) == "true";
165
+ Ok(Self {
166
+ db_client,
167
+ redis_conn,
168
+ rpc_client,
169
+ price_service,
170
+ backfill_mode,
171
+ })
172
+ }
173
+
174
+ pub async fn run(&mut self) -> Result<()> {
175
+ let stream_key = "event_queue";
176
+ let group_name = "token_aggregators";
177
+ let consumer_name = format!("consumer-tokens-{}", uuid::Uuid::new_v4());
178
+
179
+ let mut publisher_conn = self.redis_conn.clone();
180
+ let next_queue = "wallet_agg_queue";
181
+
182
+ let result: redis::RedisResult<()> = self
183
+ .redis_conn
184
+ .xgroup_create_mkstream(stream_key, group_name, "0")
185
+ .await;
186
+ if let Err(e) = result {
187
+ if !e.to_string().contains("BUSYGROUP") {
188
+ return Err(anyhow!(
189
+ "[TokenAggregator] Failed to create consumer group: {}",
190
+ e
191
+ ));
192
+ }
193
+ println!(
194
+ "[TokenAggregator] Consumer group '{}' already exists. Resuming.",
195
+ group_name
196
+ );
197
+ } else {
198
+ println!(
199
+ "[TokenAggregator] Created new consumer group '{}'.",
200
+ group_name
201
+ );
202
+ }
203
+
204
+ loop {
205
+ let messages = match self
206
+ .collect_events(stream_key, group_name, &consumer_name)
207
+ .await
208
+ {
209
+ Ok(msgs) => msgs,
210
+ Err(e) => {
211
+ eprintln!(
212
+ "[TokenAggregator] 🔴 Error reading from Redis: {}. Retrying...",
213
+ e
214
+ );
215
+ tokio::time::sleep(Duration::from_secs(5)).await;
216
+ continue;
217
+ }
218
+ };
219
+ if messages.is_empty() {
220
+ continue;
221
+ }
222
+
223
+ println!(
224
+ "[TokenAggregator] ⚙️ Starting processing for a new batch of {} events...",
225
+ messages.len()
226
+ );
227
+ let message_ids: Vec<String> = messages.iter().map(|(id, _)| id.clone()).collect();
228
+ let payloads: Vec<EventPayload> =
229
+ messages.into_iter().map(|(_, payload)| payload).collect();
230
+
231
+ match self.process_batch(payloads.clone()).await {
232
+ // Clone payloads to use them after processing
233
+ Ok(_) => {
234
+ if !message_ids.is_empty() {
235
+ // Forward each payload to the next queue in the pipeline
236
+ for payload in payloads {
237
+ let payload_data = bincode::serialize(&payload)?;
238
+ let _: () = publisher_conn
239
+ .xadd(next_queue, "*", &[("payload", payload_data)])
240
+ .await?;
241
+ }
242
+ println!(
243
+ "[TokenAggregator] ✅ Finished batch, forwarded {} events to {}.",
244
+ message_ids.len(),
245
+ next_queue
246
+ );
247
+
248
+ // Acknowledge the message from the source queue ('event_queue')
249
+ let _: () = self
250
+ .redis_conn
251
+ .xack(stream_key, group_name, &message_ids)
252
+ .await?;
253
+ let _: i64 = self
254
+ .redis_conn
255
+ .xdel::<_, _, i64>(stream_key, &message_ids)
256
+ .await?;
257
+ }
258
+ }
259
+ Err(e) => {
260
+ eprintln!(
261
+ "[TokenAggregator] ❌ Failed to process batch, will not forward or ACK. Error: {}",
262
+ e
263
+ );
264
+ }
265
+ }
266
+ }
267
+ }
268
+
269
+ async fn process_batch(&self, payloads: Vec<EventPayload>) -> Result<()> {
270
+ let mut token_contexts: HashMap<String, TokenContext> = HashMap::new();
271
+ for payload in &payloads {
272
+ if !event_success(&payload.event) {
273
+ continue;
274
+ }
275
+ let decimals_map = &payload.token_decimals;
276
+ match &payload.event {
277
+ EventType::Trade(t) => {
278
+ let pool = (!t.pool_address.is_empty()).then(|| t.pool_address.clone());
279
+ record_token_context(
280
+ &mut token_contexts,
281
+ &t.base_address,
282
+ t.timestamp,
283
+ Some(t.protocol),
284
+ pool.clone(),
285
+ decimals_map.get(&t.base_address).cloned(),
286
+ );
287
+ record_token_context(
288
+ &mut token_contexts,
289
+ &t.quote_address,
290
+ t.timestamp,
291
+ Some(t.protocol),
292
+ pool,
293
+ decimals_map.get(&t.quote_address).cloned(),
294
+ );
295
+ }
296
+ EventType::Mint(m) => {
297
+ record_token_context(
298
+ &mut token_contexts,
299
+ &m.mint_address,
300
+ m.timestamp,
301
+ Some(m.protocol),
302
+ (!m.pool_address.is_empty()).then(|| m.pool_address.clone()),
303
+ Some(m.token_decimals),
304
+ );
305
+ }
306
+ EventType::Migration(m) => {
307
+ record_token_context(
308
+ &mut token_contexts,
309
+ &m.mint_address,
310
+ m.timestamp,
311
+ Some(m.protocol),
312
+ (!m.pool_address.is_empty()).then(|| m.pool_address.clone()),
313
+ decimals_map.get(&m.mint_address).cloned(),
314
+ );
315
+ }
316
+ EventType::FeeCollection(f) => {
317
+ let vault = (!f.vault_address.is_empty()).then(|| f.vault_address.clone());
318
+ record_token_context(
319
+ &mut token_contexts,
320
+ &f.token_0_mint_address,
321
+ f.timestamp,
322
+ Some(f.protocol),
323
+ vault.clone(),
324
+ decimals_map.get(&f.token_0_mint_address).cloned(),
325
+ );
326
+ if let Some(token_1) = &f.token_1_mint_address {
327
+ record_token_context(
328
+ &mut token_contexts,
329
+ token_1,
330
+ f.timestamp,
331
+ Some(f.protocol),
332
+ vault.clone(),
333
+ decimals_map.get(token_1).cloned(),
334
+ );
335
+ }
336
+ }
337
+ EventType::PoolCreation(p) => {
338
+ record_token_context(
339
+ &mut token_contexts,
340
+ &p.base_address,
341
+ p.timestamp,
342
+ Some(p.protocol),
343
+ (!p.pool_address.is_empty()).then(|| p.pool_address.clone()),
344
+ p.base_decimals
345
+ .or_else(|| decimals_map.get(&p.base_address).cloned()),
346
+ );
347
+ record_token_context(
348
+ &mut token_contexts,
349
+ &p.quote_address,
350
+ p.timestamp,
351
+ Some(p.protocol),
352
+ (!p.pool_address.is_empty()).then(|| p.pool_address.clone()),
353
+ p.quote_decimals
354
+ .or_else(|| decimals_map.get(&p.quote_address).cloned()),
355
+ );
356
+ }
357
+ EventType::Transfer(t) => {
358
+ record_token_context(
359
+ &mut token_contexts,
360
+ &t.mint_address,
361
+ t.timestamp,
362
+ None,
363
+ None,
364
+ decimals_map.get(&t.mint_address).cloned(),
365
+ );
366
+ }
367
+ EventType::SupplyLock(lock) => {
368
+ record_token_context(
369
+ &mut token_contexts,
370
+ &lock.mint_address,
371
+ lock.timestamp,
372
+ Some(lock.protocol),
373
+ None,
374
+ decimals_map.get(&lock.mint_address).cloned(),
375
+ );
376
+ }
377
+ EventType::SupplyLockAction(action) => {
378
+ record_token_context(
379
+ &mut token_contexts,
380
+ &action.mint_address,
381
+ action.timestamp,
382
+ Some(action.protocol),
383
+ None,
384
+ decimals_map.get(&action.mint_address).cloned(),
385
+ );
386
+ }
387
+ EventType::Burn(burn) => {
388
+ record_token_context(
389
+ &mut token_contexts,
390
+ &burn.mint_address,
391
+ burn.timestamp,
392
+ None,
393
+ None,
394
+ decimals_map.get(&burn.mint_address).cloned(),
395
+ );
396
+ }
397
+ EventType::Liquidity(_) => {}
398
+ _ => {}
399
+ }
400
+ }
401
+
402
+ if token_contexts.is_empty() {
403
+ println!("[TokenAggregator] -> Batch contains no relevant token events. Skipping.");
404
+ return Ok(());
405
+ }
406
+ println!(
407
+ "[TokenAggregator] -> Batch contains {} unique tokens.",
408
+ token_contexts.len()
409
+ );
410
+
411
+ let mut tokens = self
412
+ .fetch_tokens_from_db(&token_contexts.keys().cloned().collect::<Vec<_>>())
413
+ .await?;
414
+
415
+ let missing_tokens: Vec<String> = token_contexts
416
+ .keys()
417
+ .filter(|address| !tokens.contains_key(*address))
418
+ .cloned()
419
+ .collect();
420
+
421
+ if !missing_tokens.is_empty() {
422
+ println!(
423
+ "[TokenAggregator] -> Found {} new tokens to fetch metadata for.",
424
+ missing_tokens.len()
425
+ );
426
+
427
+ if !self.backfill_mode {
428
+ let fetch_futures = missing_tokens
429
+ .iter()
430
+ .map(|key| async move { (key.clone(), self.fetch_token_metadata(key).await) });
431
+ let fetched_results = future::join_all(fetch_futures).await;
432
+
433
+ for (key, rpc_result) in fetched_results {
434
+ let context = match token_contexts.get(&key) {
435
+ Some(ctx) => ctx.clone(),
436
+ None => continue,
437
+ };
438
+ let protocol = context.protocol.unwrap_or(0);
439
+ let token_row = match rpc_result {
440
+ Ok((metadata, mint_data)) => {
441
+ println!(
442
+ "[TokenAggregator] -> ✅ Successfully fetched metadata for new token {}.",
443
+ key
444
+ );
445
+
446
+ let creator = metadata
447
+ .creators
448
+ .as_ref()
449
+ .and_then(|creators| creators.first())
450
+ .map(|c| c.address.to_string())
451
+ .unwrap_or_default();
452
+
453
+ TokenStaticRow::new(
454
+ key.clone(),
455
+ context.timestamp,
456
+ metadata.name.trim_end_matches('\0').to_string(),
457
+ metadata.symbol.trim_end_matches('\0').to_string(),
458
+ metadata.uri.trim_end_matches('\0').to_string(),
459
+ mint_data.decimals,
460
+ creator,
461
+ pool_addresses_from_context(&context),
462
+ protocol,
463
+ mint_data.supply,
464
+ metadata.is_mutable,
465
+ Some(metadata.update_authority.to_string()),
466
+ Option::from(mint_data.mint_authority)
467
+ .map(|pk: Pubkey| pk.to_string()),
468
+ Option::from(mint_data.freeze_authority)
469
+ .map(|pk: Pubkey| pk.to_string()),
470
+ )
471
+ }
472
+ Err(e) => {
473
+ eprintln!(
474
+ "[TokenAggregator] -> ❌ RPC failed for {}: {}. Creating placeholder.",
475
+ key, e
476
+ );
477
+ TokenStaticRow::new(
478
+ key.clone(),
479
+ context.timestamp,
480
+ String::new(),
481
+ String::new(),
482
+ String::new(),
483
+ context.decimals.unwrap_or(0),
484
+ String::new(),
485
+ pool_addresses_from_context(&context),
486
+ protocol,
487
+ 0,
488
+ true,
489
+ None,
490
+ None,
491
+ None,
492
+ )
493
+ }
494
+ };
495
+ tokens.insert(key.clone(), TokenEntry::new(token_row, None));
496
+ }
497
+ } else {
498
+ println!(
499
+ "[TokenAggregator] -> Creating {} placeholder tokens in backfill mode.",
500
+ missing_tokens.len()
501
+ );
502
+ for key in missing_tokens {
503
+ if let Some(context) = token_contexts.get(&key) {
504
+ let placeholder_row = TokenStaticRow::new(
505
+ key.clone(),
506
+ context.timestamp,
507
+ String::new(),
508
+ String::new(),
509
+ String::new(),
510
+ context.decimals.unwrap_or(0),
511
+ String::new(),
512
+ pool_addresses_from_context(context),
513
+ context.protocol.unwrap_or(0),
514
+ 0,
515
+ false,
516
+ None,
517
+ None,
518
+ None,
519
+ );
520
+ tokens.insert(key.clone(), TokenEntry::new(placeholder_row, None));
521
+ }
522
+ }
523
+ }
524
+ }
525
+
526
+ let trader_pairs_in_batch: Vec<(String, String)> = payloads
527
+ .iter()
528
+ .filter_map(|p| {
529
+ if let EventType::Trade(t) = &p.event {
530
+ Some((t.base_address.clone(), t.maker.clone()))
531
+ } else {
532
+ None
533
+ }
534
+ })
535
+ .collect();
536
+
537
+ let mut existing_traders = HashSet::new();
538
+ if !trader_pairs_in_batch.is_empty() {
539
+ for chunk in trader_pairs_in_batch.chunks(*TOKEN_STATS_CHUNK_SIZE) {
540
+ let mut cursor = self.db_client
541
+ .query("SELECT DISTINCT (mint_address, wallet_address) FROM wallet_holdings WHERE (mint_address, wallet_address) IN ?")
542
+ .bind(chunk)
543
+ .fetch::<(String, String)>()?;
544
+
545
+ while let Some(pair) = cursor.next().await? {
546
+ existing_traders.insert(pair);
547
+ }
548
+ }
549
+ }
550
+
551
+ let mut counted_in_this_batch: HashSet<(String, String)> = HashSet::new();
552
+
553
+ for payload in payloads.iter() {
554
+ if !event_success(&payload.event) {
555
+ continue;
556
+ }
557
+ match &payload.event {
558
+ EventType::Mint(mint) => self.process_mint(mint, &mut tokens),
559
+ EventType::Trade(trade) => {
560
+ self.process_trade(
561
+ trade,
562
+ &mut tokens,
563
+ &existing_traders,
564
+ &mut counted_in_this_batch,
565
+ );
566
+ }
567
+ EventType::Migration(migration) => self.process_migration(migration, &mut tokens),
568
+ _ => {}
569
+ }
570
+ }
571
+
572
+ self.finalize_and_persist(tokens).await
573
+ }
574
+
575
+ fn process_trade(
576
+ &self,
577
+ trade: &TradeRow,
578
+ tokens: &mut TokenCache,
579
+ existing_traders: &HashSet<(String, String)>,
580
+ counted_in_this_batch: &mut HashSet<(String, String)>,
581
+ ) {
582
+ if let Some(entry) = tokens.get_mut(&trade.base_address) {
583
+ entry.token.updated_at = trade.timestamp;
584
+ entry.metrics.updated_at = trade.timestamp;
585
+
586
+ // --- START: CORRECT UNIQUE HOLDER LOGIC ---
587
+
588
+ let current_pair = (trade.base_address.clone(), trade.maker.clone());
589
+
590
+ // We only increment the counter if:
591
+ // 1. The trader is NOT in the set of traders we know about from the database.
592
+ // 2. We have NOT already counted this trader for this token in this batch.
593
+ if !existing_traders.contains(&current_pair) {
594
+ // The .insert() returns true only the first time we see this pair in this batch.
595
+ if counted_in_this_batch.insert(current_pair) {
596
+ entry.metrics.unique_holders += 1;
597
+ }
598
+ }
599
+
600
+ let trade_total_in_usd = trade.total_usd;
601
+
602
+ entry.metrics.total_volume_usd += trade_total_in_usd;
603
+ entry.metrics.ath_price_usd = entry.metrics.ath_price_usd.max(trade.price_usd);
604
+
605
+ if trade.trade_type == 0 {
606
+ // Buy
607
+ entry.metrics.total_buys += 1;
608
+ } else {
609
+ // Sell
610
+ entry.metrics.total_sells += 1;
611
+ }
612
+ }
613
+ }
614
+
615
+ async fn fetch_tokens_from_db(&self, keys: &[String]) -> Result<TokenCache> {
616
+ if keys.is_empty() {
617
+ return Ok(HashMap::new());
618
+ }
619
+ let query_str = "
620
+ SELECT
621
+ *
622
+ FROM tokens_latest
623
+ WHERE token_address IN ?
624
+ ";
625
+
626
+ let mut statics = HashMap::new();
627
+ for chunk in keys.chunks(*TOKEN_STATS_CHUNK_SIZE) {
628
+ let mut cursor = self
629
+ .db_client
630
+ .query(query_str)
631
+ .bind(chunk)
632
+ .fetch::<TokenStaticRow>()?;
633
+
634
+ while let Ok(Some(token)) = cursor.next().await {
635
+ statics.insert(token.token_address.clone(), token);
636
+ }
637
+ }
638
+
639
+ let metrics_map = self.fetch_token_metrics(keys).await?;
640
+ let mut tokens = HashMap::new();
641
+
642
+ for (address, token) in statics {
643
+ let metrics = metrics_map.get(&address).cloned();
644
+ tokens.insert(address.clone(), TokenEntry::new(token, metrics));
645
+ }
646
+
647
+ Ok(tokens)
648
+ }
649
+
650
+ async fn fetch_token_metrics(
651
+ &self,
652
+ keys: &[String],
653
+ ) -> Result<HashMap<String, TokenMetricsRow>> {
654
+ if keys.is_empty() {
655
+ return Ok(HashMap::new());
656
+ }
657
+
658
+ let query_str = "
659
+ SELECT
660
+ *
661
+ FROM token_metrics_latest
662
+ WHERE token_address IN ?
663
+ ORDER BY token_address, updated_at DESC
664
+ LIMIT 1 BY token_address
665
+ ";
666
+
667
+ let mut metrics = HashMap::new();
668
+
669
+ for chunk in keys.chunks(*TOKEN_STATS_CHUNK_SIZE) {
670
+ let mut cursor = self
671
+ .db_client
672
+ .query(query_str)
673
+ .bind(chunk)
674
+ .fetch::<TokenMetricsRow>()?;
675
+
676
+ while let Ok(Some(row)) = cursor.next().await {
677
+ metrics.insert(row.token_address.clone(), row);
678
+ }
679
+ }
680
+
681
+ Ok(metrics)
682
+ }
683
+
684
+ async fn fetch_token_metadata(&self, mint_address_str: &str) -> Result<(Metadata, Mint)> {
685
+ let mint_pubkey = Pubkey::from_str(mint_address_str)?;
686
+ let metadata_pubkey = Metadata::find_pda(&mint_pubkey).0;
687
+
688
+ let (mint_account_res, metadata_account_res) = future::join(
689
+ self.rpc_client.get_account(&mint_pubkey),
690
+ self.rpc_client.get_account(&metadata_pubkey),
691
+ )
692
+ .await;
693
+
694
+ let mint_account = mint_account_res?;
695
+ let metadata_account = metadata_account_res?;
696
+
697
+ let mint_data = Mint::unpack(&mint_account.data)?;
698
+ let metadata = Metadata::deserialize(&mut &metadata_account.data[..])?;
699
+
700
+ Ok((metadata, mint_data))
701
+ }
702
+
703
+ fn process_mint(&self, mint: &MintRow, tokens: &mut TokenCache) {
704
+ let is_new = !tokens.contains_key(&mint.mint_address);
705
+ let entry = tokens
706
+ .entry(mint.mint_address.clone())
707
+ .or_insert_with(|| TokenEntry::new(TokenStaticRow::new_from_mint(mint), None));
708
+ let token = &mut entry.token;
709
+
710
+ if is_new {
711
+ println!(
712
+ "[TokenAggregator] -> Created new token record for {} from MINT event.",
713
+ mint.mint_address
714
+ );
715
+ } else {
716
+ println!(
717
+ "[TokenAggregator] -> Enriched existing token record for {} with MINT event data.",
718
+ mint.mint_address
719
+ );
720
+ token.updated_at = mint.timestamp;
721
+ token.created_at = token.created_at.min(mint.timestamp);
722
+ token.decimals = mint.token_decimals;
723
+ token.launchpad = mint.protocol;
724
+ token.protocol = mint.protocol;
725
+ token.total_supply = mint.total_supply;
726
+ token.is_mutable = mint.is_mutable;
727
+ token.update_authority = mint.update_authority.clone();
728
+ token.mint_authority = mint.mint_authority.clone();
729
+ token.freeze_authority = mint.freeze_authority.clone();
730
+ if token.name.is_empty() {
731
+ token.name = mint.token_name.clone().unwrap_or_default();
732
+ }
733
+ if token.symbol.is_empty() {
734
+ token.symbol = mint.token_symbol.clone().unwrap_or_default();
735
+ }
736
+ if token.token_uri.is_empty() {
737
+ token.token_uri = mint.token_uri.clone().unwrap_or_default();
738
+ }
739
+ if token.creator_address.is_empty() {
740
+ token.creator_address = mint.creator_address.clone();
741
+ }
742
+ if !mint.pool_address.is_empty() && !token.pool_addresses.contains(&mint.pool_address) {
743
+ token.pool_addresses.push(mint.pool_address.clone());
744
+ }
745
+ }
746
+ }
747
+
748
+ fn process_migration(&self, migration: &MigrationRow, tokens: &mut TokenCache) {
749
+ if let Some(entry) = tokens.get_mut(&migration.mint_address) {
750
+ let token = &mut entry.token;
751
+ println!(
752
+ "[TokenAggregator] -> Updating protocol for token {} due to migration.",
753
+ migration.mint_address
754
+ );
755
+ token.updated_at = migration.timestamp;
756
+ token.protocol = migration.protocol;
757
+ if !token.pool_addresses.contains(&migration.pool_address) {
758
+ token.pool_addresses.push(migration.pool_address.clone());
759
+ }
760
+ }
761
+ }
762
+
763
+ async fn finalize_and_persist(&self, tokens: TokenCache) -> Result<()> {
764
+ if tokens.is_empty() {
765
+ return Ok(());
766
+ }
767
+
768
+ let mut updated_tokens = Vec::new();
769
+ let mut metric_rows = Vec::new();
770
+
771
+ for entry in tokens.into_values() {
772
+ if Self::metrics_has_activity(&entry.metrics) {
773
+ metric_rows.push(entry.metrics);
774
+ }
775
+ updated_tokens.push(entry.token);
776
+ }
777
+
778
+ insert_rows(
779
+ &self.db_client,
780
+ "tokens",
781
+ updated_tokens.clone(),
782
+ "Token Aggregator",
783
+ "tokens",
784
+ )
785
+ .await
786
+ .with_context(|| "Failed to persist token data to ClickHouse")?;
787
+
788
+ insert_rows(
789
+ &self.db_client,
790
+ "tokens_latest",
791
+ updated_tokens,
792
+ "Token Aggregator",
793
+ "tokens_latest",
794
+ )
795
+ .await
796
+ .with_context(|| "Failed to persist token snapshot data to ClickHouse")?;
797
+
798
+ insert_rows(
799
+ &self.db_client,
800
+ "token_metrics",
801
+ metric_rows.clone(),
802
+ "Token Aggregator",
803
+ "token_metrics",
804
+ )
805
+ .await
806
+ .with_context(|| "Failed to persist token metric history to ClickHouse")?;
807
+
808
+ insert_rows(
809
+ &self.db_client,
810
+ "token_metrics_latest",
811
+ metric_rows,
812
+ "Token Aggregator",
813
+ "token_metrics_latest",
814
+ )
815
+ .await
816
+ .with_context(|| "Failed to persist token metric snapshots to ClickHouse")?;
817
+
818
+ Ok(())
819
+ }
820
+
821
+ fn metrics_has_activity(metrics: &TokenMetricsRow) -> bool {
822
+ metrics.total_volume_usd > 0.0
823
+ || metrics.total_buys > 0
824
+ || metrics.total_sells > 0
825
+ || metrics.unique_holders > 0
826
+ || metrics.ath_price_usd > 0.0
827
+ }
828
+
829
+ async fn collect_events(
830
+ &mut self,
831
+ stream_key: &str,
832
+ group_name: &str,
833
+ consumer_name: &str,
834
+ ) -> Result<Vec<(String, EventPayload)>> {
835
+ let opts = StreamReadOptions::default()
836
+ .group(group_name, consumer_name)
837
+ .count(1000)
838
+ .block(2000);
839
+ let reply: StreamReadReply = self
840
+ .redis_conn
841
+ .xread_options(&[stream_key], &[">"], &opts)
842
+ .await?;
843
+ let mut events = Vec::new();
844
+ for stream_entry in reply.keys {
845
+ for message in stream_entry.ids {
846
+ if let Some(payload_value) = message.map.get("payload") {
847
+ if let Ok(payload_bytes) = Vec::<u8>::from_redis_value(payload_value) {
848
+ if let Ok(payload) = bincode::deserialize::<EventPayload>(&payload_bytes) {
849
+ events.push((message.id.clone(), payload));
850
+ }
851
+ }
852
+ }
853
+ }
854
+ }
855
+ Ok(events)
856
+ }
857
+ }
train.py CHANGED
@@ -427,6 +427,7 @@ def main() -> None:
427
 
428
  # --- 7. Training Loop ---
429
  total_steps = 0
 
430
 
431
  logger.info("***** Running training *****")
432
  logger.info(f" Num examples = {len(dataset)}")
@@ -470,8 +471,12 @@ def main() -> None:
470
  outputs = model(batch)
471
 
472
  preds = outputs["quantile_logits"]
 
473
  labels = batch["labels"]
474
  labels_mask = batch["labels_mask"]
 
 
 
475
  if labels_mask is not None and labels_mask.sum().item() == 0:
476
  token_addresses = batch.get('token_addresses', [])
477
  t_cutoffs = batch.get('t_cutoffs', [])
@@ -482,11 +487,14 @@ def main() -> None:
482
  token_addresses[0] if token_addresses else "unknown",
483
  t_cutoffs[0] if t_cutoffs else "unknown",
484
  )
485
-
486
  if labels_mask.sum() == 0:
487
- loss = torch.tensor(0.0, requires_grad=True, device=accelerator.device)
488
  else:
489
- loss = quantile_pinball_loss(preds, labels, labels_mask, quantiles)
 
 
 
490
 
491
  accelerator.backward(loss)
492
 
@@ -519,6 +527,8 @@ def main() -> None:
519
  log_debug_batch_context(batch, logger, total_steps)
520
 
521
  current_loss = loss.item()
 
 
522
  epoch_loss += current_loss
523
  valid_batches += 1
524
 
@@ -526,6 +536,8 @@ def main() -> None:
526
  lr = scheduler.get_last_lr()[0]
527
  log_payload = {
528
  "train/loss": current_loss,
 
 
529
  "train/learning_rate": lr,
530
  "train/epoch": epoch + (step / len(dataloader))
531
  }
 
427
 
428
  # --- 7. Training Loop ---
429
  total_steps = 0
430
+ quality_loss_fn = nn.MSELoss()
431
 
432
  logger.info("***** Running training *****")
433
  logger.info(f" Num examples = {len(dataset)}")
 
471
  outputs = model(batch)
472
 
473
  preds = outputs["quantile_logits"]
474
+ quality_preds = outputs["quality_logits"]
475
  labels = batch["labels"]
476
  labels_mask = batch["labels_mask"]
477
+ if "quality_score" not in batch:
478
+ raise RuntimeError("FATAL: quality_score missing from batch. Cannot train quality head.")
479
+ quality_targets = batch["quality_score"].to(accelerator.device, dtype=quality_preds.dtype)
480
  if labels_mask is not None and labels_mask.sum().item() == 0:
481
  token_addresses = batch.get('token_addresses', [])
482
  t_cutoffs = batch.get('t_cutoffs', [])
 
487
  token_addresses[0] if token_addresses else "unknown",
488
  t_cutoffs[0] if t_cutoffs else "unknown",
489
  )
490
+
491
  if labels_mask.sum() == 0:
492
+ return_loss = torch.tensor(0.0, requires_grad=True, device=accelerator.device)
493
  else:
494
+ return_loss = quantile_pinball_loss(preds, labels, labels_mask, quantiles)
495
+
496
+ quality_loss = quality_loss_fn(quality_preds, quality_targets)
497
+ loss = return_loss + quality_loss
498
 
499
  accelerator.backward(loss)
500
 
 
527
  log_debug_batch_context(batch, logger, total_steps)
528
 
529
  current_loss = loss.item()
530
+ current_return_loss = return_loss.item()
531
+ current_quality_loss = quality_loss.item()
532
  epoch_loss += current_loss
533
  valid_batches += 1
534
 
 
536
  lr = scheduler.get_last_lr()[0]
537
  log_payload = {
538
  "train/loss": current_loss,
539
+ "train/return_loss": current_return_loss,
540
+ "train/quality_loss": current_quality_loss,
541
  "train/learning_rate": lr,
542
  "train/epoch": epoch + (step / len(dataloader))
543
  }
train.sh CHANGED
@@ -11,7 +11,7 @@ accelerate launch train.py \
11
  --tensorboard_dir runs/oracle \
12
  --checkpoint_dir checkpoints \
13
  --mixed_precision bf16 \
14
- --max_seq_len 8192 \
15
  --horizons_seconds 60 180 300 600 1800 3600 7200 \
16
  --quantiles 0.1 0.5 0.9 \
17
  --ohlc_stats_path ./data/ohlc_stats.npz \
 
11
  --tensorboard_dir runs/oracle \
12
  --checkpoint_dir checkpoints \
13
  --mixed_precision bf16 \
14
+ --max_seq_len 4096 \
15
  --horizons_seconds 60 180 300 600 1800 3600 7200 \
16
  --quantiles 0.1 0.5 0.9 \
17
  --ohlc_stats_path ./data/ohlc_stats.npz \