zirobtc commited on
Commit
5800f64
·
1 Parent(s): 256e651

Upload folder using huggingface_hub

Browse files
data/data_collator.py CHANGED
@@ -282,23 +282,13 @@ class MemecoinCollator:
282
  wallet_addr_to_batch_idx = {feat.get('profile', {}).get('wallet_address', f'__error_{i}'): i+1 for i, feat in enumerate(wallet_list_data)}
283
  token_addr_to_batch_idx = {feat.get('address', f'__error_{i}'): i+1 for i, feat in enumerate(token_list_data)}
284
 
 
 
285
  # Collate Static Raw Features (Tokens, Wallets, Graph)
286
  token_encoder_inputs = self._collate_features_for_encoder(token_list_data, ['name'], self.device, "token")
287
  wallet_encoder_inputs = self._collate_features_for_encoder(wallet_list_data, ['profile'], self.device, "wallet")
288
  graph_updater_links = self._collate_graph_links(batch, wallet_addr_to_batch_idx, token_addr_to_batch_idx)
289
 
290
- # --- Logging ---
291
- pool_contents = batch_wide_pooler.get_all_items()
292
- print(f"\n[DataCollator: Final Embedding Pool] ({len(pool_contents)} items):")
293
- if pool_contents:
294
- for item_data in pool_contents:
295
- sample_item = item_data['item']
296
- sample_type = "Image" if isinstance(sample_item, Image.Image) else "Text"
297
- content_preview = str(sample_item)
298
- if sample_type == "Text" and len(content_preview) > 100:
299
- content_preview = content_preview[:97] + "..."
300
- print(f" - Item (Original Idx {item_data['idx']}): Type='{sample_type}', Content='{content_preview}'")
301
-
302
  # --- 5. Prepare Sequence Tensors & Collect Dynamic Data (OHLC) ---
303
  B = batch_size
304
  L = max_len
@@ -417,13 +407,7 @@ class MemecoinCollator:
417
 
418
  # Loop through sequences to populate tensors and collect chart events
419
  for i, seq in enumerate(all_event_sequences):
420
- # --- LOGGING CONTEXT (First item only) ---
421
- if i == 0:
422
- context_names = [e.get('event_type', 'Unknown') for e in seq]
423
- print("\n[DataCollator] Context Preview (Event Sequence Names):")
424
- print(context_names)
425
- print(f"[DataCollator] Sequence Length: {len(context_names)}\n")
426
-
427
  seq_len = len(seq)
428
  if seq_len == 0: continue
429
  attention_mask[i, :seq_len] = 1
 
282
  wallet_addr_to_batch_idx = {feat.get('profile', {}).get('wallet_address', f'__error_{i}'): i+1 for i, feat in enumerate(wallet_list_data)}
283
  token_addr_to_batch_idx = {feat.get('address', f'__error_{i}'): i+1 for i, feat in enumerate(token_list_data)}
284
 
285
+ # Collate Static Raw Features (Tokens, Wallets, Graph)
286
+ token_encoder_inputs = self._collate_features_for_encoder(token_list_data, ['name'], self.device, "token")
287
  # Collate Static Raw Features (Tokens, Wallets, Graph)
288
  token_encoder_inputs = self._collate_features_for_encoder(token_list_data, ['name'], self.device, "token")
289
  wallet_encoder_inputs = self._collate_features_for_encoder(wallet_list_data, ['profile'], self.device, "wallet")
290
  graph_updater_links = self._collate_graph_links(batch, wallet_addr_to_batch_idx, token_addr_to_batch_idx)
291
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  # --- 5. Prepare Sequence Tensors & Collect Dynamic Data (OHLC) ---
293
  B = batch_size
294
  L = max_len
 
407
 
408
  # Loop through sequences to populate tensors and collect chart events
409
  for i, seq in enumerate(all_event_sequences):
410
+
 
 
 
 
 
 
411
  seq_len = len(seq)
412
  if seq_len == 0: continue
413
  attention_mask[i, :seq_len] = 1
data/data_loader.py CHANGED
@@ -273,7 +273,6 @@ class OracleDataset(Dataset):
273
 
274
  ts_list = [int(entry[0]) for entry in price_series]
275
  price_list = [float(entry[1]) for entry in price_series]
276
- print(f"[DEBUG-TRACE-LABELS] ts_list len: {len(ts_list)}, price_list len: {len(price_list)}")
277
  if not ts_list:
278
  return torch.zeros(self.num_outputs), torch.zeros(self.num_outputs), []
279
 
@@ -531,7 +530,6 @@ class OracleDataset(Dataset):
531
  profiles, socials = profiles_override, socials_override
532
  holdings = holdings_override if holdings_override is not None else {}
533
  else:
534
- print(f"INFO: Processing wallet data for {len(wallet_addresses)} unique wallets...")
535
  if self.fetcher:
536
  profiles, socials = self.fetcher.fetch_wallet_profiles_and_socials(wallet_addresses, T_cutoff)
537
  holdings = self.fetcher.fetch_wallet_holdings(wallet_addresses, T_cutoff)
@@ -539,40 +537,29 @@ class OracleDataset(Dataset):
539
  profiles, socials, holdings = {}, {}, {}
540
 
541
  valid_wallets = [addr for addr in wallet_addresses if addr in profiles]
542
- dropped_wallets = set(wallet_addresses) - set(valid_wallets)
543
- if dropped_wallets:
544
- print(f"INFO: Skipping {len(dropped_wallets)} wallets with no profile before cutoff.")
545
  if not valid_wallets:
546
- print("INFO: All wallets were graph-only or appeared after cutoff; skipping wallet processing for this token.")
547
  return {}, token_data
548
  wallet_addresses = valid_wallets
549
 
550
- # --- NEW: Collect all unique mints from holdings to fetch their data ---
551
  all_holding_mints = set()
552
  for wallet_addr in wallet_addresses:
553
  for holding_item in holdings.get(wallet_addr, []):
554
  if 'mint_address' in holding_item:
555
  all_holding_mints.add(holding_item['mint_address'])
556
 
557
- # --- NEW: Process all discovered tokens with point-in-time logic ---
558
- # 1. Fetch raw data for all newly found tokens from holdings.
559
- # 2. Process this raw data to get embedding indices and add to the pooler.
560
- # Note: _process_token_data is designed to take a list and return a dict.
561
- # We pass the addresses and let it handle the fetching and processing internally.
562
  processed_new_tokens = self._process_token_data(list(all_holding_mints), pooler, T_cutoff)
563
- # 3. Merge the fully processed new tokens with the existing main token data.
564
  all_token_data = {**token_data, **(processed_new_tokens or {})}
565
 
