zirobtc commited on
Commit
bb2313b
·
1 Parent(s): 23349c5

Upload folder using huggingface_hub

Browse files
data/data_loader.py CHANGED
@@ -128,9 +128,22 @@ class OracleDataset(Dataset):
128
  cache_dir: Optional[Union[str, Path]] = None,
129
  start_date: Optional[datetime.datetime] = None,
130
  min_trade_usd: float = 0.0,
131
- max_seq_len: int = 8192):
 
132
 
133
  self.max_seq_len = max_seq_len
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  # --- NEW: Create a persistent requests session for efficiency ---
136
  # Configure robust HTTP session
@@ -633,10 +646,18 @@ class OracleDataset(Dataset):
633
  sell_count = sum(1 for e in trades_win if e.get('trade_direction') == 1)
634
  volume = sum(float(e.get('total_usd', 0.0) or 0.0) for e in trades_win)
635
  total_txns = len(trades_win) + len(xfers_win)
636
- global_fees_paid = sum(float(e.get('priority_fee', 0.0) or 0.0) for e in trades_win) + \
637
- sum(float(e.get('priority_fee', 0.0) or 0.0) for e in xfers_win)
 
 
638
 
639
- smart_trader_addrs = set(e['wallet_address'] for e in trades_win if e.get('event_type') == 'SmartWallet_Trade')
 
 
 
 
 
 
640
  smart_traders = len(smart_trader_addrs)
641
 
642
  kol_addrs = set()
@@ -825,7 +846,7 @@ class OracleDataset(Dataset):
825
 
826
  # --- Define all expected numerical keys for a profile ---
