Upload data/data_loader.py with huggingface_hub
Browse files- data/data_loader.py +15 -2
data/data_loader.py
CHANGED
|
@@ -681,14 +681,18 @@ class OracleDataset(Dataset):
|
|
| 681 |
profile['deployed_tokens_avg_peak_mc_usd'] = torch.mean(torch.tensor(peak_mcs)).item() if peak_mcs else 0.0
|
| 682 |
profile['deployed_tokens_median_peak_mc_usd'] = torch.median(torch.tensor(peak_mcs)).item() if peak_mcs else 0.0
|
| 683 |
|
| 684 |
-
def _process_wallet_data(self, wallet_addresses: List[str], token_data: Dict[str, Any], pooler: EmbeddingPooler, T_cutoff: datetime.datetime,
|
| 685 |
profiles_override: Optional[Dict] = None, socials_override: Optional[Dict] = None, holdings_override: Optional[Dict] = None) -> tuple[Dict[str, Dict[str, Any]], Dict[str, Dict[str, Any]]]:
|
| 686 |
"""
|
| 687 |
Fetches or uses cached profile, social, and holdings data.
|
| 688 |
"""
|
|
|
|
|
|
|
|
|
|
| 689 |
if not wallet_addresses:
|
| 690 |
return {}, token_data
|
| 691 |
|
|
|
|
| 692 |
if profiles_override is not None and socials_override is not None:
|
| 693 |
profiles, socials = profiles_override, socials_override
|
| 694 |
holdings = holdings_override if holdings_override is not None else {}
|
|
@@ -698,6 +702,7 @@ class OracleDataset(Dataset):
|
|
| 698 |
holdings = self.fetcher.fetch_wallet_holdings(wallet_addresses, T_cutoff)
|
| 699 |
else:
|
| 700 |
profiles, socials, holdings = {}, {}, {}
|
|
|
|
| 701 |
|
| 702 |
valid_wallets = [addr for addr in wallet_addresses if addr in profiles]
|
| 703 |
if not valid_wallets:
|
|
@@ -710,11 +715,19 @@ class OracleDataset(Dataset):
|
|
| 710 |
for holding_item in holdings.get(wallet_addr, []):
|
| 711 |
if 'mint_address' in holding_item:
|
| 712 |
all_holding_mints.add(holding_item['mint_address'])
|
| 713 |
-
|
|
|
|
| 714 |
# --- Process all discovered tokens with point-in-time logic ---
|
|
|
|
| 715 |
processed_new_tokens = self._process_token_data(list(all_holding_mints), pooler, T_cutoff)
|
|
|
|
| 716 |
all_token_data = {**token_data, **(processed_new_tokens or {})}
|
| 717 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 718 |
# --- Calculate deployed token stats using point-in-time logic ---
|
| 719 |
self._calculate_deployed_token_stats(profiles, T_cutoff)
|
| 720 |
|
|
|
|
| 681 |
profile['deployed_tokens_avg_peak_mc_usd'] = torch.mean(torch.tensor(peak_mcs)).item() if peak_mcs else 0.0
|
| 682 |
profile['deployed_tokens_median_peak_mc_usd'] = torch.median(torch.tensor(peak_mcs)).item() if peak_mcs else 0.0
|
| 683 |
|
| 684 |
+
def _process_wallet_data(self, wallet_addresses: List[str], token_data: Dict[str, Any], pooler: EmbeddingPooler, T_cutoff: datetime.datetime,
|
| 685 |
profiles_override: Optional[Dict] = None, socials_override: Optional[Dict] = None, holdings_override: Optional[Dict] = None) -> tuple[Dict[str, Dict[str, Any]], Dict[str, Dict[str, Any]]]:
|
| 686 |
"""
|
| 687 |
Fetches or uses cached profile, social, and holdings data.
|
| 688 |
"""
|
| 689 |
+
import time as _time
|
| 690 |
+
_wd_timings = {}
|
| 691 |
+
|
| 692 |
if not wallet_addresses:
|
| 693 |
return {}, token_data
|
| 694 |
|
| 695 |
+
_t0 = _time.perf_counter()
|
| 696 |
if profiles_override is not None and socials_override is not None:
|
| 697 |
profiles, socials = profiles_override, socials_override
|
| 698 |
holdings = holdings_override if holdings_override is not None else {}
|
|
|
|
| 702 |
holdings = self.fetcher.fetch_wallet_holdings(wallet_addresses, T_cutoff)
|
| 703 |
else:
|
| 704 |
profiles, socials, holdings = {}, {}, {}
|
| 705 |
+
_wd_timings['db_fetch'] = _time.perf_counter() - _t0
|
| 706 |
|
| 707 |
valid_wallets = [addr for addr in wallet_addresses if addr in profiles]
|
| 708 |
if not valid_wallets:
|
|
|
|
| 715 |
for holding_item in holdings.get(wallet_addr, []):
|
| 716 |
if 'mint_address' in holding_item:
|
| 717 |
all_holding_mints.add(holding_item['mint_address'])
|
| 718 |
+
_wd_timings['num_holding_tokens'] = len(all_holding_mints)
|
| 719 |
+
|
| 720 |
# --- Process all discovered tokens with point-in-time logic ---
|
| 721 |
+
_t0 = _time.perf_counter()
|
| 722 |
processed_new_tokens = self._process_token_data(list(all_holding_mints), pooler, T_cutoff)
|
| 723 |
+
_wd_timings['holding_token_processing'] = _time.perf_counter() - _t0
|
| 724 |
all_token_data = {**token_data, **(processed_new_tokens or {})}
|
| 725 |
|
| 726 |
+
# Print wallet_data sub-timings
|
| 727 |
+
print(f" [WALLET_DATA] db_fetch: {_wd_timings['db_fetch']*1000:.1f}ms, "
|
| 728 |
+
f"holding_tokens: {_wd_timings['num_holding_tokens']}, "
|
| 729 |
+
f"holding_token_processing: {_wd_timings['holding_token_processing']*1000:.1f}ms")
|
| 730 |
+
|
| 731 |
# --- Calculate deployed token stats using point-in-time logic ---
|
| 732 |
self._calculate_deployed_token_stats(profiles, T_cutoff)
|
| 733 |
|