566
- # --- NEW: Calculate deployed token stats using point-in-time logic ---
567
  self._calculate_deployed_token_stats(profiles, T_cutoff)
568
 
569
  # --- Assemble the final wallet dictionary ---
570
- # This structure is exactly what the WalletEncoder expects.
571
  final_wallets = {}
572
  for addr in wallet_addresses:
573
 
574
  # --- Define all expected numerical keys for a profile ---
575
- # This prevents KeyErrors if the DB returns a partial profile.
576
  expected_profile_keys = [
577
  'age', 'deployed_tokens_count', 'deployed_tokens_migrated_pct',
578
  'deployed_tokens_avg_lifetime_sec', 'deployed_tokens_avg_peak_mc_usd',
@@ -585,54 +572,39 @@ class OracleDataset(Dataset):
585
  'stats_1d_total_fee', 'stats_1d_winrate', 'stats_1d_tokens_traded',
586
  'stats_7d_realized_profit_sol', 'stats_7d_realized_profit_pnl', 'stats_7d_buy_count', 'stats_7d_sell_count', 'stats_7d_transfer_in_count', 'stats_7d_transfer_out_count', 'stats_7d_avg_holding_period', 'stats_7d_total_bought_cost_sol', 'stats_7d_total_sold_income_sol', 'stats_7d_total_fee', 'stats_7d_winrate', 'stats_7d_tokens_traded'
587
  ]
588
- # --- FIXED: Use .get() and provide a default empty dict if not found ---
589
- # --- NEW: If a wallet profile doesn't exist in the DB, skip it entirely. ---
590
- # This removes the old logic that created a placeholder profile with zeroed-out features.
591
- # "If it doesn't exist, it doesn't exist."
592
  profile_data = profiles.get(addr, None)
593
  if not profile_data:
594
- print(f"INFO: Wallet {addr} found in graph but has no profile in DB. Skipping this wallet.")
595
  continue
596
 
597
- # --- NEW: Ensure all expected keys exist in the fetched profile ---
598
  for key in expected_profile_keys:
599
- profile_data.setdefault(key, 0.0) # Use 0.0 as a safe default for any missing numerical key
600
 
601
  social_data = socials.get(addr, {})
602
 
603
- # --- NEW: Derive boolean social flags based on schema ---
604
  social_data['has_pf_profile'] = bool(social_data.get('pumpfun_username'))
605
  social_data['has_twitter'] = bool(social_data.get('twitter_username'))
606
  social_data['has_telegram'] = bool(social_data.get('telegram_channel'))
607
- # 'is_exchange_wallet' is not in the schema, so we'll default to False for now.
608
- # This is a feature that would likely come from a 'tags' column or a separate service.
609
  social_data['is_exchange_wallet'] = 'exchange_wallet' in profile_data.get('tags', [])
610
 
611
- # --- NEW: Calculate 'age' based on user's logic ---
612
  funded_ts = profile_data.get('funded_timestamp', 0)
613
  if funded_ts and funded_ts > 0:
614
- # Calculate age in seconds from the funding timestamp
615
  age_seconds = int(datetime.datetime.now(datetime.timezone.utc).timestamp()) - funded_ts
616
  else:
617
- # Fallback for wallets older than our DB window, as requested
618
- # 5 months * 30 days/month * 24 hours/day * 3600 seconds/hour
619
  age_seconds = 12_960_000
620
 
621
- # Add the calculated age to the profile data that the WalletEncoder will receive
622
  profile_data['age'] = float(age_seconds)
623
 
624
- # Get the username and add it to the embedding pooler
625
  username = social_data.get('pumpfun_username') or social_data.get('twitter_username') or social_data.get('kolscan_name')
626
 
627
  if isinstance(username, str) and username.strip():
628
  social_data['username_emb_idx'] = pooler.get_idx(username.strip())
629
  else:
630
- social_data['username_emb_idx'] = 0 # means "no embedding"
631
 
632
- # --- NEW: Filter holdings and calculate derived features ---
633
- # We create a new list `valid_wallet_holdings` to ensure that if a holding's
634
- # token is invalid (filtered out by _process_token_data), the entire holding
635
- # row is removed and not passed to the WalletEncoder.
636
  original_holdings = holdings.get(addr, [])
637
  valid_wallet_holdings = []
638
  now_ts = datetime.datetime.now(datetime.timezone.utc)
@@ -643,7 +615,6 @@ class OracleDataset(Dataset):
643
  token_info = all_token_data.get(mint_addr)
644
 
645
  if not token_info:
646
- print(f"INFO: Skipping holding for token {mint_addr} in wallet {addr} because token data is invalid/missing.")
647
  continue
648
 
649
  end_ts = holding_item.get('end_holding_at')
@@ -662,10 +633,9 @@ class OracleDataset(Dataset):
662
  holding_item['balance_pct_to_supply'] = 0.0
663
 
664
  # 3. --- NEW: Calculate bought_amount_sol_pct_to_native_balance ---
665
- # This uses the historically accurate native balance from the profile.
666
  wallet_native_balance = profile_data.get('balance', 0.0)
667
  bought_cost_sol = holding_item.get('history_bought_cost_sol', 0.0)
668
- if wallet_native_balance > 1e-9: # Use a small epsilon to avoid division by zero
669
  holding_item['bought_amount_sol_pct_to_native_balance'] = bought_cost_sol / wallet_native_balance
670
  else:
671
  holding_item['bought_amount_sol_pct_to_native_balance'] = 0.0
@@ -695,9 +665,7 @@ class OracleDataset(Dataset):
695
  else:
696
  token_data = {}
697
 
698
- # --- NEW: Print the raw fetched token data as requested ---
699
- print("\n--- RAW TOKEN DATA FROM DATABASE ---")
700
- print(token_data)
701
 
702
  # Add pre-computed embedding indices to the token data
703
  # --- CRITICAL FIX: This function now returns None if the main token is invalid ---
@@ -836,14 +804,8 @@ class OracleDataset(Dataset):
836
  full_ohlc = []
837
  start_ts = sorted_intervals[0]
838
  end_ts = int(T_cutoff.timestamp())
