Upload folder using huggingface_hub
Browse files- data/data_collator.py +3 -19
- data/data_loader.py +23 -62
- data/ohlc_stats.npz +2 -2
- log.log +2 -2
- models/graph_updater.py +5 -0
- models/helper_encoders.py +5 -0
- models/model.py +0 -8
- models/multi_modal_processor.py +11 -27
- models/ohlc_embedder.py +5 -0
- models/token_encoder.py +5 -18
- models/wallet_encoder.py +22 -0
- scripts/cache_dataset.py +86 -0
- train.py +72 -2
data/data_collator.py
CHANGED
|
@@ -282,23 +282,13 @@ class MemecoinCollator:
|
|
| 282 |
wallet_addr_to_batch_idx = {feat.get('profile', {}).get('wallet_address', f'__error_{i}'): i+1 for i, feat in enumerate(wallet_list_data)}
|
| 283 |
token_addr_to_batch_idx = {feat.get('address', f'__error_{i}'): i+1 for i, feat in enumerate(token_list_data)}
|
| 284 |
|
|
|
|
|
|
|
| 285 |
# Collate Static Raw Features (Tokens, Wallets, Graph)
|
| 286 |
token_encoder_inputs = self._collate_features_for_encoder(token_list_data, ['name'], self.device, "token")
|
| 287 |
wallet_encoder_inputs = self._collate_features_for_encoder(wallet_list_data, ['profile'], self.device, "wallet")
|
| 288 |
graph_updater_links = self._collate_graph_links(batch, wallet_addr_to_batch_idx, token_addr_to_batch_idx)
|
| 289 |
|
| 290 |
-
# --- Logging ---
|
| 291 |
-
pool_contents = batch_wide_pooler.get_all_items()
|
| 292 |
-
print(f"\n[DataCollator: Final Embedding Pool] ({len(pool_contents)} items):")
|
| 293 |
-
if pool_contents:
|
| 294 |
-
for item_data in pool_contents:
|
| 295 |
-
sample_item = item_data['item']
|
| 296 |
-
sample_type = "Image" if isinstance(sample_item, Image.Image) else "Text"
|
| 297 |
-
content_preview = str(sample_item)
|
| 298 |
-
if sample_type == "Text" and len(content_preview) > 100:
|
| 299 |
-
content_preview = content_preview[:97] + "..."
|
| 300 |
-
print(f" - Item (Original Idx {item_data['idx']}): Type='{sample_type}', Content='{content_preview}'")
|
| 301 |
-
|
| 302 |
# --- 5. Prepare Sequence Tensors & Collect Dynamic Data (OHLC) ---
|
| 303 |
B = batch_size
|
| 304 |
L = max_len
|
|
@@ -417,13 +407,7 @@ class MemecoinCollator:
|
|
| 417 |
|
| 418 |
# Loop through sequences to populate tensors and collect chart events
|
| 419 |
for i, seq in enumerate(all_event_sequences):
|
| 420 |
-
|
| 421 |
-
if i == 0:
|
| 422 |
-
context_names = [e.get('event_type', 'Unknown') for e in seq]
|
| 423 |
-
print("\n[DataCollator] Context Preview (Event Sequence Names):")
|
| 424 |
-
print(context_names)
|
| 425 |
-
print(f"[DataCollator] Sequence Length: {len(context_names)}\n")
|
| 426 |
-
|
| 427 |
seq_len = len(seq)
|
| 428 |
if seq_len == 0: continue
|
| 429 |
attention_mask[i, :seq_len] = 1
|
|
|
|
| 282 |
wallet_addr_to_batch_idx = {feat.get('profile', {}).get('wallet_address', f'__error_{i}'): i+1 for i, feat in enumerate(wallet_list_data)}
|
| 283 |
token_addr_to_batch_idx = {feat.get('address', f'__error_{i}'): i+1 for i, feat in enumerate(token_list_data)}
|
| 284 |
|
| 285 |
+
# Collate Static Raw Features (Tokens, Wallets, Graph)
|
| 286 |
+
token_encoder_inputs = self._collate_features_for_encoder(token_list_data, ['name'], self.device, "token")
|
| 287 |
# Collate Static Raw Features (Tokens, Wallets, Graph)
|
| 288 |
token_encoder_inputs = self._collate_features_for_encoder(token_list_data, ['name'], self.device, "token")
|
| 289 |
wallet_encoder_inputs = self._collate_features_for_encoder(wallet_list_data, ['profile'], self.device, "wallet")
|
| 290 |
graph_updater_links = self._collate_graph_links(batch, wallet_addr_to_batch_idx, token_addr_to_batch_idx)
|
| 291 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
# --- 5. Prepare Sequence Tensors & Collect Dynamic Data (OHLC) ---
|
| 293 |
B = batch_size
|
| 294 |
L = max_len
|
|
|
|
| 407 |
|
| 408 |
# Loop through sequences to populate tensors and collect chart events
|
| 409 |
for i, seq in enumerate(all_event_sequences):
|
| 410 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 411 |
seq_len = len(seq)
|
| 412 |
if seq_len == 0: continue
|
| 413 |
attention_mask[i, :seq_len] = 1
|
data/data_loader.py
CHANGED
|
@@ -273,7 +273,6 @@ class OracleDataset(Dataset):
|
|
| 273 |
|
| 274 |
ts_list = [int(entry[0]) for entry in price_series]
|
| 275 |
price_list = [float(entry[1]) for entry in price_series]
|
| 276 |
-
print(f"[DEBUG-TRACE-LABELS] ts_list len: {len(ts_list)}, price_list len: {len(price_list)}")
|
| 277 |
if not ts_list:
|
| 278 |
return torch.zeros(self.num_outputs), torch.zeros(self.num_outputs), []
|
| 279 |
|
|
@@ -531,7 +530,6 @@ class OracleDataset(Dataset):
|
|
| 531 |
profiles, socials = profiles_override, socials_override
|
| 532 |
holdings = holdings_override if holdings_override is not None else {}
|
| 533 |
else:
|
| 534 |
-
print(f"INFO: Processing wallet data for {len(wallet_addresses)} unique wallets...")
|
| 535 |
if self.fetcher:
|
| 536 |
profiles, socials = self.fetcher.fetch_wallet_profiles_and_socials(wallet_addresses, T_cutoff)
|
| 537 |
holdings = self.fetcher.fetch_wallet_holdings(wallet_addresses, T_cutoff)
|
|
@@ -539,40 +537,29 @@ class OracleDataset(Dataset):
|
|
| 539 |
profiles, socials, holdings = {}, {}, {}
|
| 540 |
|
| 541 |
valid_wallets = [addr for addr in wallet_addresses if addr in profiles]
|
| 542 |
-
dropped_wallets = set(wallet_addresses) - set(valid_wallets)
|
| 543 |
-
if dropped_wallets:
|
| 544 |
-
print(f"INFO: Skipping {len(dropped_wallets)} wallets with no profile before cutoff.")
|
| 545 |
if not valid_wallets:
|
| 546 |
-
print("INFO: All wallets were graph-only or appeared after cutoff; skipping wallet processing for this token.")
|
| 547 |
return {}, token_data
|
| 548 |
wallet_addresses = valid_wallets
|
| 549 |
|
| 550 |
-
# ---
|
| 551 |
all_holding_mints = set()
|
| 552 |
for wallet_addr in wallet_addresses:
|
| 553 |
for holding_item in holdings.get(wallet_addr, []):
|
| 554 |
if 'mint_address' in holding_item:
|
| 555 |
all_holding_mints.add(holding_item['mint_address'])
|
| 556 |
|
| 557 |
-
# ---
|
| 558 |
-
# 1. Fetch raw data for all newly found tokens from holdings.
|
| 559 |
-
# 2. Process this raw data to get embedding indices and add to the pooler.
|
| 560 |
-
# Note: _process_token_data is designed to take a list and return a dict.
|
| 561 |
-
# We pass the addresses and let it handle the fetching and processing internally.
|
| 562 |
processed_new_tokens = self._process_token_data(list(all_holding_mints), pooler, T_cutoff)
|
| 563 |
-
# 3. Merge the fully processed new tokens with the existing main token data.
|
| 564 |
all_token_data = {**token_data, **(processed_new_tokens or {})}
|
| 565 |
|
| 566 |
-
# ---
|
| 567 |
self._calculate_deployed_token_stats(profiles, T_cutoff)
|
| 568 |
|
| 569 |
# --- Assemble the final wallet dictionary ---
|
| 570 |
-
# This structure is exactly what the WalletEncoder expects.
|
| 571 |
final_wallets = {}
|
| 572 |
for addr in wallet_addresses:
|
| 573 |
|
| 574 |
# --- Define all expected numerical keys for a profile ---
|
| 575 |
-
# This prevents KeyErrors if the DB returns a partial profile.
|
| 576 |
expected_profile_keys = [
|
| 577 |
'age', 'deployed_tokens_count', 'deployed_tokens_migrated_pct',
|
| 578 |
'deployed_tokens_avg_lifetime_sec', 'deployed_tokens_avg_peak_mc_usd',
|
|
@@ -585,54 +572,39 @@ class OracleDataset(Dataset):
|
|
| 585 |
'stats_1d_total_fee', 'stats_1d_winrate', 'stats_1d_tokens_traded',
|
| 586 |
'stats_7d_realized_profit_sol', 'stats_7d_realized_profit_pnl', 'stats_7d_buy_count', 'stats_7d_sell_count', 'stats_7d_transfer_in_count', 'stats_7d_transfer_out_count', 'stats_7d_avg_holding_period', 'stats_7d_total_bought_cost_sol', 'stats_7d_total_sold_income_sol', 'stats_7d_total_fee', 'stats_7d_winrate', 'stats_7d_tokens_traded'
|
| 587 |
]
|
| 588 |
-
|
| 589 |
-
# --- NEW: If a wallet profile doesn't exist in the DB, skip it entirely. ---
|
| 590 |
-
# This removes the old logic that created a placeholder profile with zeroed-out features.
|
| 591 |
-
# "If it doesn't exist, it doesn't exist."
|
| 592 |
profile_data = profiles.get(addr, None)
|
| 593 |
if not profile_data:
|
| 594 |
-
print(f"INFO: Wallet {addr} found in graph but has no profile in DB. Skipping this wallet.")
|
| 595 |
continue
|
| 596 |
|
| 597 |
-
# --- NEW: Ensure all expected keys exist in the fetched profile ---
|
| 598 |
for key in expected_profile_keys:
|
| 599 |
-
profile_data.setdefault(key, 0.0)
|
| 600 |
|
| 601 |
social_data = socials.get(addr, {})
|
| 602 |
|
| 603 |
-
# ---
|
| 604 |
social_data['has_pf_profile'] = bool(social_data.get('pumpfun_username'))
|
| 605 |
social_data['has_twitter'] = bool(social_data.get('twitter_username'))
|
| 606 |
social_data['has_telegram'] = bool(social_data.get('telegram_channel'))
|
| 607 |
-
# 'is_exchange_wallet' is not in the schema, so we'll default to False for now.
|
| 608 |
-
# This is a feature that would likely come from a 'tags' column or a separate service.
|
| 609 |
social_data['is_exchange_wallet'] = 'exchange_wallet' in profile_data.get('tags', [])
|
| 610 |
|
| 611 |
-
# ---
|
| 612 |
funded_ts = profile_data.get('funded_timestamp', 0)
|
| 613 |
if funded_ts and funded_ts > 0:
|
| 614 |
-
# Calculate age in seconds from the funding timestamp
|
| 615 |
age_seconds = int(datetime.datetime.now(datetime.timezone.utc).timestamp()) - funded_ts
|
| 616 |
else:
|
| 617 |
-
# Fallback for wallets older than our DB window, as requested
|
| 618 |
-
# 5 months * 30 days/month * 24 hours/day * 3600 seconds/hour
|
| 619 |
age_seconds = 12_960_000
|
| 620 |
|
| 621 |
-
# Add the calculated age to the profile data that the WalletEncoder will receive
|
| 622 |
profile_data['age'] = float(age_seconds)
|
| 623 |
|
| 624 |
-
# Get the username and add it to the embedding pooler
|
| 625 |
username = social_data.get('pumpfun_username') or social_data.get('twitter_username') or social_data.get('kolscan_name')
|
| 626 |
|
| 627 |
if isinstance(username, str) and username.strip():
|
| 628 |
social_data['username_emb_idx'] = pooler.get_idx(username.strip())
|
| 629 |
else:
|
| 630 |
-
social_data['username_emb_idx'] = 0
|
| 631 |
|
| 632 |
-
# ---
|
| 633 |
-
# We create a new list `valid_wallet_holdings` to ensure that if a holding's
|
| 634 |
-
# token is invalid (filtered out by _process_token_data), the entire holding
|
| 635 |
-
# row is removed and not passed to the WalletEncoder.
|
| 636 |
original_holdings = holdings.get(addr, [])
|
| 637 |
valid_wallet_holdings = []
|
| 638 |
now_ts = datetime.datetime.now(datetime.timezone.utc)
|
|
@@ -643,7 +615,6 @@ class OracleDataset(Dataset):
|
|
| 643 |
token_info = all_token_data.get(mint_addr)
|
| 644 |
|
| 645 |
if not token_info:
|
| 646 |
-
print(f"INFO: Skipping holding for token {mint_addr} in wallet {addr} because token data is invalid/missing.")
|
| 647 |
continue
|
| 648 |
|
| 649 |
end_ts = holding_item.get('end_holding_at')
|
|
@@ -662,10 +633,9 @@ class OracleDataset(Dataset):
|
|
| 662 |
holding_item['balance_pct_to_supply'] = 0.0
|
| 663 |
|
| 664 |
# 3. --- NEW: Calculate bought_amount_sol_pct_to_native_balance ---
|
| 665 |
-
# This uses the historically accurate native balance from the profile.
|
| 666 |
wallet_native_balance = profile_data.get('balance', 0.0)
|
| 667 |
bought_cost_sol = holding_item.get('history_bought_cost_sol', 0.0)
|
| 668 |
-
if wallet_native_balance > 1e-9:
|
| 669 |
holding_item['bought_amount_sol_pct_to_native_balance'] = bought_cost_sol / wallet_native_balance
|
| 670 |
else:
|
| 671 |
holding_item['bought_amount_sol_pct_to_native_balance'] = 0.0
|
|
@@ -695,9 +665,7 @@ class OracleDataset(Dataset):
|
|
| 695 |
else:
|
| 696 |
token_data = {}
|
| 697 |
|
| 698 |
-
|
| 699 |
-
print("\n--- RAW TOKEN DATA FROM DATABASE ---")
|
| 700 |
-
print(token_data)
|
| 701 |
|
| 702 |
# Add pre-computed embedding indices to the token data
|
| 703 |
# --- CRITICAL FIX: This function now returns None if the main token is invalid ---
|
|
@@ -836,14 +804,8 @@ class OracleDataset(Dataset):
|
|
| 836 |
full_ohlc = []
|
| 837 |
start_ts = sorted_intervals[0]
|
| 838 |
end_ts = int(T_cutoff.timestamp())
|
| 839 |
-
# Align end_ts to the interval grid
|
| 840 |
-
end_ts = (end_ts // interval_seconds) * interval_seconds
|
| 841 |
last_price = aggregation_trades[0]['price_usd']
|
| 842 |
|
| 843 |
-
# --- NEW: Debugging log for trades grouped by interval ---
|
| 844 |
-
print(f"\n[DEBUG] OHLC Generation: Trades grouped by interval bucket:")
|
| 845 |
-
print(dict(trades_by_interval))
|
| 846 |
-
|
| 847 |
for ts in range(start_ts, end_ts + 1, interval_seconds):
|
| 848 |
if ts in trades_by_interval:
|
| 849 |
prices = trades_by_interval[ts]
|
|
@@ -940,12 +902,21 @@ class OracleDataset(Dataset):
|
|
| 940 |
# If somehow we have fewer than 25 trades (cache mismatch?), fallback to last.
|
| 941 |
safe_idx = min(24, len(sorted_trades_ts) - 1)
|
| 942 |
min_cutoff_ts = sorted_trades_ts[safe_idx]
|
| 943 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 944 |
|
| 945 |
if max_cutoff_ts <= min_cutoff_ts:
|
| 946 |
sample_offset_ts = min_cutoff_ts
|
| 947 |
else:
|
| 948 |
-
# Standard case: sample uniformly between [Trade[24], LastTrade]
|
| 949 |
sample_offset_ts = random.uniform(min_cutoff_ts, max_cutoff_ts)
|
| 950 |
|
| 951 |
T_cutoff = datetime.datetime.fromtimestamp(sample_offset_ts, tz=datetime.timezone.utc)
|
|
@@ -1221,17 +1192,7 @@ class OracleDataset(Dataset):
|
|
| 1221 |
raw_data['snapshots_5m'] = snapshot_stats
|
| 1222 |
raw_data['holder_snapshots_list'] = holder_snapshots_list # Save the list
|
| 1223 |
|
| 1224 |
-
|
| 1225 |
-
print(f" [Cache Summary]")
|
| 1226 |
-
print(f" - 1s Candles: {len(ohlc_1s)}")
|
| 1227 |
-
print(f" - 5m Snapshots: {len(snapshot_stats)}")
|
| 1228 |
-
print(f" - Trades (Succ): {len(trades)}")
|
| 1229 |
-
print(f" - Pool Events: {len(raw_data.get('pool_creations', []))}")
|
| 1230 |
-
print(f" - Liquidity Chgs: {len(raw_data.get('liquidity_changes', []))}")
|
| 1231 |
-
print(f" - Burns: {len(raw_data.get('burns', []))}")
|
| 1232 |
-
print(f" - Supply Locks: {len(raw_data.get('supply_locks', []))}")
|
| 1233 |
-
print(f" - Migrations: {len(raw_data.get('migrations', []))}")
|
| 1234 |
-
|
| 1235 |
raw_data["protocol_id"] = initial_mint_record.get("protocol")
|
| 1236 |
return raw_data
|
| 1237 |
|
|
|
|
| 273 |
|
| 274 |
ts_list = [int(entry[0]) for entry in price_series]
|
| 275 |
price_list = [float(entry[1]) for entry in price_series]
|
|
|
|
| 276 |
if not ts_list:
|
| 277 |
return torch.zeros(self.num_outputs), torch.zeros(self.num_outputs), []
|
| 278 |
|
|
|
|
| 530 |
profiles, socials = profiles_override, socials_override
|
| 531 |
holdings = holdings_override if holdings_override is not None else {}
|
| 532 |
else:
|
|
|
|
| 533 |
if self.fetcher:
|
| 534 |
profiles, socials = self.fetcher.fetch_wallet_profiles_and_socials(wallet_addresses, T_cutoff)
|
| 535 |
holdings = self.fetcher.fetch_wallet_holdings(wallet_addresses, T_cutoff)
|
|
|
|
| 537 |
profiles, socials, holdings = {}, {}, {}
|
| 538 |
|
| 539 |
valid_wallets = [addr for addr in wallet_addresses if addr in profiles]
|
|
|
|
|
|
|
|
|
|
| 540 |
if not valid_wallets:
|
|
|
|
| 541 |
return {}, token_data
|
| 542 |
wallet_addresses = valid_wallets
|
| 543 |
|
| 544 |
+
# --- Collect all unique mints from holdings to fetch their data ---
|
| 545 |
all_holding_mints = set()
|
| 546 |
for wallet_addr in wallet_addresses:
|
| 547 |
for holding_item in holdings.get(wallet_addr, []):
|
| 548 |
if 'mint_address' in holding_item:
|
| 549 |
all_holding_mints.add(holding_item['mint_address'])
|
| 550 |
|
| 551 |
+
# --- Process all discovered tokens with point-in-time logic ---
|
|
|
|
|
|
|
|
|
|
|
|
|
| 552 |
processed_new_tokens = self._process_token_data(list(all_holding_mints), pooler, T_cutoff)
|
|
|
|
| 553 |
all_token_data = {**token_data, **(processed_new_tokens or {})}
|
| 554 |
|
| 555 |
+
# --- Calculate deployed token stats using point-in-time logic ---
|
| 556 |
self._calculate_deployed_token_stats(profiles, T_cutoff)
|
| 557 |
|
| 558 |
# --- Assemble the final wallet dictionary ---
|
|
|
|
| 559 |
final_wallets = {}
|
| 560 |
for addr in wallet_addresses:
|
| 561 |
|
| 562 |
# --- Define all expected numerical keys for a profile ---
|
|
|
|
| 563 |
expected_profile_keys = [
|
| 564 |
'age', 'deployed_tokens_count', 'deployed_tokens_migrated_pct',
|
| 565 |
'deployed_tokens_avg_lifetime_sec', 'deployed_tokens_avg_peak_mc_usd',
|
|
|
|
| 572 |
'stats_1d_total_fee', 'stats_1d_winrate', 'stats_1d_tokens_traded',
|
| 573 |
'stats_7d_realized_profit_sol', 'stats_7d_realized_profit_pnl', 'stats_7d_buy_count', 'stats_7d_sell_count', 'stats_7d_transfer_in_count', 'stats_7d_transfer_out_count', 'stats_7d_avg_holding_period', 'stats_7d_total_bought_cost_sol', 'stats_7d_total_sold_income_sol', 'stats_7d_total_fee', 'stats_7d_winrate', 'stats_7d_tokens_traded'
|
| 574 |
]
|
| 575 |
+
|
|
|
|
|
|
|
|
|
|
| 576 |
profile_data = profiles.get(addr, None)
|
| 577 |
if not profile_data:
|
|
|
|
| 578 |
continue
|
| 579 |
|
|
|
|
| 580 |
for key in expected_profile_keys:
|
| 581 |
+
profile_data.setdefault(key, 0.0)
|
| 582 |
|
| 583 |
social_data = socials.get(addr, {})
|
| 584 |
|
| 585 |
+
# --- Derive boolean social flags based on schema ---
|
| 586 |
social_data['has_pf_profile'] = bool(social_data.get('pumpfun_username'))
|
| 587 |
social_data['has_twitter'] = bool(social_data.get('twitter_username'))
|
| 588 |
social_data['has_telegram'] = bool(social_data.get('telegram_channel'))
|
|
|
|
|
|
|
| 589 |
social_data['is_exchange_wallet'] = 'exchange_wallet' in profile_data.get('tags', [])
|
| 590 |
|
| 591 |
+
# --- Calculate 'age' based on user's logic ---
|
| 592 |
funded_ts = profile_data.get('funded_timestamp', 0)
|
| 593 |
if funded_ts and funded_ts > 0:
|
|
|
|
| 594 |
age_seconds = int(datetime.datetime.now(datetime.timezone.utc).timestamp()) - funded_ts
|
| 595 |
else:
|
|
|
|
|
|
|
| 596 |
age_seconds = 12_960_000
|
| 597 |
|
|
|
|
| 598 |
profile_data['age'] = float(age_seconds)
|
| 599 |
|
|
|
|
| 600 |
username = social_data.get('pumpfun_username') or social_data.get('twitter_username') or social_data.get('kolscan_name')
|
| 601 |
|
| 602 |
if isinstance(username, str) and username.strip():
|
| 603 |
social_data['username_emb_idx'] = pooler.get_idx(username.strip())
|
| 604 |
else:
|
| 605 |
+
social_data['username_emb_idx'] = 0
|
| 606 |
|
| 607 |
+
# --- Filter holdings and calculate derived features ---
|
|
|
|
|
|
|
|
|
|
| 608 |
original_holdings = holdings.get(addr, [])
|
| 609 |
valid_wallet_holdings = []
|
| 610 |
now_ts = datetime.datetime.now(datetime.timezone.utc)
|
|
|
|
| 615 |
token_info = all_token_data.get(mint_addr)
|
| 616 |
|
| 617 |
if not token_info:
|
|
|
|
| 618 |
continue
|
| 619 |
|
| 620 |
end_ts = holding_item.get('end_holding_at')
|
|
|
|
| 633 |
holding_item['balance_pct_to_supply'] = 0.0
|
| 634 |
|
| 635 |
# 3. --- NEW: Calculate bought_amount_sol_pct_to_native_balance ---
|
|
|
|
| 636 |
wallet_native_balance = profile_data.get('balance', 0.0)
|
| 637 |
bought_cost_sol = holding_item.get('history_bought_cost_sol', 0.0)
|
| 638 |
+
if wallet_native_balance > 1e-9:
|
| 639 |
holding_item['bought_amount_sol_pct_to_native_balance'] = bought_cost_sol / wallet_native_balance
|
| 640 |
else:
|
| 641 |
holding_item['bought_amount_sol_pct_to_native_balance'] = 0.0
|
|
|
|
| 665 |
else:
|
| 666 |
token_data = {}
|
| 667 |
|
| 668 |
+
|
|
|
|
|
|
|
| 669 |
|
| 670 |
# Add pre-computed embedding indices to the token data
|
| 671 |
# --- CRITICAL FIX: This function now returns None if the main token is invalid ---
|
|
|
|
| 804 |
full_ohlc = []
|
| 805 |
start_ts = sorted_intervals[0]
|
| 806 |
end_ts = int(T_cutoff.timestamp())
|
|
|
|
|
|
|
| 807 |
last_price = aggregation_trades[0]['price_usd']
|
| 808 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 809 |
for ts in range(start_ts, end_ts + 1, interval_seconds):
|
| 810 |
if ts in trades_by_interval:
|
| 811 |
prices = trades_by_interval[ts]
|
|
|
|
| 902 |
# If somehow we have fewer than 25 trades (cache mismatch?), fallback to last.
|
| 903 |
safe_idx = min(24, len(sorted_trades_ts) - 1)
|
| 904 |
min_cutoff_ts = sorted_trades_ts[safe_idx]
|
| 905 |
+
|
| 906 |
+
# --- FIX: Ensure max_cutoff leaves room for the largest horizon ---
|
| 907 |
+
# Otherwise, if T_cutoff is near the end, all horizons are masked as 0.
|
| 908 |
+
max_horizon = max(horizons) if horizons else 600
|
| 909 |
+
max_cutoff_ts = sorted_trades_ts[-1] - max_horizon
|
| 910 |
+
|
| 911 |
+
# Safety: Ensure max_cutoff_ts >= min_cutoff_ts
|
| 912 |
+
if max_cutoff_ts < min_cutoff_ts:
|
| 913 |
+
# Token duration is too short for the horizons, use earliest valid cutoff
|
| 914 |
+
max_cutoff_ts = min_cutoff_ts
|
| 915 |
|
| 916 |
if max_cutoff_ts <= min_cutoff_ts:
|
| 917 |
sample_offset_ts = min_cutoff_ts
|
| 918 |
else:
|
| 919 |
+
# Standard case: sample uniformly between [Trade[24], LastTrade - max_horizon]
|
| 920 |
sample_offset_ts = random.uniform(min_cutoff_ts, max_cutoff_ts)
|
| 921 |
|
| 922 |
T_cutoff = datetime.datetime.fromtimestamp(sample_offset_ts, tz=datetime.timezone.utc)
|
|
|
|
| 1192 |
raw_data['snapshots_5m'] = snapshot_stats
|
| 1193 |
raw_data['holder_snapshots_list'] = holder_snapshots_list # Save the list
|
| 1194 |
|
| 1195 |
+
raw_data['holder_snapshots_list'] = holder_snapshots_list # Save the list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1196 |
raw_data["protocol_id"] = initial_mint_record.get("protocol")
|
| 1197 |
return raw_data
|
| 1198 |
|
data/ohlc_stats.npz
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:847193fc90f4b0313f515ea38a24fd073be09188cfc4764c5dce3f658d4dc117
|
| 3 |
+
size 1660
|
log.log
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:10917f8ad8d8962a8c05a46f2b24dcb1180b23665d0767ea5c65c63d9ec09c92
|
| 3 |
+
size 314966
|
models/graph_updater.py
CHANGED
|
@@ -352,6 +352,11 @@ class GraphUpdater(nn.Module):
|
|
| 352 |
self.norm = nn.LayerNorm(node_dim)
|
| 353 |
self.to(dtype) # Move norm layer and ModuleList container
|
| 354 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
def _build_edge_groups(self) -> Dict[tuple, List[str]]:
|
| 356 |
"""Group relations by (src_type, dst_type) so conv weights can be shared."""
|
| 357 |
groups: Dict[tuple, List[str]] = defaultdict(list)
|
|
|
|
| 352 |
self.norm = nn.LayerNorm(node_dim)
|
| 353 |
self.to(dtype) # Move norm layer and ModuleList container
|
| 354 |
|
| 355 |
+
# Log params
|
| 356 |
+
total_params = sum(p.numel() for p in self.parameters())
|
| 357 |
+
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 358 |
+
print(f"[GraphUpdater] Params: {total_params:,} (Trainable: {trainable_params:,})")
|
| 359 |
+
|
| 360 |
def _build_edge_groups(self) -> Dict[tuple, List[str]]:
|
| 361 |
"""Group relations by (src_type, dst_type) so conv weights can be shared."""
|
| 362 |
groups: Dict[tuple, List[str]] = defaultdict(list)
|
models/helper_encoders.py
CHANGED
|
@@ -33,6 +33,11 @@ class ContextualTimeEncoder(nn.Module):
|
|
| 33 |
# Cast the entire module to the specified dtype
|
| 34 |
self.to(dtype)
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
def _sinusoidal_encode(self, values: torch.Tensor, d_model: int) -> torch.Tensor:
|
| 37 |
device = values.device
|
| 38 |
half_dim = d_model // 2
|
|
|
|
| 33 |
# Cast the entire module to the specified dtype
|
| 34 |
self.to(dtype)
|
| 35 |
|
| 36 |
+
# Log params
|
| 37 |
+
total_params = sum(p.numel() for p in self.parameters())
|
| 38 |
+
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 39 |
+
print(f"[ContextualTimeEncoder] Params: {total_params:,} (Trainable: {trainable_params:,})")
|
| 40 |
+
|
| 41 |
def _sinusoidal_encode(self, values: torch.Tensor, d_model: int) -> torch.Tensor:
|
| 42 |
device = values.device
|
| 43 |
half_dim = d_model // 2
|
models/model.py
CHANGED
|
@@ -375,14 +375,6 @@ class Oracle(nn.Module):
|
|
| 375 |
# 1a. Encode Tokens
|
| 376 |
# --- FIXED: Check for a key that still exists ---
|
| 377 |
if token_encoder_inputs['name_embed_indices'].numel() > 0:
|
| 378 |
-
# --- AGGRESSIVE LOGGING ---
|
| 379 |
-
print("\n--- [Oracle DynamicEncoder LOG] ---")
|
| 380 |
-
print(f"[Oracle LOG] embedding_pool shape: {embedding_pool.shape}")
|
| 381 |
-
print(f"[Oracle LOG] name_embed_indices (shape {token_encoder_inputs['name_embed_indices'].shape}):\n{token_encoder_inputs['name_embed_indices']}")
|
| 382 |
-
print(f"[Oracle LOG] symbol_embed_indices (shape {token_encoder_inputs['symbol_embed_indices'].shape}):\n{token_encoder_inputs['symbol_embed_indices']}")
|
| 383 |
-
print(f"[Oracle LOG] image_embed_indices (shape {token_encoder_inputs['image_embed_indices'].shape}):\n{token_encoder_inputs['image_embed_indices']}")
|
| 384 |
-
print("--- [Oracle LOG] Calling F.embedding and TokenEncoder... ---")
|
| 385 |
-
# --- END LOGGING ---
|
| 386 |
# --- NEW: Gather pre-computed embeddings and pass to encoder ---
|
| 387 |
# --- CRITICAL FIX: Remove keys that are not part of the TokenEncoder's signature ---
|
| 388 |
encoder_args = token_encoder_inputs.copy()
|
|
|
|
| 375 |
# 1a. Encode Tokens
|
| 376 |
# --- FIXED: Check for a key that still exists ---
|
| 377 |
if token_encoder_inputs['name_embed_indices'].numel() > 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 378 |
# --- NEW: Gather pre-computed embeddings and pass to encoder ---
|
| 379 |
# --- CRITICAL FIX: Remove keys that are not part of the TokenEncoder's signature ---
|
| 380 |
encoder_args = token_encoder_inputs.copy()
|
models/multi_modal_processor.py
CHANGED
|
@@ -11,6 +11,8 @@ import os
|
|
| 11 |
import traceback
|
| 12 |
import numpy as np
|
| 13 |
|
|
|
|
|
|
|
| 14 |
# Suppress warnings
|
| 15 |
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
|
| 16 |
|
|
@@ -22,6 +24,10 @@ class MultiModalEncoder:
|
|
| 22 |
"""
|
| 23 |
|
| 24 |
def __init__(self, model_id="google/siglip-so400m-patch16-256-i18n", dtype: torch.dtype = torch.bfloat16, device: str = None):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
self.model_id = model_id
|
| 26 |
if device:
|
| 27 |
self.device = device
|
|
@@ -72,46 +78,24 @@ class MultiModalEncoder:
|
|
| 72 |
|
| 73 |
autocast_dtype = self.dtype if self.dtype in [torch.float16, torch.bfloat16] else None
|
| 74 |
|
| 75 |
-
|
| 76 |
-
print(f"[MME LOG] Input data preview: {str(x[0])[:100] if is_text else x[0]}")
|
| 77 |
-
|
| 78 |
-
with torch.amp.autocast(device_type=self.device, enabled=(self.device == 'cuda' and autocast_dtype is not None), dtype=autocast_dtype):
|
| 79 |
try:
|
| 80 |
if is_text:
|
| 81 |
-
inputs = self.processor(
|
| 82 |
-
text=x,
|
| 83 |
-
return_tensors="pt",
|
| 84 |
-
padding="max_length",
|
| 85 |
-
truncation=True
|
| 86 |
-
).to(self.device)
|
| 87 |
-
print(f"[MME LOG] Text processor output shape: {inputs['input_ids'].shape}")
|
| 88 |
embeddings = self.model.get_text_features(**inputs)
|
| 89 |
else:
|
| 90 |
-
|
| 91 |
-
inputs = self.processor(
|
| 92 |
-
images=rgb_images,
|
| 93 |
-
return_tensors="pt"
|
| 94 |
-
).to(self.device)
|
| 95 |
-
|
| 96 |
-
if 'pixel_values' in inputs and inputs['pixel_values'].dtype != self.dtype:
|
| 97 |
-
inputs['pixel_values'] = inputs['pixel_values'].to(self.dtype)
|
| 98 |
-
|
| 99 |
embeddings = self.model.get_image_features(**inputs)
|
| 100 |
|
| 101 |
-
print(f"[MME LOG] Raw model output embeddings shape: {embeddings.shape}, dtype: {embeddings.dtype}")
|
| 102 |
-
|
| 103 |
-
# <<< THIS IS THE FIX. I accidentally removed this.
|
| 104 |
# Normalize in float32 for numerical stability
|
| 105 |
embeddings = F.normalize(embeddings.float(), p=2, dim=-1)
|
| 106 |
-
print(f"[MME LOG] Normalized embeddings shape: {embeddings.shape}, dtype: {embeddings.dtype}")
|
| 107 |
|
| 108 |
final_embeddings = embeddings.to(self.dtype)
|
| 109 |
-
print(f"[MME LOG] Final embeddings shape: {final_embeddings.shape}, dtype: {final_embeddings.dtype}. EXITING __call__.")
|
| 110 |
return final_embeddings
|
| 111 |
|
| 112 |
except Exception as e:
|
| 113 |
-
|
| 114 |
-
traceback.print_exc()
|
| 115 |
return torch.empty(0, self.embedding_dim).to(self.device)
|
| 116 |
|
| 117 |
# --- Test block (SigLIP) ---
|
|
|
|
| 11 |
import traceback
|
| 12 |
import numpy as np
|
| 13 |
|
| 14 |
+
from transformers.utils import logging as hf_logging
|
| 15 |
+
|
| 16 |
# Suppress warnings
|
| 17 |
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
|
| 18 |
|
|
|
|
| 24 |
"""
|
| 25 |
|
| 26 |
def __init__(self, model_id="google/siglip-so400m-patch16-256-i18n", dtype: torch.dtype = torch.bfloat16, device: str = None):
|
| 27 |
+
# Force silence progress bars locally for this class
|
| 28 |
+
hf_logging.set_verbosity_error()
|
| 29 |
+
hf_logging.disable_progress_bar()
|
| 30 |
+
|
| 31 |
self.model_id = model_id
|
| 32 |
if device:
|
| 33 |
self.device = device
|
|
|
|
| 78 |
|
| 79 |
autocast_dtype = self.dtype if self.dtype in [torch.float16, torch.bfloat16] else None
|
| 80 |
|
| 81 |
+
with torch.autocast(device_type=self.device, dtype=autocast_dtype, enabled=(autocast_dtype is not None)):
|
|
|
|
|
|
|
|
|
|
| 82 |
try:
|
| 83 |
if is_text:
|
| 84 |
+
inputs = self.processor(text=x, return_tensors="pt", padding=True, truncation=True).to(self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
embeddings = self.model.get_text_features(**inputs)
|
| 86 |
else:
|
| 87 |
+
inputs = self.processor(images=x, return_tensors="pt").to(self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
embeddings = self.model.get_image_features(**inputs)
|
| 89 |
|
|
|
|
|
|
|
|
|
|
| 90 |
# Normalize in float32 for numerical stability
|
| 91 |
embeddings = F.normalize(embeddings.float(), p=2, dim=-1)
|
|
|
|
| 92 |
|
| 93 |
final_embeddings = embeddings.to(self.dtype)
|
|
|
|
| 94 |
return final_embeddings
|
| 95 |
|
| 96 |
except Exception as e:
|
| 97 |
+
# Silently fail or log debug only if needed
|
| 98 |
+
# traceback.print_exc()
|
| 99 |
return torch.empty(0, self.embedding_dim).to(self.device)
|
| 100 |
|
| 101 |
# --- Test block (SigLIP) ---
|
models/ohlc_embedder.py
CHANGED
|
@@ -71,6 +71,11 @@ class OHLCEmbedder(nn.Module):
|
|
| 71 |
|
| 72 |
self.to(dtype)
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
def forward(self, x: torch.Tensor, interval_ids: torch.Tensor) -> torch.Tensor:
|
| 75 |
"""
|
| 76 |
Args:
|
|
|
|
| 71 |
|
| 72 |
self.to(dtype)
|
| 73 |
|
| 74 |
+
# Log params
|
| 75 |
+
total_params = sum(p.numel() for p in self.parameters())
|
| 76 |
+
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 77 |
+
print(f"[OHLCEmbedder] Params: {total_params:,} (Trainable: {trainable_params:,})")
|
| 78 |
+
|
| 79 |
def forward(self, x: torch.Tensor, interval_ids: torch.Tensor) -> torch.Tensor:
|
| 80 |
"""
|
| 81 |
Args:
|
models/token_encoder.py
CHANGED
|
@@ -98,6 +98,11 @@ class TokenEncoder(nn.Module):
|
|
| 98 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 99 |
self.to(device=device, dtype=dtype)
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
def forward(
|
| 102 |
self,
|
| 103 |
name_embeds: torch.Tensor,
|
|
@@ -123,21 +128,14 @@ class TokenEncoder(nn.Module):
|
|
| 123 |
device = name_embeds.device
|
| 124 |
batch_size = name_embeds.shape[0]
|
| 125 |
|
| 126 |
-
# 2. Get Protocol embedding (small)
|
| 127 |
-
print(f"\n--- [TokenEncoder LOG] ENTERING FORWARD PASS (Batch Size: {batch_size}) ---")
|
| 128 |
-
print(f"[TokenEncoder LOG] Input protocol_ids (shape {protocol_ids.shape}):\n{protocol_ids}")
|
| 129 |
-
print(f"[TokenEncoder LOG] Protocol Embedding Vocab Size: {self.protocol_embedding.num_embeddings}")
|
| 130 |
-
|
| 131 |
protocol_ids_long = protocol_ids.to(device, dtype=torch.long)
|
| 132 |
protocol_emb_raw = self.protocol_embedding(protocol_ids_long) # [B, 64]
|
| 133 |
-
print(f"[TokenEncoder LOG] Raw protocol embeddings shape: {protocol_emb_raw.shape}")
|
| 134 |
|
| 135 |
# NEW: Get vanity embedding
|
| 136 |
vanity_ids_long = is_vanity_flags.to(device, dtype=torch.long)
|
| 137 |
vanity_emb_raw = self.vanity_embedding(vanity_ids_long) # [B, 32]
|
| 138 |
|
| 139 |
# 3. Project all features to internal_dim (e.g., 1024)
|
| 140 |
-
print(f"[TokenEncoder LOG] Projecting features to internal_dim: {self.internal_dim}")
|
| 141 |
name_emb = self.name_proj(name_embeds)
|
| 142 |
symbol_emb = self.symbol_proj(symbol_embeds)
|
| 143 |
image_emb = self.image_proj(image_embeds)
|
|
@@ -153,16 +151,8 @@ class TokenEncoder(nn.Module):
|
|
| 153 |
vanity_emb, # NEW: Add the vanity embedding to the sequence
|
| 154 |
], dim=1)
|
| 155 |
|
| 156 |
-
print(f"[TokenEncoder LOG] Stacked feature_sequence shape: {feature_sequence.shape}")
|
| 157 |
-
print(f" - name_emb shape: {name_emb.shape}")
|
| 158 |
-
print(f" - symbol_emb shape: {symbol_emb.shape}")
|
| 159 |
-
print(f" - image_emb shape: {image_emb.shape}")
|
| 160 |
-
print(f" - protocol_emb shape: {protocol_emb.shape}")
|
| 161 |
-
print(f" - vanity_emb shape: {vanity_emb.shape}") # ADDED: Log the new vanity embedding shape
|
| 162 |
-
|
| 163 |
# 5. Create the padding mask (all False, since we have a fixed number of features for all)
|
| 164 |
padding_mask = torch.zeros(batch_size, feature_sequence.shape[1], device=device, dtype=torch.bool)
|
| 165 |
-
print(f"[TokenEncoder LOG] Created padding_mask of shape: {padding_mask.shape}")
|
| 166 |
|
| 167 |
# 6. Fuse the sequence with the Transformer Encoder
|
| 168 |
# This returns the [CLS] token output.
|
|
@@ -171,12 +161,9 @@ class TokenEncoder(nn.Module):
|
|
| 171 |
item_embeds=feature_sequence,
|
| 172 |
src_key_padding_mask=padding_mask
|
| 173 |
)
|
| 174 |
-
print(f"[TokenEncoder LOG] Fused embedding shape after transformer: {fused_embedding.shape}")
|
| 175 |
|
| 176 |
# 7. Project to the final output dimension
|
| 177 |
# Shape: [B, output_dim]
|
| 178 |
token_vibe_embedding = self.final_projection(fused_embedding)
|
| 179 |
-
print(f"[TokenEncoder LOG] Final token_vibe_embedding shape: {token_vibe_embedding.shape}")
|
| 180 |
-
print(f"--- [TokenEncoder LOG] EXITING FORWARD PASS ---\n")
|
| 181 |
|
| 182 |
return token_vibe_embedding
|
|
|
|
| 98 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 99 |
self.to(device=device, dtype=dtype)
|
| 100 |
|
| 101 |
+
# Log params
|
| 102 |
+
total_params = sum(p.numel() for p in self.parameters())
|
| 103 |
+
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 104 |
+
print(f"[TokenEncoder] Params: {total_params:,} (Trainable: {trainable_params:,})")
|
| 105 |
+
|
| 106 |
def forward(
|
| 107 |
self,
|
| 108 |
name_embeds: torch.Tensor,
|
|
|
|
| 128 |
device = name_embeds.device
|
| 129 |
batch_size = name_embeds.shape[0]
|
| 130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
protocol_ids_long = protocol_ids.to(device, dtype=torch.long)
|
| 132 |
protocol_emb_raw = self.protocol_embedding(protocol_ids_long) # [B, 64]
|
|
|
|
| 133 |
|
| 134 |
# NEW: Get vanity embedding
|
| 135 |
vanity_ids_long = is_vanity_flags.to(device, dtype=torch.long)
|
| 136 |
vanity_emb_raw = self.vanity_embedding(vanity_ids_long) # [B, 32]
|
| 137 |
|
| 138 |
# 3. Project all features to internal_dim (e.g., 1024)
|
|
|
|
| 139 |
name_emb = self.name_proj(name_embeds)
|
| 140 |
symbol_emb = self.symbol_proj(symbol_embeds)
|
| 141 |
image_emb = self.image_proj(image_embeds)
|
|
|
|
| 151 |
vanity_emb, # NEW: Add the vanity embedding to the sequence
|
| 152 |
], dim=1)
|
| 153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
# 5. Create the padding mask (all False, since we have a fixed number of features for all)
|
| 155 |
padding_mask = torch.zeros(batch_size, feature_sequence.shape[1], device=device, dtype=torch.bool)
|
|
|
|
| 156 |
|
| 157 |
# 6. Fuse the sequence with the Transformer Encoder
|
| 158 |
# This returns the [CLS] token output.
|
|
|
|
| 161 |
item_embeds=feature_sequence,
|
| 162 |
src_key_padding_mask=padding_mask
|
| 163 |
)
|
|
|
|
| 164 |
|
| 165 |
# 7. Project to the final output dimension
|
| 166 |
# Shape: [B, output_dim]
|
| 167 |
token_vibe_embedding = self.final_projection(fused_embedding)
|
|
|
|
|
|
|
| 168 |
|
| 169 |
return token_vibe_embedding
|
models/wallet_encoder.py
CHANGED
|
@@ -95,6 +95,28 @@ class WalletEncoder(nn.Module):
|
|
| 95 |
)
|
| 96 |
self.to(dtype)
|
| 97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
def _build_mlp(self, in_dim, out_dim):
|
| 99 |
return nn.Sequential(
|
| 100 |
nn.Linear(in_dim, out_dim * 2),
|
|
|
|
| 95 |
)
|
| 96 |
self.to(dtype)
|
| 97 |
|
| 98 |
+
# Log params (excluding the shared encoder which might be huge and already logged)
|
| 99 |
+
# Note: self.encoder is external, but if we include it here, it will double count.
|
| 100 |
+
# Ideally we only log *this* module's params.
|
| 101 |
+
my_params = sum(p.numel() for p in self.parameters())
|
| 102 |
+
# To avoid double counting the external encoder if it's a submodule (it is assigned to self.encoder)
|
| 103 |
+
# But wait, self.encoder IS a submodule.
|
| 104 |
+
# We should subtract it if we just want "WalletEncoder specific" params, or clarify.
|
| 105 |
+
# Let's verify if self.encoder params are included in self.parameters().
|
| 106 |
+
# Yes they are because `self.encoder = encoder` assigns it.
|
| 107 |
+
# Actually `encoder` is passed in. If `MultiModalEncoder` is an `nn.Module` (it is NOT), then it would be registered.
|
| 108 |
+
# `MultiModalEncoder` is a wrapper class, NOT an `nn.Module`.
|
| 109 |
+
# However, it contains `self.model` which is an `nn.Module`.
|
| 110 |
+
# But `WalletEncoder` stores `self.encoder = encoder`.
|
| 111 |
+
# Since `MultiModalEncoder` is not an `nn.Module`, `self.encoder` is just a standard attribute.
|
| 112 |
+
# So `self.parameters()` of `WalletEncoder` will NOT include `MultiModalEncoder` params.
|
| 113 |
+
# EXCEPT... we don't know if `MultiModalEncoder` subclassed `nn.Module`.
|
| 114 |
+
# I checked earlier: `class MultiModalEncoder:` -> No `nn.Module`.
|
| 115 |
+
# So we are safe. `self.parameters()` will only be the MLPs and SetEncoders defined in WalletEncoder.
|
| 116 |
+
|
| 117 |
+
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 118 |
+
print(f"[WalletEncoder] Params: {my_params:,} (Trainable: {trainable_params:,})")
|
| 119 |
+
|
| 120 |
def _build_mlp(self, in_dim, out_dim):
|
| 121 |
return nn.Sequential(
|
| 122 |
nn.Linear(in_dim, out_dim * 2),
|
scripts/cache_dataset.py
CHANGED
|
@@ -2,12 +2,20 @@
|
|
| 2 |
import os
|
| 3 |
import sys
|
| 4 |
import argparse
|
|
|
|
| 5 |
import datetime
|
| 6 |
import torch
|
| 7 |
import json
|
| 8 |
from pathlib import Path
|
| 9 |
from tqdm import tqdm
|
| 10 |
from dotenv import load_dotenv
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
# Add parent directory to path to import modules
|
| 13 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
@@ -19,9 +27,84 @@ from scripts.analyze_distribution import get_return_class_map
|
|
| 19 |
from clickhouse_driver import Client as ClickHouseClient
|
| 20 |
from neo4j import GraphDatabase
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
def main():
|
| 23 |
load_dotenv()
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
parser = argparse.ArgumentParser(description="Cache dataset samples for training.")
|
| 26 |
parser.add_argument("--output_dir", type=str, default="data/cache", help="Directory to save cached samples")
|
| 27 |
parser.add_argument("--max_samples", type=int, default=None, help="Maximum number of samples to generate")
|
|
@@ -50,6 +133,9 @@ def main():
|
|
| 50 |
neo4j_driver = GraphDatabase.driver(args.neo4j_uri, auth=(args.neo4j_user, args.neo4j_password))
|
| 51 |
|
| 52 |
try:
|
|
|
|
|
|
|
|
|
|
| 53 |
# --- 2. Initialize DataFetcher and OracleDataset ---
|
| 54 |
data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
|
| 55 |
|
|
|
|
| 2 |
import os
|
| 3 |
import sys
|
| 4 |
import argparse
|
| 5 |
+
import numpy as np
|
| 6 |
import datetime
|
| 7 |
import torch
|
| 8 |
import json
|
| 9 |
from pathlib import Path
|
| 10 |
from tqdm import tqdm
|
| 11 |
from dotenv import load_dotenv
|
| 12 |
+
import huggingface_hub
|
| 13 |
+
import logging
|
| 14 |
+
|
| 15 |
+
# Suppress noisy libraries
|
| 16 |
+
logging.getLogger("httpx").setLevel(logging.WARNING)
|
| 17 |
+
logging.getLogger("transformers").setLevel(logging.ERROR)
|
| 18 |
+
logging.getLogger("huggingface_hub").setLevel(logging.WARNING)
|
| 19 |
|
| 20 |
# Add parent directory to path to import modules
|
| 21 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
| 27 |
from clickhouse_driver import Client as ClickHouseClient
|
| 28 |
from neo4j import GraphDatabase
|
| 29 |
|
| 30 |
+
def compute_save_ohlc_stats(client: ClickHouseClient, output_path: str):
|
| 31 |
+
"""
|
| 32 |
+
Computes global mean/std for price/volume from ClickHouse and saves to .npz
|
| 33 |
+
This allows the dataset loader to normalize inputs correctly.
|
| 34 |
+
"""
|
| 35 |
+
print(f"INFO: Computing OHLC stats (mean/std) from ClickHouse...")
|
| 36 |
+
|
| 37 |
+
# Query matching preprocess_distribution.py logic
|
| 38 |
+
# We use hardcoded min_price/vol filters to avoid skewing stats with dust
|
| 39 |
+
min_price = 0.0
|
| 40 |
+
min_vol = 0.0
|
| 41 |
+
|
| 42 |
+
query = """
|
| 43 |
+
SELECT
|
| 44 |
+
AVG(t.price_usd) AS mean_price_usd,
|
| 45 |
+
stddevPop(t.price_usd) AS std_price_usd,
|
| 46 |
+
AVG(t.price) AS mean_price_native,
|
| 47 |
+
stddevPop(t.price) AS std_price_native,
|
| 48 |
+
AVG(t.total_usd) AS mean_trade_value_usd,
|
| 49 |
+
stddevPop(t.total_usd) AS std_trade_value_usd
|
| 50 |
+
FROM trades AS t
|
| 51 |
+
WHERE t.price_usd > %(min_price)s AND t.total_usd > %(min_vol)s
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
params = {"min_price": min_price, "min_vol": min_vol}
|
| 55 |
+
|
| 56 |
+
try:
|
| 57 |
+
result = client.execute(query, params=params)
|
| 58 |
+
if not result or not result[0]:
|
| 59 |
+
print("WARNING: Stats query returned no rows. Using default identity stats.")
|
| 60 |
+
stats = {
|
| 61 |
+
"mean_price_usd": 0.0, "std_price_usd": 1.0,
|
| 62 |
+
"mean_price_native": 0.0, "std_price_native": 1.0,
|
| 63 |
+
"mean_trade_value_usd": 0.0, "std_trade_value_usd": 1.0,
|
| 64 |
+
}
|
| 65 |
+
else:
|
| 66 |
+
row = result[0]
|
| 67 |
+
# Handle potential None values if DB is empty
|
| 68 |
+
def safe_float(x, default=0.0):
|
| 69 |
+
return float(x) if x is not None else default
|
| 70 |
+
|
| 71 |
+
def safe_std(x):
|
| 72 |
+
val = safe_float(x, 1.0)
|
| 73 |
+
return val if val > 1e-9 else 1.0
|
| 74 |
+
|
| 75 |
+
stats = {
|
| 76 |
+
"mean_price_usd": safe_float(row[0]),
|
| 77 |
+
"std_price_usd": safe_std(row[1]),
|
| 78 |
+
"mean_price_native": safe_float(row[2]),
|
| 79 |
+
"std_price_native": safe_std(row[3]),
|
| 80 |
+
"mean_trade_value_usd": safe_float(row[4]),
|
| 81 |
+
"std_trade_value_usd": safe_std(row[5]),
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
# Save to NPZ
|
| 85 |
+
out_p = Path(output_path)
|
| 86 |
+
out_p.parent.mkdir(parents=True, exist_ok=True)
|
| 87 |
+
np.savez(out_p, **stats)
|
| 88 |
+
|
| 89 |
+
print(f"INFO: Saved OHLC stats to {out_p}")
|
| 90 |
+
for k, v in stats.items():
|
| 91 |
+
print(f" {k}: {v:.4f}")
|
| 92 |
+
|
| 93 |
+
except Exception as e:
|
| 94 |
+
print(f"ERROR: Failed to compute OHLC stats: {e}")
|
| 95 |
+
# Don't crash, let it try to proceed (though dataset might complain if file missing)
|
| 96 |
+
|
| 97 |
def main():
|
| 98 |
load_dotenv()
|
| 99 |
|
| 100 |
+
# Explicit Login
|
| 101 |
+
hf_token = os.getenv("HF_TOKEN")
|
| 102 |
+
if hf_token:
|
| 103 |
+
print(f"INFO: Logging in to Hugging Face with token starting with: {hf_token[:4]}...")
|
| 104 |
+
huggingface_hub.login(token=hf_token)
|
| 105 |
+
else:
|
| 106 |
+
print("WARNING: HF_TOKEN not found in environment.")
|
| 107 |
+
|
| 108 |
parser = argparse.ArgumentParser(description="Cache dataset samples for training.")
|
| 109 |
parser.add_argument("--output_dir", type=str, default="data/cache", help="Directory to save cached samples")
|
| 110 |
parser.add_argument("--max_samples", type=int, default=None, help="Maximum number of samples to generate")
|
|
|
|
| 133 |
neo4j_driver = GraphDatabase.driver(args.neo4j_uri, auth=(args.neo4j_user, args.neo4j_password))
|
| 134 |
|
| 135 |
try:
|
| 136 |
+
# --- 1. Compute OHLC Stats (Global) ---
|
| 137 |
+
compute_save_ohlc_stats(clickhouse_client, args.ohlc_stats_path)
|
| 138 |
+
|
| 139 |
# --- 2. Initialize DataFetcher and OracleDataset ---
|
| 140 |
data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
|
| 141 |
|
train.py
CHANGED
|
@@ -15,6 +15,14 @@ resolved_tmp = str(_DEFAULT_TMP.resolve())
|
|
| 15 |
for key in ("TMPDIR", "TMP", "TEMP"):
|
| 16 |
os.environ.setdefault(key, resolved_tmp)
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
try:
|
| 19 |
mp.set_start_method('spawn', force=True)
|
| 20 |
except RuntimeError:
|
|
@@ -100,8 +108,8 @@ def quantile_pinball_loss(preds: torch.Tensor,
|
|
| 100 |
# Preds shape: [B, Horizons * Quantiles]
|
| 101 |
# Logic assumes interleaved outputs or consistent flattening.
|
| 102 |
pred_slice = preds[:, idx::num_quantiles]
|
| 103 |
-
target_slice = targets
|
| 104 |
-
mask_slice = mask
|
| 105 |
|
| 106 |
diff = target_slice - pred_slice
|
| 107 |
pinball = torch.maximum((q - 1.0) * diff, q * diff)
|
|
@@ -118,6 +126,44 @@ def filtered_collate(collator: MemecoinCollator,
|
|
| 118 |
return None
|
| 119 |
return collator(batch)
|
| 120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
def parse_args() -> argparse.Namespace:
|
| 123 |
parser = argparse.ArgumentParser(description="Train the Oracle quantile model.")
|
|
@@ -209,6 +255,25 @@ def main() -> None:
|
|
| 209 |
max_seq_len = args.max_seq_len
|
| 210 |
max_seq_len = args.max_seq_len
|
| 211 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
logger.info(f"Initializing Encoders with dtype={init_dtype}...")
|
| 213 |
|
| 214 |
# Encoders
|
|
@@ -423,6 +488,11 @@ def main() -> None:
|
|
| 423 |
# Logging
|
| 424 |
if accelerator.sync_gradients:
|
| 425 |
total_steps += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 426 |
current_loss = loss.item()
|
| 427 |
epoch_loss += current_loss
|
| 428 |
valid_batches += 1
|
|
|
|
| 15 |
for key in ("TMPDIR", "TMP", "TEMP"):
|
| 16 |
os.environ.setdefault(key, resolved_tmp)
|
| 17 |
|
| 18 |
+
# --- Environment & Logging Setup ---
|
| 19 |
+
from dotenv import load_dotenv
|
| 20 |
+
import huggingface_hub
|
| 21 |
+
from transformers.utils import logging as hf_logging
|
| 22 |
+
|
| 23 |
+
# Load .env explicitly (benign at global scope but moving heavy lifting to main)
|
| 24 |
+
load_dotenv()
|
| 25 |
+
|
| 26 |
try:
|
| 27 |
mp.set_start_method('spawn', force=True)
|
| 28 |
except RuntimeError:
|
|
|
|
| 108 |
# Preds shape: [B, Horizons * Quantiles]
|
| 109 |
# Logic assumes interleaved outputs or consistent flattening.
|
| 110 |
pred_slice = preds[:, idx::num_quantiles]
|
| 111 |
+
target_slice = targets
|
| 112 |
+
mask_slice = mask
|
| 113 |
|
| 114 |
diff = target_slice - pred_slice
|
| 115 |
pinball = torch.maximum((q - 1.0) * diff, q * diff)
|
|
|
|
| 126 |
return None
|
| 127 |
return collator(batch)
|
| 128 |
|
| 129 |
+
def log_debug_batch_context(batch: Dict[str, Any], logger: logging.Logger, step: int):
|
| 130 |
+
"""
|
| 131 |
+
Logs decoded event sequence and labels for the first sample in the batch.
|
| 132 |
+
Use this to verify what the model is actually seeing.
|
| 133 |
+
"""
|
| 134 |
+
if not logger.isEnabledFor(logging.INFO): return
|
| 135 |
+
|
| 136 |
+
try:
|
| 137 |
+
# Only look at the first sample in batch
|
| 138 |
+
idx = 0
|
| 139 |
+
event_ids = batch['event_type_ids'][idx].cpu() # [L]
|
| 140 |
+
labels = batch['labels'][idx].cpu() # [Horizons * Quantiles]
|
| 141 |
+
mask = batch['labels_mask'][idx].cpu()
|
| 142 |
+
|
| 143 |
+
# Decode events
|
| 144 |
+
events = []
|
| 145 |
+
for eid in event_ids:
|
| 146 |
+
eid_val = eid.item()
|
| 147 |
+
if eid_val == 0: continue # Skip PAD
|
| 148 |
+
# Get name from vocab
|
| 149 |
+
name = vocab.ID_TO_EVENT.get(eid_val, f"UNK_{eid_val}")
|
| 150 |
+
events.append(name)
|
| 151 |
+
|
| 152 |
+
logger.info(f"\n--- [Step {step}] Batch Input Preview (Sample 0) ---")
|
| 153 |
+
# Show a slice of events (e.g. last 50)
|
| 154 |
+
tail_len = 50
|
| 155 |
+
context_str = ", ".join(events[-tail_len:])
|
| 156 |
+
logger.info(f"Event Stream (Last {tail_len} of {len(events)}): [{context_str}]")
|
| 157 |
+
|
| 158 |
+
# Show Labels
|
| 159 |
+
# Assuming flattened labels [H*Q]
|
| 160 |
+
logger.info(f"Labels (First 10): {labels[:10].tolist()}")
|
| 161 |
+
logger.info(f"Masks (First 10): {mask[:10].tolist()}")
|
| 162 |
+
logger.info("----------------------------------------------------\n")
|
| 163 |
+
|
| 164 |
+
except Exception as e:
|
| 165 |
+
logger.warning(f"Failed to log batch context: {e}")
|
| 166 |
+
|
| 167 |
|
| 168 |
def parse_args() -> argparse.Namespace:
|
| 169 |
parser = argparse.ArgumentParser(description="Train the Oracle quantile model.")
|
|
|
|
| 255 |
max_seq_len = args.max_seq_len
|
| 256 |
max_seq_len = args.max_seq_len
|
| 257 |
|
| 258 |
+
|
| 259 |
+
# --- Environment & Logging Setup ---
|
| 260 |
+
# Load .env explicitly
|
| 261 |
+
load_dotenv()
|
| 262 |
+
|
| 263 |
+
# Suppress noisy libraries
|
| 264 |
+
logging.getLogger("httpx").setLevel(logging.WARNING)
|
| 265 |
+
logging.getLogger("transformers").setLevel(logging.ERROR)
|
| 266 |
+
logging.getLogger("huggingface_hub").setLevel(logging.WARNING)
|
| 267 |
+
hf_logging.disable_progress_bar()
|
| 268 |
+
|
| 269 |
+
# Explicit Login
|
| 270 |
+
hf_token = os.getenv("HF_TOKEN")
|
| 271 |
+
if hf_token:
|
| 272 |
+
print(f"Logging in to Hugging Face with token starting with: {hf_token[:4]}...")
|
| 273 |
+
huggingface_hub.login(token=hf_token)
|
| 274 |
+
else:
|
| 275 |
+
print("WARNING: HF_TOKEN not found in environment.")
|
| 276 |
+
|
| 277 |
logger.info(f"Initializing Encoders with dtype={init_dtype}...")
|
| 278 |
|
| 279 |
# Encoders
|
|
|
|
| 488 |
# Logging
|
| 489 |
if accelerator.sync_gradients:
|
| 490 |
total_steps += 1
|
| 491 |
+
|
| 492 |
+
# --- NEW: Debug Log Batch Context ---
|
| 493 |
+
if total_steps % log_every == 0 and accelerator.is_main_process:
|
| 494 |
+
log_debug_batch_context(batch, logger, total_steps)
|
| 495 |
+
|
| 496 |
current_loss = loss.item()
|
| 497 |
epoch_loss += current_loss
|
| 498 |
valid_batches += 1
|