Upload folder using huggingface_hub
Browse files- data/data_collator.py +4 -0
- data/data_loader.py +28 -27
- data/ohlc_stats.npz +1 -1
- log.log +2 -2
- models/model.py +12 -1
- models/multi_modal_processor.py +7 -3
- pre_cache.sh +1 -1
- scripts/analyze_distribution.py +285 -437
- scripts/cache_dataset.py +171 -47
- scripts/compute_quality_score.py +132 -47
- token_stats.rs +857 -0
- train.py +15 -3
- train.sh +1 -1
data/data_collator.py
CHANGED
|
@@ -711,11 +711,15 @@ class MemecoinCollator:
|
|
| 711 |
# Labels
|
| 712 |
'labels': torch.stack([item['labels'] for item in batch]) if batch and 'labels' in batch[0] else None,
|
| 713 |
'labels_mask': torch.stack([item['labels_mask'] for item in batch]) if batch and 'labels_mask' in batch[0] else None,
|
|
|
|
| 714 |
# Debug info
|
| 715 |
'token_addresses': [item.get('token_address', 'unknown') for item in batch],
|
| 716 |
't_cutoffs': [item.get('t_cutoff', 'unknown') for item in batch],
|
| 717 |
'sample_indices': [item.get('sample_idx', -1) for item in batch]
|
| 718 |
}
|
| 719 |
|
|
|
|
|
|
|
|
|
|
| 720 |
# Filter out None values (e.g., if no labels provided)
|
| 721 |
return {k: v for k, v in collated_batch.items() if v is not None}
|
|
|
|
| 711 |
# Labels
|
| 712 |
'labels': torch.stack([item['labels'] for item in batch]) if batch and 'labels' in batch[0] else None,
|
| 713 |
'labels_mask': torch.stack([item['labels_mask'] for item in batch]) if batch and 'labels_mask' in batch[0] else None,
|
| 714 |
+
'quality_score': torch.stack([item['quality_score'] for item in batch]) if batch and 'quality_score' in batch[0] else None,
|
| 715 |
# Debug info
|
| 716 |
'token_addresses': [item.get('token_address', 'unknown') for item in batch],
|
| 717 |
't_cutoffs': [item.get('t_cutoff', 'unknown') for item in batch],
|
| 718 |
'sample_indices': [item.get('sample_idx', -1) for item in batch]
|
| 719 |
}
|
| 720 |
|
| 721 |
+
if collated_batch['quality_score'] is None:
|
| 722 |
+
raise RuntimeError("FATAL: Missing quality_score in batch items. Rebuild cache with quality_score enabled.")
|
| 723 |
+
|
| 724 |
# Filter out None values (e.g., if no labels provided)
|
| 725 |
return {k: v for k, v in collated_batch.items() if v is not None}
|
data/data_loader.py
CHANGED
|
@@ -156,43 +156,41 @@ class OracleDataset(Dataset):
|
|
| 156 |
if not self.cached_files:
|
| 157 |
raise RuntimeError(f"Cache directory '{self.cache_dir}' provided but contains no 'sample_*.pt' files.")
|
| 158 |
|
| 159 |
-
# --- NEW: Strict Metadata & Weighting ---
|
| 160 |
-
metadata_path = self.cache_dir / "metadata.jsonl"
|
| 161 |
-
if not metadata_path.exists():
|
| 162 |
-
raise RuntimeError(f"FATAL: metadata.jsonl not found in {self.cache_dir}. Cannot train without class-balanced sampling.")
|
| 163 |
-
|
| 164 |
-
print(f"INFO: Loading metadata from {metadata_path}...")
|
| 165 |
file_class_map = {}
|
| 166 |
class_counts = defaultdict(int)
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
|
|
|
|
|
|
| 170 |
try:
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
|
| 179 |
print(f"INFO: Class Distribution: {dict(class_counts)}")
|
| 180 |
-
|
| 181 |
# Compute Weights
|
| 182 |
self.weights_list = []
|
| 183 |
valid_files = []
|
| 184 |
-
|
| 185 |
# We iterate properly sorted cached files to align with __getitem__ index
|
| 186 |
for p in self.cached_files:
|
| 187 |
fname = p.name
|
| 188 |
if fname not in file_class_map:
|
| 189 |
-
#
|
| 190 |
-
|
| 191 |
-
# Current pipeline writes metadata only for successful caches.
|
| 192 |
-
# So if it's in cached_files but not metadata, it might be a stale file.
|
| 193 |
-
print(f"WARN: File {fname} found in cache but missing metadata. Skipping.")
|
| 194 |
continue
|
| 195 |
-
|
| 196 |
cid = file_class_map[fname]
|
| 197 |
count = class_counts[cid]
|
| 198 |
weight = 1.0 / count if count > 0 else 0.0
|
|
@@ -976,7 +974,8 @@ class OracleDataset(Dataset):
|
|
| 976 |
"fee_collections",
|
| 977 |
"burns",
|
| 978 |
"supply_locks",
|
| 979 |
-
"migrations"
|
|
|
|
| 980 |
]
|
| 981 |
missing_keys = [key for key in required_keys if key not in raw_data]
|
| 982 |
if missing_keys:
|
|
@@ -1683,7 +1682,8 @@ class OracleDataset(Dataset):
|
|
| 1683 |
'graph_links': graph_links,
|
| 1684 |
'embedding_pooler': pooler,
|
| 1685 |
'labels': torch.zeros(len(self.horizons_seconds), dtype=torch.float32),
|
| 1686 |
-
'labels_mask': torch.zeros(len(self.horizons_seconds), dtype=torch.float32)
|
|
|
|
| 1687 |
}
|
| 1688 |
|
| 1689 |
# Ensure sorted
|
|
@@ -1758,5 +1758,6 @@ class OracleDataset(Dataset):
|
|
| 1758 |
'graph_links': graph_links,
|
| 1759 |
'embedding_pooler': pooler,
|
| 1760 |
'labels': torch.tensor(label_values, dtype=torch.float32),
|
| 1761 |
-
'labels_mask': torch.tensor(mask_values, dtype=torch.float32)
|
|
|
|
| 1762 |
}
|
|
|
|
| 156 |
if not self.cached_files:
|
| 157 |
raise RuntimeError(f"Cache directory '{self.cache_dir}' provided but contains no 'sample_*.pt' files.")
|
| 158 |
|
| 159 |
+
# --- NEW: Strict Metadata & Weighting (from cached samples) ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
file_class_map = {}
|
| 161 |
class_counts = defaultdict(int)
|
| 162 |
+
|
| 163 |
+
# Read class_id directly from each cached sample
|
| 164 |
+
for p in self.cached_files:
|
| 165 |
+
try:
|
| 166 |
+
# Cached samples are trusted local artifacts; allow full load.
|
| 167 |
try:
|
| 168 |
+
cached_item = torch.load(p, map_location="cpu", weights_only=False)
|
| 169 |
+
except TypeError:
|
| 170 |
+
cached_item = torch.load(p, map_location="cpu")
|
| 171 |
+
cid = cached_item.get("class_id")
|
| 172 |
+
if cid is None:
|
| 173 |
+
print(f"WARN: File {p.name} missing class_id. Skipping.")
|
| 174 |
+
continue
|
| 175 |
+
file_class_map[p.name] = cid
|
| 176 |
+
class_counts[cid] += 1
|
| 177 |
+
except Exception as e:
|
| 178 |
+
print(f"WARN: Failed to read cached sample {p.name}: {e}")
|
| 179 |
|
| 180 |
print(f"INFO: Class Distribution: {dict(class_counts)}")
|
| 181 |
+
|
| 182 |
# Compute Weights
|
| 183 |
self.weights_list = []
|
| 184 |
valid_files = []
|
| 185 |
+
|
| 186 |
# We iterate properly sorted cached files to align with __getitem__ index
|
| 187 |
for p in self.cached_files:
|
| 188 |
fname = p.name
|
| 189 |
if fname not in file_class_map:
|
| 190 |
+
# If file exists but missing class_id, it might be stale or from an older cache.
|
| 191 |
+
print(f"WARN: File {fname} found in cache but missing class_id. Skipping.")
|
|
|
|
|
|
|
|
|
|
| 192 |
continue
|
| 193 |
+
|
| 194 |
cid = file_class_map[fname]
|
| 195 |
count = class_counts[cid]
|
| 196 |
weight = 1.0 / count if count > 0 else 0.0
|
|
|
|
| 974 |
"fee_collections",
|
| 975 |
"burns",
|
| 976 |
"supply_locks",
|
| 977 |
+
"migrations",
|
| 978 |
+
"quality_score"
|
| 979 |
]
|
| 980 |
missing_keys = [key for key in required_keys if key not in raw_data]
|
| 981 |
if missing_keys:
|
|
|
|
| 1682 |
'graph_links': graph_links,
|
| 1683 |
'embedding_pooler': pooler,
|
| 1684 |
'labels': torch.zeros(len(self.horizons_seconds), dtype=torch.float32),
|
| 1685 |
+
'labels_mask': torch.zeros(len(self.horizons_seconds), dtype=torch.float32),
|
| 1686 |
+
'quality_score': torch.tensor(raw_data['quality_score'], dtype=torch.float32)
|
| 1687 |
}
|
| 1688 |
|
| 1689 |
# Ensure sorted
|
|
|
|
| 1758 |
'graph_links': graph_links,
|
| 1759 |
'embedding_pooler': pooler,
|
| 1760 |
'labels': torch.tensor(label_values, dtype=torch.float32),
|
| 1761 |
+
'labels_mask': torch.tensor(mask_values, dtype=torch.float32),
|
| 1762 |
+
'quality_score': torch.tensor(raw_data['quality_score'], dtype=torch.float32)
|
| 1763 |
}
|
data/ohlc_stats.npz
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1660
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6f2c86bf03e5761e7fb319a54274e032f7aa1d01dd5873f2f44a52c9e0be5244
|
| 3 |
size 1660
|
log.log
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:461e55d31752fd72f09aa30c5bcc3a619654ae86ddf1e759c9c57b0dc5db53f6
|
| 3 |
+
size 21794
|
models/model.py
CHANGED
|
@@ -54,7 +54,9 @@ class Oracle(nn.Module):
|
|
| 54 |
self.dtype = dtype
|
| 55 |
|
| 56 |
# --- 2. Load Qwen3 Configuration (architecture only; training from scratch) ---
|
| 57 |
-
|
|
|
|
|
|
|
| 58 |
self.d_model = model_config.hidden_size
|
| 59 |
self.model = AutoModel.from_config(model_config, trust_remote_code=True)
|
| 60 |
self.model.to(self.device, dtype=self.dtype)
|
|
@@ -65,6 +67,11 @@ class Oracle(nn.Module):
|
|
| 65 |
nn.GELU(),
|
| 66 |
nn.Linear(self.d_model, self.num_outputs)
|
| 67 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
self.event_type_to_id = event_type_to_id
|
| 70 |
|
|
@@ -947,8 +954,10 @@ class Oracle(nn.Module):
|
|
| 947 |
empty_hidden = torch.empty(0, L, self.d_model, device=device, dtype=self.dtype)
|
| 948 |
empty_mask = torch.empty(0, L, device=device, dtype=torch.long)
|
| 949 |
empty_quantiles = torch.empty(0, self.num_outputs, device=device, dtype=self.dtype)
|
|
|
|
| 950 |
return {
|
| 951 |
'quantile_logits': empty_quantiles,
|
|
|
|
| 952 |
'pooled_states': torch.empty(0, self.d_model, device=device, dtype=self.dtype),
|
| 953 |
'hidden_states': empty_hidden,
|
| 954 |
'attention_mask': empty_mask
|
|
@@ -1068,9 +1077,11 @@ class Oracle(nn.Module):
|
|
| 1068 |
sequence_hidden = outputs.last_hidden_state
|
| 1069 |
pooled_states = self._pool_hidden_states(sequence_hidden, hf_attention_mask)
|
| 1070 |
quantile_logits = self.quantile_head(pooled_states)
|
|
|
|
| 1071 |
|
| 1072 |
return {
|
| 1073 |
'quantile_logits': quantile_logits,
|
|
|
|
| 1074 |
'pooled_states': pooled_states,
|
| 1075 |
'hidden_states': sequence_hidden,
|
| 1076 |
'attention_mask': hf_attention_mask
|
|
|
|
| 54 |
self.dtype = dtype
|
| 55 |
|
| 56 |
# --- 2. Load Qwen3 Configuration (architecture only; training from scratch) ---
|
| 57 |
+
hf_token = os.getenv("Hf_TOKEN") or os.getenv("HF_TOKEN")
|
| 58 |
+
hf_kwargs = {"token": hf_token} if hf_token else {}
|
| 59 |
+
model_config = AutoConfig.from_pretrained(model_config_name, trust_remote_code=True, **hf_kwargs)
|
| 60 |
self.d_model = model_config.hidden_size
|
| 61 |
self.model = AutoModel.from_config(model_config, trust_remote_code=True)
|
| 62 |
self.model.to(self.device, dtype=self.dtype)
|
|
|
|
| 67 |
nn.GELU(),
|
| 68 |
nn.Linear(self.d_model, self.num_outputs)
|
| 69 |
)
|
| 70 |
+
self.quality_head = nn.Sequential(
|
| 71 |
+
nn.Linear(self.d_model, self.d_model),
|
| 72 |
+
nn.GELU(),
|
| 73 |
+
nn.Linear(self.d_model, 1)
|
| 74 |
+
)
|
| 75 |
|
| 76 |
self.event_type_to_id = event_type_to_id
|
| 77 |
|
|
|
|
| 954 |
empty_hidden = torch.empty(0, L, self.d_model, device=device, dtype=self.dtype)
|
| 955 |
empty_mask = torch.empty(0, L, device=device, dtype=torch.long)
|
| 956 |
empty_quantiles = torch.empty(0, self.num_outputs, device=device, dtype=self.dtype)
|
| 957 |
+
empty_quality = torch.empty(0, device=device, dtype=self.dtype)
|
| 958 |
return {
|
| 959 |
'quantile_logits': empty_quantiles,
|
| 960 |
+
'quality_logits': empty_quality,
|
| 961 |
'pooled_states': torch.empty(0, self.d_model, device=device, dtype=self.dtype),
|
| 962 |
'hidden_states': empty_hidden,
|
| 963 |
'attention_mask': empty_mask
|
|
|
|
| 1077 |
sequence_hidden = outputs.last_hidden_state
|
| 1078 |
pooled_states = self._pool_hidden_states(sequence_hidden, hf_attention_mask)
|
| 1079 |
quantile_logits = self.quantile_head(pooled_states)
|
| 1080 |
+
quality_logits = self.quality_head(pooled_states).squeeze(-1)
|
| 1081 |
|
| 1082 |
return {
|
| 1083 |
'quantile_logits': quantile_logits,
|
| 1084 |
+
'quality_logits': quality_logits,
|
| 1085 |
'pooled_states': pooled_states,
|
| 1086 |
'hidden_states': sequence_hidden,
|
| 1087 |
'attention_mask': hf_attention_mask
|
models/multi_modal_processor.py
CHANGED
|
@@ -38,13 +38,16 @@ class MultiModalEncoder:
|
|
| 38 |
|
| 39 |
|
| 40 |
try:
|
|
|
|
|
|
|
| 41 |
# --- SigLIP Loading with Config Fix ---
|
| 42 |
self.processor = AutoProcessor.from_pretrained(
|
| 43 |
self.model_id,
|
| 44 |
-
use_fast=True
|
|
|
|
| 45 |
)
|
| 46 |
|
| 47 |
-
config = AutoConfig.from_pretrained(self.model_id)
|
| 48 |
|
| 49 |
if not hasattr(config, 'projection_dim'):
|
| 50 |
# print("❗ Config missing projection_dim, patching...")
|
|
@@ -54,7 +57,8 @@ class MultiModalEncoder:
|
|
| 54 |
self.model_id,
|
| 55 |
config=config,
|
| 56 |
dtype=self.dtype, # Use torch_dtype for from_pretrained
|
| 57 |
-
trust_remote_code=False
|
|
|
|
| 58 |
).to(self.device).eval()
|
| 59 |
# -----------------------------------------------
|
| 60 |
|
|
|
|
| 38 |
|
| 39 |
|
| 40 |
try:
|
| 41 |
+
hf_token = os.getenv("Hf_TOKEN") or os.getenv("HF_TOKEN")
|
| 42 |
+
hf_kwargs = {"token": hf_token} if hf_token else {}
|
| 43 |
# --- SigLIP Loading with Config Fix ---
|
| 44 |
self.processor = AutoProcessor.from_pretrained(
|
| 45 |
self.model_id,
|
| 46 |
+
use_fast=True,
|
| 47 |
+
**hf_kwargs
|
| 48 |
)
|
| 49 |
|
| 50 |
+
config = AutoConfig.from_pretrained(self.model_id, **hf_kwargs)
|
| 51 |
|
| 52 |
if not hasattr(config, 'projection_dim'):
|
| 53 |
# print("❗ Config missing projection_dim, patching...")
|
|
|
|
| 57 |
self.model_id,
|
| 58 |
config=config,
|
| 59 |
dtype=self.dtype, # Use torch_dtype for from_pretrained
|
| 60 |
+
trust_remote_code=False,
|
| 61 |
+
**hf_kwargs
|
| 62 |
).to(self.device).eval()
|
| 63 |
# -----------------------------------------------
|
| 64 |
|
pre_cache.sh
CHANGED
|
@@ -4,6 +4,6 @@
|
|
| 4 |
echo "Starting dataset caching..."
|
| 5 |
python3 scripts/cache_dataset.py \
|
| 6 |
--ohlc_stats_path "/workspace/apollo/data/ohlc_stats.npz" \
|
| 7 |
-
--max_samples
|
| 8 |
|
| 9 |
echo "Done!"
|
|
|
|
| 4 |
echo "Starting dataset caching..."
|
| 5 |
python3 scripts/cache_dataset.py \
|
| 6 |
--ohlc_stats_path "/workspace/apollo/data/ohlc_stats.npz" \
|
| 7 |
+
--max_samples 50
|
| 8 |
|
| 9 |
echo "Done!"
|
scripts/analyze_distribution.py
CHANGED
|
@@ -1,21 +1,22 @@
|
|
| 1 |
-
|
| 2 |
import os
|
| 3 |
import sys
|
| 4 |
import datetime
|
|
|
|
|
|
|
| 5 |
from clickhouse_driver import Client as ClickHouseClient
|
| 6 |
|
| 7 |
# Add parent to path
|
| 8 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 9 |
|
| 10 |
-
|
| 11 |
-
# load_dotenv()
|
| 12 |
|
| 13 |
CLICKHOUSE_HOST = os.getenv("CLICKHOUSE_HOST", "localhost")
|
| 14 |
CLICKHOUSE_PORT = int(os.getenv("CLICKHOUSE_PORT", 9000))
|
| 15 |
-
# .env shows empty user/pass, which implies 'default' user and empty password for ClickHouse
|
| 16 |
CLICKHOUSE_USER = os.getenv("CLICKHOUSE_USER", "default")
|
| 17 |
CLICKHOUSE_PASSWORD = os.getenv("CLICKHOUSE_PASSWORD", "")
|
| 18 |
CLICKHOUSE_DATABASE = os.getenv("CLICKHOUSE_DATABASE", "default")
|
|
|
|
|
|
|
| 19 |
|
| 20 |
def get_client():
|
| 21 |
return ClickHouseClient(
|
|
@@ -26,484 +27,331 @@ def get_client():
|
|
| 26 |
database=CLICKHOUSE_DATABASE
|
| 27 |
)
|
| 28 |
|
| 29 |
-
def
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
min(val),
|
| 38 |
-
max(val),
|
| 39 |
-
count()
|
| 40 |
-
FROM (
|
| 41 |
-
{subquery}
|
| 42 |
-
)
|
| 43 |
"""
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
pct = (r[1] / count_val * 100) if count_val > 0 else 0
|
| 80 |
-
print(f" {r[0]}: {r[1]} ({pct:.1f}%)")
|
| 81 |
-
except Exception as e:
|
| 82 |
-
print(f" Error calculating buckets: {e}")
|
| 83 |
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
"""
|
| 93 |
-
# We need to know if the inner query produces 'base_address' or 'token_address'
|
| 94 |
-
# Currently our queries produce 'base_address' mostly, except token_metrics ones.
|
| 95 |
-
# Let's standardize inner queries in the main loop to alias the key column to 'join_key'
|
| 96 |
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
-
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
- Calculates Median Fees, Volume, Holders for each Class (1-4).
|
| 112 |
-
- Downgrades tokens with metrics < 10% of their class median to Class 5 (Manipulated).
|
| 113 |
"""
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
# aggregating trades for fees/vol to appear more robust than token_metrics snapshots
|
| 117 |
-
print(" -> Fetching metrics for classification...")
|
| 118 |
-
# SQL OPTIMIZATION:
|
| 119 |
-
# 1. Use token_metrics for Volume/Holders (Pre-computed).
|
| 120 |
-
# 2. Pre-aggregate trades for Fees in a subquery to avoid massive JOIN explosion.
|
| 121 |
-
query = """
|
| 122 |
-
SELECT
|
| 123 |
-
tm.token_address,
|
| 124 |
-
(argMax(tm.ath_price_usd, tm.updated_at) / 0.000004) as ret,
|
| 125 |
-
any(tr.fees) as fees,
|
| 126 |
-
argMax(tm.total_volume_usd, tm.updated_at) as vol,
|
| 127 |
-
argMax(tm.unique_holders, tm.updated_at) as holders
|
| 128 |
-
FROM token_metrics tm
|
| 129 |
-
LEFT JOIN (
|
| 130 |
-
SELECT
|
| 131 |
-
base_address,
|
| 132 |
-
sum(priority_fee + coin_creator_fee) as fees
|
| 133 |
-
FROM trades
|
| 134 |
-
GROUP BY base_address
|
| 135 |
-
) tr ON tm.token_address = tr.base_address
|
| 136 |
-
GROUP BY tm.token_address
|
| 137 |
-
HAVING ret <= 10000
|
| 138 |
"""
|
| 139 |
-
rows = client.execute(query)
|
| 140 |
-
|
| 141 |
# 1. Initial Classification
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
# Storage for stats calculation
|
| 145 |
-
class_stats = {i: {'fees': [], 'vol': [], 'holders': []} for i in range(len(RETURN_THRESHOLDS)-1)}
|
| 146 |
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
ret_val = r[1]
|
| 151 |
-
fees = r[2] or 0.0
|
| 152 |
-
vol = r[3] or 0.0
|
| 153 |
-
holders = r[4] or 0
|
| 154 |
|
| 155 |
-
|
|
|
|
| 156 |
for i in range(len(RETURN_THRESHOLDS) - 1):
|
| 157 |
lower = RETURN_THRESHOLDS[i]
|
| 158 |
upper = RETURN_THRESHOLDS[i+1]
|
| 159 |
-
if
|
| 160 |
-
|
|
|
|
| 161 |
break
|
| 162 |
|
| 163 |
-
if
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
|
|
|
|
|
|
| 168 |
|
| 169 |
-
# 2. Calculate
|
|
|
|
| 170 |
thresholds = {}
|
| 171 |
-
|
| 172 |
-
for i in range(1,
|
| 173 |
-
|
| 174 |
-
if len(
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
|
| 179 |
thresholds[i] = {
|
| 180 |
'fees': med_fees * 0.5,
|
| 181 |
'vol': med_vol * 0.5,
|
| 182 |
'holders': med_holders * 0.5
|
| 183 |
}
|
| 184 |
-
print(f" [Class {i}] Median Fees: {med_fees:.4f} (Thresh: {thresholds[i]['fees']:.4f}) | Median Vol: ${med_vol:.0f} (Thresh: ${thresholds[i]['vol']:.0f}) | Median Holders: {med_holders:.0f} (Thresh: {thresholds[i]['holders']:.0f})")
|
| 185 |
else:
|
| 186 |
-
|
| 187 |
|
| 188 |
# 3. Reclassification
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
|
|
|
| 192 |
|
| 193 |
-
for
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
|
|
|
| 200 |
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
final_map[token] = cid
|
| 206 |
-
else:
|
| 207 |
-
final_map[token] = cid
|
| 208 |
|
| 209 |
-
|
| 210 |
-
return final_map, thresholds
|
| 211 |
-
|
| 212 |
-
def analyze():
|
| 213 |
-
client = get_client()
|
| 214 |
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
-
#
|
| 221 |
-
|
| 222 |
-
for
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
# Common SQL parts
|
| 236 |
-
# We need a robust base for the WHERE clause variables (fees, vol, holders)
|
| 237 |
-
# Since we can't easily alias in the WHERE clause of a subquery filter without re-joining,
|
| 238 |
-
# we will rely on a standardized CTE-like structure or just simpler subqueries in the condition.
|
| 239 |
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
|
|
|
|
|
|
|
|
|
| 243 |
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
# So we define a base cohort query that computes these 4 values for EVERY token,
|
| 247 |
-
# and then wrap it with the WHERE clause.
|
| 248 |
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
any(tr.fees) as fees,
|
| 254 |
-
argMax(tm.total_volume_usd, tm.updated_at) as vol,
|
| 255 |
-
argMax(tm.unique_holders, tm.updated_at) as holders
|
| 256 |
-
FROM token_metrics tm
|
| 257 |
-
LEFT JOIN (
|
| 258 |
-
SELECT base_address, sum(priority_fee + coin_creator_fee) as fees
|
| 259 |
-
FROM trades
|
| 260 |
-
GROUP BY base_address
|
| 261 |
-
) tr ON tm.token_address = tr.base_address
|
| 262 |
-
GROUP BY tm.token_address
|
| 263 |
-
"""
|
| 264 |
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
print(f"\n\n==================================================")
|
| 272 |
-
print(f"SEGMENT: {label}")
|
| 273 |
-
print(f"==================================================")
|
| 274 |
-
print(f"Tokens in segment: {count}")
|
| 275 |
-
|
| 276 |
-
if count == 0:
|
| 277 |
-
continue
|
| 278 |
-
|
| 279 |
-
# Construct SQL Condition based on ID
|
| 280 |
-
condition = "1=0" # Default fail
|
| 281 |
|
| 282 |
-
if cid ==
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
# The only way to be class 0 in the map is if ret < 3.
|
| 287 |
-
# Downgraded tokens go to Class 5.
|
| 288 |
-
condition = "ret < 3"
|
| 289 |
-
|
| 290 |
-
elif cid == MANIPULATED_CLASS_ID:
|
| 291 |
-
# Manipulated:
|
| 292 |
-
# It's the collection of (Class K logic AND is_outlier)
|
| 293 |
-
sub_conds = []
|
| 294 |
-
for k in range(1, 5):
|
| 295 |
-
if k in thresholds:
|
| 296 |
-
t = thresholds[k]
|
| 297 |
-
# Range for Class K
|
| 298 |
-
lower = RETURN_THRESHOLDS[k]
|
| 299 |
-
upper = RETURN_THRESHOLDS[k+1]
|
| 300 |
-
# Outlier logic
|
| 301 |
-
sub_conds.append(f"(ret >= {lower} AND ret < {upper} AND (fees < {t['fees']} OR vol < {t['vol']} OR holders < {t['holders']}))")
|
| 302 |
-
|
| 303 |
-
if sub_conds:
|
| 304 |
-
condition = " OR ".join(sub_conds)
|
| 305 |
-
|
| 306 |
else:
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
# Valid logic: In Range AND NOT Outlier
|
| 313 |
-
condition = f"(ret >= {lower} AND ret < {upper} AND fees >= {t['fees']} AND vol >= {t['vol']} AND holders >= {t['holders']})"
|
| 314 |
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
{base_cohort_source}
|
| 320 |
-
) WHERE {condition}
|
| 321 |
-
"""
|
| 322 |
-
|
| 323 |
-
# Helper to construct the full condition "join_key IN (...)"
|
| 324 |
-
# NOW we use the subquery instead of a literal list
|
| 325 |
-
def make_query(inner, cohort_subquery):
|
| 326 |
-
return f"""
|
| 327 |
-
SELECT * FROM (
|
| 328 |
-
{inner}
|
| 329 |
-
) WHERE join_key IN (
|
| 330 |
-
{cohort_subquery}
|
| 331 |
-
)
|
| 332 |
-
"""
|
| 333 |
-
|
| 334 |
-
# --- Metrics Definitions ---
|
| 335 |
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
GROUP BY base_address
|
| 341 |
-
"""
|
| 342 |
-
fees_buckets = """
|
| 343 |
-
case
|
| 344 |
-
when val < 0.001 then '1. < 0.001 SOL'
|
| 345 |
-
when val >= 0.001 AND val < 0.01 then '2. 0.001 - 0.01'
|
| 346 |
-
when val >= 0.01 AND val < 0.1 then '3. 0.01 - 0.1'
|
| 347 |
-
when val >= 0.1 AND val < 1 then '4. 0.1 - 1'
|
| 348 |
-
when val >= 1 then '5. > 1 SOL'
|
| 349 |
-
else 'Unknown'
|
| 350 |
-
end
|
| 351 |
-
"""
|
| 352 |
-
print_distribution_stats(client, "Total Fees (SOL)", make_query(fees_inner, cohort_sql), fees_buckets)
|
| 353 |
-
|
| 354 |
-
# 2. Volume (USD)
|
| 355 |
-
vol_inner = """
|
| 356 |
-
SELECT base_address as join_key, sum(total_usd) as val
|
| 357 |
-
FROM trades
|
| 358 |
-
GROUP BY base_address
|
| 359 |
-
"""
|
| 360 |
-
vol_buckets = """
|
| 361 |
-
case
|
| 362 |
-
when val < 1000 then '1. < $1k'
|
| 363 |
-
when val >= 1000 AND val < 10000 then '2. $1k - $10k'
|
| 364 |
-
when val >= 10000 AND val < 100000 then '3. $10k - $100k'
|
| 365 |
-
when val >= 100000 AND val < 1000000 then '4. $100k - $1M'
|
| 366 |
-
when val >= 1000000 then '5. > $1M'
|
| 367 |
-
else 'Unknown'
|
| 368 |
-
end
|
| 369 |
-
"""
|
| 370 |
-
print_distribution_stats(client, "Total Volume (USD)", make_query(vol_inner, cohort_sql), vol_buckets)
|
| 371 |
-
|
| 372 |
-
# 3. Unique Holders
|
| 373 |
-
holders_inner = """
|
| 374 |
-
SELECT token_address as join_key, argMax(unique_holders, updated_at) as val
|
| 375 |
-
FROM token_metrics
|
| 376 |
-
GROUP BY token_address
|
| 377 |
-
"""
|
| 378 |
-
holders_buckets = """
|
| 379 |
-
case
|
| 380 |
-
when val < 10 then '1. < 10'
|
| 381 |
-
when val >= 10 AND val < 50 then '2. 10 - 50'
|
| 382 |
-
when val >= 50 AND val < 100 then '3. 50 - 100'
|
| 383 |
-
when val >= 100 AND val < 500 then '4. 100 - 500'
|
| 384 |
-
when val >= 500 then '5. > 500'
|
| 385 |
-
else 'Unknown'
|
| 386 |
-
end
|
| 387 |
-
"""
|
| 388 |
-
print_distribution_stats(client, "Unique Holders", make_query(holders_inner, cohort_sql), holders_buckets)
|
| 389 |
-
|
| 390 |
-
# 4. Snipers % Supply
|
| 391 |
-
snipers_inner = """
|
| 392 |
-
SELECT
|
| 393 |
-
m.base_address as join_key,
|
| 394 |
-
(m.val / t.total_supply * 100) as val
|
| 395 |
-
FROM (
|
| 396 |
-
SELECT
|
| 397 |
-
base_address,
|
| 398 |
-
sumIf(base_amount, buyer_rank <= 70) as val
|
| 399 |
-
FROM (
|
| 400 |
-
SELECT
|
| 401 |
-
base_address,
|
| 402 |
-
base_amount,
|
| 403 |
-
dense_rank() OVER (PARTITION BY base_address ORDER BY min_slot, min_idx) as buyer_rank
|
| 404 |
-
FROM (
|
| 405 |
-
SELECT
|
| 406 |
-
base_address,
|
| 407 |
-
maker,
|
| 408 |
-
min(slot) as min_slot,
|
| 409 |
-
min(transaction_index) as min_idx,
|
| 410 |
-
sum(base_amount) as base_amount
|
| 411 |
-
FROM trades
|
| 412 |
-
WHERE trade_type = 0
|
| 413 |
-
GROUP BY base_address, maker
|
| 414 |
-
)
|
| 415 |
-
)
|
| 416 |
-
GROUP BY base_address
|
| 417 |
-
) m
|
| 418 |
-
JOIN (
|
| 419 |
-
SELECT token_address, argMax(total_supply, updated_at) as total_supply
|
| 420 |
-
FROM tokens
|
| 421 |
-
GROUP BY token_address
|
| 422 |
-
) t ON m.base_address = t.token_address
|
| 423 |
-
WHERE t.total_supply > 0
|
| 424 |
-
"""
|
| 425 |
-
pct_buckets = """
|
| 426 |
-
case
|
| 427 |
-
when val < 1 then '1. < 1%'
|
| 428 |
-
when val >= 1 AND val < 5 then '2. 1% - 5%'
|
| 429 |
-
when val >= 5 AND val < 10 then '3. 5% - 10%'
|
| 430 |
-
when val >= 10 AND val < 20 then '4. 10% - 20%'
|
| 431 |
-
when val >= 20 AND val < 50 then '5. 20% - 50%'
|
| 432 |
-
when val >= 50 then '6. > 50%'
|
| 433 |
-
else 'Unknown'
|
| 434 |
-
end
|
| 435 |
-
"""
|
| 436 |
-
print_distribution_stats(client, "Snipers % Supply (Top 70)", make_query(snipers_inner, cohort_sql), pct_buckets)
|
| 437 |
-
|
| 438 |
-
# 5. Bundled % Supply
|
| 439 |
-
bundled_inner = """
|
| 440 |
-
SELECT
|
| 441 |
-
m.base_address as join_key,
|
| 442 |
-
(m.val / t.total_supply * 100) as val
|
| 443 |
-
FROM (
|
| 444 |
-
SELECT
|
| 445 |
-
t.base_address,
|
| 446 |
-
sum(t.base_amount) as val
|
| 447 |
-
FROM trades t
|
| 448 |
-
JOIN (
|
| 449 |
-
SELECT base_address, min(slot) as min_slot
|
| 450 |
-
FROM trades
|
| 451 |
-
GROUP BY base_address
|
| 452 |
-
) m ON t.base_address = m.base_address AND t.slot = m.min_slot
|
| 453 |
-
WHERE t.trade_type = 0
|
| 454 |
-
GROUP BY t.base_address
|
| 455 |
-
) m
|
| 456 |
-
JOIN (
|
| 457 |
-
SELECT token_address, argMax(total_supply, updated_at) as total_supply
|
| 458 |
-
FROM tokens
|
| 459 |
-
GROUP BY token_address
|
| 460 |
-
) t ON m.base_address = t.token_address
|
| 461 |
-
WHERE t.total_supply > 0
|
| 462 |
-
"""
|
| 463 |
-
print_distribution_stats(client, "Bundled % Supply", make_query(bundled_inner, cohort_sql), pct_buckets)
|
| 464 |
-
|
| 465 |
-
# 6. Dev Holding % Supply
|
| 466 |
-
dev_inner = """
|
| 467 |
-
SELECT
|
| 468 |
-
t.token_address as join_key,
|
| 469 |
-
(wh.current_balance / (t.total_supply / pow(10, t.decimals)) * 100) as val
|
| 470 |
-
FROM (
|
| 471 |
-
SELECT token_address, argMax(creator_address, updated_at) as creator_address, argMax(total_supply, updated_at) as total_supply, argMax(decimals, updated_at) as decimals
|
| 472 |
-
FROM tokens
|
| 473 |
-
GROUP BY token_address
|
| 474 |
-
) t
|
| 475 |
-
JOIN (
|
| 476 |
-
SELECT mint_address, wallet_address, argMax(current_balance, updated_at) as current_balance
|
| 477 |
-
FROM wallet_holdings
|
| 478 |
-
GROUP BY mint_address, wallet_address
|
| 479 |
-
) wh ON t.token_address = wh.mint_address AND t.creator_address = wh.wallet_address
|
| 480 |
-
WHERE t.total_supply > 0
|
| 481 |
-
"""
|
| 482 |
-
print_distribution_stats(client, "Dev Holding % Supply", make_query(dev_inner, cohort_sql), pct_buckets)
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
# 8. Time to ATH (Seconds)
|
| 487 |
-
time_ath_inner = """
|
| 488 |
-
SELECT
|
| 489 |
-
base_address as join_key,
|
| 490 |
-
(argMax(timestamp, price_usd) - min(timestamp)) as val
|
| 491 |
-
FROM trades
|
| 492 |
-
GROUP BY base_address
|
| 493 |
-
"""
|
| 494 |
-
time_ath_buckets = """
|
| 495 |
-
case
|
| 496 |
-
when val < 5 then '1. < 5s'
|
| 497 |
-
when val >= 5 AND val < 30 then '2. 5s - 30s'
|
| 498 |
-
when val >= 30 AND val < 60 then '3. 30s - 1m'
|
| 499 |
-
when val >= 60 AND val < 300 then '4. 1m - 5m'
|
| 500 |
-
when val >= 300 AND val < 3600 then '5. 5m - 1h'
|
| 501 |
-
when val >= 3600 then '6. > 1h'
|
| 502 |
-
else 'Unknown'
|
| 503 |
-
end
|
| 504 |
-
"""
|
| 505 |
-
print_distribution_stats(client, "Time to ATH (Seconds)", make_query(time_ath_inner, cohort_sql), time_ath_buckets)
|
| 506 |
-
|
| 507 |
|
| 508 |
if __name__ == "__main__":
|
| 509 |
analyze()
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
import datetime
|
| 4 |
+
import numpy as np
|
| 5 |
+
import math
|
| 6 |
from clickhouse_driver import Client as ClickHouseClient
|
| 7 |
|
| 8 |
# Add parent to path
|
| 9 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 10 |
|
| 11 |
+
from models.vocabulary import RETURN_THRESHOLDS, MANIPULATED_CLASS_ID
|
|
|
|
| 12 |
|
| 13 |
CLICKHOUSE_HOST = os.getenv("CLICKHOUSE_HOST", "localhost")
|
| 14 |
CLICKHOUSE_PORT = int(os.getenv("CLICKHOUSE_PORT", 9000))
|
|
|
|
| 15 |
CLICKHOUSE_USER = os.getenv("CLICKHOUSE_USER", "default")
|
| 16 |
CLICKHOUSE_PASSWORD = os.getenv("CLICKHOUSE_PASSWORD", "")
|
| 17 |
CLICKHOUSE_DATABASE = os.getenv("CLICKHOUSE_DATABASE", "default")
|
| 18 |
+
LAUNCH_PRICE_USD = 0.000004
|
| 19 |
+
EPS = 1e-9
|
| 20 |
|
| 21 |
def get_client():
|
| 22 |
return ClickHouseClient(
|
|
|
|
| 27 |
database=CLICKHOUSE_DATABASE
|
| 28 |
)
|
| 29 |
|
| 30 |
+
def fetch_all_metrics(client):
|
| 31 |
+
"""
|
| 32 |
+
Fetches all needed metrics for all tokens in a single query.
|
| 33 |
+
Base Table: MINTS (to ensure we cover all ~50k tokens).
|
| 34 |
+
Definitions:
|
| 35 |
+
- Snipers: Peak Balance Sum of top 70 buyers
|
| 36 |
+
- Bundles: Base Amount Sum of trades in multi-buy slots
|
| 37 |
+
- Dev Hold: Max Peak Balance of Creator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
"""
|
| 39 |
+
print(" -> Fetching all token metrics (Unified Query)...")
|
| 40 |
+
|
| 41 |
+
query = f"""
|
| 42 |
+
WITH
|
| 43 |
+
-- 1. Aggregated trade stats (Fees, Volume, ATH Time)
|
| 44 |
+
trade_agg AS (
|
| 45 |
+
SELECT
|
| 46 |
+
base_address,
|
| 47 |
+
sum(priority_fee + coin_creator_fee) AS fees_sol,
|
| 48 |
+
sum(total_usd) AS volume_usd,
|
| 49 |
+
count() AS n_trades,
|
| 50 |
+
argMax(timestamp, price_usd) AS t_ath,
|
| 51 |
+
min(timestamp) AS t0
|
| 52 |
+
FROM trades
|
| 53 |
+
GROUP BY base_address
|
| 54 |
+
),
|
| 55 |
+
|
| 56 |
+
-- 2. Token Metadata from MINTS (Base Source of Truth)
|
| 57 |
+
token_meta AS (
|
| 58 |
+
SELECT
|
| 59 |
+
mint_address AS token_address,
|
| 60 |
+
argMax(creator_address, timestamp) AS creator_address,
|
| 61 |
+
argMax(total_supply, timestamp) AS total_supply,
|
| 62 |
+
argMax(token_decimals, timestamp) AS decimals
|
| 63 |
+
FROM mints
|
| 64 |
+
GROUP BY mint_address
|
| 65 |
+
),
|
| 66 |
|
| 67 |
+
-- 3. Returns & Holders (from Token Metrics or manual calc)
|
| 68 |
+
metrics AS (
|
| 69 |
+
SELECT
|
| 70 |
+
token_address,
|
| 71 |
+
argMax(ath_price_usd, updated_at) as ath_price_usd,
|
| 72 |
+
argMax(unique_holders, updated_at) as unique_holders
|
| 73 |
+
FROM token_metrics
|
| 74 |
+
GROUP BY token_address
|
| 75 |
+
),
|
| 76 |
|
| 77 |
+
-- 4. WALLET PEAKS (normalized balance likely)
|
| 78 |
+
wallet_peaks AS (
|
| 79 |
+
SELECT
|
| 80 |
+
mint_address,
|
| 81 |
+
wallet_address,
|
| 82 |
+
max(current_balance) AS peak_balance
|
| 83 |
+
FROM wallet_holdings
|
| 84 |
+
GROUP BY mint_address, wallet_address
|
| 85 |
+
),
|
| 86 |
|
| 87 |
+
-- 5. SNIPERS: Identify sniper addresses (rank <= 70)
|
| 88 |
+
snipers_list AS (
|
| 89 |
+
SELECT
|
| 90 |
+
base_address,
|
| 91 |
+
maker
|
| 92 |
+
FROM (
|
| 93 |
+
SELECT
|
| 94 |
+
base_address,
|
| 95 |
+
maker,
|
| 96 |
+
dense_rank() OVER (PARTITION BY base_address ORDER BY min_slot, min_idx) AS buyer_rank
|
| 97 |
+
FROM (
|
| 98 |
+
SELECT
|
| 99 |
+
base_address,
|
| 100 |
+
maker,
|
| 101 |
+
min(slot) AS min_slot,
|
| 102 |
+
min(transaction_index) AS min_idx
|
| 103 |
+
FROM trades
|
| 104 |
+
WHERE trade_type = 0 -- buy
|
| 105 |
+
GROUP BY base_address, maker
|
| 106 |
+
)
|
| 107 |
+
)
|
| 108 |
+
WHERE buyer_rank <= 70
|
| 109 |
+
),
|
| 110 |
+
snipers_agg AS (
|
| 111 |
+
SELECT
|
| 112 |
+
s.base_address AS token_address,
|
| 113 |
+
sum(wp.peak_balance) AS snipers_total_peak
|
| 114 |
+
FROM snipers_list s
|
| 115 |
+
JOIN wallet_peaks wp ON s.base_address = wp.mint_address AND s.maker = wp.wallet_address
|
| 116 |
+
GROUP BY s.base_address
|
| 117 |
+
),
|
| 118 |
|
| 119 |
+
-- 6. BUNDLED: Sum the base_amount of ALL trades that happened in a slot with multiple buys
|
| 120 |
+
bundled_agg AS (
|
| 121 |
+
SELECT
|
| 122 |
+
t.base_address AS token_address,
|
| 123 |
+
sum(t.base_amount) AS bundled_total_peak
|
| 124 |
+
FROM trades t
|
| 125 |
+
WHERE (t.base_address, t.slot) IN (
|
| 126 |
+
SELECT base_address, slot
|
| 127 |
+
FROM trades
|
| 128 |
+
WHERE trade_type = 0 -- buy
|
| 129 |
+
GROUP BY base_address, slot
|
| 130 |
+
HAVING count() > 1
|
| 131 |
+
)
|
| 132 |
+
AND t.trade_type = 0 -- buy
|
| 133 |
+
GROUP BY t.base_address
|
| 134 |
+
),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
+
-- 7. DEV HOLD: Creator's Peak Balance
|
| 137 |
+
dev_hold_agg AS (
|
| 138 |
+
SELECT
|
| 139 |
+
t.token_address,
|
| 140 |
+
max(wp.peak_balance) AS dev_peak
|
| 141 |
+
FROM token_meta t
|
| 142 |
+
JOIN wallet_peaks wp ON t.token_address = wp.mint_address AND t.creator_address = wp.wallet_address
|
| 143 |
+
GROUP BY t.token_address
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
SELECT
|
| 147 |
+
t.token_address,
|
| 148 |
+
(COALESCE(m.ath_price_usd, ta.t_ath, 0) / {LAUNCH_PRICE_USD}) AS ret,
|
| 149 |
+
|
| 150 |
+
COALESCE(ta.fees_sol, 0) AS fees_sol,
|
| 151 |
+
COALESCE(ta.volume_usd, 0) AS volume_usd,
|
| 152 |
+
COALESCE(m.unique_holders, 0) AS unique_holders,
|
| 153 |
+
(ta.t_ath - ta.t0) AS time_to_ath_sec,
|
| 154 |
+
|
| 155 |
+
COALESCE(s.snipers_total_peak, 0) AS snipers_val,
|
| 156 |
+
COALESCE(b.bundled_total_peak, 0) AS bundled_val,
|
| 157 |
+
COALESCE(d.dev_peak, 0) AS dev_val,
|
| 158 |
+
|
| 159 |
+
t.total_supply AS total_supply,
|
| 160 |
+
t.decimals AS decimals
|
| 161 |
+
|
| 162 |
+
FROM token_meta t
|
| 163 |
+
LEFT JOIN trade_agg ta ON t.token_address = ta.base_address
|
| 164 |
+
LEFT JOIN metrics m ON t.token_address = m.token_address
|
| 165 |
+
LEFT JOIN snipers_agg s ON t.token_address = s.token_address
|
| 166 |
+
LEFT JOIN bundled_agg b ON t.token_address = b.token_address
|
| 167 |
+
LEFT JOIN dev_hold_agg d ON t.token_address = d.token_address
|
| 168 |
"""
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
+
rows = client.execute(query)
|
| 171 |
+
# Convert to list of dicts
|
| 172 |
+
cols = [
|
| 173 |
+
"token_address", "ret", "fees_sol", "volume_usd", "unique_holders", "time_to_ath_sec",
|
| 174 |
+
"snipers_val", "bundled_val", "dev_val", "total_supply", "decimals"
|
| 175 |
+
]
|
| 176 |
+
results = []
|
| 177 |
+
|
| 178 |
+
print(f" -> Fetched {len(rows)} tokens.")
|
| 179 |
+
|
| 180 |
+
for r in rows:
|
| 181 |
+
d = dict(zip(cols, r))
|
| 182 |
+
|
| 183 |
+
supply = d["total_supply"]
|
| 184 |
+
decimals = d["decimals"]
|
| 185 |
+
|
| 186 |
+
try:
|
| 187 |
+
adj_supply = supply / (10 ** decimals) if (supply and decimals is not None) else supply
|
| 188 |
+
except:
|
| 189 |
+
adj_supply = supply
|
| 190 |
|
| 191 |
+
if adj_supply and adj_supply > 0:
|
| 192 |
+
d["snipers_pct"] = (d["snipers_val"] / adj_supply) * 100
|
| 193 |
+
d["dev_hold_pct"] = (d["dev_val"] / adj_supply) * 100
|
| 194 |
+
else:
|
| 195 |
+
d["snipers_pct"] = 0.0
|
| 196 |
+
d["dev_hold_pct"] = 0.0
|
| 197 |
+
|
| 198 |
+
if supply and supply > 0:
|
| 199 |
+
d["bundled_pct"] = (d["bundled_val"] / supply) * 100
|
| 200 |
+
else:
|
| 201 |
+
d["bundled_pct"] = 0.0
|
| 202 |
|
| 203 |
+
results.append(d)
|
| 204 |
+
|
| 205 |
+
return results
|
| 206 |
+
|
| 207 |
+
def _classify_tokens(data):
|
|
|
|
|
|
|
| 208 |
"""
|
| 209 |
+
Internal logic: returns (buckets_dict, thresholds_dict, count_manipulated)
|
| 210 |
+
buckets_dict: {class_id: [list of tokens]}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
"""
|
|
|
|
|
|
|
| 212 |
# 1. Initial Classification
|
| 213 |
+
temp_buckets = {i: [] for i in range(len(RETURN_THRESHOLDS))}
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
+
for d in data:
|
| 216 |
+
ret = d["ret"]
|
| 217 |
+
if ret > 10000: continue
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
|
| 219 |
+
cid = 0
|
| 220 |
+
found = False
|
| 221 |
for i in range(len(RETURN_THRESHOLDS) - 1):
|
| 222 |
lower = RETURN_THRESHOLDS[i]
|
| 223 |
upper = RETURN_THRESHOLDS[i+1]
|
| 224 |
+
if ret >= lower and ret < upper:
|
| 225 |
+
cid = i
|
| 226 |
+
found = True
|
| 227 |
break
|
| 228 |
|
| 229 |
+
if found:
|
| 230 |
+
d["class_id_initial"] = cid
|
| 231 |
+
temp_buckets[cid].append(d)
|
| 232 |
+
else:
|
| 233 |
+
if ret >= 10000: continue
|
| 234 |
+
d["class_id_initial"] = 0
|
| 235 |
+
temp_buckets[0].append(d)
|
| 236 |
|
| 237 |
+
# 2. Calculate Thresholds (50% of Median)
|
| 238 |
+
print("\n -> Calculating Class Medians & Thresholds (Dynamic Outlier Detection)...")
|
| 239 |
thresholds = {}
|
| 240 |
+
|
| 241 |
+
for i in range(1, len(RETURN_THRESHOLDS)-1):
|
| 242 |
+
items = temp_buckets.get(i, [])
|
| 243 |
+
if len(items) > 5:
|
| 244 |
+
fees = [x["fees_sol"] for x in items]
|
| 245 |
+
vols = [x["volume_usd"] for x in items]
|
| 246 |
+
holders = [x["unique_holders"] for x in items]
|
| 247 |
+
|
| 248 |
+
med_fees = np.median(fees)
|
| 249 |
+
med_vol = np.median(vols)
|
| 250 |
+
med_holders = np.median(holders)
|
| 251 |
|
| 252 |
thresholds[i] = {
|
| 253 |
'fees': med_fees * 0.5,
|
| 254 |
'vol': med_vol * 0.5,
|
| 255 |
'holders': med_holders * 0.5
|
| 256 |
}
|
|
|
|
| 257 |
else:
|
| 258 |
+
thresholds[i] = {'fees': 0, 'vol': 0, 'holders': 0}
|
| 259 |
|
| 260 |
# 3. Reclassification
|
| 261 |
+
final_buckets = {i: [] for i in range(len(RETURN_THRESHOLDS))}
|
| 262 |
+
final_buckets[MANIPULATED_CLASS_ID] = []
|
| 263 |
+
|
| 264 |
+
count_manipulated = 0
|
| 265 |
|
| 266 |
+
for cid, items in temp_buckets.items():
|
| 267 |
+
for d in items:
|
| 268 |
+
final_cid = cid
|
| 269 |
+
if cid > 0 and cid in thresholds:
|
| 270 |
+
t = thresholds[cid]
|
| 271 |
+
if (d["fees_sol"] < t['fees']) or (d["volume_usd"] < t['vol']) or (d["unique_holders"] < t['holders']):
|
| 272 |
+
final_cid = MANIPULATED_CLASS_ID
|
| 273 |
+
count_manipulated += 1
|
| 274 |
|
| 275 |
+
d["class_id_final"] = final_cid
|
| 276 |
+
if final_cid not in final_buckets:
|
| 277 |
+
final_buckets[final_cid] = []
|
| 278 |
+
final_buckets[final_cid].append(d)
|
|
|
|
|
|
|
|
|
|
| 279 |
|
| 280 |
+
return final_buckets, thresholds, count_manipulated
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
|
| 282 |
+
def get_return_class_map(client):
|
| 283 |
+
"""
|
| 284 |
+
Returns (map {token_addr: class_id}, thresholds)
|
| 285 |
+
Used by cache_dataset.py
|
| 286 |
+
"""
|
| 287 |
+
data = fetch_all_metrics(client)
|
| 288 |
+
buckets, thresholds, _ = _classify_tokens(data)
|
| 289 |
|
| 290 |
+
# Flatten buckets to map
|
| 291 |
+
ret_map = {}
|
| 292 |
+
for cid, items in buckets.items():
|
| 293 |
+
for d in items:
|
| 294 |
+
ret_map[d["token_address"]] = cid
|
| 295 |
+
|
| 296 |
+
return ret_map, thresholds
|
| 297 |
+
|
| 298 |
+
def print_stats(name, values):
|
| 299 |
+
"""
|
| 300 |
+
prints compact stats: mean, p50, p90, p99
|
| 301 |
+
"""
|
| 302 |
+
if not values:
|
| 303 |
+
print(f" {name}: No data")
|
| 304 |
+
return
|
| 305 |
|
| 306 |
+
vals = np.array(values)
|
| 307 |
+
mean = np.mean(vals)
|
| 308 |
+
p50 = np.percentile(vals, 50)
|
| 309 |
+
p90 = np.percentile(vals, 90)
|
| 310 |
+
p99 = np.percentile(vals, 99)
|
| 311 |
+
nonzero = np.count_nonzero(vals)
|
| 312 |
+
nonzero_rate = nonzero / len(vals)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
|
| 314 |
+
print(f" {name}: mean={mean:.4f} p50={p50:.4f} p90={p90:.4f} p99={p99:.4f} nonzero_rate={nonzero_rate:.3f} (n={len(vals)})")
|
| 315 |
+
|
| 316 |
+
def analyze():
|
| 317 |
+
client = get_client()
|
| 318 |
+
data = fetch_all_metrics(client)
|
| 319 |
+
final_buckets, thresholds, count_manipulated = _classify_tokens(data)
|
| 320 |
|
| 321 |
+
print(f" -> Reclassification Complete. Identified {count_manipulated} manipulated tokens.")
|
| 322 |
+
print("\n=== SEGMENTED DISTRIBUTION ANALYSIS ===")
|
|
|
|
|
|
|
| 323 |
|
| 324 |
+
# Print Thresholds debug
|
| 325 |
+
for k, t in thresholds.items():
|
| 326 |
+
if t['fees'] > 0:
|
| 327 |
+
print(f" [Class {k}] Thresh: Fees>{t['fees']:.3f} Vol>${t['vol']:.0f} Holders>{t['holders']:.0f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
|
| 329 |
+
sorted_classes = sorted([k for k in final_buckets.keys() if k != MANIPULATED_CLASS_ID]) + [MANIPULATED_CLASS_ID]
|
| 330 |
+
|
| 331 |
+
for cid in sorted_classes:
|
| 332 |
+
items = final_buckets.get(cid, [])
|
| 333 |
+
if not items: continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
|
| 335 |
+
if cid == MANIPULATED_CLASS_ID:
|
| 336 |
+
label = f"{cid}. MANIPULATED / FAKE (Outliers from {1}~{4})"
|
| 337 |
+
elif cid < len(RETURN_THRESHOLDS)-1:
|
| 338 |
+
label = f"{cid}. {RETURN_THRESHOLDS[cid]}x - {RETURN_THRESHOLDS[cid+1]}x"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
else:
|
| 340 |
+
label = f"{cid}. Unknown"
|
| 341 |
+
|
| 342 |
+
print(f"\nSEGMENT: {label}")
|
| 343 |
+
print("="*50)
|
| 344 |
+
print(f"Tokens in segment: {len(items)}")
|
|
|
|
|
|
|
| 345 |
|
| 346 |
+
bundled = [x["bundled_pct"] for x in items]
|
| 347 |
+
dev_hold = [x["dev_hold_pct"] for x in items]
|
| 348 |
+
fees = [x["fees_sol"] for x in items]
|
| 349 |
+
snipers = [x["snipers_pct"] for x in items]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
|
| 351 |
+
print_stats("bundled_pct", bundled)
|
| 352 |
+
print_stats("dev_hold_pct", dev_hold)
|
| 353 |
+
print_stats("fees_sol", fees)
|
| 354 |
+
print_stats("snipers_pct", snipers)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
|
| 356 |
if __name__ == "__main__":
|
| 357 |
analyze()
|
scripts/cache_dataset.py
CHANGED
|
@@ -6,6 +6,7 @@ import numpy as np
|
|
| 6 |
import datetime
|
| 7 |
import torch
|
| 8 |
import json
|
|
|
|
| 9 |
from pathlib import Path
|
| 10 |
from tqdm import tqdm
|
| 11 |
from dotenv import load_dotenv
|
|
@@ -23,6 +24,8 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
| 23 |
from data.data_loader import OracleDataset
|
| 24 |
from data.data_fetcher import DataFetcher
|
| 25 |
from scripts.analyze_distribution import get_return_class_map
|
|
|
|
|
|
|
| 26 |
|
| 27 |
from clickhouse_driver import Client as ClickHouseClient
|
| 28 |
from neo4j import GraphDatabase
|
|
@@ -94,6 +97,86 @@ def compute_save_ohlc_stats(client: ClickHouseClient, output_path: str):
|
|
| 94 |
print(f"ERROR: Failed to compute OHLC stats: {e}")
|
| 95 |
# Don't crash, let it try to proceed (though dataset might complain if file missing)
|
| 96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
def main():
|
| 98 |
load_dotenv()
|
| 99 |
|
|
@@ -140,10 +223,15 @@ def main():
|
|
| 140 |
data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
|
| 141 |
|
| 142 |
# Pre-fetch the Return Class Map
|
| 143 |
-
# tokens not in this map (e.g. >10k x) are INVALID and will be skipped
|
| 144 |
print("INFO: Fetching Return Classification Map...")
|
| 145 |
return_class_map, thresholds = get_return_class_map(clickhouse_client)
|
| 146 |
print(f"INFO: Loaded {len(return_class_map)} valid classified tokens.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
dataset = OracleDataset(
|
| 149 |
data_fetcher=data_fetcher,
|
|
@@ -158,67 +246,103 @@ def main():
|
|
| 158 |
if len(dataset) == 0:
|
| 159 |
print("WARNING: Dataset initialization resulted in 0 samples. Nothing to cache.")
|
| 160 |
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
# --- 3. Iterate and cache each item ---
|
| 163 |
print(f"INFO: Starting to generate and cache {len(dataset)} samples...")
|
| 164 |
|
| 165 |
-
metadata_path = output_dir / "metadata.jsonl"
|
| 166 |
-
print(f"INFO: Writing metadata to {metadata_path}")
|
| 167 |
-
|
| 168 |
skipped_count = 0
|
| 169 |
-
filtered_count = 0
|
| 170 |
cached_count = 0
|
| 171 |
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
continue
|
| 182 |
-
|
| 183 |
-
class_id = return_class_map[mint_addr]
|
| 184 |
-
|
| 185 |
-
try:
|
| 186 |
-
item = dataset.__cacheitem__(i)
|
| 187 |
-
if item is None:
|
| 188 |
-
skipped_count += 1
|
| 189 |
-
continue
|
| 190 |
-
|
| 191 |
-
filename = f"sample_{i}.pt"
|
| 192 |
-
output_path = output_dir / filename
|
| 193 |
-
torch.save(item, output_path)
|
| 194 |
-
|
| 195 |
-
# Write metadata entry
|
| 196 |
-
# Minimizing IO overhead by keeping line short
|
| 197 |
-
meta_entry = {"file": filename, "class_id": class_id}
|
| 198 |
-
meta_f.write(json.dumps(meta_entry) + "\n")
|
| 199 |
-
|
| 200 |
-
cached_count += 1
|
| 201 |
-
|
| 202 |
-
except Exception as e:
|
| 203 |
-
error_msg = str(e)
|
| 204 |
-
# If a FATAL error occurs (e.g. persistent DB auth failure), stop the script immediately.
|
| 205 |
-
if "FATAL" in error_msg or "AuthenticationRateLimit" in error_msg:
|
| 206 |
-
print(f"\nCRITICAL: Fatal error encountered processing sample {i}. Stopping execution.\nError: {e}", file=sys.stderr)
|
| 207 |
-
sys.exit(1)
|
| 208 |
-
|
| 209 |
-
print(f"\nERROR: Failed to generate or save sample {i} for mint '{mint_addr}'. Error: {e}", file=sys.stderr)
|
| 210 |
-
# print trackback
|
| 211 |
-
import traceback
|
| 212 |
-
traceback.print_exc()
|
| 213 |
skipped_count += 1
|
| 214 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
|
| 216 |
print(f"\n--- Caching Complete ---")
|
| 217 |
print(f"Successfully cached: {cached_count} items.")
|
| 218 |
print(f"Filtered (Invalid/High Return): {filtered_count} items.")
|
| 219 |
print(f"Skipped (Errors/Empty): {skipped_count} items.")
|
| 220 |
print(f"Cache location: {output_dir.resolve()}")
|
| 221 |
-
print(f"Metadata location: {metadata_path.resolve()}")
|
| 222 |
|
| 223 |
finally:
|
| 224 |
# --- 4. Close connections ---
|
|
|
|
| 6 |
import datetime
|
| 7 |
import torch
|
| 8 |
import json
|
| 9 |
+
import math
|
| 10 |
from pathlib import Path
|
| 11 |
from tqdm import tqdm
|
| 12 |
from dotenv import load_dotenv
|
|
|
|
| 24 |
from data.data_loader import OracleDataset
|
| 25 |
from data.data_fetcher import DataFetcher
|
| 26 |
from scripts.analyze_distribution import get_return_class_map
|
| 27 |
+
# Import quality score calculator
|
| 28 |
+
from scripts.compute_quality_score import get_token_quality_scores, fetch_token_metrics, _bucket_id, _midrank_percentiles, EPS
|
| 29 |
|
| 30 |
from clickhouse_driver import Client as ClickHouseClient
|
| 31 |
from neo4j import GraphDatabase
|
|
|
|
| 97 |
print(f"ERROR: Failed to compute OHLC stats: {e}")
|
| 98 |
# Don't crash, let it try to proceed (though dataset might complain if file missing)
|
| 99 |
|
| 100 |
+
def build_quality_missing_reason_map(client: ClickHouseClient, max_ret: float = 1e9):
|
| 101 |
+
"""
|
| 102 |
+
Build a map: token_address -> reason string for why a quality score is missing.
|
| 103 |
+
This mirrors compute_quality_scores filtering and feature availability.
|
| 104 |
+
"""
|
| 105 |
+
data = fetch_token_metrics(client)
|
| 106 |
+
metrics_by_token = {d.get("token_address"): d for d in data if d.get("token_address")}
|
| 107 |
+
|
| 108 |
+
# Build buckets with the same return filtering as compute_quality_scores
|
| 109 |
+
buckets = {}
|
| 110 |
+
for d in data:
|
| 111 |
+
ret_val = d.get("ret")
|
| 112 |
+
if ret_val is None or ret_val <= 0 or ret_val > max_ret:
|
| 113 |
+
continue
|
| 114 |
+
b = _bucket_id(ret_val)
|
| 115 |
+
if b == -1:
|
| 116 |
+
continue
|
| 117 |
+
d["bucket_id"] = b
|
| 118 |
+
buckets.setdefault(b, []).append(d)
|
| 119 |
+
|
| 120 |
+
# Same feature definitions as compute_quality_scores
|
| 121 |
+
feature_defs = [
|
| 122 |
+
("fees_log", lambda d: math.log1p(d["fees_sol"]) if d.get("fees_sol") is not None else None, True),
|
| 123 |
+
("volume_log", lambda d: math.log1p(d["volume_usd"]) if d.get("volume_usd") is not None else None, True),
|
| 124 |
+
("holders_log", lambda d: math.log1p(d["unique_holders"]) if d.get("unique_holders") is not None else None, True),
|
| 125 |
+
("time_to_ath_log", lambda d: math.log1p(d["time_to_ath_sec"]) if d.get("time_to_ath_sec") is not None else None, True),
|
| 126 |
+
("fees_per_volume", lambda d: (d["fees_sol"] / (d["volume_usd"] + EPS)) if d.get("fees_sol") is not None and d.get("volume_usd") is not None else None, True),
|
| 127 |
+
("fees_per_trade", lambda d: (d["fees_sol"] / (d["n_trades"] + EPS)) if d.get("fees_sol") is not None and d.get("n_trades") is not None else None, True),
|
| 128 |
+
("holders_per_trade", lambda d: (d["unique_holders"] / (d["n_trades"] + EPS)) if d.get("unique_holders") is not None and d.get("n_trades") is not None else None, True),
|
| 129 |
+
("holders_per_volume", lambda d: (d["unique_holders"] / (d["volume_usd"] + EPS)) if d.get("unique_holders") is not None and d.get("volume_usd") is not None else None, True),
|
| 130 |
+
("snipers_pct", lambda d: d.get("snipers_pct"), True),
|
| 131 |
+
("bundled_pct", lambda d: d.get("bundled_pct"), True),
|
| 132 |
+
("dev_hold_pct", lambda d: d.get("dev_hold_pct"), True),
|
| 133 |
+
]
|
| 134 |
+
|
| 135 |
+
# Precompute percentiles per bucket + feature
|
| 136 |
+
bucket_feature_percentiles = {}
|
| 137 |
+
for b, items in buckets.items():
|
| 138 |
+
feature_percentiles = {}
|
| 139 |
+
for fname, fget, _pos in feature_defs:
|
| 140 |
+
vals = []
|
| 141 |
+
for d in items:
|
| 142 |
+
v = fget(d)
|
| 143 |
+
if v is None or (isinstance(v, float) and (math.isnan(v) or math.isinf(v))):
|
| 144 |
+
continue
|
| 145 |
+
vals.append((d["token_address"], v))
|
| 146 |
+
feature_percentiles[fname] = _midrank_percentiles(vals)
|
| 147 |
+
bucket_feature_percentiles[b] = feature_percentiles
|
| 148 |
+
|
| 149 |
+
def _reason_for(token_address: str) -> str:
|
| 150 |
+
d = metrics_by_token.get(token_address)
|
| 151 |
+
if not d:
|
| 152 |
+
return "no metrics found (missing from token_metrics/trades/mints joins)"
|
| 153 |
+
ret_val = d.get("ret")
|
| 154 |
+
if ret_val is None:
|
| 155 |
+
return "ret is None (missing ATH/launch metrics)"
|
| 156 |
+
if ret_val <= 0:
|
| 157 |
+
return f"ret <= 0 ({ret_val})"
|
| 158 |
+
if ret_val > max_ret:
|
| 159 |
+
return f"ret > max_ret ({ret_val} > {max_ret})"
|
| 160 |
+
b = _bucket_id(ret_val)
|
| 161 |
+
if b == -1:
|
| 162 |
+
return f"ret {ret_val} not in RETURN_THRESHOLDS"
|
| 163 |
+
items = buckets.get(b, [])
|
| 164 |
+
if not items:
|
| 165 |
+
return f"bucket {b} empty after filtering"
|
| 166 |
+
feature_percentiles = bucket_feature_percentiles.get(b, {})
|
| 167 |
+
has_any = False
|
| 168 |
+
missing_features = []
|
| 169 |
+
for fname, _fget, _pos in feature_defs:
|
| 170 |
+
if feature_percentiles.get(fname, {}).get(token_address) is None:
|
| 171 |
+
missing_features.append(fname)
|
| 172 |
+
else:
|
| 173 |
+
has_any = True
|
| 174 |
+
if not has_any:
|
| 175 |
+
return "no valid feature percentiles for token (all features missing/invalid)"
|
| 176 |
+
return f"unexpected: has feature percentiles but no score; missing features={','.join(missing_features)}"
|
| 177 |
+
|
| 178 |
+
return _reason_for
|
| 179 |
+
|
| 180 |
def main():
|
| 181 |
load_dotenv()
|
| 182 |
|
|
|
|
| 223 |
data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
|
| 224 |
|
| 225 |
# Pre-fetch the Return Class Map
|
|
|
|
| 226 |
print("INFO: Fetching Return Classification Map...")
|
| 227 |
return_class_map, thresholds = get_return_class_map(clickhouse_client)
|
| 228 |
print(f"INFO: Loaded {len(return_class_map)} valid classified tokens.")
|
| 229 |
+
|
| 230 |
+
# Pre-fetch Quality Scores
|
| 231 |
+
print("INFO: Fetching Token Quality Scores...")
|
| 232 |
+
quality_scores_map = get_token_quality_scores(clickhouse_client)
|
| 233 |
+
quality_missing_reason = build_quality_missing_reason_map(clickhouse_client, max_ret=1e9)
|
| 234 |
+
print(f"INFO: Loaded {len(quality_scores_map)} quality scores.")
|
| 235 |
|
| 236 |
dataset = OracleDataset(
|
| 237 |
data_fetcher=data_fetcher,
|
|
|
|
| 246 |
if len(dataset) == 0:
|
| 247 |
print("WARNING: Dataset initialization resulted in 0 samples. Nothing to cache.")
|
| 248 |
return
|
| 249 |
+
|
| 250 |
+
# --- FILTER DATASET BY CLASS MAP ---
|
| 251 |
+
# Only keep mints that are classified (valid return, sufficient data)
|
| 252 |
+
original_size = len(dataset)
|
| 253 |
+
print(f"INFO: Filtering dataset... Original size: {original_size}")
|
| 254 |
+
dataset.sampled_mints = [
|
| 255 |
+
m for m in dataset.sampled_mints
|
| 256 |
+
if m['mint_address'] in return_class_map
|
| 257 |
+
]
|
| 258 |
+
filtered_size = len(dataset)
|
| 259 |
+
filtered_count = original_size - filtered_size
|
| 260 |
+
print(f"INFO: Filtered size: {filtered_size}")
|
| 261 |
+
|
| 262 |
+
if len(dataset) == 0:
|
| 263 |
+
print("WARNING: No tokens remain after filtering by return_class_map.")
|
| 264 |
+
return
|
| 265 |
|
| 266 |
# --- 3. Iterate and cache each item ---
|
| 267 |
print(f"INFO: Starting to generate and cache {len(dataset)} samples...")
|
| 268 |
|
|
|
|
|
|
|
|
|
|
| 269 |
skipped_count = 0
|
|
|
|
| 270 |
cached_count = 0
|
| 271 |
|
| 272 |
+
for i in tqdm(range(len(dataset)), desc="Caching samples"):
|
| 273 |
+
mint_addr = dataset.sampled_mints[i]['mint_address']
|
| 274 |
+
|
| 275 |
+
# (No need to check if in return_class_map anymore, we filtered)
|
| 276 |
+
class_id = return_class_map[mint_addr]
|
| 277 |
+
|
| 278 |
+
try:
|
| 279 |
+
item = dataset.__cacheitem__(i)
|
| 280 |
+
if item is None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
skipped_count += 1
|
| 282 |
continue
|
| 283 |
+
|
| 284 |
+
# Require quality score only for samples that will be cached
|
| 285 |
+
if mint_addr not in quality_scores_map:
|
| 286 |
+
reason = quality_missing_reason(mint_addr)
|
| 287 |
+
raise RuntimeError(
|
| 288 |
+
f"Missing quality score for mint {mint_addr}. Reason: {reason}. "
|
| 289 |
+
"Refusing to cache without quality_score."
|
| 290 |
+
)
|
| 291 |
+
q_score = quality_scores_map[mint_addr]
|
| 292 |
+
|
| 293 |
+
# INJECT QUALITY SCORE INTO TENSOR DICT
|
| 294 |
+
item["quality_score"] = q_score
|
| 295 |
+
item["class_id"] = class_id
|
| 296 |
+
|
| 297 |
+
filename = f"sample_{i}.pt"
|
| 298 |
+
output_path = output_dir / filename
|
| 299 |
+
torch.save(item, output_path)
|
| 300 |
+
|
| 301 |
+
cached_count += 1
|
| 302 |
+
|
| 303 |
+
# Log progress details (reflect all cached event lists)
|
| 304 |
+
n_trades = len(item.get("trades", []))
|
| 305 |
+
n_transfers = len(item.get("transfers", []))
|
| 306 |
+
n_pool_creations = len(item.get("pool_creations", []))
|
| 307 |
+
n_liquidity_changes = len(item.get("liquidity_changes", []))
|
| 308 |
+
n_fee_collections = len(item.get("fee_collections", []))
|
| 309 |
+
n_burns = len(item.get("burns", []))
|
| 310 |
+
n_supply_locks = len(item.get("supply_locks", []))
|
| 311 |
+
n_migrations = len(item.get("migrations", []))
|
| 312 |
+
n_ohlc = len(item.get("ohlc_1s", [])) if item.get("ohlc_1s") is not None else 0
|
| 313 |
+
n_snapshots_5m = len(item.get("snapshots_5m", []))
|
| 314 |
+
n_holders = len(item.get("holder_snapshots_list", []))
|
| 315 |
+
|
| 316 |
+
tqdm.write(f" + Cached: {mint_addr} | Class: {class_id} | Q: {q_score:.4f}")
|
| 317 |
+
tqdm.write(
|
| 318 |
+
" Events | "
|
| 319 |
+
f"Trades: {n_trades} | Transfers: {n_transfers} | Pool Creations: {n_pool_creations} | "
|
| 320 |
+
f"Liquidity Changes: {n_liquidity_changes} | Fee Collections: {n_fee_collections} | "
|
| 321 |
+
f"Burns: {n_burns} | Supply Locks: {n_supply_locks} | Migrations: {n_migrations}"
|
| 322 |
+
)
|
| 323 |
+
tqdm.write(
|
| 324 |
+
f" Derived | Mint: 1 | Ohlc 1s: {n_ohlc} | Snapshots 5m: {n_snapshots_5m} | Holder Snapshots: {n_holders}"
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
except Exception as e:
|
| 328 |
+
error_msg = str(e)
|
| 329 |
+
# If a FATAL error occurs (e.g. persistent DB auth failure), stop the script immediately.
|
| 330 |
+
if "FATAL" in error_msg or "AuthenticationRateLimit" in error_msg:
|
| 331 |
+
print(f"\nCRITICAL: Fatal error encountered processing sample {i}. Stopping execution.\nError: {e}", file=sys.stderr)
|
| 332 |
+
sys.exit(1)
|
| 333 |
+
|
| 334 |
+
print(f"\nERROR: Failed to generate or save sample {i} for mint '{mint_addr}'. Error: {e}", file=sys.stderr)
|
| 335 |
+
# print trackback
|
| 336 |
+
import traceback
|
| 337 |
+
traceback.print_exc()
|
| 338 |
+
skipped_count += 1
|
| 339 |
+
continue
|
| 340 |
|
| 341 |
print(f"\n--- Caching Complete ---")
|
| 342 |
print(f"Successfully cached: {cached_count} items.")
|
| 343 |
print(f"Filtered (Invalid/High Return): {filtered_count} items.")
|
| 344 |
print(f"Skipped (Errors/Empty): {skipped_count} items.")
|
| 345 |
print(f"Cache location: {output_dir.resolve()}")
|
|
|
|
| 346 |
|
| 347 |
finally:
|
| 348 |
# --- 4. Close connections ---
|
scripts/compute_quality_score.py
CHANGED
|
@@ -87,15 +87,15 @@ def fetch_token_metrics(client) -> List[dict]:
|
|
| 87 |
FROM trades
|
| 88 |
GROUP BY base_address
|
| 89 |
),
|
| 90 |
-
-- 2. Token
|
| 91 |
token_meta_raw AS (
|
| 92 |
SELECT
|
| 93 |
-
token_address,
|
| 94 |
-
argMax(creator_address,
|
| 95 |
-
argMax(total_supply,
|
| 96 |
-
argMax(
|
| 97 |
-
FROM
|
| 98 |
-
GROUP BY
|
| 99 |
),
|
| 100 |
token_meta AS (
|
| 101 |
SELECT
|
|
@@ -161,28 +161,21 @@ def fetch_token_metrics(client) -> List[dict]:
|
|
| 161 |
GROUP BY s.base_address
|
| 162 |
),
|
| 163 |
|
| 164 |
-
-- 6. BUNDLED:
|
| 165 |
-
-- Bundled definition: Bought in the same slot as the very first buy slot for that token.
|
| 166 |
-
bundled_list AS (
|
| 167 |
-
SELECT
|
| 168 |
-
t.base_address,
|
| 169 |
-
t.maker
|
| 170 |
-
FROM trades t
|
| 171 |
-
JOIN (
|
| 172 |
-
SELECT base_address, min(slot) AS min_slot
|
| 173 |
-
FROM trades
|
| 174 |
-
GROUP BY base_address
|
| 175 |
-
) m ON t.base_address = m.base_address AND t.slot = m.min_slot
|
| 176 |
-
WHERE t.trade_type = 0 -- buy
|
| 177 |
-
GROUP BY t.base_address, t.maker
|
| 178 |
-
),
|
| 179 |
bundled_agg AS (
|
| 180 |
SELECT
|
| 181 |
-
|
| 182 |
-
sum(
|
| 183 |
-
FROM
|
| 184 |
-
|
| 185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
),
|
| 187 |
|
| 188 |
-- 7. DEV HOLD: Creator's Peak Balance
|
|
@@ -196,7 +189,7 @@ def fetch_token_metrics(client) -> List[dict]:
|
|
| 196 |
)
|
| 197 |
|
| 198 |
SELECT
|
| 199 |
-
|
| 200 |
r.ret,
|
| 201 |
r.unique_holders,
|
| 202 |
f.fees_sol,
|
|
@@ -205,14 +198,14 @@ def fetch_token_metrics(client) -> List[dict]:
|
|
| 205 |
(f.t_ath - f.t0) AS time_to_ath_sec,
|
| 206 |
-- Calculate Percentages using Peak Sums / Total Supply
|
| 207 |
(COALESCE(s.snipers_total_peak, 0) / t.adj_supply * 100) AS snipers_pct,
|
| 208 |
-
(COALESCE(b.bundled_total_peak, 0) / t.
|
| 209 |
(COALESCE(d.dev_peak, 0) / t.adj_supply * 100) AS dev_hold_pct
|
| 210 |
-
FROM
|
| 211 |
-
JOIN
|
| 212 |
-
LEFT JOIN trade_agg f ON
|
| 213 |
-
LEFT JOIN snipers_agg s ON
|
| 214 |
-
LEFT JOIN bundled_agg b ON
|
| 215 |
-
LEFT JOIN dev_hold_agg d ON
|
| 216 |
"""
|
| 217 |
rows = client.execute(query)
|
| 218 |
cols = [
|
|
@@ -233,7 +226,7 @@ def fetch_token_metrics(client) -> List[dict]:
|
|
| 233 |
return out
|
| 234 |
|
| 235 |
|
| 236 |
-
def
|
| 237 |
client,
|
| 238 |
max_ret: float = 10000.0,
|
| 239 |
rerank: bool = True,
|
|
@@ -251,12 +244,12 @@ def _compute_quality_scores(
|
|
| 251 |
("fees_per_trade", lambda d: (d["fees_sol"] / (d["n_trades"] + EPS)) if d["fees_sol"] is not None and d["n_trades"] is not None else None, True),
|
| 252 |
("holders_per_trade", lambda d: (d["unique_holders"] / (d["n_trades"] + EPS)) if d["unique_holders"] is not None and d["n_trades"] is not None else None, True),
|
| 253 |
("holders_per_volume", lambda d: (d["unique_holders"] / (d["volume_usd"] + EPS)) if d["unique_holders"] is not None and d["volume_usd"] is not None else None, True),
|
| 254 |
-
("snipers_pct", lambda d: d["snipers_pct"],
|
| 255 |
-
("bundled_pct", lambda d: d["bundled_pct"],
|
| 256 |
-
("dev_hold_pct", lambda d: d["dev_hold_pct"],
|
| 257 |
]
|
| 258 |
|
| 259 |
-
raw_metrics = ["snipers_pct", "bundled_pct", "dev_hold_pct"]
|
| 260 |
|
| 261 |
debug = None
|
| 262 |
if with_debug:
|
|
@@ -357,6 +350,10 @@ def _compute_quality_scores(
|
|
| 357 |
"ret": d["ret"],
|
| 358 |
"q_raw": q_raw_map[t],
|
| 359 |
"q": q_final,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 360 |
}
|
| 361 |
)
|
| 362 |
else:
|
|
@@ -371,6 +368,10 @@ def _compute_quality_scores(
|
|
| 371 |
"ret": d["ret"],
|
| 372 |
"q_raw": q_raw_map[t],
|
| 373 |
"q": q_raw_map[t],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
}
|
| 375 |
)
|
| 376 |
|
|
@@ -379,12 +380,7 @@ def _compute_quality_scores(
|
|
| 379 |
return token_scores
|
| 380 |
|
| 381 |
|
| 382 |
-
|
| 383 |
-
client,
|
| 384 |
-
max_ret: float = 10000.0,
|
| 385 |
-
rerank: bool = True,
|
| 386 |
-
) -> List[dict]:
|
| 387 |
-
return _compute_quality_scores(client, max_ret=max_ret, rerank=rerank, with_debug=False)
|
| 388 |
|
| 389 |
|
| 390 |
def write_jsonl(path: str, rows: List[dict]) -> None:
|
|
@@ -491,6 +487,23 @@ def print_summary(scores: List[dict]) -> None:
|
|
| 491 |
print(f" Mean: {stats_q_raw['mean']:.4f} | Min: {stats_q_raw['min']:.4f} | Max: {stats_q_raw['max']:.4f}")
|
| 492 |
print(f" Q: p10={stats_q_raw['p10']:.2f} p50={stats_q_raw['p50']:.2f} p90={stats_q_raw['p90']:.2f} p99={stats_q_raw['p99']:.2f}")
|
| 493 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 494 |
|
| 495 |
def print_diagnostics(debug: dict) -> None:
|
| 496 |
if not debug:
|
|
@@ -563,6 +576,77 @@ def print_diagnostics(debug: dict) -> None:
|
|
| 563 |
corr = _pearson_corr(xs, ys)
|
| 564 |
print(f" log(ret) vs {metric}: {corr:.4f} (n={len(xs)})")
|
| 565 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 566 |
|
| 567 |
def main():
|
| 568 |
parser = argparse.ArgumentParser(description="Compute token quality/health score.")
|
|
@@ -577,7 +661,7 @@ def main():
|
|
| 577 |
scores = compute_quality_scores(client, max_ret=args.max_ret, rerank=not args.no_rerank)
|
| 578 |
debug = None
|
| 579 |
else:
|
| 580 |
-
scores, debug =
|
| 581 |
client,
|
| 582 |
max_ret=args.max_ret,
|
| 583 |
rerank=not args.no_rerank,
|
|
@@ -587,6 +671,7 @@ def main():
|
|
| 587 |
print_summary(scores)
|
| 588 |
if not args.no_diagnostics:
|
| 589 |
print_diagnostics(debug)
|
|
|
|
| 590 |
|
| 591 |
|
| 592 |
if __name__ == "__main__":
|
|
|
|
| 87 |
FROM trades
|
| 88 |
GROUP BY base_address
|
| 89 |
),
|
| 90 |
+
-- 2. "Token list derived MINTS.
|
| 91 |
token_meta_raw AS (
|
| 92 |
SELECT
|
| 93 |
+
mint_address AS token_address,
|
| 94 |
+
argMax(creator_address, timestamp) AS creator_address,
|
| 95 |
+
argMax(total_supply, timestamp) AS total_supply,
|
| 96 |
+
argMax(token_decimals, timestamp) AS decimals
|
| 97 |
+
FROM mints
|
| 98 |
+
GROUP BY mint_address
|
| 99 |
),
|
| 100 |
token_meta AS (
|
| 101 |
SELECT
|
|
|
|
| 161 |
GROUP BY s.base_address
|
| 162 |
),
|
| 163 |
|
| 164 |
+
-- 6. BUNDLED: Sum the base_amount of ALL trades that happened in a slot with multiple buys
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
bundled_agg AS (
|
| 166 |
SELECT
|
| 167 |
+
t.base_address AS token_address,
|
| 168 |
+
sum(t.base_amount) AS bundled_total_peak
|
| 169 |
+
FROM trades t
|
| 170 |
+
WHERE (t.base_address, t.slot) IN (
|
| 171 |
+
SELECT base_address, slot
|
| 172 |
+
FROM trades
|
| 173 |
+
WHERE trade_type = 0 -- buy
|
| 174 |
+
GROUP BY base_address, slot
|
| 175 |
+
HAVING count() > 1
|
| 176 |
+
)
|
| 177 |
+
AND t.trade_type = 0 -- buy
|
| 178 |
+
GROUP BY t.base_address
|
| 179 |
),
|
| 180 |
|
| 181 |
-- 7. DEV HOLD: Creator's Peak Balance
|
|
|
|
| 189 |
)
|
| 190 |
|
| 191 |
SELECT
|
| 192 |
+
t.token_address,
|
| 193 |
r.ret,
|
| 194 |
r.unique_holders,
|
| 195 |
f.fees_sol,
|
|
|
|
| 198 |
(f.t_ath - f.t0) AS time_to_ath_sec,
|
| 199 |
-- Calculate Percentages using Peak Sums / Total Supply
|
| 200 |
(COALESCE(s.snipers_total_peak, 0) / t.adj_supply * 100) AS snipers_pct,
|
| 201 |
+
(COALESCE(b.bundled_total_peak, 0) / t.total_supply * 100) AS bundled_pct,
|
| 202 |
(COALESCE(d.dev_peak, 0) / t.adj_supply * 100) AS dev_hold_pct
|
| 203 |
+
FROM token_meta t
|
| 204 |
+
LEFT JOIN ret_agg r ON t.token_address = r.token_address
|
| 205 |
+
LEFT JOIN trade_agg f ON t.token_address = f.base_address
|
| 206 |
+
LEFT JOIN snipers_agg s ON t.token_address = s.token_address
|
| 207 |
+
LEFT JOIN bundled_agg b ON t.token_address = b.token_address
|
| 208 |
+
LEFT JOIN dev_hold_agg d ON t.token_address = d.token_address
|
| 209 |
"""
|
| 210 |
rows = client.execute(query)
|
| 211 |
cols = [
|
|
|
|
| 226 |
return out
|
| 227 |
|
| 228 |
|
| 229 |
+
def compute_quality_scores(
|
| 230 |
client,
|
| 231 |
max_ret: float = 10000.0,
|
| 232 |
rerank: bool = True,
|
|
|
|
| 244 |
("fees_per_trade", lambda d: (d["fees_sol"] / (d["n_trades"] + EPS)) if d["fees_sol"] is not None and d["n_trades"] is not None else None, True),
|
| 245 |
("holders_per_trade", lambda d: (d["unique_holders"] / (d["n_trades"] + EPS)) if d["unique_holders"] is not None and d["n_trades"] is not None else None, True),
|
| 246 |
("holders_per_volume", lambda d: (d["unique_holders"] / (d["volume_usd"] + EPS)) if d["unique_holders"] is not None and d["volume_usd"] is not None else None, True),
|
| 247 |
+
("snipers_pct", lambda d: d["snipers_pct"], True),
|
| 248 |
+
("bundled_pct", lambda d: d["bundled_pct"], True),
|
| 249 |
+
("dev_hold_pct", lambda d: d["dev_hold_pct"], True),
|
| 250 |
]
|
| 251 |
|
| 252 |
+
raw_metrics = ["snipers_pct", "bundled_pct", "dev_hold_pct", "fees_sol"] # Added fees_sol for diagnostic logging
|
| 253 |
|
| 254 |
debug = None
|
| 255 |
if with_debug:
|
|
|
|
| 350 |
"ret": d["ret"],
|
| 351 |
"q_raw": q_raw_map[t],
|
| 352 |
"q": q_final,
|
| 353 |
+
# Pass through raw metrics for analysis
|
| 354 |
+
"bundled_pct": d.get("bundled_pct"),
|
| 355 |
+
"snipers_pct": d.get("snipers_pct"),
|
| 356 |
+
"fees_sol": d.get("fees_sol"),
|
| 357 |
}
|
| 358 |
)
|
| 359 |
else:
|
|
|
|
| 368 |
"ret": d["ret"],
|
| 369 |
"q_raw": q_raw_map[t],
|
| 370 |
"q": q_raw_map[t],
|
| 371 |
+
# Pass through raw metrics for analysis
|
| 372 |
+
"bundled_pct": d.get("bundled_pct"),
|
| 373 |
+
"snipers_pct": d.get("snipers_pct"),
|
| 374 |
+
"fees_sol": d.get("fees_sol"),
|
| 375 |
}
|
| 376 |
)
|
| 377 |
|
|
|
|
| 380 |
return token_scores
|
| 381 |
|
| 382 |
|
| 383 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
|
| 385 |
|
| 386 |
def write_jsonl(path: str, rows: List[dict]) -> None:
|
|
|
|
| 487 |
print(f" Mean: {stats_q_raw['mean']:.4f} | Min: {stats_q_raw['min']:.4f} | Max: {stats_q_raw['max']:.4f}")
|
| 488 |
print(f" Q: p10={stats_q_raw['p10']:.2f} p50={stats_q_raw['p50']:.2f} p90={stats_q_raw['p90']:.2f} p99={stats_q_raw['p99']:.2f}")
|
| 489 |
|
| 490 |
+
# --- NEW: Print 3 Examples (Min, Mid, Max) ---
|
| 491 |
+
if items:
|
| 492 |
+
# Sort items by 'q' to find min/mid/max easily
|
| 493 |
+
items_sorted = sorted(items, key=lambda x: x.get("q", 0))
|
| 494 |
+
|
| 495 |
+
ex_min = items_sorted[0]
|
| 496 |
+
ex_max = items_sorted[-1]
|
| 497 |
+
|
| 498 |
+
# Find mid (closest to 0.0, or just median index? Request said "mean quality" which is 0.0)
|
| 499 |
+
# finding item with q closest to 0.0
|
| 500 |
+
ex_mid = min(items_sorted, key=lambda x: abs(x.get("q", 0) - 0.0))
|
| 501 |
+
|
| 502 |
+
print(" Examples:")
|
| 503 |
+
print(f" Low (-1.0): {ex_min['token_address']} (q={ex_min.get('q',0):.4f}, ret={ex_min.get('ret',0):.2f}x)")
|
| 504 |
+
print(f" Mid (~0.0): {ex_mid['token_address']} (q={ex_mid.get('q',0):.4f}, ret={ex_mid.get('ret',0):.2f}x)")
|
| 505 |
+
print(f" High ( 1.0): {ex_max['token_address']} (q={ex_max.get('q',0):.4f}, ret={ex_max.get('ret',0):.2f}x)")
|
| 506 |
+
|
| 507 |
|
| 508 |
def print_diagnostics(debug: dict) -> None:
|
| 509 |
if not debug:
|
|
|
|
| 576 |
corr = _pearson_corr(xs, ys)
|
| 577 |
print(f" log(ret) vs {metric}: {corr:.4f} (n={len(xs)})")
|
| 578 |
|
| 579 |
+
# Removed placeholder
|
| 580 |
+
pass
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
def print_high_ret_analysis(scores: List[dict]) -> None:
|
| 584 |
+
print("\n=== MID-HIGH RETURN SPLIT ANALYSIS (10x - 20x) ===")
|
| 585 |
+
|
| 586 |
+
# 1. Filter for Mid-High Return Cohort (10x - 20x)
|
| 587 |
+
cohort = [s for s in scores if s.get("ret") is not None and s["ret"] >= 10.0 and s["ret"] < 20.0]
|
| 588 |
+
if not cohort:
|
| 589 |
+
print("No tokens 10x-20x found.")
|
| 590 |
+
return
|
| 591 |
+
|
| 592 |
+
print(f"Total tokens 10x-20x: {len(cohort)}")
|
| 593 |
+
|
| 594 |
+
# 2. Extract Bundled Pct
|
| 595 |
+
bundled_vals = [s.get("bundled_pct", 0) for s in cohort if s.get("bundled_pct") is not None]
|
| 596 |
+
if not bundled_vals:
|
| 597 |
+
print("No bundled_pct data found.")
|
| 598 |
+
return
|
| 599 |
+
|
| 600 |
+
median_bundled = _percentile(sorted(bundled_vals), 0.50)
|
| 601 |
+
print(f"Median Bundled% for Cohort: {median_bundled:.2f}%")
|
| 602 |
+
|
| 603 |
+
# 3. Split
|
| 604 |
+
low_group = [s for s in cohort if (s.get("bundled_pct") or 0) <= median_bundled]
|
| 605 |
+
high_group = [s for s in cohort if (s.get("bundled_pct") or 0) > median_bundled]
|
| 606 |
+
|
| 607 |
+
# 4. Analyze Fees
|
| 608 |
+
def get_mean_fees(group):
|
| 609 |
+
fees = [s.get("fees_sol", 0) for s in group if s.get("fees_sol") is not None]
|
| 610 |
+
if not fees: return 0.0
|
| 611 |
+
return sum(fees) / len(fees)
|
| 612 |
+
|
| 613 |
+
mean_fees_low = get_mean_fees(low_group)
|
| 614 |
+
mean_fees_high = get_mean_fees(high_group)
|
| 615 |
+
|
| 616 |
+
print(f"\nGroup 1: LOW Bundled (<= {median_bundled:.2f}%)")
|
| 617 |
+
print(f" Count: {len(low_group)}")
|
| 618 |
+
print(f" Mean Fees: {mean_fees_low:.4f} SOL")
|
| 619 |
+
|
| 620 |
+
print(f"\nGroup 2: HIGH Bundled (> {median_bundled:.2f}%)")
|
| 621 |
+
print(f" Count: {len(high_group)}")
|
| 622 |
+
print(f" Mean Fees: {mean_fees_high:.4f} SOL")
|
| 623 |
+
|
| 624 |
+
# Extra: Check returns too
|
| 625 |
+
def get_mean_ret(group):
|
| 626 |
+
rets = [s["ret"] for s in group]
|
| 627 |
+
if not rets: return 0.0
|
| 628 |
+
return sum(rets) / len(rets)
|
| 629 |
+
|
| 630 |
+
print(f" Mean Ret: {get_mean_ret(high_group):.2f}x (vs Low: {get_mean_ret(low_group):.2f}x)")
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
def get_token_quality_scores(client):
|
| 634 |
+
"""
|
| 635 |
+
Returns a dictionary mapping token_address -> q (quality score)
|
| 636 |
+
"""
|
| 637 |
+
# Force rerank=True to get final scores
|
| 638 |
+
results = compute_quality_scores(client, max_ret=1e9, rerank=True)
|
| 639 |
+
|
| 640 |
+
# Return mapping
|
| 641 |
+
# If compute_quality_scores returns (scores, debug) tuple (when with_debug=True), handle it.
|
| 642 |
+
# Default call rerank=True returns 'scores' list if with_debug=False?
|
| 643 |
+
# No, looking at main, it returns 'scores' if no_diagnostics.
|
| 644 |
+
# But get_token_quality_scores uses default args.
|
| 645 |
+
# Let's check compute_quality_score signature... it has with_debug=False default.
|
| 646 |
+
# So it returns 'scores'.
|
| 647 |
+
|
| 648 |
+
return {r["token_address"]: r.get("q", 0.0) for r in results}
|
| 649 |
+
|
| 650 |
|
| 651 |
def main():
|
| 652 |
parser = argparse.ArgumentParser(description="Compute token quality/health score.")
|
|
|
|
| 661 |
scores = compute_quality_scores(client, max_ret=args.max_ret, rerank=not args.no_rerank)
|
| 662 |
debug = None
|
| 663 |
else:
|
| 664 |
+
scores, debug = compute_quality_scores(
|
| 665 |
client,
|
| 666 |
max_ret=args.max_ret,
|
| 667 |
rerank=not args.no_rerank,
|
|
|
|
| 671 |
print_summary(scores)
|
| 672 |
if not args.no_diagnostics:
|
| 673 |
print_diagnostics(debug)
|
| 674 |
+
print_high_ret_analysis(scores) # Call the new analysis
|
| 675 |
|
| 676 |
|
| 677 |
if __name__ == "__main__":
|
token_stats.rs
ADDED
|
@@ -0,0 +1,857 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
use crate::database::insert_rows;
|
| 2 |
+
use crate::services::price_service::PriceService;
|
| 3 |
+
use crate::types::{
|
| 4 |
+
EventPayload, EventType, MigrationRow, MintRow, TokenMetricsRow, TokenStaticRow, TradeRow,
|
| 5 |
+
};
|
| 6 |
+
use anyhow::{Context, Result, anyhow};
|
| 7 |
+
use borsh::BorshDeserialize;
|
| 8 |
+
use clickhouse::Client;
|
| 9 |
+
use futures_util::future;
|
| 10 |
+
use mpl_token_metadata::accounts::Metadata;
|
| 11 |
+
use once_cell::sync::Lazy;
|
| 12 |
+
use redis::aio::MultiplexedConnection;
|
| 13 |
+
use redis::streams::{StreamReadOptions, StreamReadReply};
|
| 14 |
+
use redis::{AsyncCommands, Client as RedisClient, FromRedisValue};
|
| 15 |
+
use solana_client::nonblocking::rpc_client::RpcClient;
|
| 16 |
+
use solana_program::program_pack::Pack;
|
| 17 |
+
use solana_sdk::pubkey::Pubkey;
|
| 18 |
+
use spl_token::state::Mint;
|
| 19 |
+
use std::collections::{HashMap, HashSet};
|
| 20 |
+
use std::env;
|
| 21 |
+
use std::str::FromStr;
|
| 22 |
+
use std::sync::Arc;
|
| 23 |
+
use std::time::Duration;
|
| 24 |
+
use tokio::sync::RwLock;
|
| 25 |
+
|
| 26 |
+
type TokenCache = HashMap<String, TokenEntry>;
|
| 27 |
+
|
| 28 |
+
fn env_parse<T: FromStr>(key: &str, default: T) -> T {
|
| 29 |
+
env::var(key)
|
| 30 |
+
.ok()
|
| 31 |
+
.and_then(|v| v.parse::<T>().ok())
|
| 32 |
+
.unwrap_or(default)
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
static TOKEN_STATS_CHUNK_SIZE: Lazy<usize> =
|
| 36 |
+
Lazy::new(|| env_parse("TOKEN_STATS_CHUNK_SIZE", 1000usize));
|
| 37 |
+
|
| 38 |
+
#[derive(Debug, Clone)]
|
| 39 |
+
struct TokenEntry {
|
| 40 |
+
token: TokenStaticRow,
|
| 41 |
+
metrics: TokenMetricsRow,
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
impl TokenEntry {
|
| 45 |
+
fn new(token: TokenStaticRow, metrics: Option<TokenMetricsRow>) -> Self {
|
| 46 |
+
let metrics = metrics
|
| 47 |
+
.unwrap_or_else(|| TokenMetricsRow::new(token.token_address.clone(), token.updated_at));
|
| 48 |
+
Self { token, metrics }
|
| 49 |
+
}
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
#[derive(Clone, Debug)]
|
| 53 |
+
struct TokenContext {
|
| 54 |
+
timestamp: u32,
|
| 55 |
+
protocol: Option<u8>,
|
| 56 |
+
pool_address: Option<String>,
|
| 57 |
+
decimals: Option<u8>,
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
impl TokenContext {
|
| 61 |
+
fn new(
|
| 62 |
+
timestamp: u32,
|
| 63 |
+
protocol: Option<u8>,
|
| 64 |
+
pool_address: Option<String>,
|
| 65 |
+
decimals: Option<u8>,
|
| 66 |
+
) -> Self {
|
| 67 |
+
Self {
|
| 68 |
+
timestamp,
|
| 69 |
+
protocol,
|
| 70 |
+
pool_address,
|
| 71 |
+
decimals,
|
| 72 |
+
}
|
| 73 |
+
}
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
fn record_token_context(
|
| 77 |
+
contexts: &mut HashMap<String, TokenContext>,
|
| 78 |
+
token_address: &str,
|
| 79 |
+
timestamp: u32,
|
| 80 |
+
protocol: Option<u8>,
|
| 81 |
+
pool_address: Option<String>,
|
| 82 |
+
decimals: Option<u8>,
|
| 83 |
+
) {
|
| 84 |
+
if token_address.is_empty() {
|
| 85 |
+
return;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
let mut pool_for_insert = pool_address.clone();
|
| 89 |
+
let entry = contexts
|
| 90 |
+
.entry(token_address.to_string())
|
| 91 |
+
.or_insert_with(|| {
|
| 92 |
+
TokenContext::new(timestamp, protocol, pool_for_insert.take(), decimals)
|
| 93 |
+
});
|
| 94 |
+
|
| 95 |
+
if timestamp < entry.timestamp {
|
| 96 |
+
entry.timestamp = timestamp;
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
if entry.protocol.is_none() {
|
| 100 |
+
entry.protocol = protocol;
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
let should_update_pool = entry
|
| 104 |
+
.pool_address
|
| 105 |
+
.as_ref()
|
| 106 |
+
.map(|p| p.is_empty())
|
| 107 |
+
.unwrap_or(true);
|
| 108 |
+
if should_update_pool {
|
| 109 |
+
if let Some(pool) = pool_address {
|
| 110 |
+
if !pool.is_empty() {
|
| 111 |
+
entry.pool_address = Some(pool);
|
| 112 |
+
}
|
| 113 |
+
}
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
if let Some(dec) = decimals {
|
| 117 |
+
entry.decimals = Some(dec);
|
| 118 |
+
}
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
fn pool_addresses_from_context(context: &TokenContext) -> Vec<String> {
|
| 122 |
+
context
|
| 123 |
+
.pool_address
|
| 124 |
+
.as_ref()
|
| 125 |
+
.filter(|addr| !addr.is_empty())
|
| 126 |
+
.map(|addr| vec![addr.clone()])
|
| 127 |
+
.unwrap_or_default()
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
fn event_success(event: &EventType) -> bool {
|
| 131 |
+
match event {
|
| 132 |
+
EventType::Trade(row) => row.success,
|
| 133 |
+
EventType::Mint(row) => row.success,
|
| 134 |
+
EventType::Migration(row) => row.success,
|
| 135 |
+
EventType::FeeCollection(row) => row.success,
|
| 136 |
+
EventType::Liquidity(row) => row.success,
|
| 137 |
+
EventType::PoolCreation(row) => row.success,
|
| 138 |
+
EventType::Transfer(row) => row.success,
|
| 139 |
+
EventType::SupplyLock(row) => row.success,
|
| 140 |
+
EventType::SupplyLockAction(row) => row.success,
|
| 141 |
+
EventType::Burn(row) => row.success,
|
| 142 |
+
}
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
pub struct TokenAggregator {
|
| 146 |
+
db_client: Client,
|
| 147 |
+
redis_conn: MultiplexedConnection,
|
| 148 |
+
rpc_client: Arc<RpcClient>,
|
| 149 |
+
price_service: PriceService,
|
| 150 |
+
backfill_mode: bool,
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
impl TokenAggregator {
|
| 154 |
+
pub async fn new(
|
| 155 |
+
db_client: Client,
|
| 156 |
+
redis_client: RedisClient,
|
| 157 |
+
rpc_client: Arc<RpcClient>,
|
| 158 |
+
price_service: PriceService,
|
| 159 |
+
) -> Result<Self> {
|
| 160 |
+
let redis_conn = redis_client.get_multiplexed_async_connection().await?;
|
| 161 |
+
println!("[TokenAggregator] ✔️ Connected to ClickHouse, Redis, and Solana RPC.");
|
| 162 |
+
|
| 163 |
+
let backfill_mode =
|
| 164 |
+
env::var("BACKFILL_MODE").unwrap_or_else(|_| "false".to_string()) == "true";
|
| 165 |
+
Ok(Self {
|
| 166 |
+
db_client,
|
| 167 |
+
redis_conn,
|
| 168 |
+
rpc_client,
|
| 169 |
+
price_service,
|
| 170 |
+
backfill_mode,
|
| 171 |
+
})
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
pub async fn run(&mut self) -> Result<()> {
|
| 175 |
+
let stream_key = "event_queue";
|
| 176 |
+
let group_name = "token_aggregators";
|
| 177 |
+
let consumer_name = format!("consumer-tokens-{}", uuid::Uuid::new_v4());
|
| 178 |
+
|
| 179 |
+
let mut publisher_conn = self.redis_conn.clone();
|
| 180 |
+
let next_queue = "wallet_agg_queue";
|
| 181 |
+
|
| 182 |
+
let result: redis::RedisResult<()> = self
|
| 183 |
+
.redis_conn
|
| 184 |
+
.xgroup_create_mkstream(stream_key, group_name, "0")
|
| 185 |
+
.await;
|
| 186 |
+
if let Err(e) = result {
|
| 187 |
+
if !e.to_string().contains("BUSYGROUP") {
|
| 188 |
+
return Err(anyhow!(
|
| 189 |
+
"[TokenAggregator] Failed to create consumer group: {}",
|
| 190 |
+
e
|
| 191 |
+
));
|
| 192 |
+
}
|
| 193 |
+
println!(
|
| 194 |
+
"[TokenAggregator] Consumer group '{}' already exists. Resuming.",
|
| 195 |
+
group_name
|
| 196 |
+
);
|
| 197 |
+
} else {
|
| 198 |
+
println!(
|
| 199 |
+
"[TokenAggregator] Created new consumer group '{}'.",
|
| 200 |
+
group_name
|
| 201 |
+
);
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
loop {
|
| 205 |
+
let messages = match self
|
| 206 |
+
.collect_events(stream_key, group_name, &consumer_name)
|
| 207 |
+
.await
|
| 208 |
+
{
|
| 209 |
+
Ok(msgs) => msgs,
|
| 210 |
+
Err(e) => {
|
| 211 |
+
eprintln!(
|
| 212 |
+
"[TokenAggregator] 🔴 Error reading from Redis: {}. Retrying...",
|
| 213 |
+
e
|
| 214 |
+
);
|
| 215 |
+
tokio::time::sleep(Duration::from_secs(5)).await;
|
| 216 |
+
continue;
|
| 217 |
+
}
|
| 218 |
+
};
|
| 219 |
+
if messages.is_empty() {
|
| 220 |
+
continue;
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
println!(
|
| 224 |
+
"[TokenAggregator] ⚙️ Starting processing for a new batch of {} events...",
|
| 225 |
+
messages.len()
|
| 226 |
+
);
|
| 227 |
+
let message_ids: Vec<String> = messages.iter().map(|(id, _)| id.clone()).collect();
|
| 228 |
+
let payloads: Vec<EventPayload> =
|
| 229 |
+
messages.into_iter().map(|(_, payload)| payload).collect();
|
| 230 |
+
|
| 231 |
+
match self.process_batch(payloads.clone()).await {
|
| 232 |
+
// Clone payloads to use them after processing
|
| 233 |
+
Ok(_) => {
|
| 234 |
+
if !message_ids.is_empty() {
|
| 235 |
+
// Forward each payload to the next queue in the pipeline
|
| 236 |
+
for payload in payloads {
|
| 237 |
+
let payload_data = bincode::serialize(&payload)?;
|
| 238 |
+
let _: () = publisher_conn
|
| 239 |
+
.xadd(next_queue, "*", &[("payload", payload_data)])
|
| 240 |
+
.await?;
|
| 241 |
+
}
|
| 242 |
+
println!(
|
| 243 |
+
"[TokenAggregator] ✅ Finished batch, forwarded {} events to {}.",
|
| 244 |
+
message_ids.len(),
|
| 245 |
+
next_queue
|
| 246 |
+
);
|
| 247 |
+
|
| 248 |
+
// Acknowledge the message from the source queue ('event_queue')
|
| 249 |
+
let _: () = self
|
| 250 |
+
.redis_conn
|
| 251 |
+
.xack(stream_key, group_name, &message_ids)
|
| 252 |
+
.await?;
|
| 253 |
+
let _: i64 = self
|
| 254 |
+
.redis_conn
|
| 255 |
+
.xdel::<_, _, i64>(stream_key, &message_ids)
|
| 256 |
+
.await?;
|
| 257 |
+
}
|
| 258 |
+
}
|
| 259 |
+
Err(e) => {
|
| 260 |
+
eprintln!(
|
| 261 |
+
"[TokenAggregator] ❌ Failed to process batch, will not forward or ACK. Error: {}",
|
| 262 |
+
e
|
| 263 |
+
);
|
| 264 |
+
}
|
| 265 |
+
}
|
| 266 |
+
}
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
async fn process_batch(&self, payloads: Vec<EventPayload>) -> Result<()> {
|
| 270 |
+
let mut token_contexts: HashMap<String, TokenContext> = HashMap::new();
|
| 271 |
+
for payload in &payloads {
|
| 272 |
+
if !event_success(&payload.event) {
|
| 273 |
+
continue;
|
| 274 |
+
}
|
| 275 |
+
let decimals_map = &payload.token_decimals;
|
| 276 |
+
match &payload.event {
|
| 277 |
+
EventType::Trade(t) => {
|
| 278 |
+
let pool = (!t.pool_address.is_empty()).then(|| t.pool_address.clone());
|
| 279 |
+
record_token_context(
|
| 280 |
+
&mut token_contexts,
|
| 281 |
+
&t.base_address,
|
| 282 |
+
t.timestamp,
|
| 283 |
+
Some(t.protocol),
|
| 284 |
+
pool.clone(),
|
| 285 |
+
decimals_map.get(&t.base_address).cloned(),
|
| 286 |
+
);
|
| 287 |
+
record_token_context(
|
| 288 |
+
&mut token_contexts,
|
| 289 |
+
&t.quote_address,
|
| 290 |
+
t.timestamp,
|
| 291 |
+
Some(t.protocol),
|
| 292 |
+
pool,
|
| 293 |
+
decimals_map.get(&t.quote_address).cloned(),
|
| 294 |
+
);
|
| 295 |
+
}
|
| 296 |
+
EventType::Mint(m) => {
|
| 297 |
+
record_token_context(
|
| 298 |
+
&mut token_contexts,
|
| 299 |
+
&m.mint_address,
|
| 300 |
+
m.timestamp,
|
| 301 |
+
Some(m.protocol),
|
| 302 |
+
(!m.pool_address.is_empty()).then(|| m.pool_address.clone()),
|
| 303 |
+
Some(m.token_decimals),
|
| 304 |
+
);
|
| 305 |
+
}
|
| 306 |
+
EventType::Migration(m) => {
|
| 307 |
+
record_token_context(
|
| 308 |
+
&mut token_contexts,
|
| 309 |
+
&m.mint_address,
|
| 310 |
+
m.timestamp,
|
| 311 |
+
Some(m.protocol),
|
| 312 |
+
(!m.pool_address.is_empty()).then(|| m.pool_address.clone()),
|
| 313 |
+
decimals_map.get(&m.mint_address).cloned(),
|
| 314 |
+
);
|
| 315 |
+
}
|
| 316 |
+
EventType::FeeCollection(f) => {
|
| 317 |
+
let vault = (!f.vault_address.is_empty()).then(|| f.vault_address.clone());
|
| 318 |
+
record_token_context(
|
| 319 |
+
&mut token_contexts,
|
| 320 |
+
&f.token_0_mint_address,
|
| 321 |
+
f.timestamp,
|
| 322 |
+
Some(f.protocol),
|
| 323 |
+
vault.clone(),
|
| 324 |
+
decimals_map.get(&f.token_0_mint_address).cloned(),
|
| 325 |
+
);
|
| 326 |
+
if let Some(token_1) = &f.token_1_mint_address {
|
| 327 |
+
record_token_context(
|
| 328 |
+
&mut token_contexts,
|
| 329 |
+
token_1,
|
| 330 |
+
f.timestamp,
|
| 331 |
+
Some(f.protocol),
|
| 332 |
+
vault.clone(),
|
| 333 |
+
decimals_map.get(token_1).cloned(),
|
| 334 |
+
);
|
| 335 |
+
}
|
| 336 |
+
}
|
| 337 |
+
EventType::PoolCreation(p) => {
|
| 338 |
+
record_token_context(
|
| 339 |
+
&mut token_contexts,
|
| 340 |
+
&p.base_address,
|
| 341 |
+
p.timestamp,
|
| 342 |
+
Some(p.protocol),
|
| 343 |
+
(!p.pool_address.is_empty()).then(|| p.pool_address.clone()),
|
| 344 |
+
p.base_decimals
|
| 345 |
+
.or_else(|| decimals_map.get(&p.base_address).cloned()),
|
| 346 |
+
);
|
| 347 |
+
record_token_context(
|
| 348 |
+
&mut token_contexts,
|
| 349 |
+
&p.quote_address,
|
| 350 |
+
p.timestamp,
|
| 351 |
+
Some(p.protocol),
|
| 352 |
+
(!p.pool_address.is_empty()).then(|| p.pool_address.clone()),
|
| 353 |
+
p.quote_decimals
|
| 354 |
+
.or_else(|| decimals_map.get(&p.quote_address).cloned()),
|
| 355 |
+
);
|
| 356 |
+
}
|
| 357 |
+
EventType::Transfer(t) => {
|
| 358 |
+
record_token_context(
|
| 359 |
+
&mut token_contexts,
|
| 360 |
+
&t.mint_address,
|
| 361 |
+
t.timestamp,
|
| 362 |
+
None,
|
| 363 |
+
None,
|
| 364 |
+
decimals_map.get(&t.mint_address).cloned(),
|
| 365 |
+
);
|
| 366 |
+
}
|
| 367 |
+
EventType::SupplyLock(lock) => {
|
| 368 |
+
record_token_context(
|
| 369 |
+
&mut token_contexts,
|
| 370 |
+
&lock.mint_address,
|
| 371 |
+
lock.timestamp,
|
| 372 |
+
Some(lock.protocol),
|
| 373 |
+
None,
|
| 374 |
+
decimals_map.get(&lock.mint_address).cloned(),
|
| 375 |
+
);
|
| 376 |
+
}
|
| 377 |
+
EventType::SupplyLockAction(action) => {
|
| 378 |
+
record_token_context(
|
| 379 |
+
&mut token_contexts,
|
| 380 |
+
&action.mint_address,
|
| 381 |
+
action.timestamp,
|
| 382 |
+
Some(action.protocol),
|
| 383 |
+
None,
|
| 384 |
+
decimals_map.get(&action.mint_address).cloned(),
|
| 385 |
+
);
|
| 386 |
+
}
|
| 387 |
+
EventType::Burn(burn) => {
|
| 388 |
+
record_token_context(
|
| 389 |
+
&mut token_contexts,
|
| 390 |
+
&burn.mint_address,
|
| 391 |
+
burn.timestamp,
|
| 392 |
+
None,
|
| 393 |
+
None,
|
| 394 |
+
decimals_map.get(&burn.mint_address).cloned(),
|
| 395 |
+
);
|
| 396 |
+
}
|
| 397 |
+
EventType::Liquidity(_) => {}
|
| 398 |
+
_ => {}
|
| 399 |
+
}
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
if token_contexts.is_empty() {
|
| 403 |
+
println!("[TokenAggregator] -> Batch contains no relevant token events. Skipping.");
|
| 404 |
+
return Ok(());
|
| 405 |
+
}
|
| 406 |
+
println!(
|
| 407 |
+
"[TokenAggregator] -> Batch contains {} unique tokens.",
|
| 408 |
+
token_contexts.len()
|
| 409 |
+
);
|
| 410 |
+
|
| 411 |
+
let mut tokens = self
|
| 412 |
+
.fetch_tokens_from_db(&token_contexts.keys().cloned().collect::<Vec<_>>())
|
| 413 |
+
.await?;
|
| 414 |
+
|
| 415 |
+
let missing_tokens: Vec<String> = token_contexts
|
| 416 |
+
.keys()
|
| 417 |
+
.filter(|address| !tokens.contains_key(*address))
|
| 418 |
+
.cloned()
|
| 419 |
+
.collect();
|
| 420 |
+
|
| 421 |
+
if !missing_tokens.is_empty() {
|
| 422 |
+
println!(
|
| 423 |
+
"[TokenAggregator] -> Found {} new tokens to fetch metadata for.",
|
| 424 |
+
missing_tokens.len()
|
| 425 |
+
);
|
| 426 |
+
|
| 427 |
+
if !self.backfill_mode {
|
| 428 |
+
let fetch_futures = missing_tokens
|
| 429 |
+
.iter()
|
| 430 |
+
.map(|key| async move { (key.clone(), self.fetch_token_metadata(key).await) });
|
| 431 |
+
let fetched_results = future::join_all(fetch_futures).await;
|
| 432 |
+
|
| 433 |
+
for (key, rpc_result) in fetched_results {
|
| 434 |
+
let context = match token_contexts.get(&key) {
|
| 435 |
+
Some(ctx) => ctx.clone(),
|
| 436 |
+
None => continue,
|
| 437 |
+
};
|
| 438 |
+
let protocol = context.protocol.unwrap_or(0);
|
| 439 |
+
let token_row = match rpc_result {
|
| 440 |
+
Ok((metadata, mint_data)) => {
|
| 441 |
+
println!(
|
| 442 |
+
"[TokenAggregator] -> ✅ Successfully fetched metadata for new token {}.",
|
| 443 |
+
key
|
| 444 |
+
);
|
| 445 |
+
|
| 446 |
+
let creator = metadata
|
| 447 |
+
.creators
|
| 448 |
+
.as_ref()
|
| 449 |
+
.and_then(|creators| creators.first())
|
| 450 |
+
.map(|c| c.address.to_string())
|
| 451 |
+
.unwrap_or_default();
|
| 452 |
+
|
| 453 |
+
TokenStaticRow::new(
|
| 454 |
+
key.clone(),
|
| 455 |
+
context.timestamp,
|
| 456 |
+
metadata.name.trim_end_matches('\0').to_string(),
|
| 457 |
+
metadata.symbol.trim_end_matches('\0').to_string(),
|
| 458 |
+
metadata.uri.trim_end_matches('\0').to_string(),
|
| 459 |
+
mint_data.decimals,
|
| 460 |
+
creator,
|
| 461 |
+
pool_addresses_from_context(&context),
|
| 462 |
+
protocol,
|
| 463 |
+
mint_data.supply,
|
| 464 |
+
metadata.is_mutable,
|
| 465 |
+
Some(metadata.update_authority.to_string()),
|
| 466 |
+
Option::from(mint_data.mint_authority)
|
| 467 |
+
.map(|pk: Pubkey| pk.to_string()),
|
| 468 |
+
Option::from(mint_data.freeze_authority)
|
| 469 |
+
.map(|pk: Pubkey| pk.to_string()),
|
| 470 |
+
)
|
| 471 |
+
}
|
| 472 |
+
Err(e) => {
|
| 473 |
+
eprintln!(
|
| 474 |
+
"[TokenAggregator] -> ❌ RPC failed for {}: {}. Creating placeholder.",
|
| 475 |
+
key, e
|
| 476 |
+
);
|
| 477 |
+
TokenStaticRow::new(
|
| 478 |
+
key.clone(),
|
| 479 |
+
context.timestamp,
|
| 480 |
+
String::new(),
|
| 481 |
+
String::new(),
|
| 482 |
+
String::new(),
|
| 483 |
+
context.decimals.unwrap_or(0),
|
| 484 |
+
String::new(),
|
| 485 |
+
pool_addresses_from_context(&context),
|
| 486 |
+
protocol,
|
| 487 |
+
0,
|
| 488 |
+
true,
|
| 489 |
+
None,
|
| 490 |
+
None,
|
| 491 |
+
None,
|
| 492 |
+
)
|
| 493 |
+
}
|
| 494 |
+
};
|
| 495 |
+
tokens.insert(key.clone(), TokenEntry::new(token_row, None));
|
| 496 |
+
}
|
| 497 |
+
} else {
|
| 498 |
+
println!(
|
| 499 |
+
"[TokenAggregator] -> Creating {} placeholder tokens in backfill mode.",
|
| 500 |
+
missing_tokens.len()
|
| 501 |
+
);
|
| 502 |
+
for key in missing_tokens {
|
| 503 |
+
if let Some(context) = token_contexts.get(&key) {
|
| 504 |
+
let placeholder_row = TokenStaticRow::new(
|
| 505 |
+
key.clone(),
|
| 506 |
+
context.timestamp,
|
| 507 |
+
String::new(),
|
| 508 |
+
String::new(),
|
| 509 |
+
String::new(),
|
| 510 |
+
context.decimals.unwrap_or(0),
|
| 511 |
+
String::new(),
|
| 512 |
+
pool_addresses_from_context(context),
|
| 513 |
+
context.protocol.unwrap_or(0),
|
| 514 |
+
0,
|
| 515 |
+
false,
|
| 516 |
+
None,
|
| 517 |
+
None,
|
| 518 |
+
None,
|
| 519 |
+
);
|
| 520 |
+
tokens.insert(key.clone(), TokenEntry::new(placeholder_row, None));
|
| 521 |
+
}
|
| 522 |
+
}
|
| 523 |
+
}
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
let trader_pairs_in_batch: Vec<(String, String)> = payloads
|
| 527 |
+
.iter()
|
| 528 |
+
.filter_map(|p| {
|
| 529 |
+
if let EventType::Trade(t) = &p.event {
|
| 530 |
+
Some((t.base_address.clone(), t.maker.clone()))
|
| 531 |
+
} else {
|
| 532 |
+
None
|
| 533 |
+
}
|
| 534 |
+
})
|
| 535 |
+
.collect();
|
| 536 |
+
|
| 537 |
+
let mut existing_traders = HashSet::new();
|
| 538 |
+
if !trader_pairs_in_batch.is_empty() {
|
| 539 |
+
for chunk in trader_pairs_in_batch.chunks(*TOKEN_STATS_CHUNK_SIZE) {
|
| 540 |
+
let mut cursor = self.db_client
|
| 541 |
+
.query("SELECT DISTINCT (mint_address, wallet_address) FROM wallet_holdings WHERE (mint_address, wallet_address) IN ?")
|
| 542 |
+
.bind(chunk)
|
| 543 |
+
.fetch::<(String, String)>()?;
|
| 544 |
+
|
| 545 |
+
while let Some(pair) = cursor.next().await? {
|
| 546 |
+
existing_traders.insert(pair);
|
| 547 |
+
}
|
| 548 |
+
}
|
| 549 |
+
}
|
| 550 |
+
|
| 551 |
+
let mut counted_in_this_batch: HashSet<(String, String)> = HashSet::new();
|
| 552 |
+
|
| 553 |
+
for payload in payloads.iter() {
|
| 554 |
+
if !event_success(&payload.event) {
|
| 555 |
+
continue;
|
| 556 |
+
}
|
| 557 |
+
match &payload.event {
|
| 558 |
+
EventType::Mint(mint) => self.process_mint(mint, &mut tokens),
|
| 559 |
+
EventType::Trade(trade) => {
|
| 560 |
+
self.process_trade(
|
| 561 |
+
trade,
|
| 562 |
+
&mut tokens,
|
| 563 |
+
&existing_traders,
|
| 564 |
+
&mut counted_in_this_batch,
|
| 565 |
+
);
|
| 566 |
+
}
|
| 567 |
+
EventType::Migration(migration) => self.process_migration(migration, &mut tokens),
|
| 568 |
+
_ => {}
|
| 569 |
+
}
|
| 570 |
+
}
|
| 571 |
+
|
| 572 |
+
self.finalize_and_persist(tokens).await
|
| 573 |
+
}
|
| 574 |
+
|
| 575 |
+
fn process_trade(
|
| 576 |
+
&self,
|
| 577 |
+
trade: &TradeRow,
|
| 578 |
+
tokens: &mut TokenCache,
|
| 579 |
+
existing_traders: &HashSet<(String, String)>,
|
| 580 |
+
counted_in_this_batch: &mut HashSet<(String, String)>,
|
| 581 |
+
) {
|
| 582 |
+
if let Some(entry) = tokens.get_mut(&trade.base_address) {
|
| 583 |
+
entry.token.updated_at = trade.timestamp;
|
| 584 |
+
entry.metrics.updated_at = trade.timestamp;
|
| 585 |
+
|
| 586 |
+
// --- START: CORRECT UNIQUE HOLDER LOGIC ---
|
| 587 |
+
|
| 588 |
+
let current_pair = (trade.base_address.clone(), trade.maker.clone());
|
| 589 |
+
|
| 590 |
+
// We only increment the counter if:
|
| 591 |
+
// 1. The trader is NOT in the set of traders we know about from the database.
|
| 592 |
+
// 2. We have NOT already counted this trader for this token in this batch.
|
| 593 |
+
if !existing_traders.contains(¤t_pair) {
|
| 594 |
+
// The .insert() returns true only the first time we see this pair in this batch.
|
| 595 |
+
if counted_in_this_batch.insert(current_pair) {
|
| 596 |
+
entry.metrics.unique_holders += 1;
|
| 597 |
+
}
|
| 598 |
+
}
|
| 599 |
+
|
| 600 |
+
let trade_total_in_usd = trade.total_usd;
|
| 601 |
+
|
| 602 |
+
entry.metrics.total_volume_usd += trade_total_in_usd;
|
| 603 |
+
entry.metrics.ath_price_usd = entry.metrics.ath_price_usd.max(trade.price_usd);
|
| 604 |
+
|
| 605 |
+
if trade.trade_type == 0 {
|
| 606 |
+
// Buy
|
| 607 |
+
entry.metrics.total_buys += 1;
|
| 608 |
+
} else {
|
| 609 |
+
// Sell
|
| 610 |
+
entry.metrics.total_sells += 1;
|
| 611 |
+
}
|
| 612 |
+
}
|
| 613 |
+
}
|
| 614 |
+
|
| 615 |
+
async fn fetch_tokens_from_db(&self, keys: &[String]) -> Result<TokenCache> {
|
| 616 |
+
if keys.is_empty() {
|
| 617 |
+
return Ok(HashMap::new());
|
| 618 |
+
}
|
| 619 |
+
let query_str = "
|
| 620 |
+
SELECT
|
| 621 |
+
*
|
| 622 |
+
FROM tokens_latest
|
| 623 |
+
WHERE token_address IN ?
|
| 624 |
+
";
|
| 625 |
+
|
| 626 |
+
let mut statics = HashMap::new();
|
| 627 |
+
for chunk in keys.chunks(*TOKEN_STATS_CHUNK_SIZE) {
|
| 628 |
+
let mut cursor = self
|
| 629 |
+
.db_client
|
| 630 |
+
.query(query_str)
|
| 631 |
+
.bind(chunk)
|
| 632 |
+
.fetch::<TokenStaticRow>()?;
|
| 633 |
+
|
| 634 |
+
while let Ok(Some(token)) = cursor.next().await {
|
| 635 |
+
statics.insert(token.token_address.clone(), token);
|
| 636 |
+
}
|
| 637 |
+
}
|
| 638 |
+
|
| 639 |
+
let metrics_map = self.fetch_token_metrics(keys).await?;
|
| 640 |
+
let mut tokens = HashMap::new();
|
| 641 |
+
|
| 642 |
+
for (address, token) in statics {
|
| 643 |
+
let metrics = metrics_map.get(&address).cloned();
|
| 644 |
+
tokens.insert(address.clone(), TokenEntry::new(token, metrics));
|
| 645 |
+
}
|
| 646 |
+
|
| 647 |
+
Ok(tokens)
|
| 648 |
+
}
|
| 649 |
+
|
| 650 |
+
async fn fetch_token_metrics(
|
| 651 |
+
&self,
|
| 652 |
+
keys: &[String],
|
| 653 |
+
) -> Result<HashMap<String, TokenMetricsRow>> {
|
| 654 |
+
if keys.is_empty() {
|
| 655 |
+
return Ok(HashMap::new());
|
| 656 |
+
}
|
| 657 |
+
|
| 658 |
+
let query_str = "
|
| 659 |
+
SELECT
|
| 660 |
+
*
|
| 661 |
+
FROM token_metrics_latest
|
| 662 |
+
WHERE token_address IN ?
|
| 663 |
+
ORDER BY token_address, updated_at DESC
|
| 664 |
+
LIMIT 1 BY token_address
|
| 665 |
+
";
|
| 666 |
+
|
| 667 |
+
let mut metrics = HashMap::new();
|
| 668 |
+
|
| 669 |
+
for chunk in keys.chunks(*TOKEN_STATS_CHUNK_SIZE) {
|
| 670 |
+
let mut cursor = self
|
| 671 |
+
.db_client
|
| 672 |
+
.query(query_str)
|
| 673 |
+
.bind(chunk)
|
| 674 |
+
.fetch::<TokenMetricsRow>()?;
|
| 675 |
+
|
| 676 |
+
while let Ok(Some(row)) = cursor.next().await {
|
| 677 |
+
metrics.insert(row.token_address.clone(), row);
|
| 678 |
+
}
|
| 679 |
+
}
|
| 680 |
+
|
| 681 |
+
Ok(metrics)
|
| 682 |
+
}
|
| 683 |
+
|
| 684 |
+
async fn fetch_token_metadata(&self, mint_address_str: &str) -> Result<(Metadata, Mint)> {
|
| 685 |
+
let mint_pubkey = Pubkey::from_str(mint_address_str)?;
|
| 686 |
+
let metadata_pubkey = Metadata::find_pda(&mint_pubkey).0;
|
| 687 |
+
|
| 688 |
+
let (mint_account_res, metadata_account_res) = future::join(
|
| 689 |
+
self.rpc_client.get_account(&mint_pubkey),
|
| 690 |
+
self.rpc_client.get_account(&metadata_pubkey),
|
| 691 |
+
)
|
| 692 |
+
.await;
|
| 693 |
+
|
| 694 |
+
let mint_account = mint_account_res?;
|
| 695 |
+
let metadata_account = metadata_account_res?;
|
| 696 |
+
|
| 697 |
+
let mint_data = Mint::unpack(&mint_account.data)?;
|
| 698 |
+
let metadata = Metadata::deserialize(&mut &metadata_account.data[..])?;
|
| 699 |
+
|
| 700 |
+
Ok((metadata, mint_data))
|
| 701 |
+
}
|
| 702 |
+
|
| 703 |
+
fn process_mint(&self, mint: &MintRow, tokens: &mut TokenCache) {
|
| 704 |
+
let is_new = !tokens.contains_key(&mint.mint_address);
|
| 705 |
+
let entry = tokens
|
| 706 |
+
.entry(mint.mint_address.clone())
|
| 707 |
+
.or_insert_with(|| TokenEntry::new(TokenStaticRow::new_from_mint(mint), None));
|
| 708 |
+
let token = &mut entry.token;
|
| 709 |
+
|
| 710 |
+
if is_new {
|
| 711 |
+
println!(
|
| 712 |
+
"[TokenAggregator] -> Created new token record for {} from MINT event.",
|
| 713 |
+
mint.mint_address
|
| 714 |
+
);
|
| 715 |
+
} else {
|
| 716 |
+
println!(
|
| 717 |
+
"[TokenAggregator] -> Enriched existing token record for {} with MINT event data.",
|
| 718 |
+
mint.mint_address
|
| 719 |
+
);
|
| 720 |
+
token.updated_at = mint.timestamp;
|
| 721 |
+
token.created_at = token.created_at.min(mint.timestamp);
|
| 722 |
+
token.decimals = mint.token_decimals;
|
| 723 |
+
token.launchpad = mint.protocol;
|
| 724 |
+
token.protocol = mint.protocol;
|
| 725 |
+
token.total_supply = mint.total_supply;
|
| 726 |
+
token.is_mutable = mint.is_mutable;
|
| 727 |
+
token.update_authority = mint.update_authority.clone();
|
| 728 |
+
token.mint_authority = mint.mint_authority.clone();
|
| 729 |
+
token.freeze_authority = mint.freeze_authority.clone();
|
| 730 |
+
if token.name.is_empty() {
|
| 731 |
+
token.name = mint.token_name.clone().unwrap_or_default();
|
| 732 |
+
}
|
| 733 |
+
if token.symbol.is_empty() {
|
| 734 |
+
token.symbol = mint.token_symbol.clone().unwrap_or_default();
|
| 735 |
+
}
|
| 736 |
+
if token.token_uri.is_empty() {
|
| 737 |
+
token.token_uri = mint.token_uri.clone().unwrap_or_default();
|
| 738 |
+
}
|
| 739 |
+
if token.creator_address.is_empty() {
|
| 740 |
+
token.creator_address = mint.creator_address.clone();
|
| 741 |
+
}
|
| 742 |
+
if !mint.pool_address.is_empty() && !token.pool_addresses.contains(&mint.pool_address) {
|
| 743 |
+
token.pool_addresses.push(mint.pool_address.clone());
|
| 744 |
+
}
|
| 745 |
+
}
|
| 746 |
+
}
|
| 747 |
+
|
| 748 |
+
fn process_migration(&self, migration: &MigrationRow, tokens: &mut TokenCache) {
|
| 749 |
+
if let Some(entry) = tokens.get_mut(&migration.mint_address) {
|
| 750 |
+
let token = &mut entry.token;
|
| 751 |
+
println!(
|
| 752 |
+
"[TokenAggregator] -> Updating protocol for token {} due to migration.",
|
| 753 |
+
migration.mint_address
|
| 754 |
+
);
|
| 755 |
+
token.updated_at = migration.timestamp;
|
| 756 |
+
token.protocol = migration.protocol;
|
| 757 |
+
if !token.pool_addresses.contains(&migration.pool_address) {
|
| 758 |
+
token.pool_addresses.push(migration.pool_address.clone());
|
| 759 |
+
}
|
| 760 |
+
}
|
| 761 |
+
}
|
| 762 |
+
|
| 763 |
+
async fn finalize_and_persist(&self, tokens: TokenCache) -> Result<()> {
|
| 764 |
+
if tokens.is_empty() {
|
| 765 |
+
return Ok(());
|
| 766 |
+
}
|
| 767 |
+
|
| 768 |
+
let mut updated_tokens = Vec::new();
|
| 769 |
+
let mut metric_rows = Vec::new();
|
| 770 |
+
|
| 771 |
+
for entry in tokens.into_values() {
|
| 772 |
+
if Self::metrics_has_activity(&entry.metrics) {
|
| 773 |
+
metric_rows.push(entry.metrics);
|
| 774 |
+
}
|
| 775 |
+
updated_tokens.push(entry.token);
|
| 776 |
+
}
|
| 777 |
+
|
| 778 |
+
insert_rows(
|
| 779 |
+
&self.db_client,
|
| 780 |
+
"tokens",
|
| 781 |
+
updated_tokens.clone(),
|
| 782 |
+
"Token Aggregator",
|
| 783 |
+
"tokens",
|
| 784 |
+
)
|
| 785 |
+
.await
|
| 786 |
+
.with_context(|| "Failed to persist token data to ClickHouse")?;
|
| 787 |
+
|
| 788 |
+
insert_rows(
|
| 789 |
+
&self.db_client,
|
| 790 |
+
"tokens_latest",
|
| 791 |
+
updated_tokens,
|
| 792 |
+
"Token Aggregator",
|
| 793 |
+
"tokens_latest",
|
| 794 |
+
)
|
| 795 |
+
.await
|
| 796 |
+
.with_context(|| "Failed to persist token snapshot data to ClickHouse")?;
|
| 797 |
+
|
| 798 |
+
insert_rows(
|
| 799 |
+
&self.db_client,
|
| 800 |
+
"token_metrics",
|
| 801 |
+
metric_rows.clone(),
|
| 802 |
+
"Token Aggregator",
|
| 803 |
+
"token_metrics",
|
| 804 |
+
)
|
| 805 |
+
.await
|
| 806 |
+
.with_context(|| "Failed to persist token metric history to ClickHouse")?;
|
| 807 |
+
|
| 808 |
+
insert_rows(
|
| 809 |
+
&self.db_client,
|
| 810 |
+
"token_metrics_latest",
|
| 811 |
+
metric_rows,
|
| 812 |
+
"Token Aggregator",
|
| 813 |
+
"token_metrics_latest",
|
| 814 |
+
)
|
| 815 |
+
.await
|
| 816 |
+
.with_context(|| "Failed to persist token metric snapshots to ClickHouse")?;
|
| 817 |
+
|
| 818 |
+
Ok(())
|
| 819 |
+
}
|
| 820 |
+
|
| 821 |
+
fn metrics_has_activity(metrics: &TokenMetricsRow) -> bool {
|
| 822 |
+
metrics.total_volume_usd > 0.0
|
| 823 |
+
|| metrics.total_buys > 0
|
| 824 |
+
|| metrics.total_sells > 0
|
| 825 |
+
|| metrics.unique_holders > 0
|
| 826 |
+
|| metrics.ath_price_usd > 0.0
|
| 827 |
+
}
|
| 828 |
+
|
| 829 |
+
async fn collect_events(
|
| 830 |
+
&mut self,
|
| 831 |
+
stream_key: &str,
|
| 832 |
+
group_name: &str,
|
| 833 |
+
consumer_name: &str,
|
| 834 |
+
) -> Result<Vec<(String, EventPayload)>> {
|
| 835 |
+
let opts = StreamReadOptions::default()
|
| 836 |
+
.group(group_name, consumer_name)
|
| 837 |
+
.count(1000)
|
| 838 |
+
.block(2000);
|
| 839 |
+
let reply: StreamReadReply = self
|
| 840 |
+
.redis_conn
|
| 841 |
+
.xread_options(&[stream_key], &[">"], &opts)
|
| 842 |
+
.await?;
|
| 843 |
+
let mut events = Vec::new();
|
| 844 |
+
for stream_entry in reply.keys {
|
| 845 |
+
for message in stream_entry.ids {
|
| 846 |
+
if let Some(payload_value) = message.map.get("payload") {
|
| 847 |
+
if let Ok(payload_bytes) = Vec::<u8>::from_redis_value(payload_value) {
|
| 848 |
+
if let Ok(payload) = bincode::deserialize::<EventPayload>(&payload_bytes) {
|
| 849 |
+
events.push((message.id.clone(), payload));
|
| 850 |
+
}
|
| 851 |
+
}
|
| 852 |
+
}
|
| 853 |
+
}
|
| 854 |
+
}
|
| 855 |
+
Ok(events)
|
| 856 |
+
}
|
| 857 |
+
}
|
train.py
CHANGED
|
@@ -427,6 +427,7 @@ def main() -> None:
|
|
| 427 |
|
| 428 |
# --- 7. Training Loop ---
|
| 429 |
total_steps = 0
|
|
|
|
| 430 |
|
| 431 |
logger.info("***** Running training *****")
|
| 432 |
logger.info(f" Num examples = {len(dataset)}")
|
|
@@ -470,8 +471,12 @@ def main() -> None:
|
|
| 470 |
outputs = model(batch)
|
| 471 |
|
| 472 |
preds = outputs["quantile_logits"]
|
|
|
|
| 473 |
labels = batch["labels"]
|
| 474 |
labels_mask = batch["labels_mask"]
|
|
|
|
|
|
|
|
|
|
| 475 |
if labels_mask is not None and labels_mask.sum().item() == 0:
|
| 476 |
token_addresses = batch.get('token_addresses', [])
|
| 477 |
t_cutoffs = batch.get('t_cutoffs', [])
|
|
@@ -482,11 +487,14 @@ def main() -> None:
|
|
| 482 |
token_addresses[0] if token_addresses else "unknown",
|
| 483 |
t_cutoffs[0] if t_cutoffs else "unknown",
|
| 484 |
)
|
| 485 |
-
|
| 486 |
if labels_mask.sum() == 0:
|
| 487 |
-
|
| 488 |
else:
|
| 489 |
-
|
|
|
|
|
|
|
|
|
|
| 490 |
|
| 491 |
accelerator.backward(loss)
|
| 492 |
|
|
@@ -519,6 +527,8 @@ def main() -> None:
|
|
| 519 |
log_debug_batch_context(batch, logger, total_steps)
|
| 520 |
|
| 521 |
current_loss = loss.item()
|
|
|
|
|
|
|
| 522 |
epoch_loss += current_loss
|
| 523 |
valid_batches += 1
|
| 524 |
|
|
@@ -526,6 +536,8 @@ def main() -> None:
|
|
| 526 |
lr = scheduler.get_last_lr()[0]
|
| 527 |
log_payload = {
|
| 528 |
"train/loss": current_loss,
|
|
|
|
|
|
|
| 529 |
"train/learning_rate": lr,
|
| 530 |
"train/epoch": epoch + (step / len(dataloader))
|
| 531 |
}
|
|
|
|
| 427 |
|
| 428 |
# --- 7. Training Loop ---
|
| 429 |
total_steps = 0
|
| 430 |
+
quality_loss_fn = nn.MSELoss()
|
| 431 |
|
| 432 |
logger.info("***** Running training *****")
|
| 433 |
logger.info(f" Num examples = {len(dataset)}")
|
|
|
|
| 471 |
outputs = model(batch)
|
| 472 |
|
| 473 |
preds = outputs["quantile_logits"]
|
| 474 |
+
quality_preds = outputs["quality_logits"]
|
| 475 |
labels = batch["labels"]
|
| 476 |
labels_mask = batch["labels_mask"]
|
| 477 |
+
if "quality_score" not in batch:
|
| 478 |
+
raise RuntimeError("FATAL: quality_score missing from batch. Cannot train quality head.")
|
| 479 |
+
quality_targets = batch["quality_score"].to(accelerator.device, dtype=quality_preds.dtype)
|
| 480 |
if labels_mask is not None and labels_mask.sum().item() == 0:
|
| 481 |
token_addresses = batch.get('token_addresses', [])
|
| 482 |
t_cutoffs = batch.get('t_cutoffs', [])
|
|
|
|
| 487 |
token_addresses[0] if token_addresses else "unknown",
|
| 488 |
t_cutoffs[0] if t_cutoffs else "unknown",
|
| 489 |
)
|
| 490 |
+
|
| 491 |
if labels_mask.sum() == 0:
|
| 492 |
+
return_loss = torch.tensor(0.0, requires_grad=True, device=accelerator.device)
|
| 493 |
else:
|
| 494 |
+
return_loss = quantile_pinball_loss(preds, labels, labels_mask, quantiles)
|
| 495 |
+
|
| 496 |
+
quality_loss = quality_loss_fn(quality_preds, quality_targets)
|
| 497 |
+
loss = return_loss + quality_loss
|
| 498 |
|
| 499 |
accelerator.backward(loss)
|
| 500 |
|
|
|
|
| 527 |
log_debug_batch_context(batch, logger, total_steps)
|
| 528 |
|
| 529 |
current_loss = loss.item()
|
| 530 |
+
current_return_loss = return_loss.item()
|
| 531 |
+
current_quality_loss = quality_loss.item()
|
| 532 |
epoch_loss += current_loss
|
| 533 |
valid_batches += 1
|
| 534 |
|
|
|
|
| 536 |
lr = scheduler.get_last_lr()[0]
|
| 537 |
log_payload = {
|
| 538 |
"train/loss": current_loss,
|
| 539 |
+
"train/return_loss": current_return_loss,
|
| 540 |
+
"train/quality_loss": current_quality_loss,
|
| 541 |
"train/learning_rate": lr,
|
| 542 |
"train/epoch": epoch + (step / len(dataloader))
|
| 543 |
}
|
train.sh
CHANGED
|
@@ -11,7 +11,7 @@ accelerate launch train.py \
|
|
| 11 |
--tensorboard_dir runs/oracle \
|
| 12 |
--checkpoint_dir checkpoints \
|
| 13 |
--mixed_precision bf16 \
|
| 14 |
-
--max_seq_len
|
| 15 |
--horizons_seconds 60 180 300 600 1800 3600 7200 \
|
| 16 |
--quantiles 0.1 0.5 0.9 \
|
| 17 |
--ohlc_stats_path ./data/ohlc_stats.npz \
|
|
|
|
| 11 |
--tensorboard_dir runs/oracle \
|
| 12 |
--checkpoint_dir checkpoints \
|
| 13 |
--mixed_precision bf16 \
|
| 14 |
+
--max_seq_len 4096 \
|
| 15 |
--horizons_seconds 60 180 300 600 1800 3600 7200 \
|
| 16 |
--quantiles 0.1 0.5 0.9 \
|
| 17 |
--ohlc_stats_path ./data/ohlc_stats.npz \
|