839
- # Align end_ts to the interval grid
840
- end_ts = (end_ts // interval_seconds) * interval_seconds
841
  last_price = aggregation_trades[0]['price_usd']
842
 
843
- # --- NEW: Debugging log for trades grouped by interval ---
844
- print(f"\n[DEBUG] OHLC Generation: Trades grouped by interval bucket:")
845
- print(dict(trades_by_interval))
846
-
847
  for ts in range(start_ts, end_ts + 1, interval_seconds):
848
  if ts in trades_by_interval:
849
  prices = trades_by_interval[ts]
@@ -940,12 +902,21 @@ class OracleDataset(Dataset):
940
  # If somehow we have fewer than 25 trades (cache mismatch?), fallback to last.
941
  safe_idx = min(24, len(sorted_trades_ts) - 1)
942
  min_cutoff_ts = sorted_trades_ts[safe_idx]
943
- max_cutoff_ts = sorted_trades_ts[-1]
 
 
 
 
 
 
 
 
 
944
 
945
  if max_cutoff_ts <= min_cutoff_ts:
946
  sample_offset_ts = min_cutoff_ts
947
  else:
948
- # Standard case: sample uniformly between [Trade[24], LastTrade]
949
  sample_offset_ts = random.uniform(min_cutoff_ts, max_cutoff_ts)
950
 
951
  T_cutoff = datetime.datetime.fromtimestamp(sample_offset_ts, tz=datetime.timezone.utc)
@@ -1221,17 +1192,7 @@ class OracleDataset(Dataset):
1221
  raw_data['snapshots_5m'] = snapshot_stats
1222
  raw_data['holder_snapshots_list'] = holder_snapshots_list # Save the list
1223
 
1224
- # --- Summary Log ---
1225
- print(f" [Cache Summary]")
1226
- print(f" - 1s Candles: {len(ohlc_1s)}")
1227
- print(f" - 5m Snapshots: {len(snapshot_stats)}")
1228
- print(f" - Trades (Succ): {len(trades)}")
1229
- print(f" - Pool Events: {len(raw_data.get('pool_creations', []))}")
1230
- print(f" - Liquidity Chgs: {len(raw_data.get('liquidity_changes', []))}")
1231
- print(f" - Burns: {len(raw_data.get('burns', []))}")
1232
- print(f" - Supply Locks: {len(raw_data.get('supply_locks', []))}")
1233
- print(f" - Migrations: {len(raw_data.get('migrations', []))}")
1234
-
1235
  raw_data["protocol_id"] = initial_mint_record.get("protocol")
1236
  return raw_data
1237
 
 
273
 
274
  ts_list = [int(entry[0]) for entry in price_series]
275
  price_list = [float(entry[1]) for entry in price_series]
 
276
  if not ts_list:
277
  return torch.zeros(self.num_outputs), torch.zeros(self.num_outputs), []
278
 
 
530
  profiles, socials = profiles_override, socials_override
531
  holdings = holdings_override if holdings_override is not None else {}
532
  else:
 
533
  if self.fetcher:
534
  profiles, socials = self.fetcher.fetch_wallet_profiles_and_socials(wallet_addresses, T_cutoff)
535
  holdings = self.fetcher.fetch_wallet_holdings(wallet_addresses, T_cutoff)
 
537
  profiles, socials, holdings = {}, {}, {}
538
 
539
  valid_wallets = [addr for addr in wallet_addresses if addr in profiles]
 
 
 
540
  if not valid_wallets:
 
541
  return {}, token_data
542
  wallet_addresses = valid_wallets
543
 
544
+ # --- Collect all unique mints from holdings to fetch their data ---
545
  all_holding_mints = set()
546
  for wallet_addr in wallet_addresses:
547
  for holding_item in holdings.get(wallet_addr, []):
548
  if 'mint_address' in holding_item:
549
  all_holding_mints.add(holding_item['mint_address'])
550
 
551
+ # --- Process all discovered tokens with point-in-time logic ---
 
 
 
 
552
  processed_new_tokens = self._process_token_data(list(all_holding_mints), pooler, T_cutoff)
 
553
  all_token_data = {**token_data, **(processed_new_tokens or {})}
554
 
555
+ # --- Calculate deployed token stats using point-in-time logic ---
556
  self._calculate_deployed_token_stats(profiles, T_cutoff)
557
 
558
  # --- Assemble the final wallet dictionary ---
 
559
  final_wallets = {}
560
  for addr in wallet_addresses:
561
 
562
  # --- Define all expected numerical keys for a profile ---
 
563
  expected_profile_keys = [
564
  'age', 'deployed_tokens_count', 'deployed_tokens_migrated_pct',
565
  'deployed_tokens_avg_lifetime_sec', 'deployed_tokens_avg_peak_mc_usd',
 
572
  'stats_1d_total_fee', 'stats_1d_winrate', 'stats_1d_tokens_traded',
573
  'stats_7d_realized_profit_sol', 'stats_7d_realized_profit_pnl', 'stats_7d_buy_count', 'stats_7d_sell_count', 'stats_7d_transfer_in_count', 'stats_7d_transfer_out_count', 'stats_7d_avg_holding_period', 'stats_7d_total_bought_cost_sol', 'stats_7d_total_sold_income_sol', 'stats_7d_total_fee', 'stats_7d_winrate', 'stats_7d_tokens_traded'
574
  ]
575
+
 
 
 
576
  profile_data = profiles.get(addr, None)
577
  if not profile_data:
 
578
  continue
579
 
 
580
  for key in expected_profile_keys:
581
+ profile_data.setdefault(key, 0.0)
582
 
583
  social_data = socials.get(addr, {})
584
 
585
+ # --- Derive boolean social flags based on schema ---
586
  social_data['has_pf_profile'] = bool(social_data.get('pumpfun_username'))
587
  social_data['has_twitter'] = bool(social_data.get('twitter_username'))
588
  social_data['has_telegram'] = bool(social_data.get('telegram_channel'))
 
 
589
  social_data['is_exchange_wallet'] = 'exchange_wallet' in profile_data.get('tags', [])
590
 
591
+ # --- Calculate 'age' based on user's logic ---
592
  funded_ts = profile_data.get('funded_timestamp', 0)
593
  if funded_ts and funded_ts > 0:
 
594
  age_seconds = int(datetime.datetime.now(datetime.timezone.utc).timestamp()) - funded_ts
595
  else:
 
 
596
  age_seconds = 12_960_000
597
 
 
598
  profile_data['age'] = float(age_seconds)
599
 
 
600
  username = social_data.get('pumpfun_username') or social_data.get('twitter_username') or social_data.get('kolscan_name')
601
 
602
  if isinstance(username, str) and username.strip():
603
  social_data['username_emb_idx'] = pooler.get_idx(username.strip())
604
  else:
605
+ social_data['username_emb_idx'] = 0
606
 
607
+ # --- Filter holdings and calculate derived features ---
 
 
 
608
  original_holdings = holdings.get(addr, [])
609
  valid_wallet_holdings = []
610
  now_ts = datetime.datetime.now(datetime.timezone.utc)
 
615
  token_info = all_token_data.get(mint_addr)
616
 
617
  if not token_info:
 
618
  continue
619
 
620
  end_ts = holding_item.get('end_holding_at')
 
633
  holding_item['balance_pct_to_supply'] = 0.0
634
 
635
  # 3. --- NEW: Calculate bought_amount_sol_pct_to_native_balance ---
 
636
  wallet_native_balance = profile_data.get('balance', 0.0)
637
  bought_cost_sol = holding_item.get('history_bought_cost_sol', 0.0)
638
+ if wallet_native_balance > 1e-9:
639
  holding_item['bought_amount_sol_pct_to_native_balance'] = bought_cost_sol / wallet_native_balance
640
  else:
641
  holding_item['bought_amount_sol_pct_to_native_balance'] = 0.0
 
665
  else:
666
  token_data = {}
667
 
668
+
 
 
669
 
670
  # Add pre-computed embedding indices to the token data
671
  # --- CRITICAL FIX: This function now returns None if the main token is invalid ---
 
804
  full_ohlc = []
805
  start_ts = sorted_intervals[0]
806
  end_ts = int(T_cutoff.timestamp())
 
 
807
  last_price = aggregation_trades[0]['price_usd']
808
 
 
 
 
 
809
  for ts in range(start_ts, end_ts + 1, interval_seconds):
810
  if ts in trades_by_interval:
811
  prices = trades_by_interval[ts]
 
902
  # If somehow we have fewer than 25 trades (cache mismatch?), fallback to last.
903
  safe_idx = min(24, len(sorted_trades_ts) - 1)
904
  min_cutoff_ts = sorted_trades_ts[safe_idx]
905
+
906
+ # --- FIX: Ensure max_cutoff leaves room for the largest horizon ---
907
+ # Otherwise, if T_cutoff is near the end, all horizons are masked as 0.
908
+ max_horizon = max(horizons) if horizons else 600
909
+ max_cutoff_ts = sorted_trades_ts[-1] - max_horizon
910
+
911
+ # Safety: Ensure max_cutoff_ts >= min_cutoff_ts
912
+ if max_cutoff_ts < min_cutoff_ts:
913
+ # Token duration is too short for the horizons, use earliest valid cutoff
914
+ max_cutoff_ts = min_cutoff_ts
915
 
916
  if max_cutoff_ts <= min_cutoff_ts:
917
  sample_offset_ts = min_cutoff_ts
918
  else:
919
+ # Standard case: sample uniformly between [Trade[24], LastTrade - max_horizon]
920
  sample_offset_ts = random.uniform(min_cutoff_ts, max_cutoff_ts)
921
 
922
  T_cutoff = datetime.datetime.fromtimestamp(sample_offset_ts, tz=datetime.timezone.utc)
 
1192
  raw_data['snapshots_5m'] = snapshot_stats
1193
  raw_data['holder_snapshots_list'] = holder_snapshots_list # Save the list
1194
 
1195
+ raw_data['holder_snapshots_list'] = holder_snapshots_list # Save the list
 
 
 
 
 
 
 
 
 
 
1196
  raw_data["protocol_id"] = initial_mint_record.get("protocol")
1197
  return raw_data
1198
 
data/ohlc_stats.npz CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2faeb4a20390db85ca6a4f09d609f56da11266084aa0550fe7861de2dee2da4f
3
- size 556
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:847193fc90f4b0313f515ea38a24fd073be09188cfc4764c5dce3f658d4dc117
3
+ size 1660
log.log CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:df6cd6a1404a931ba4869d7eaf6e6a564e98b0a87f04d8edf8f6189aebfdeab4
3
- size 20694
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:10917f8ad8d8962a8c05a46f2b24dcb1180b23665d0767ea5c65c63d9ec09c92
3
+ size 314966
models/graph_updater.py CHANGED
@@ -352,6 +352,11 @@ class GraphUpdater(nn.Module):
352
  self.norm = nn.LayerNorm(node_dim)
353
  self.to(dtype) # Move norm layer and ModuleList container
354
 
 
 
 
 
 
355
  def _build_edge_groups(self) -> Dict[tuple, List[str]]:
356
  """Group relations by (src_type, dst_type) so conv weights can be shared."""
357
  groups: Dict[tuple, List[str]] = defaultdict(list)
 
352
  self.norm = nn.LayerNorm(node_dim)
353
  self.to(dtype) # Move norm layer and ModuleList container
354
 
355
+ # Log params
356
+ total_params = sum(p.numel() for p in self.parameters())
357
+ trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
358
+ print(f"[GraphUpdater] Params: {total_params:,} (Trainable: {trainable_params:,})")
359
+
360
  def _build_edge_groups(self) -> Dict[tuple, List[str]]:
361
  """Group relations by (src_type, dst_type) so conv weights can be shared."""
362
  groups: Dict[tuple, List[str]] = defaultdict(list)
models/helper_encoders.py CHANGED
@@ -33,6 +33,11 @@ class ContextualTimeEncoder(nn.Module):
33
  # Cast the entire module to the specified dtype
34
  self.to(dtype)
35
 
 
 
 
 
 
36
  def _sinusoidal_encode(self, values: torch.Tensor, d_model: int) -> torch.Tensor:
37
  device = values.device
38
  half_dim = d_model // 2
 
33
  # Cast the entire module to the specified dtype
34
  self.to(dtype)
35
 
36
+ # Log params
37
+ total_params = sum(p.numel() for p in self.parameters())
38
+ trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
39
+ print(f"[ContextualTimeEncoder] Params: {total_params:,} (Trainable: {trainable_params:,})")
40
+
41
  def _sinusoidal_encode(self, values: torch.Tensor, d_model: int) -> torch.Tensor:
42
  device = values.device
43
  half_dim = d_model // 2
models/model.py CHANGED
@@ -375,14 +375,6 @@ class Oracle(nn.Module):
375
  # 1a. Encode Tokens
376
  # --- FIXED: Check for a key that still exists ---
377
  if token_encoder_inputs['name_embed_indices'].numel() > 0:
378
- # --- AGGRESSIVE LOGGING ---
379
- print("\n--- [Oracle DynamicEncoder LOG] ---")
380
- print(f"[Oracle LOG] embedding_pool shape: {embedding_pool.shape}")
381
- print(f"[Oracle LOG] name_embed_indices (shape {token_encoder_inputs['name_embed_indices'].shape}):\n{token_encoder_inputs['name_embed_indices']}")
382
- print(f"[Oracle LOG] symbol_embed_indices (shape {token_encoder_inputs['symbol_embed_indices'].shape}):\n{token_encoder_inputs['symbol_embed_indices']}")
383
- print(f"[Oracle LOG] image_embed_indices (shape {token_encoder_inputs['image_embed_indices'].shape}):\n{token_encoder_inputs['image_embed_indices']}")
384
- print("--- [Oracle LOG] Calling F.embedding and TokenEncoder... ---")
385
- # --- END LOGGING ---
386
  # --- NEW: Gather pre-computed embeddings and pass to encoder ---
387
  # --- CRITICAL FIX: Remove keys that are not part of the TokenEncoder's signature ---
388
  encoder_args = token_encoder_inputs.copy()
 
375
  # 1a. Encode Tokens
376
  # --- FIXED: Check for a key that still exists ---
377
  if token_encoder_inputs['name_embed_indices'].numel() > 0:
 
 
 
 
 
 
 
 
378
  # --- NEW: Gather pre-computed embeddings and pass to encoder ---
379
  # --- CRITICAL FIX: Remove keys that are not part of the TokenEncoder's signature ---
380
  encoder_args = token_encoder_inputs.copy()
models/multi_modal_processor.py CHANGED
@@ -11,6 +11,8 @@ import os
11
  import traceback
12
  import numpy as np
13
 
 
 
14
  # Suppress warnings
15
  os.environ["TRANSFORMERS_VERBOSITY"] = "error"
16
 
@@ -22,6 +24,10 @@ class MultiModalEncoder:
22
  """
23
 
24
  def __init__(self, model_id="google/siglip-so400m-patch16-256-i18n", dtype: torch.dtype = torch.bfloat16, device: str = None):
 
 
 
 
25
  self.model_id = model_id
26
  if device:
27
  self.device = device
@@ -72,46 +78,24 @@ class MultiModalEncoder:
72
 
73
  autocast_dtype = self.dtype if self.dtype in [torch.float16, torch.bfloat16] else None
74
 
75
- print(f"\n[MME LOG] ENTERING __call__ for {'TEXT' if is_text else 'IMAGE'} batch of size {len(x)}")
76
- print(f"[MME LOG] Input data preview: {str(x[0])[:100] if is_text else x[0]}")
77
-
78
- with torch.amp.autocast(device_type=self.device, enabled=(self.device == 'cuda' and autocast_dtype is not None), dtype=autocast_dtype):
79
  try:
80
  if is_text:
81
- inputs = self.processor(
82
- text=x,
83
- return_tensors="pt",
84
- padding="max_length",
85
- truncation=True
86
- ).to(self.device)
87
- print(f"[MME LOG] Text processor output shape: {inputs['input_ids'].shape}")
88
  embeddings = self.model.get_text_features(**inputs)
89
  else:
90
- rgb_images = [img.convert("RGB") if img.mode != 'RGB' else img for img in x]
91
- inputs = self.processor(
92
- images=rgb_images,
93
- return_tensors="pt"
94
- ).to(self.device)
95
-
96
- if 'pixel_values' in inputs and inputs['pixel_values'].dtype != self.dtype:
97
- inputs['pixel_values'] = inputs['pixel_values'].to(self.dtype)
98
-
99
  embeddings = self.model.get_image_features(**inputs)
100
 
101
- print(f"[MME LOG] Raw model output embeddings shape: {embeddings.shape}, dtype: {embeddings.dtype}")
102
-
103
- # <<< THIS IS THE FIX. I accidentally removed this.
104
  # Normalize in float32 for numerical stability
105
  embeddings = F.normalize(embeddings.float(), p=2, dim=-1)
106
- print(f"[MME LOG] Normalized embeddings shape: {embeddings.shape}, dtype: {embeddings.dtype}")
107
 
108
  final_embeddings = embeddings.to(self.dtype)
109
- print(f"[MME LOG] Final embeddings shape: {final_embeddings.shape}, dtype: {final_embeddings.dtype}. EXITING __call__.")
110
  return final_embeddings
111
 
112
  except Exception as e:
113
- print(f"❌ [MME LOG] FATAL ERROR during encoding {'text' if is_text else 'images'}: {e}")
114
- traceback.print_exc()
115
  return torch.empty(0, self.embedding_dim).to(self.device)
116
 
117
  # --- Test block (SigLIP) ---
 
11
  import traceback
12
  import numpy as np
13
 
14
+ from transformers.utils import logging as hf_logging
15
+
16
  # Suppress warnings
17
  os.environ["TRANSFORMERS_VERBOSITY"] = "error"
18
 
 
24
  """
25
 
26
  def __init__(self, model_id="google/siglip-so400m-patch16-256-i18n", dtype: torch.dtype = torch.bfloat16, device: str = None):
27
+ # Force silence progress bars locally for this class
28
+ hf_logging.set_verbosity_error()
29
+ hf_logging.disable_progress_bar()
30
+
31
  self.model_id = model_id
32
  if device:
33
  self.device = device
 
78
 
79
  autocast_dtype = self.dtype if self.dtype in [torch.float16, torch.bfloat16] else None
80
 
81
+ with torch.autocast(device_type=self.device, dtype=autocast_dtype, enabled=(autocast_dtype is not None)):
 
 
 
82
  try:
83
  if is_text:
84
+ inputs = self.processor(text=x, return_tensors="pt", padding=True, truncation=True).to(self.device)
 
 
 
 
 
 
85
  embeddings = self.model.get_text_features(**inputs)
86
  else:
87
+ inputs = self.processor(images=x, return_tensors="pt").to(self.device)
 
 
 
 
 
 
 
 
88
  embeddings = self.model.get_image_features(**inputs)
89
 
 
 
 
90
  # Normalize in float32 for numerical stability
91
  embeddings = F.normalize(embeddings.float(), p=2, dim=-1)
 
92
 
93
  final_embeddings = embeddings.to(self.dtype)
 
94
  return final_embeddings
95
 
96
  except Exception as e:
97
+ # Silently fail or log debug only if needed
98
+ # traceback.print_exc()
99
  return torch.empty(0, self.embedding_dim).to(self.device)
100
 
101
  # --- Test block (SigLIP) ---
models/ohlc_embedder.py CHANGED
@@ -71,6 +71,11 @@ class OHLCEmbedder(nn.Module):
71
 
72
  self.to(dtype)
73
 
 
 
 
 
 
74
  def forward(self, x: torch.Tensor, interval_ids: torch.Tensor) -> torch.Tensor:
75
  """
76
  Args:
 
71
 
72
  self.to(dtype)
73
 
74
+ # Log params
75
+ total_params = sum(p.numel() for p in self.parameters())
76
+ trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
77
+ print(f"[OHLCEmbedder] Params: {total_params:,} (Trainable: {trainable_params:,})")
78
+
79
  def forward(self, x: torch.Tensor, interval_ids: torch.Tensor) -> torch.Tensor:
80
  """
81
  Args:
models/token_encoder.py CHANGED
@@ -98,6 +98,11 @@ class TokenEncoder(nn.Module):
98
  device = "cuda" if torch.cuda.is_available() else "cpu"
99
  self.to(device=device, dtype=dtype)
100
 
 
 
 
 
 
101
  def forward(
102
  self,
103
  name_embeds: torch.Tensor,
@@ -123,21 +128,14 @@ class TokenEncoder(nn.Module):
123
  device = name_embeds.device
124
  batch_size = name_embeds.shape[0]
125
 
126
- # 2. Get Protocol embedding (small)
127
- print(f"\n--- [TokenEncoder LOG] ENTERING FORWARD PASS (Batch Size: {batch_size}) ---")
128
- print(f"[TokenEncoder LOG] Input protocol_ids (shape {protocol_ids.shape}):\n{protocol_ids}")
129
- print(f"[TokenEncoder LOG] Protocol Embedding Vocab Size: {self.protocol_embedding.num_embeddings}")
130
-
131
  protocol_ids_long = protocol_ids.to(device, dtype=torch.long)
132
  protocol_emb_raw = self.protocol_embedding(protocol_ids_long) # [B, 64]
133
- print(f"[TokenEncoder LOG] Raw protocol embeddings shape: {protocol_emb_raw.shape}")
134
 
135
  # NEW: Get vanity embedding
136
  vanity_ids_long = is_vanity_flags.to(device, dtype=torch.long)
137
  vanity_emb_raw = self.vanity_embedding(vanity_ids_long) # [B, 32]
138
 
139
  # 3. Project all features to internal_dim (e.g., 1024)
140
- print(f"[TokenEncoder LOG] Projecting features to internal_dim: {self.internal_dim}")
141
  name_emb = self.name_proj(name_embeds)
142
  symbol_emb = self.symbol_proj(symbol_embeds)
143
  image_emb = self.image_proj(image_embeds)
@@ -153,16 +151,8 @@ class TokenEncoder(nn.Module):
153
  vanity_emb, # NEW: Add the vanity embedding to the sequence
154
  ], dim=1)