827
  expected_profile_keys = [
828
- 'age', 'deployed_tokens_count', 'deployed_tokens_migrated_pct',
829
  'deployed_tokens_avg_lifetime_sec', 'deployed_tokens_avg_peak_mc_usd',
830
  'deployed_tokens_median_peak_mc_usd', 'balance', 'transfers_in_count',
831
  'transfers_out_count', 'spl_transfers_in_count', 'spl_transfers_out_count',
@@ -852,14 +873,7 @@ class OracleDataset(Dataset):
852
  social_data['has_telegram'] = bool(social_data.get('telegram_channel'))
853
  social_data['is_exchange_wallet'] = 'exchange_wallet' in profile_data.get('tags', [])
854
 
855
- # --- Calculate 'age' based on user's logic ---
856
- funded_ts = profile_data.get('funded_timestamp', 0)
857
- if funded_ts and funded_ts > 0:
858
- age_seconds = int(datetime.datetime.now(datetime.timezone.utc).timestamp()) - funded_ts
859
- else:
860
- age_seconds = 12_960_000
861
-
862
- profile_data['age'] = float(age_seconds)
863
 
864
  username = social_data.get('pumpfun_username') or social_data.get('twitter_username') or social_data.get('kolscan_name')
865
 
@@ -909,12 +923,12 @@ class OracleDataset(Dataset):
909
  'mint_address': mint_addr,
910
  'holding_time': float(holding_item.get('holding_time', 0.0) or 0.0),
911
  'balance_pct_to_supply': min(1.0, float(holding_item.get('balance_pct_to_supply', 0.0) or 0.0)),
912
- 'history_bought_cost_sol': min(1000.0, float(holding_item.get('history_bought_cost_sol', 0.0) or 0.0)),
913
  'bought_amount_sol_pct_to_native_balance': min(1.0, float(holding_item.get('bought_amount_sol_pct_to_native_balance', 0.0) or 0.0)),
914
  'history_total_buys': float(holding_item.get('history_total_buys', 0.0) or 0.0),
915
  'history_total_sells': float(holding_item.get('history_total_sells', 0.0) or 0.0),
916
  'realized_profit_pnl': float(holding_item.get('realized_profit_pnl', 0.0) or 0.0),
917
- 'realized_profit_sol': max(-1000.0, min(1000.0, float(holding_item.get('realized_profit_sol', 0.0) or 0.0))),
918
  'history_transfer_in': float(holding_item.get('history_transfer_in', 0.0) or 0.0),
919
  'history_transfer_out': float(holding_item.get('history_transfer_out', 0.0) or 0.0),
920
  'avarage_trade_gap_seconds': float(holding_item.get('avarage_trade_gap_seconds', 0.0) or 0.0),
@@ -926,7 +940,7 @@ class OracleDataset(Dataset):
926
  compact_profile = {'wallet_address': addr}
927
  for key in expected_profile_keys:
928
  compact_profile[key] = float(profile_data.get(key, 0.0) or 0.0)
929
- compact_profile['age'] = float(profile_data.get('age', 0.0) or 0.0)
930
 
931
  compact_social = {
932
  'has_pf_profile': bool(social_data.get('has_pf_profile', False)),
@@ -2073,7 +2087,7 @@ class OracleDataset(Dataset):
2073
  'mev_protection': 1 if trade.get('mev_protection', 0) > 0 else 0,
2074
  'token_amount_pct_of_holding': token_pct_hold,
2075
  'quote_amount_pct_of_holding': quote_pct_hold,
2076
- 'slippage': min(10.0, float(trade.get('slippage', 0.0) or 0.0)),
2077
  'token_amount_pct_to_total_supply': token_pct_supply,
2078
  'success': is_success,
2079
  'is_bundle': trade.get('is_bundle', False),
 
128
  cache_dir: Optional[Union[str, Path]] = None,
129
  start_date: Optional[datetime.datetime] = None,
130
  min_trade_usd: float = 0.0,
131
+ max_seq_len: int = 8192,
132
+ p99_clamps: Optional[Dict[str, float]] = None):
133
 
134
  self.max_seq_len = max_seq_len
135
+
136
+ # --- P99 data-driven clamp values (replace hardcoded min/max) ---
137
+ self.p99_clamps = {
138
+ 'slippage': 1.0,
139
+ 'priority_fee': 0.1,
140
+ 'total_usd': 100000.0,
141
+ 'history_bought_cost_sol': 30.0,
142
+ 'realized_profit_sol': 150.0,
143
+ }
144
+ if p99_clamps:
145
+ self.p99_clamps.update(p99_clamps)
146
+ print(f"INFO: Using P99 clamps: {self.p99_clamps}")
147
 
148
  # --- NEW: Create a persistent requests session for efficiency ---
149
  # Configure robust HTTP session
 
646
  sell_count = sum(1 for e in trades_win if e.get('trade_direction') == 1)
647
  volume = sum(float(e.get('total_usd', 0.0) or 0.0) for e in trades_win)
648
  total_txns = len(trades_win) + len(xfers_win)
649
+ global_fees_paid = sum(
650
+ float(e.get('priority_fee', 0.0) or 0.0) + float(e.get('bribe_fee', 0.0) or 0.0)
651
+ for e in trades_win
652
+ )
653
 
654
+ smart_trader_addrs = set(
655
+ e['wallet_address'] for e in trade_events
656
+ if e.get('event_type') == 'SmartWallet_Trade'
657
+ and e.get('success', False)
658
+ and e['timestamp'] <= ts_value
659
+ and holder_pct_map_ts.get(e['wallet_address'], 0.0) > 0.0
660
+ )
661
  smart_traders = len(smart_trader_addrs)
662
 
663
  kol_addrs = set()
 
846
 
847
  # --- Define all expected numerical keys for a profile ---
848
  expected_profile_keys = [
849
+ 'deployed_tokens_count', 'deployed_tokens_migrated_pct',
850
  'deployed_tokens_avg_lifetime_sec', 'deployed_tokens_avg_peak_mc_usd',
851
  'deployed_tokens_median_peak_mc_usd', 'balance', 'transfers_in_count',
852
  'transfers_out_count', 'spl_transfers_in_count', 'spl_transfers_out_count',
 
873
  social_data['has_telegram'] = bool(social_data.get('telegram_channel'))
874
  social_data['is_exchange_wallet'] = 'exchange_wallet' in profile_data.get('tags', [])
875
 
876
+
 
 
 
 
 
 
 
877
 
878
  username = social_data.get('pumpfun_username') or social_data.get('twitter_username') or social_data.get('kolscan_name')
879
 
 
923
  'mint_address': mint_addr,
924
  'holding_time': float(holding_item.get('holding_time', 0.0) or 0.0),
925
  'balance_pct_to_supply': min(1.0, float(holding_item.get('balance_pct_to_supply', 0.0) or 0.0)),
926
+ 'history_bought_cost_sol': min(self.p99_clamps['history_bought_cost_sol'], float(holding_item.get('history_bought_cost_sol', 0.0) or 0.0)),
927
  'bought_amount_sol_pct_to_native_balance': min(1.0, float(holding_item.get('bought_amount_sol_pct_to_native_balance', 0.0) or 0.0)),
928
  'history_total_buys': float(holding_item.get('history_total_buys', 0.0) or 0.0),
929
  'history_total_sells': float(holding_item.get('history_total_sells', 0.0) or 0.0),
930
  'realized_profit_pnl': float(holding_item.get('realized_profit_pnl', 0.0) or 0.0),
931
+ 'realized_profit_sol': max(-self.p99_clamps['realized_profit_sol'], min(self.p99_clamps['realized_profit_sol'], float(holding_item.get('realized_profit_sol', 0.0) or 0.0))),
932
  'history_transfer_in': float(holding_item.get('history_transfer_in', 0.0) or 0.0),
933
  'history_transfer_out': float(holding_item.get('history_transfer_out', 0.0) or 0.0),
934
  'avarage_trade_gap_seconds': float(holding_item.get('avarage_trade_gap_seconds', 0.0) or 0.0),
 
940
  compact_profile = {'wallet_address': addr}
941
  for key in expected_profile_keys:
942
  compact_profile[key] = float(profile_data.get(key, 0.0) or 0.0)
943
+
944
 
945
  compact_social = {
946
  'has_pf_profile': bool(social_data.get('has_pf_profile', False)),
 
2087
  'mev_protection': 1 if trade.get('mev_protection', 0) > 0 else 0,
2088
  'token_amount_pct_of_holding': token_pct_hold,
2089
  'quote_amount_pct_of_holding': quote_pct_hold,
2090
+ 'slippage': min(self.p99_clamps['slippage'], float(trade.get('slippage', 0.0) or 0.0)),
2091
  'token_amount_pct_to_total_supply': token_pct_supply,
2092
  'success': is_success,
2093
  'is_bundle': trade.get('is_bundle', False),
data/ohlc_stats.npz CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:87f6e823bed45b3e399d6fe2ab46f3297d80a623e5a41cca785a60c5a7db067d
3
  size 1660
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d757990a0158118444be61f3d944dfb125237928809b4568ac209ab260f032e
3
  size 1660
inference.py CHANGED
@@ -29,7 +29,7 @@ if __name__ == "__main__":
29
  print("--- Oracle Inference Script (Full Pipeline Test) ---")
30
 
31
  # --- 1. Define Configs ---
32
- OHLC_SEQ_LEN = 60
33
  print(f"Using {vocab.NUM_EVENT_TYPES} event types from vocabulary.")
34
 
35
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
29
  print("--- Oracle Inference Script (Full Pipeline Test) ---")
30
 
31
  # --- 1. Define Configs ---
32
+ OHLC_SEQ_LEN = 300
33
  print(f"Using {vocab.NUM_EVENT_TYPES} event types from vocabulary.")
34
 
35
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log.log CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c605ecab2de1c8c8442dda85ada5345b9d6ba43aae4095130f1d92ce6261c127
3
- size 44400
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:656c6818f224b26869b5d0ae10f6b522ff7eb5c7b1b3aeb59b34c3db218338a9
3
+ size 11360
models/wallet_encoder.py CHANGED
@@ -47,14 +47,14 @@ class WalletEncoder(nn.Module):
47
  self.mmp_dim = self.encoder.embedding_dim # 1152
48
 
49
  # === 1. Profile Encoder (FIXED) ===
50
- # 1 age + 5 deployer_stats + 1 balance + 4 lifetime_counts +
51
- # 3 lifetime_trading + 12 1d_stats + 12 7d_stats = 38
52
- self.profile_numerical_features = 38
53
  self.profile_num_norm = nn.LayerNorm(self.profile_numerical_features)
54
 
55
 
56
  # FIXED: Input dim no longer has bool embed or deployed tokens embed
57
- profile_mlp_in_dim = self.profile_numerical_features # 38
58
  self.profile_encoder_mlp = self._build_mlp(profile_mlp_in_dim, d_model)
59
 
60
 
@@ -152,17 +152,15 @@ class WalletEncoder(nn.Module):
152
 
153
  def _encode_profile_batch(self, profile_rows, device):
154
  batch_size = len(profile_rows)
155
- # FIXED: 38 numerical features
156
  num_tensor = torch.zeros(batch_size, self.profile_numerical_features, device=device, dtype=self.dtype)
157
  # bool_tensor removed
158
  # time_tensor removed
159
 
160
  for i, row in enumerate(profile_rows):
161
- # A: Numerical (FIXED: 38 features, MUST be present)
162
  num_data = [
163
- # 1. Age
164
- row.get('age', 0.0),
165
- # 2. Deployed Token Aggregates (5)
166
  row.get('deployed_tokens_count', 0.0),
167
  row.get('deployed_tokens_migrated_pct', 0.0),
168
  row.get('deployed_tokens_avg_lifetime_sec', 0.0),
@@ -195,7 +193,7 @@ class WalletEncoder(nn.Module):
195
 
196
  # C: Booleans and deployed_tokens lists are GONE
197
 
198
- # Log-normalize all numerical features (age, stats, etc.)
199
  num_embed = self.profile_num_norm(self._safe_signed_log(num_tensor))
200
 
201
  # The profile fused tensor is now just the numerical embeddings
 
47
  self.mmp_dim = self.encoder.embedding_dim # 1152
48
 
49
  # === 1. Profile Encoder (FIXED) ===
50
+ # 5 deployer_stats + 1 balance + 4 lifetime_counts +
51
+ # 3 lifetime_trading + 12 1d_stats + 12 7d_stats = 37
52
+ self.profile_numerical_features = 37
53
  self.profile_num_norm = nn.LayerNorm(self.profile_numerical_features)
54
 
55
 
56
  # FIXED: Input dim no longer has bool embed or deployed tokens embed
57
+ profile_mlp_in_dim = self.profile_numerical_features # 37
58
  self.profile_encoder_mlp = self._build_mlp(profile_mlp_in_dim, d_model)
59
 
60
 
 
152
 
153
  def _encode_profile_batch(self, profile_rows, device):
154
  batch_size = len(profile_rows)
155
+ # FIXED: 37 numerical features
156
  num_tensor = torch.zeros(batch_size, self.profile_numerical_features, device=device, dtype=self.dtype)
157
  # bool_tensor removed
158
  # time_tensor removed
159
 
160
  for i, row in enumerate(profile_rows):
161
+ # A: Numerical (FIXED: 37 features, MUST be present)
162
  num_data = [
163
+ # 1. Deployed Token Aggregates (5)
 
 
164
  row.get('deployed_tokens_count', 0.0),
165
  row.get('deployed_tokens_migrated_pct', 0.0),
166
  row.get('deployed_tokens_avg_lifetime_sec', 0.0),
 
193
 
194
  # C: Booleans and deployed_tokens lists are GONE
195
 
196
+ # Log-normalize all numerical features (stats, etc.)
197
  num_embed = self.profile_num_norm(self._safe_signed_log(num_tensor))
198
 
199
  # The profile fused tensor is now just the numerical embeddings
sample_DYtPmhyxPDbMEdVP_0.json ADDED
The diff for this file is too large to render. See raw diff
 
scripts/analyze_distribution.py CHANGED
@@ -27,6 +27,55 @@ def get_client():
27
  database=CLICKHOUSE_DATABASE
28
  )
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def fetch_all_metrics(client):
31
  """
32
  Fetches all needed metrics for all tokens in a single query.
 
27
  database=CLICKHOUSE_DATABASE
28
  )
29
 
30
+ def compute_p99_clamps(client):
31
+ """
32
+ Computes P99 percentile clamp values from ClickHouse for fields prone to
33
+ garbage outliers. These values replace hardcoded clamps in data_loader.py.
34
+ Returns a dict of {field_name: p99_value}.
35
+ """
36
+ print(" -> Computing P99 clamp values from trades table...")
37
+ trade_query = """
38
+ SELECT
39
+ quantile(0.99)(abs(slippage)) AS p99_slippage,
40
+ quantile(0.99)(priority_fee) AS p99_priority_fee,
41
+ quantile(0.99)(total_usd) AS p99_total_usd
42
+ FROM trades
43
+ WHERE success = 1
44
+ """
45
+ trade_row = client.execute(trade_query)
46
+
47
+ print(" -> Computing P99 clamp values from wallet_holdings table...")
48
+ holdings_query = """
49
+ SELECT
50
+ quantile(0.99)(history_bought_cost_sol) AS p99_bought_cost_sol,
51
+ quantile(0.99)(abs(realized_profit_sol)) AS p99_realized_profit_sol
52
+ FROM wallet_holdings
53
+ """
54
+ holdings_row = client.execute(holdings_query)
55
+
56
+ clamps = {
57
+ # Defaults as fallback if queries return nothing
58
+ 'slippage': 1.0,
59
+ 'priority_fee': 0.1,
60
+ 'total_usd': 100000.0,
61
+ 'history_bought_cost_sol': 30.0,
62
+ 'realized_profit_sol': 150.0,
63
+ }
64
+
65
+ if trade_row and trade_row[0]:
66
+ r = trade_row[0]
67
+ clamps['slippage'] = max(float(r[0]), 0.01)
68
+ clamps['priority_fee'] = max(float(r[1]), 1e-9)
69
+ clamps['total_usd'] = max(float(r[2]), 1.0)
70
+
71
+ if holdings_row and holdings_row[0]:
72
+ r = holdings_row[0]
73
+ clamps['history_bought_cost_sol'] = max(float(r[0]), 0.01)
74
+ clamps['realized_profit_sol'] = max(float(r[1]), 0.01)
75
+
76
+ print(f" -> P99 Clamps: {clamps}")
77
+ return clamps
78
+
79
  def fetch_all_metrics(client):
80
  """
81
  Fetches all needed metrics for all tokens in a single query.
scripts/cache_dataset.py CHANGED
@@ -22,7 +22,7 @@ logging.getLogger("huggingface_hub").setLevel(logging.WARNING)
22
 
23
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
24
 
25
- from scripts.analyze_distribution import get_return_class_map
26
  from scripts.compute_quality_score import get_token_quality_scores, fetch_token_metrics, _bucket_id, _midrank_percentiles, EPS
27
 
28
  from clickhouse_driver import Client as ClickHouseClient
@@ -65,7 +65,8 @@ def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map
65
  horizons_seconds=dataset_config['horizons_seconds'],
66
  quantiles=dataset_config['quantiles'],
67
  min_trade_usd=dataset_config['min_trade_usd'],
68
- max_seq_len=dataset_config['max_seq_len']
 
69
  )
70
  _worker_dataset.sampled_mints = dataset_config['sampled_mints']
71
  _worker_return_class_map = return_class_map
@@ -179,11 +180,14 @@ def main():
179
  return_class_map, _ = get_return_class_map(clickhouse_client)
180
  print(f"INFO: Loaded {len(return_class_map)} classified tokens.")
181
 
 
 
 
182
  print("INFO: Fetching Quality Scores...")
183
  quality_scores_map = get_token_quality_scores(clickhouse_client)
184
  print(f"INFO: Loaded {len(quality_scores_map)} quality scores.")
185
 
186
- dataset = OracleDataset(data_fetcher=data_fetcher, max_samples=args.max_samples, start_date=start_date_dt, ohlc_stats_path=args.ohlc_stats_path, horizons_seconds=args.horizons_seconds, quantiles=args.quantiles, min_trade_usd=args.min_trade_usd, max_seq_len=args.context_length)
187
 
188
  if len(dataset) == 0:
189
  print("WARNING: No samples. Exiting.")
@@ -219,7 +223,7 @@ def main():
219
  print(f"INFO: Workers: {args.num_workers}")
220
 
221
  db_config = {'clickhouse_host': args.clickhouse_host, 'clickhouse_port': args.clickhouse_port, 'neo4j_uri': args.neo4j_uri, 'neo4j_user': args.neo4j_user, 'neo4j_password': args.neo4j_password}
222
- dataset_config = {'max_samples': args.max_samples, 'start_date': start_date_dt, 'ohlc_stats_path': args.ohlc_stats_path, 'horizons_seconds': args.horizons_seconds, 'quantiles': args.quantiles, 'min_trade_usd': args.min_trade_usd, 'max_seq_len': args.context_length, 'sampled_mints': filtered_mints}
223
 
224
  # Build tasks with class-aware multi-sampling for balanced cache
225
  import random
 
22
 
23
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
24
 
25
+ from scripts.analyze_distribution import get_return_class_map, compute_p99_clamps
26
  from scripts.compute_quality_score import get_token_quality_scores, fetch_token_metrics, _bucket_id, _midrank_percentiles, EPS
27
 
28
  from clickhouse_driver import Client as ClickHouseClient
 
65
  horizons_seconds=dataset_config['horizons_seconds'],
66
  quantiles=dataset_config['quantiles'],
67
  min_trade_usd=dataset_config['min_trade_usd'],
68
+ max_seq_len=dataset_config['max_seq_len'],
69
+ p99_clamps=dataset_config.get('p99_clamps')
70
  )
71
  _worker_dataset.sampled_mints = dataset_config['sampled_mints']
72
  _worker_return_class_map = return_class_map
 
180
  return_class_map, _ = get_return_class_map(clickhouse_client)
181
  print(f"INFO: Loaded {len(return_class_map)} classified tokens.")
182
 
183
+ print("INFO: Computing P99 clamp values...")
184
+ p99_clamps = compute_p99_clamps(clickhouse_client)
185
+
186
  print("INFO: Fetching Quality Scores...")
187
  quality_scores_map = get_token_quality_scores(clickhouse_client)
188
  print(f"INFO: Loaded {len(quality_scores_map)} quality scores.")
189
 
190
+ dataset = OracleDataset(data_fetcher=data_fetcher, max_samples=args.max_samples, start_date=start_date_dt, ohlc_stats_path=args.ohlc_stats_path, horizons_seconds=args.horizons_seconds, quantiles=args.quantiles, min_trade_usd=args.min_trade_usd, max_seq_len=args.context_length, p99_clamps=p99_clamps)
191
 
192
  if len(dataset) == 0:
193
  print("WARNING: No samples. Exiting.")
 
223
  print(f"INFO: Workers: {args.num_workers}")
224
 
225
  db_config = {'clickhouse_host': args.clickhouse_host, 'clickhouse_port': args.clickhouse_port, 'neo4j_uri': args.neo4j_uri, 'neo4j_user': args.neo4j_user, 'neo4j_password': args.neo4j_password}
226
+ dataset_config = {'max_samples': args.max_samples, 'start_date': start_date_dt, 'ohlc_stats_path': args.ohlc_stats_path, 'horizons_seconds': args.horizons_seconds, 'quantiles': args.quantiles, 'min_trade_usd': args.min_trade_usd, 'max_seq_len': args.context_length, 'sampled_mints': filtered_mints, 'p99_clamps': p99_clamps}
227
 
228
  # Build tasks with class-aware multi-sampling for balanced cache
229
  import random