zirobtc commited on
Commit
7901ae2
·
1 Parent(s): 98b813a

Upload data/data_loader.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. data/data_loader.py +15 -2
data/data_loader.py CHANGED
@@ -681,14 +681,18 @@ class OracleDataset(Dataset):
681
  profile['deployed_tokens_avg_peak_mc_usd'] = torch.mean(torch.tensor(peak_mcs)).item() if peak_mcs else 0.0
682
  profile['deployed_tokens_median_peak_mc_usd'] = torch.median(torch.tensor(peak_mcs)).item() if peak_mcs else 0.0
683
 
684
- def _process_wallet_data(self, wallet_addresses: List[str], token_data: Dict[str, Any], pooler: EmbeddingPooler, T_cutoff: datetime.datetime,
685
  profiles_override: Optional[Dict] = None, socials_override: Optional[Dict] = None, holdings_override: Optional[Dict] = None) -> tuple[Dict[str, Dict[str, Any]], Dict[str, Dict[str, Any]]]:
686
  """
687
  Fetches or uses cached profile, social, and holdings data.
688
  """
 
 
 
689
  if not wallet_addresses:
690
  return {}, token_data
691
 
 
692
  if profiles_override is not None and socials_override is not None:
693
  profiles, socials = profiles_override, socials_override
694
  holdings = holdings_override if holdings_override is not None else {}
@@ -698,6 +702,7 @@ class OracleDataset(Dataset):
698
  holdings = self.fetcher.fetch_wallet_holdings(wallet_addresses, T_cutoff)
699
  else:
700
  profiles, socials, holdings = {}, {}, {}
 
701
 
702
  valid_wallets = [addr for addr in wallet_addresses if addr in profiles]
703
  if not valid_wallets:
@@ -710,11 +715,19 @@ class OracleDataset(Dataset):
710
  for holding_item in holdings.get(wallet_addr, []):
711
  if 'mint_address' in holding_item:
712
  all_holding_mints.add(holding_item['mint_address'])
713
-
 
714
  # --- Process all discovered tokens with point-in-time logic ---
 
715
  processed_new_tokens = self._process_token_data(list(all_holding_mints), pooler, T_cutoff)
 
716
  all_token_data = {**token_data, **(processed_new_tokens or {})}
717
 
 
 
 
 
 
718
  # --- Calculate deployed token stats using point-in-time logic ---
719
  self._calculate_deployed_token_stats(profiles, T_cutoff)
720
 
 
681
  profile['deployed_tokens_avg_peak_mc_usd'] = torch.mean(torch.tensor(peak_mcs)).item() if peak_mcs else 0.0
682
  profile['deployed_tokens_median_peak_mc_usd'] = torch.median(torch.tensor(peak_mcs)).item() if peak_mcs else 0.0
683
 
684
+ def _process_wallet_data(self, wallet_addresses: List[str], token_data: Dict[str, Any], pooler: EmbeddingPooler, T_cutoff: datetime.datetime,
685
  profiles_override: Optional[Dict] = None, socials_override: Optional[Dict] = None, holdings_override: Optional[Dict] = None) -> tuple[Dict[str, Dict[str, Any]], Dict[str, Dict[str, Any]]]:
686
  """
687
  Fetches or uses cached profile, social, and holdings data.
688
  """
689
+ import time as _time
690
+ _wd_timings = {}
691
+
692
  if not wallet_addresses:
693
  return {}, token_data
694
 
695
+ _t0 = _time.perf_counter()
696
  if profiles_override is not None and socials_override is not None:
697
  profiles, socials = profiles_override, socials_override
698
  holdings = holdings_override if holdings_override is not None else {}
 
702
  holdings = self.fetcher.fetch_wallet_holdings(wallet_addresses, T_cutoff)
703
  else:
704
  profiles, socials, holdings = {}, {}, {}
705
+ _wd_timings['db_fetch'] = _time.perf_counter() - _t0
706
 
707
  valid_wallets = [addr for addr in wallet_addresses if addr in profiles]
708
  if not valid_wallets:
 
715
  for holding_item in holdings.get(wallet_addr, []):
716
  if 'mint_address' in holding_item:
717
  all_holding_mints.add(holding_item['mint_address'])
718
+ _wd_timings['num_holding_tokens'] = len(all_holding_mints)
719
+
720
  # --- Process all discovered tokens with point-in-time logic ---
721
+ _t0 = _time.perf_counter()
722
  processed_new_tokens = self._process_token_data(list(all_holding_mints), pooler, T_cutoff)
723
+ _wd_timings['holding_token_processing'] = _time.perf_counter() - _t0
724
  all_token_data = {**token_data, **(processed_new_tokens or {})}
725
 
726
+ # Print wallet_data sub-timings
727
+ print(f" [WALLET_DATA] db_fetch: {_wd_timings['db_fetch']*1000:.1f}ms, "
728
+ f"holding_tokens: {_wd_timings['num_holding_tokens']}, "
729
+ f"holding_token_processing: {_wd_timings['holding_token_processing']*1000:.1f}ms")
730
+
731
  # --- Calculate deployed token stats using point-in-time logic ---
732
  self._calculate_deployed_token_stats(profiles, T_cutoff)
733