155
 
156
- print(f"[TokenEncoder LOG] Stacked feature_sequence shape: {feature_sequence.shape}")
157
- print(f" - name_emb shape: {name_emb.shape}")
158
- print(f" - symbol_emb shape: {symbol_emb.shape}")
159
- print(f" - image_emb shape: {image_emb.shape}")
160
- print(f" - protocol_emb shape: {protocol_emb.shape}")
161
- print(f" - vanity_emb shape: {vanity_emb.shape}") # ADDED: Log the new vanity embedding shape
162
-
163
  # 5. Create the padding mask (all False, since we have a fixed number of features for all)
164
  padding_mask = torch.zeros(batch_size, feature_sequence.shape[1], device=device, dtype=torch.bool)
165
- print(f"[TokenEncoder LOG] Created padding_mask of shape: {padding_mask.shape}")
166
 
167
  # 6. Fuse the sequence with the Transformer Encoder
168
  # This returns the [CLS] token output.
@@ -171,12 +161,9 @@ class TokenEncoder(nn.Module):
171
  item_embeds=feature_sequence,
172
  src_key_padding_mask=padding_mask
173
  )
174
- print(f"[TokenEncoder LOG] Fused embedding shape after transformer: {fused_embedding.shape}")
175
 
176
  # 7. Project to the final output dimension
