Upload folder using huggingface_hub
Browse files- data/data_loader.py +31 -17
- data/ohlc_stats.npz +1 -1
- inference.py +1 -1
- log.log +2 -2
- models/wallet_encoder.py +8 -10
- sample_DYtPmhyxPDbMEdVP_0.json +0 -0
- scripts/analyze_distribution.py +49 -0
- scripts/cache_dataset.py +8 -4
data/data_loader.py
CHANGED
|
@@ -128,9 +128,22 @@ class OracleDataset(Dataset):
|
|
| 128 |
cache_dir: Optional[Union[str, Path]] = None,
|
| 129 |
start_date: Optional[datetime.datetime] = None,
|
| 130 |
min_trade_usd: float = 0.0,
|
| 131 |
-
max_seq_len: int = 8192
|
|
|
|
| 132 |
|
| 133 |
self.max_seq_len = max_seq_len
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
# --- NEW: Create a persistent requests session for efficiency ---
|
| 136 |
# Configure robust HTTP session
|
|
@@ -633,10 +646,18 @@ class OracleDataset(Dataset):
|
|
| 633 |
sell_count = sum(1 for e in trades_win if e.get('trade_direction') == 1)
|
| 634 |
volume = sum(float(e.get('total_usd', 0.0) or 0.0) for e in trades_win)
|
| 635 |
total_txns = len(trades_win) + len(xfers_win)
|
| 636 |
-
global_fees_paid = sum(
|
| 637 |
-
|
|
|
|
|
|
|
| 638 |
|
| 639 |
-
smart_trader_addrs = set(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 640 |
smart_traders = len(smart_trader_addrs)
|
| 641 |
|
| 642 |
kol_addrs = set()
|
|
@@ -825,7 +846,7 @@ class OracleDataset(Dataset):
|
|
| 825 |
|
| 826 |
# --- Define all expected numerical keys for a profile ---
|
| 827 |
expected_profile_keys = [
|
| 828 |
-
'
|
| 829 |
'deployed_tokens_avg_lifetime_sec', 'deployed_tokens_avg_peak_mc_usd',
|
| 830 |
'deployed_tokens_median_peak_mc_usd', 'balance', 'transfers_in_count',
|
| 831 |
'transfers_out_count', 'spl_transfers_in_count', 'spl_transfers_out_count',
|
|
@@ -852,14 +873,7 @@ class OracleDataset(Dataset):
|
|
| 852 |
social_data['has_telegram'] = bool(social_data.get('telegram_channel'))
|
| 853 |
social_data['is_exchange_wallet'] = 'exchange_wallet' in profile_data.get('tags', [])
|
| 854 |
|
| 855 |
-
|
| 856 |
-
funded_ts = profile_data.get('funded_timestamp', 0)
|
| 857 |
-
if funded_ts and funded_ts > 0:
|
| 858 |
-
age_seconds = int(datetime.datetime.now(datetime.timezone.utc).timestamp()) - funded_ts
|
| 859 |
-
else:
|
| 860 |
-
age_seconds = 12_960_000
|
| 861 |
-
|
| 862 |
-
profile_data['age'] = float(age_seconds)
|
| 863 |
|
| 864 |
username = social_data.get('pumpfun_username') or social_data.get('twitter_username') or social_data.get('kolscan_name')
|
| 865 |
|
|
@@ -909,12 +923,12 @@ class OracleDataset(Dataset):
|
|
| 909 |
'mint_address': mint_addr,
|
| 910 |
'holding_time': float(holding_item.get('holding_time', 0.0) or 0.0),
|
| 911 |
'balance_pct_to_supply': min(1.0, float(holding_item.get('balance_pct_to_supply', 0.0) or 0.0)),
|
| 912 |
-
'history_bought_cost_sol': min(
|
| 913 |
'bought_amount_sol_pct_to_native_balance': min(1.0, float(holding_item.get('bought_amount_sol_pct_to_native_balance', 0.0) or 0.0)),
|
| 914 |
'history_total_buys': float(holding_item.get('history_total_buys', 0.0) or 0.0),
|
| 915 |
'history_total_sells': float(holding_item.get('history_total_sells', 0.0) or 0.0),
|
| 916 |
'realized_profit_pnl': float(holding_item.get('realized_profit_pnl', 0.0) or 0.0),
|
| 917 |
-
'realized_profit_sol': max(-
|
| 918 |
'history_transfer_in': float(holding_item.get('history_transfer_in', 0.0) or 0.0),
|
| 919 |
'history_transfer_out': float(holding_item.get('history_transfer_out', 0.0) or 0.0),
|
| 920 |
'avarage_trade_gap_seconds': float(holding_item.get('avarage_trade_gap_seconds', 0.0) or 0.0),
|
|
@@ -926,7 +940,7 @@ class OracleDataset(Dataset):
|
|
| 926 |
compact_profile = {'wallet_address': addr}
|
| 927 |
for key in expected_profile_keys:
|
| 928 |
compact_profile[key] = float(profile_data.get(key, 0.0) or 0.0)
|
| 929 |
-
|
| 930 |
|
| 931 |
compact_social = {
|
| 932 |
'has_pf_profile': bool(social_data.get('has_pf_profile', False)),
|
|
@@ -2073,7 +2087,7 @@ class OracleDataset(Dataset):
|
|
| 2073 |
'mev_protection': 1 if trade.get('mev_protection', 0) > 0 else 0,
|
| 2074 |
'token_amount_pct_of_holding': token_pct_hold,
|
| 2075 |
'quote_amount_pct_of_holding': quote_pct_hold,
|
| 2076 |
-
'slippage': min(
|
| 2077 |
'token_amount_pct_to_total_supply': token_pct_supply,
|
| 2078 |
'success': is_success,
|
| 2079 |
'is_bundle': trade.get('is_bundle', False),
|
|
|
|
| 128 |
cache_dir: Optional[Union[str, Path]] = None,
|
| 129 |
start_date: Optional[datetime.datetime] = None,
|
| 130 |
min_trade_usd: float = 0.0,
|
| 131 |
+
max_seq_len: int = 8192,
|
| 132 |
+
p99_clamps: Optional[Dict[str, float]] = None):
|
| 133 |
|
| 134 |
self.max_seq_len = max_seq_len
|
| 135 |
+
|
| 136 |
+
# --- P99 data-driven clamp values (replace hardcoded min/max) ---
|
| 137 |
+
self.p99_clamps = {
|
| 138 |
+
'slippage': 1.0,
|
| 139 |
+
'priority_fee': 0.1,
|
| 140 |
+
'total_usd': 100000.0,
|
| 141 |
+
'history_bought_cost_sol': 30.0,
|
| 142 |
+
'realized_profit_sol': 150.0,
|
| 143 |
+
}
|
| 144 |
+
if p99_clamps:
|
| 145 |
+
self.p99_clamps.update(p99_clamps)
|
| 146 |
+
print(f"INFO: Using P99 clamps: {self.p99_clamps}")
|
| 147 |
|
| 148 |
# --- NEW: Create a persistent requests session for efficiency ---
|
| 149 |
# Configure robust HTTP session
|
|
|
|
| 646 |
sell_count = sum(1 for e in trades_win if e.get('trade_direction') == 1)
|
| 647 |
volume = sum(float(e.get('total_usd', 0.0) or 0.0) for e in trades_win)
|
| 648 |
total_txns = len(trades_win) + len(xfers_win)
|
| 649 |
+
global_fees_paid = sum(
|
| 650 |
+
float(e.get('priority_fee', 0.0) or 0.0) + float(e.get('bribe_fee', 0.0) or 0.0)
|
| 651 |
+
for e in trades_win
|
| 652 |
+
)
|
| 653 |
|
| 654 |
+
smart_trader_addrs = set(
|
| 655 |
+
e['wallet_address'] for e in trade_events
|
| 656 |
+
if e.get('event_type') == 'SmartWallet_Trade'
|
| 657 |
+
and e.get('success', False)
|
| 658 |
+
and e['timestamp'] <= ts_value
|
| 659 |
+
and holder_pct_map_ts.get(e['wallet_address'], 0.0) > 0.0
|
| 660 |
+
)
|
| 661 |
smart_traders = len(smart_trader_addrs)
|
| 662 |
|
| 663 |
kol_addrs = set()
|
|
|
|
| 846 |
|
| 847 |
# --- Define all expected numerical keys for a profile ---
|
| 848 |
expected_profile_keys = [
|
| 849 |
+
'deployed_tokens_count', 'deployed_tokens_migrated_pct',
|
| 850 |
'deployed_tokens_avg_lifetime_sec', 'deployed_tokens_avg_peak_mc_usd',
|
| 851 |
'deployed_tokens_median_peak_mc_usd', 'balance', 'transfers_in_count',
|
| 852 |
'transfers_out_count', 'spl_transfers_in_count', 'spl_transfers_out_count',
|
|
|
|
| 873 |
social_data['has_telegram'] = bool(social_data.get('telegram_channel'))
|
| 874 |
social_data['is_exchange_wallet'] = 'exchange_wallet' in profile_data.get('tags', [])
|
| 875 |
|
| 876 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 877 |
|
| 878 |
username = social_data.get('pumpfun_username') or social_data.get('twitter_username') or social_data.get('kolscan_name')
|
| 879 |
|
|
|
|
| 923 |
'mint_address': mint_addr,
|
| 924 |
'holding_time': float(holding_item.get('holding_time', 0.0) or 0.0),
|
| 925 |
'balance_pct_to_supply': min(1.0, float(holding_item.get('balance_pct_to_supply', 0.0) or 0.0)),
|
| 926 |
+
'history_bought_cost_sol': min(self.p99_clamps['history_bought_cost_sol'], float(holding_item.get('history_bought_cost_sol', 0.0) or 0.0)),
|
| 927 |
'bought_amount_sol_pct_to_native_balance': min(1.0, float(holding_item.get('bought_amount_sol_pct_to_native_balance', 0.0) or 0.0)),
|
| 928 |
'history_total_buys': float(holding_item.get('history_total_buys', 0.0) or 0.0),
|
| 929 |
'history_total_sells': float(holding_item.get('history_total_sells', 0.0) or 0.0),
|
| 930 |
'realized_profit_pnl': float(holding_item.get('realized_profit_pnl', 0.0) or 0.0),
|
| 931 |
+
'realized_profit_sol': max(-self.p99_clamps['realized_profit_sol'], min(self.p99_clamps['realized_profit_sol'], float(holding_item.get('realized_profit_sol', 0.0) or 0.0))),
|
| 932 |
'history_transfer_in': float(holding_item.get('history_transfer_in', 0.0) or 0.0),
|
| 933 |
'history_transfer_out': float(holding_item.get('history_transfer_out', 0.0) or 0.0),
|
| 934 |
'avarage_trade_gap_seconds': float(holding_item.get('avarage_trade_gap_seconds', 0.0) or 0.0),
|
|
|
|
| 940 |
compact_profile = {'wallet_address': addr}
|
| 941 |
for key in expected_profile_keys:
|
| 942 |
compact_profile[key] = float(profile_data.get(key, 0.0) or 0.0)
|
| 943 |
+
|
| 944 |
|
| 945 |
compact_social = {
|
| 946 |
'has_pf_profile': bool(social_data.get('has_pf_profile', False)),
|
|
|
|
| 2087 |
'mev_protection': 1 if trade.get('mev_protection', 0) > 0 else 0,
|
| 2088 |
'token_amount_pct_of_holding': token_pct_hold,
|
| 2089 |
'quote_amount_pct_of_holding': quote_pct_hold,
|
| 2090 |
+
'slippage': min(self.p99_clamps['slippage'], float(trade.get('slippage', 0.0) or 0.0)),
|
| 2091 |
'token_amount_pct_to_total_supply': token_pct_supply,
|
| 2092 |
'success': is_success,
|
| 2093 |
'is_bundle': trade.get('is_bundle', False),
|
data/ohlc_stats.npz
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1660
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1d757990a0158118444be61f3d944dfb125237928809b4568ac209ab260f032e
|
| 3 |
size 1660
|
inference.py
CHANGED
|
@@ -29,7 +29,7 @@ if __name__ == "__main__":
|
|
| 29 |
print("--- Oracle Inference Script (Full Pipeline Test) ---")
|
| 30 |
|
| 31 |
# --- 1. Define Configs ---
|
| 32 |
-
OHLC_SEQ_LEN =
|
| 33 |
print(f"Using {vocab.NUM_EVENT_TYPES} event types from vocabulary.")
|
| 34 |
|
| 35 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 29 |
print("--- Oracle Inference Script (Full Pipeline Test) ---")
|
| 30 |
|
| 31 |
# --- 1. Define Configs ---
|
| 32 |
+
OHLC_SEQ_LEN = 300
|
| 33 |
print(f"Using {vocab.NUM_EVENT_TYPES} event types from vocabulary.")
|
| 34 |
|
| 35 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
log.log
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:656c6818f224b26869b5d0ae10f6b522ff7eb5c7b1b3aeb59b34c3db218338a9
|
| 3 |
+
size 11360
|
models/wallet_encoder.py
CHANGED
|
@@ -47,14 +47,14 @@ class WalletEncoder(nn.Module):
|
|
| 47 |
self.mmp_dim = self.encoder.embedding_dim # 1152
|
| 48 |
|
| 49 |
# === 1. Profile Encoder (FIXED) ===
|
| 50 |
-
#
|
| 51 |
-
# 3 lifetime_trading + 12 1d_stats + 12 7d_stats =
|
| 52 |
-
self.profile_numerical_features =
|
| 53 |
self.profile_num_norm = nn.LayerNorm(self.profile_numerical_features)
|
| 54 |
|
| 55 |
|
| 56 |
# FIXED: Input dim no longer has bool embed or deployed tokens embed
|
| 57 |
-
profile_mlp_in_dim = self.profile_numerical_features #
|
| 58 |
self.profile_encoder_mlp = self._build_mlp(profile_mlp_in_dim, d_model)
|
| 59 |
|
| 60 |
|
|
@@ -152,17 +152,15 @@ class WalletEncoder(nn.Module):
|
|
| 152 |
|
| 153 |
def _encode_profile_batch(self, profile_rows, device):
|
| 154 |
batch_size = len(profile_rows)
|
| 155 |
-
# FIXED:
|
| 156 |
num_tensor = torch.zeros(batch_size, self.profile_numerical_features, device=device, dtype=self.dtype)
|
| 157 |
# bool_tensor removed
|
| 158 |
# time_tensor removed
|
| 159 |
|
| 160 |
for i, row in enumerate(profile_rows):
|
| 161 |
-
# A: Numerical (FIXED:
|
| 162 |
num_data = [
|
| 163 |
-
# 1.
|
| 164 |
-
row.get('age', 0.0),
|
| 165 |
-
# 2. Deployed Token Aggregates (5)
|
| 166 |
row.get('deployed_tokens_count', 0.0),
|
| 167 |
row.get('deployed_tokens_migrated_pct', 0.0),
|
| 168 |
row.get('deployed_tokens_avg_lifetime_sec', 0.0),
|
|
@@ -195,7 +193,7 @@ class WalletEncoder(nn.Module):
|
|
| 195 |
|
| 196 |
# C: Booleans and deployed_tokens lists are GONE
|
| 197 |
|
| 198 |
-
# Log-normalize all numerical features (
|
| 199 |
num_embed = self.profile_num_norm(self._safe_signed_log(num_tensor))
|
| 200 |
|
| 201 |
# The profile fused tensor is now just the numerical embeddings
|
|
|
|
| 47 |
self.mmp_dim = self.encoder.embedding_dim # 1152
|
| 48 |
|
| 49 |
# === 1. Profile Encoder (FIXED) ===
|
| 50 |
+
# 5 deployer_stats + 1 balance + 4 lifetime_counts +
|
| 51 |
+
# 3 lifetime_trading + 12 1d_stats + 12 7d_stats = 37
|
| 52 |
+
self.profile_numerical_features = 37
|
| 53 |
self.profile_num_norm = nn.LayerNorm(self.profile_numerical_features)
|
| 54 |
|
| 55 |
|
| 56 |
# FIXED: Input dim no longer has bool embed or deployed tokens embed
|
| 57 |
+
profile_mlp_in_dim = self.profile_numerical_features # 37
|
| 58 |
self.profile_encoder_mlp = self._build_mlp(profile_mlp_in_dim, d_model)
|
| 59 |
|
| 60 |
|
|
|
|
| 152 |
|
| 153 |
def _encode_profile_batch(self, profile_rows, device):
|
| 154 |
batch_size = len(profile_rows)
|
| 155 |
+
# FIXED: 37 numerical features
|
| 156 |
num_tensor = torch.zeros(batch_size, self.profile_numerical_features, device=device, dtype=self.dtype)
|
| 157 |
# bool_tensor removed
|
| 158 |
# time_tensor removed
|
| 159 |
|
| 160 |
for i, row in enumerate(profile_rows):
|
| 161 |
+
# A: Numerical (FIXED: 37 features, MUST be present)
|
| 162 |
num_data = [
|
| 163 |
+
# 1. Deployed Token Aggregates (5)
|
|
|
|
|
|
|
| 164 |
row.get('deployed_tokens_count', 0.0),
|
| 165 |
row.get('deployed_tokens_migrated_pct', 0.0),
|
| 166 |
row.get('deployed_tokens_avg_lifetime_sec', 0.0),
|
|
|
|
| 193 |
|
| 194 |
# C: Booleans and deployed_tokens lists are GONE
|
| 195 |
|
| 196 |
+
# Log-normalize all numerical features (stats, etc.)
|
| 197 |
num_embed = self.profile_num_norm(self._safe_signed_log(num_tensor))
|
| 198 |
|
| 199 |
# The profile fused tensor is now just the numerical embeddings
|
sample_DYtPmhyxPDbMEdVP_0.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
scripts/analyze_distribution.py
CHANGED
|
@@ -27,6 +27,55 @@ def get_client():
|
|
| 27 |
database=CLICKHOUSE_DATABASE
|
| 28 |
)
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
def fetch_all_metrics(client):
|
| 31 |
"""
|
| 32 |
Fetches all needed metrics for all tokens in a single query.
|
|
|
|
| 27 |
database=CLICKHOUSE_DATABASE
|
| 28 |
)
|
| 29 |
|
| 30 |
+
def compute_p99_clamps(client):
|
| 31 |
+
"""
|
| 32 |
+
Computes P99 percentile clamp values from ClickHouse for fields prone to
|
| 33 |
+
garbage outliers. These values replace hardcoded clamps in data_loader.py.
|
| 34 |
+
Returns a dict of {field_name: p99_value}.
|
| 35 |
+
"""
|
| 36 |
+
print(" -> Computing P99 clamp values from trades table...")
|
| 37 |
+
trade_query = """
|
| 38 |
+
SELECT
|
| 39 |
+
quantile(0.99)(abs(slippage)) AS p99_slippage,
|
| 40 |
+
quantile(0.99)(priority_fee) AS p99_priority_fee,
|
| 41 |
+
quantile(0.99)(total_usd) AS p99_total_usd
|
| 42 |
+
FROM trades
|
| 43 |
+
WHERE success = 1
|
| 44 |
+
"""
|
| 45 |
+
trade_row = client.execute(trade_query)
|
| 46 |
+
|
| 47 |
+
print(" -> Computing P99 clamp values from wallet_holdings table...")
|
| 48 |
+
holdings_query = """
|
| 49 |
+
SELECT
|
| 50 |
+
quantile(0.99)(history_bought_cost_sol) AS p99_bought_cost_sol,
|
| 51 |
+
quantile(0.99)(abs(realized_profit_sol)) AS p99_realized_profit_sol
|
| 52 |
+
FROM wallet_holdings
|
| 53 |
+
"""
|
| 54 |
+
holdings_row = client.execute(holdings_query)
|
| 55 |
+
|
| 56 |
+
clamps = {
|
| 57 |
+
# Defaults as fallback if queries return nothing
|
| 58 |
+
'slippage': 1.0,
|
| 59 |
+
'priority_fee': 0.1,
|
| 60 |
+
'total_usd': 100000.0,
|
| 61 |
+
'history_bought_cost_sol': 30.0,
|
| 62 |
+
'realized_profit_sol': 150.0,
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
if trade_row and trade_row[0]:
|
| 66 |
+
r = trade_row[0]
|
| 67 |
+
clamps['slippage'] = max(float(r[0]), 0.01)
|
| 68 |
+
clamps['priority_fee'] = max(float(r[1]), 1e-9)
|
| 69 |
+
clamps['total_usd'] = max(float(r[2]), 1.0)
|
| 70 |
+
|
| 71 |
+
if holdings_row and holdings_row[0]:
|
| 72 |
+
r = holdings_row[0]
|
| 73 |
+
clamps['history_bought_cost_sol'] = max(float(r[0]), 0.01)
|
| 74 |
+
clamps['realized_profit_sol'] = max(float(r[1]), 0.01)
|
| 75 |
+
|
| 76 |
+
print(f" -> P99 Clamps: {clamps}")
|
| 77 |
+
return clamps
|
| 78 |
+
|
| 79 |
def fetch_all_metrics(client):
|
| 80 |
"""
|
| 81 |
Fetches all needed metrics for all tokens in a single query.
|
scripts/cache_dataset.py
CHANGED
|
@@ -22,7 +22,7 @@ logging.getLogger("huggingface_hub").setLevel(logging.WARNING)
|
|
| 22 |
|
| 23 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 24 |
|
| 25 |
-
from scripts.analyze_distribution import get_return_class_map
|
| 26 |
from scripts.compute_quality_score import get_token_quality_scores, fetch_token_metrics, _bucket_id, _midrank_percentiles, EPS
|
| 27 |
|
| 28 |
from clickhouse_driver import Client as ClickHouseClient
|
|
@@ -65,7 +65,8 @@ def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map
|
|
| 65 |
horizons_seconds=dataset_config['horizons_seconds'],
|
| 66 |
quantiles=dataset_config['quantiles'],
|
| 67 |
min_trade_usd=dataset_config['min_trade_usd'],
|
| 68 |
-
max_seq_len=dataset_config['max_seq_len']
|
|
|
|
| 69 |
)
|
| 70 |
_worker_dataset.sampled_mints = dataset_config['sampled_mints']
|
| 71 |
_worker_return_class_map = return_class_map
|
|
@@ -179,11 +180,14 @@ def main():
|
|
| 179 |
return_class_map, _ = get_return_class_map(clickhouse_client)
|
| 180 |
print(f"INFO: Loaded {len(return_class_map)} classified tokens.")
|
| 181 |
|
|
|
|
|
|
|
|
|
|
| 182 |
print("INFO: Fetching Quality Scores...")
|
| 183 |
quality_scores_map = get_token_quality_scores(clickhouse_client)
|
| 184 |
print(f"INFO: Loaded {len(quality_scores_map)} quality scores.")
|
| 185 |
|
| 186 |
-
dataset = OracleDataset(data_fetcher=data_fetcher, max_samples=args.max_samples, start_date=start_date_dt, ohlc_stats_path=args.ohlc_stats_path, horizons_seconds=args.horizons_seconds, quantiles=args.quantiles, min_trade_usd=args.min_trade_usd, max_seq_len=args.context_length)
|
| 187 |
|
| 188 |
if len(dataset) == 0:
|
| 189 |
print("WARNING: No samples. Exiting.")
|
|
@@ -219,7 +223,7 @@ def main():
|
|
| 219 |
print(f"INFO: Workers: {args.num_workers}")
|
| 220 |
|
| 221 |
db_config = {'clickhouse_host': args.clickhouse_host, 'clickhouse_port': args.clickhouse_port, 'neo4j_uri': args.neo4j_uri, 'neo4j_user': args.neo4j_user, 'neo4j_password': args.neo4j_password}
|
| 222 |
-
dataset_config = {'max_samples': args.max_samples, 'start_date': start_date_dt, 'ohlc_stats_path': args.ohlc_stats_path, 'horizons_seconds': args.horizons_seconds, 'quantiles': args.quantiles, 'min_trade_usd': args.min_trade_usd, 'max_seq_len': args.context_length, 'sampled_mints': filtered_mints}
|
| 223 |
|
| 224 |
# Build tasks with class-aware multi-sampling for balanced cache
|
| 225 |
import random
|
|
|
|
| 22 |
|
| 23 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 24 |
|
| 25 |
+
from scripts.analyze_distribution import get_return_class_map, compute_p99_clamps
|
| 26 |
from scripts.compute_quality_score import get_token_quality_scores, fetch_token_metrics, _bucket_id, _midrank_percentiles, EPS
|
| 27 |
|
| 28 |
from clickhouse_driver import Client as ClickHouseClient
|
|
|
|
| 65 |
horizons_seconds=dataset_config['horizons_seconds'],
|
| 66 |
quantiles=dataset_config['quantiles'],
|
| 67 |
min_trade_usd=dataset_config['min_trade_usd'],
|
| 68 |
+
max_seq_len=dataset_config['max_seq_len'],
|
| 69 |
+
p99_clamps=dataset_config.get('p99_clamps')
|
| 70 |
)
|
| 71 |
_worker_dataset.sampled_mints = dataset_config['sampled_mints']
|
| 72 |
_worker_return_class_map = return_class_map
|
|
|
|
| 180 |
return_class_map, _ = get_return_class_map(clickhouse_client)
|
| 181 |
print(f"INFO: Loaded {len(return_class_map)} classified tokens.")
|
| 182 |
|
| 183 |
+
print("INFO: Computing P99 clamp values...")
|
| 184 |
+
p99_clamps = compute_p99_clamps(clickhouse_client)
|
| 185 |
+
|
| 186 |
print("INFO: Fetching Quality Scores...")
|
| 187 |
quality_scores_map = get_token_quality_scores(clickhouse_client)
|
| 188 |
print(f"INFO: Loaded {len(quality_scores_map)} quality scores.")
|
| 189 |
|
| 190 |
+
dataset = OracleDataset(data_fetcher=data_fetcher, max_samples=args.max_samples, start_date=start_date_dt, ohlc_stats_path=args.ohlc_stats_path, horizons_seconds=args.horizons_seconds, quantiles=args.quantiles, min_trade_usd=args.min_trade_usd, max_seq_len=args.context_length, p99_clamps=p99_clamps)
|
| 191 |
|
| 192 |
if len(dataset) == 0:
|
| 193 |
print("WARNING: No samples. Exiting.")
|
|
|
|
| 223 |
print(f"INFO: Workers: {args.num_workers}")
|
| 224 |
|
| 225 |
db_config = {'clickhouse_host': args.clickhouse_host, 'clickhouse_port': args.clickhouse_port, 'neo4j_uri': args.neo4j_uri, 'neo4j_user': args.neo4j_user, 'neo4j_password': args.neo4j_password}
|
| 226 |
+
dataset_config = {'max_samples': args.max_samples, 'start_date': start_date_dt, 'ohlc_stats_path': args.ohlc_stats_path, 'horizons_seconds': args.horizons_seconds, 'quantiles': args.quantiles, 'min_trade_usd': args.min_trade_usd, 'max_seq_len': args.context_length, 'sampled_mints': filtered_mints, 'p99_clamps': p99_clamps}
|
| 227 |
|
| 228 |
# Build tasks with class-aware multi-sampling for balanced cache
|
| 229 |
import random
|