177
  # Shape: [B, output_dim]
178
  token_vibe_embedding = self.final_projection(fused_embedding)
179
- print(f"[TokenEncoder LOG] Final token_vibe_embedding shape: {token_vibe_embedding.shape}")
180
- print(f"--- [TokenEncoder LOG] EXITING FORWARD PASS ---\n")
181
 
182
  return token_vibe_embedding
 
98
  device = "cuda" if torch.cuda.is_available() else "cpu"
99
  self.to(device=device, dtype=dtype)
100
 
101
+ # Log params
102
+ total_params = sum(p.numel() for p in self.parameters())
103
+ trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
104
+ print(f"[TokenEncoder] Params: {total_params:,} (Trainable: {trainable_params:,})")
105
+
106
  def forward(
107
  self,
108
  name_embeds: torch.Tensor,
 
128
  device = name_embeds.device
129
  batch_size = name_embeds.shape[0]
130
 
 
 
 
 
 
131
  protocol_ids_long = protocol_ids.to(device, dtype=torch.long)
132
  protocol_emb_raw = self.protocol_embedding(protocol_ids_long) # [B, 64]
 
133
 
134
  # NEW: Get vanity embedding
135
  vanity_ids_long = is_vanity_flags.to(device, dtype=torch.long)
136
  vanity_emb_raw = self.vanity_embedding(vanity_ids_long) # [B, 32]
137
 
138
  # 3. Project all features to internal_dim (e.g., 1024)
 
139
  name_emb = self.name_proj(name_embeds)
140
  symbol_emb = self.symbol_proj(symbol_embeds)
141
  image_emb = self.image_proj(image_embeds)
 
151
  vanity_emb, # NEW: Add the vanity embedding to the sequence
152
  ], dim=1)
153
 
 
 
 
 
 
 
 
154
  # 5. Create the padding mask (all False, since we have a fixed number of features for all)
155
  padding_mask = torch.zeros(batch_size, feature_sequence.shape[1], device=device, dtype=torch.bool)
 
156
 
157
  # 6. Fuse the sequence with the Transformer Encoder
158
  # This returns the [CLS] token output.
 
161
  item_embeds=feature_sequence,
162
  src_key_padding_mask=padding_mask
163
  )
 
164
 
165
  # 7. Project to the final output dimension
166
  # Shape: [B, output_dim]
167
  token_vibe_embedding = self.final_projection(fused_embedding)
 
 
168
 
169
  return token_vibe_embedding
models/wallet_encoder.py CHANGED
@@ -95,6 +95,28 @@ class WalletEncoder(nn.Module):
95
  )
96
  self.to(dtype)
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  def _build_mlp(self, in_dim, out_dim):
99
  return nn.Sequential(
100
  nn.Linear(in_dim, out_dim * 2),
 
95
  )
96
  self.to(dtype)
97
 
98
+ # Log params (excluding the shared encoder which might be huge and already logged)
99
+ # Note: self.encoder is external, but if we include it here, it will double count.
100
+ # Ideally we only log *this* module's params.
101
+ my_params = sum(p.numel() for p in self.parameters())
102
+ # To avoid double counting the external encoder if it's a submodule (it is assigned to self.encoder)
103
+ # But wait, self.encoder IS a submodule.
104
+ # We should subtract it if we just want "WalletEncoder specific" params, or clarify.
105
+ # Let's verify if self.encoder params are included in self.parameters().
106
+ # Yes they are because `self.encoder = encoder` assigns it.
107
+ # Actually `encoder` is passed in. If `MultiModalEncoder` is an `nn.Module` (it is NOT), then it would be registered.
108
+ # `MultiModalEncoder` is a wrapper class, NOT an `nn.Module`.
109
+ # However, it contains `self.model` which is an `nn.Module`.
110
+ # But `WalletEncoder` stores `self.encoder = encoder`.
111
+ # Since `MultiModalEncoder` is not an `nn.Module`, `self.encoder` is just a standard attribute.
112
+ # So `self.parameters()` of `WalletEncoder` will NOT include `MultiModalEncoder` params.
113
+ # EXCEPT... we don't know if `MultiModalEncoder` subclassed `nn.Module`.
114
+ # I checked earlier: `class MultiModalEncoder:` -> No `nn.Module`.
115
+ # So we are safe. `self.parameters()` will only be the MLPs and SetEncoders defined in WalletEncoder.
116
+
117
+ trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
118
+ print(f"[WalletEncoder] Params: {my_params:,} (Trainable: {trainable_params:,})")
119
+
120
  def _build_mlp(self, in_dim, out_dim):
121
  return nn.Sequential(
122
  nn.Linear(in_dim, out_dim * 2),
scripts/cache_dataset.py CHANGED
@@ -2,12 +2,20 @@
2
  import os
3
  import sys
4
  import argparse
 
5
  import datetime
6
  import torch
7
  import json
8
  from pathlib import Path
9
  from tqdm import tqdm
10
  from dotenv import load_dotenv
 
 
 
 
 
 
 
11
 
12
  # Add parent directory to path to import modules
13
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@@ -19,9 +27,84 @@ from scripts.analyze_distribution import get_return_class_map
19
  from clickhouse_driver import Client as ClickHouseClient
20
  from neo4j import GraphDatabase
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def main():
23
  load_dotenv()
24
 
 
 
 
 
 
 
 
 
25
  parser = argparse.ArgumentParser(description="Cache dataset samples for training.")
26
  parser.add_argument("--output_dir", type=str, default="data/cache", help="Directory to save cached samples")
27
  parser.add_argument("--max_samples", type=int, default=None, help="Maximum number of samples to generate")
@@ -50,6 +133,9 @@ def main():
50
  neo4j_driver = GraphDatabase.driver(args.neo4j_uri, auth=(args.neo4j_user, args.neo4j_password))
51
 
52
  try:
 
 
 
53
  # --- 2. Initialize DataFetcher and OracleDataset ---
54
  data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
55
 
 
2
  import os
3
  import sys
4
  import argparse
5
+ import numpy as np
6
  import datetime
7
  import torch
8
  import json
9
  from pathlib import Path
10
  from tqdm import tqdm
11
  from dotenv import load_dotenv
12
+ import huggingface_hub
13
+ import logging
14
+
15
+ # Suppress noisy libraries
16
+ logging.getLogger("httpx").setLevel(logging.WARNING)
17
+ logging.getLogger("transformers").setLevel(logging.ERROR)
18
+ logging.getLogger("huggingface_hub").setLevel(logging.WARNING)
19
 
20
  # Add parent directory to path to import modules
21
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
 
27
  from clickhouse_driver import Client as ClickHouseClient
28
  from neo4j import GraphDatabase
29
 
30
+ def compute_save_ohlc_stats(client: ClickHouseClient, output_path: str):
31
+ """
32
+ Computes global mean/std for price/volume from ClickHouse and saves to .npz
33
+ This allows the dataset loader to normalize inputs correctly.
34
+ """
35
+ print(f"INFO: Computing OHLC stats (mean/std) from ClickHouse...")
36
+
37
+ # Query matching preprocess_distribution.py logic
38
+ # We use hardcoded min_price/vol filters to avoid skewing stats with dust
39
+ min_price = 0.0
40
+ min_vol = 0.0
41
+
42
+ query = """
43
+ SELECT
44
+ AVG(t.price_usd) AS mean_price_usd,
45
+ stddevPop(t.price_usd) AS std_price_usd,
46
+ AVG(t.price) AS mean_price_native,
47
+ stddevPop(t.price) AS std_price_native,
48
+ AVG(t.total_usd) AS mean_trade_value_usd,
49
+ stddevPop(t.total_usd) AS std_trade_value_usd
50
+ FROM trades AS t
51
+ WHERE t.price_usd > %(min_price)s AND t.total_usd > %(min_vol)s
52
+ """
53
+
54
+ params = {"min_price": min_price, "min_vol": min_vol}
55
+
56
+ try:
57
+ result = client.execute(query, params=params)
58
+ if not result or not result[0]:
59
+ print("WARNING: Stats query returned no rows. Using default identity stats.")
60
+ stats = {
61
+ "mean_price_usd": 0.0, "std_price_usd": 1.0,
62
+ "mean_price_native": 0.0, "std_price_native": 1.0,
63
+ "mean_trade_value_usd": 0.0, "std_trade_value_usd": 1.0,
64
+ }
65
+ else:
66
+ row = result[0]
67
+ # Handle potential None values if DB is empty
68
+ def safe_float(x, default=0.0):
69
+ return float(x) if x is not None else default
70
+
71
+ def safe_std(x):
72
+ val = safe_float(x, 1.0)
73
+ return val if val > 1e-9 else 1.0
74
+
75
+ stats = {
76
+ "mean_price_usd": safe_float(row[0]),
77
+ "std_price_usd": safe_std(row[1]),
78
+ "mean_price_native": safe_float(row[2]),
79
+ "std_price_native": safe_std(row[3]),
80
+ "mean_trade_value_usd": safe_float(row[4]),
81
+ "std_trade_value_usd": safe_std(row[5]),
82
+ }
83
+
84
+ # Save to NPZ
85
+ out_p = Path(output_path)
86
+ out_p.parent.mkdir(parents=True, exist_ok=True)
87
+ np.savez(out_p, **stats)
88
+
89
+ print(f"INFO: Saved OHLC stats to {out_p}")
90
+ for k, v in stats.items():
91
+ print(f" {k}: {v:.4f}")
92
+
93
+ except Exception as e:
94
+ print(f"ERROR: Failed to compute OHLC stats: {e}")
95
+ # Don't crash, let it try to proceed (though dataset might complain if file missing)
96
+
97
  def main():
98
  load_dotenv()
99
 
100
+ # Explicit Login
101
+ hf_token = os.getenv("HF_TOKEN")
102
+ if hf_token:
103
+ print(f"INFO: Logging in to Hugging Face with token starting with: {hf_token[:4]}...")
104
+ huggingface_hub.login(token=hf_token)
105
+ else:
106
+ print("WARNING: HF_TOKEN not found in environment.")
107
+
108
  parser = argparse.ArgumentParser(description="Cache dataset samples for training.")
109
  parser.add_argument("--output_dir", type=str, default="data/cache", help="Directory to save cached samples")
110
  parser.add_argument("--max_samples", type=int, default=None, help="Maximum number of samples to generate")
 
133
  neo4j_driver = GraphDatabase.driver(args.neo4j_uri, auth=(args.neo4j_user, args.neo4j_password))
134
 
135
  try:
136
+ # --- 1. Compute OHLC Stats (Global) ---
137
+ compute_save_ohlc_stats(clickhouse_client, args.ohlc_stats_path)
138
+
139
  # --- 2. Initialize DataFetcher and OracleDataset ---
140
  data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
141
 
train.py CHANGED
@@ -15,6 +15,14 @@ resolved_tmp = str(_DEFAULT_TMP.resolve())
15
  for key in ("TMPDIR", "TMP", "TEMP"):
16
  os.environ.setdefault(key, resolved_tmp)
17
 
 
 
 
 
 
 
 
 
18
  try:
19
  mp.set_start_method('spawn', force=True)
20
  except RuntimeError:
@@ -100,8 +108,8 @@ def quantile_pinball_loss(preds: torch.Tensor,
100
  # Preds shape: [B, Horizons * Quantiles]
101
  # Logic assumes interleaved outputs or consistent flattening.
102
  pred_slice = preds[:, idx::num_quantiles]
103
- target_slice = targets[:, idx::num_quantiles]
104
- mask_slice = mask[:, idx::num_quantiles]
105
 
106
  diff = target_slice - pred_slice
107
  pinball = torch.maximum((q - 1.0) * diff, q * diff)
@@ -118,6 +126,44 @@ def filtered_collate(collator: MemecoinCollator,
118
  return None
119
  return collator(batch)
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  def parse_args() -> argparse.Namespace:
123
  parser = argparse.ArgumentParser(description="Train the Oracle quantile model.")
@@ -209,6 +255,25 @@ def main() -> None:
209
  max_seq_len = args.max_seq_len
210
  max_seq_len = args.max_seq_len
211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  logger.info(f"Initializing Encoders with dtype={init_dtype}...")
213
 
214
  # Encoders
@@ -423,6 +488,11 @@ def main() -> None:
423
  # Logging
424
  if accelerator.sync_gradients:
425
  total_steps += 1
 
 
 
 
 
426
  current_loss = loss.item()
427
  epoch_loss += current_loss
428
  valid_batches += 1
 
15
  for key in ("TMPDIR", "TMP", "TEMP"):
16
  os.environ.setdefault(key, resolved_tmp)
17
 
18
+ # --- Environment & Logging Setup ---
19
+ from dotenv import load_dotenv
20
+ import huggingface_hub
21
+ from transformers.utils import logging as hf_logging
22
+
23
+ # Load .env explicitly (benign at global scope but moving heavy lifting to main)
24
+ load_dotenv()
25
+
26
  try:
27
  mp.set_start_method('spawn', force=True)
28
  except RuntimeError:
 
108
  # Preds shape: [B, Horizons * Quantiles]
109
  # Logic assumes interleaved outputs or consistent flattening.
110
  pred_slice = preds[:, idx::num_quantiles]
111
+ target_slice = targets
112
+ mask_slice = mask
113
 
114
  diff = target_slice - pred_slice
115
  pinball = torch.maximum((q - 1.0) * diff, q * diff)
 
126
  return None
127
  return collator(batch)
128
 
129
+ def log_debug_batch_context(batch: Dict[str, Any], logger: logging.Logger, step: int):
130
+ """
131
+ Logs decoded event sequence and labels for the first sample in the batch.
132
+ Use this to verify what the model is actually seeing.
133
+ """
134
+ if not logger.isEnabledFor(logging.INFO): return
135
+
136
+ try:
137
+ # Only look at the first sample in batch
138
+ idx = 0
139
+ event_ids = batch['event_type_ids'][idx].cpu() # [L]
140
+ labels = batch['labels'][idx].cpu() # [Horizons * Quantiles]
141
+ mask = batch['labels_mask'][idx].cpu()
142
+
143
+ # Decode events
144
+ events = []
145
+ for eid in event_ids:
146
+ eid_val = eid.item()
147
+ if eid_val == 0: continue # Skip PAD
148
+ # Get name from vocab
149
+ name = vocab.ID_TO_EVENT.get(eid_val, f"UNK_{eid_val}")
150
+ events.append(name)
151
+
152
+ logger.info(f"\n--- [Step {step}] Batch Input Preview (Sample 0) ---")
153
+ # Show a slice of events (e.g. last 50)
154
+ tail_len = 50
155
+ context_str = ", ".join(events[-tail_len:])
156
+ logger.info(f"Event Stream (Last {tail_len} of {len(events)}): [{context_str}]")
157
+
158
+ # Show Labels
159
+ # Assuming flattened labels [H*Q]
160
+ logger.info(f"Labels (First 10): {labels[:10].tolist()}")
161
+ logger.info(f"Masks (First 10): {mask[:10].tolist()}")
162
+ logger.info("----------------------------------------------------\n")
163
+
164
+ except Exception as e:
165
+ logger.warning(f"Failed to log batch context: {e}")
166
+
167
 
168
  def parse_args() -> argparse.Namespace:
169
  parser = argparse.ArgumentParser(description="Train the Oracle quantile model.")
 
255
  max_seq_len = args.max_seq_len
256
  max_seq_len = args.max_seq_len
257
 
258
+
259
+ # --- Environment & Logging Setup ---
260
+ # Load .env explicitly
261
+ load_dotenv()
262
+
263
+ # Suppress noisy libraries
264
+ logging.getLogger("httpx").setLevel(logging.WARNING)
265
+ logging.getLogger("transformers").setLevel(logging.ERROR)
266
+ logging.getLogger("huggingface_hub").setLevel(logging.WARNING)
267
+ hf_logging.disable_progress_bar()
268
+
269
+ # Explicit Login
270
+ hf_token = os.getenv("HF_TOKEN")
271
+ if hf_token:
272
+ print(f"Logging in to Hugging Face with token starting with: {hf_token[:4]}...")
273
+ huggingface_hub.login(token=hf_token)
274
+ else:
275
+ print("WARNING: HF_TOKEN not found in environment.")
276
+
277
  logger.info(f"Initializing Encoders with dtype={init_dtype}...")
278
 
279
  # Encoders
 
488
  # Logging
489
  if accelerator.sync_gradients:
490
  total_steps += 1
491
+
492
+ # --- NEW: Debug Log Batch Context ---
493
+ if total_steps % log_every == 0 and accelerator.is_main_process:
494
+ log_debug_batch_context(batch, logger, total_steps)
495
+
496
  current_loss = loss.item()
497
  epoch_loss += current_loss
498
  valid_batches += 1