Update scripts/cache_dataset.py: speed + balance + correctness fixes
Browse files- scripts/cache_dataset.py +95 -11
scripts/cache_dataset.py
CHANGED
|
@@ -146,6 +146,8 @@ def main():
|
|
| 146 |
parser.add_argument("--context_length", type=int, default=8192)
|
| 147 |
parser.add_argument("--min_trades", type=int, default=10)
|
| 148 |
parser.add_argument("--samples_per_token", type=int, default=1)
|
|
|
|
|
|
|
| 149 |
parser.add_argument("--num_workers", type=int, default=1)
|
| 150 |
parser.add_argument("--clickhouse_host", type=str, default=os.getenv("CLICKHOUSE_HOST", "localhost"))
|
| 151 |
parser.add_argument("--clickhouse_port", type=int, default=int(os.getenv("CLICKHOUSE_PORT", 9000)))
|
|
@@ -180,7 +182,7 @@ def main():
|
|
| 180 |
quality_scores_map = get_token_quality_scores(clickhouse_client)
|
| 181 |
print(f"INFO: Loaded {len(quality_scores_map)} quality scores.")
|
| 182 |
|
| 183 |
-
dataset = OracleDataset(data_fetcher=data_fetcher, max_samples=args.max_samples, start_date=start_date_dt, ohlc_stats_path=args.ohlc_stats_path, horizons_seconds=
|
| 184 |
|
| 185 |
if len(dataset) == 0:
|
| 186 |
print("WARNING: No samples. Exiting.")
|
|
@@ -189,7 +191,25 @@ def main():
|
|
| 189 |
# Filter mints by return_class_map
|
| 190 |
original_size = len(dataset.sampled_mints)
|
| 191 |
filtered_mints = [m for m in dataset.sampled_mints if m['mint_address'] in return_class_map]
|
| 192 |
-
print(f"INFO: Filtered {original_size} -> {len(filtered_mints)} tokens")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
if len(filtered_mints) == 0:
|
| 195 |
print("WARNING: No tokens after filtering.")
|
|
@@ -198,16 +218,69 @@ def main():
|
|
| 198 |
print(f"INFO: Cache mode: {args.cache_mode}, Workers: {args.num_workers}")
|
| 199 |
|
| 200 |
db_config = {'clickhouse_host': args.clickhouse_host, 'clickhouse_port': args.clickhouse_port, 'neo4j_uri': args.neo4j_uri, 'neo4j_user': args.neo4j_user, 'neo4j_password': args.neo4j_password}
|
| 201 |
-
dataset_config = {'max_samples': args.max_samples, 'start_date': start_date_dt, 'ohlc_stats_path': args.ohlc_stats_path, 'horizons_seconds':
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
-
# Build
|
| 204 |
tasks = []
|
| 205 |
-
for
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
print(f"INFO: Starting to cache {len(tasks)} tokens...")
|
| 213 |
|
|
@@ -255,7 +328,18 @@ def main():
|
|
| 255 |
pass
|
| 256 |
|
| 257 |
with open(output_dir / "class_metadata.json", 'w') as f:
|
| 258 |
-
json.dump({
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
|
| 260 |
print(f"\n--- Done ---\nSuccess: {success_count}, Skipped: {skipped_count}, Errors: {error_count}\nFiles: {len(file_class_map)}\nLocation: {output_dir.resolve()}")
|
| 261 |
|
|
|
|
| 146 |
parser.add_argument("--context_length", type=int, default=8192)
|
| 147 |
parser.add_argument("--min_trades", type=int, default=10)
|
| 148 |
parser.add_argument("--samples_per_token", type=int, default=1)
|
| 149 |
+
parser.add_argument("--horizons_seconds", type=int, nargs="+", default=[30, 60, 120, 240, 420])
|
| 150 |
+
parser.add_argument("--quantiles", type=float, nargs="+", default=[0.1, 0.5, 0.9])
|
| 151 |
parser.add_argument("--num_workers", type=int, default=1)
|
| 152 |
parser.add_argument("--clickhouse_host", type=str, default=os.getenv("CLICKHOUSE_HOST", "localhost"))
|
| 153 |
parser.add_argument("--clickhouse_port", type=int, default=int(os.getenv("CLICKHOUSE_PORT", 9000)))
|
|
|
|
| 182 |
quality_scores_map = get_token_quality_scores(clickhouse_client)
|
| 183 |
print(f"INFO: Loaded {len(quality_scores_map)} quality scores.")
|
| 184 |
|
| 185 |
+
dataset = OracleDataset(data_fetcher=data_fetcher, max_samples=args.max_samples, start_date=start_date_dt, ohlc_stats_path=args.ohlc_stats_path, horizons_seconds=args.horizons_seconds, quantiles=args.quantiles, min_trade_usd=args.min_trade_usd, max_seq_len=args.context_length)
|
| 186 |
|
| 187 |
if len(dataset) == 0:
|
| 188 |
print("WARNING: No samples. Exiting.")
|
|
|
|
| 191 |
# Filter mints by return_class_map
|
| 192 |
original_size = len(dataset.sampled_mints)
|
| 193 |
filtered_mints = [m for m in dataset.sampled_mints if m['mint_address'] in return_class_map]
|
| 194 |
+
print(f"INFO: Filtered by class map: {original_size} -> {len(filtered_mints)} tokens")
|
| 195 |
+
|
| 196 |
+
# Pre-filter: only keep tokens with >= min_trades trades (fast ClickHouse count query)
|
| 197 |
+
print(f"INFO: Pre-filtering tokens by trade count (>= {args.min_trades} trades)...")
|
| 198 |
+
trade_counts = clickhouse_client.execute("""
|
| 199 |
+
SELECT base_address, count() as cnt
|
| 200 |
+
FROM trades
|
| 201 |
+
GROUP BY base_address
|
| 202 |
+
HAVING cnt >= %(min_trades)s
|
| 203 |
+
""", {'min_trades': args.min_trades})
|
| 204 |
+
valid_tokens = {row[0] for row in trade_counts}
|
| 205 |
+
pre_filter_size = len(filtered_mints)
|
| 206 |
+
filtered_mints = [m for m in filtered_mints if m['mint_address'] in valid_tokens]
|
| 207 |
+
print(f"INFO: Pre-filtered by trade count: {pre_filter_size} -> {len(filtered_mints)} tokens (removed {pre_filter_size - len(filtered_mints)} with < {args.min_trades} trades)")
|
| 208 |
+
|
| 209 |
+
# Also filter by quality score availability
|
| 210 |
+
pre_quality_size = len(filtered_mints)
|
| 211 |
+
filtered_mints = [m for m in filtered_mints if m['mint_address'] in quality_scores_map]
|
| 212 |
+
print(f"INFO: Filtered by quality score: {pre_quality_size} -> {len(filtered_mints)} tokens")
|
| 213 |
|
| 214 |
if len(filtered_mints) == 0:
|
| 215 |
print("WARNING: No tokens after filtering.")
|
|
|
|
| 218 |
print(f"INFO: Cache mode: {args.cache_mode}, Workers: {args.num_workers}")
|
| 219 |
|
| 220 |
db_config = {'clickhouse_host': args.clickhouse_host, 'clickhouse_port': args.clickhouse_port, 'neo4j_uri': args.neo4j_uri, 'neo4j_user': args.neo4j_user, 'neo4j_password': args.neo4j_password}
|
| 221 |
+
dataset_config = {'max_samples': args.max_samples, 'start_date': start_date_dt, 'ohlc_stats_path': args.ohlc_stats_path, 'horizons_seconds': args.horizons_seconds, 'quantiles': args.quantiles, 'min_trade_usd': args.min_trade_usd, 'max_seq_len': args.context_length, 'sampled_mints': filtered_mints}
|
| 222 |
+
|
| 223 |
+
# Build tasks with class-aware multi-sampling for balanced cache
|
| 224 |
+
import random
|
| 225 |
+
from collections import Counter, defaultdict
|
| 226 |
+
|
| 227 |
+
# Count eligible tokens per class
|
| 228 |
+
eligible_class_counts = Counter()
|
| 229 |
+
mints_by_class = defaultdict(list)
|
| 230 |
+
for i, m in enumerate(filtered_mints):
|
| 231 |
+
cid = return_class_map.get(m['mint_address'])
|
| 232 |
+
if cid is not None:
|
| 233 |
+
eligible_class_counts[cid] += 1
|
| 234 |
+
mints_by_class[cid].append((i, m))
|
| 235 |
+
|
| 236 |
+
print(f"INFO: Eligible tokens per class: {dict(sorted(eligible_class_counts.items()))}")
|
| 237 |
+
|
| 238 |
+
# Compute balanced samples_per_token for each class
|
| 239 |
+
num_classes = len(eligible_class_counts)
|
| 240 |
+
if args.max_samples:
|
| 241 |
+
target_total = args.max_samples
|
| 242 |
+
else:
|
| 243 |
+
target_total = 15000 # Default target: 15k balanced files
|
| 244 |
+
target_per_class = target_total // max(num_classes, 1)
|
| 245 |
+
|
| 246 |
+
class_multipliers = {}
|
| 247 |
+
class_token_caps = {}
|
| 248 |
+
for cid, count in eligible_class_counts.items():
|
| 249 |
+
if count >= target_per_class:
|
| 250 |
+
# Enough tokens — 1 sample each, cap token count
|
| 251 |
+
class_multipliers[cid] = 1
|
| 252 |
+
class_token_caps[cid] = target_per_class
|
| 253 |
+
else:
|
| 254 |
+
# Not enough tokens — multi-sample, use all tokens
|
| 255 |
+
class_multipliers[cid] = min(10, max(1, math.ceil(target_per_class / max(count, 1))))
|
| 256 |
+
class_token_caps[cid] = count
|
| 257 |
+
|
| 258 |
+
print(f"INFO: Target total: {target_total}, Target per class: {target_per_class}")
|
| 259 |
+
print(f"INFO: Class multipliers: {dict(sorted(class_multipliers.items()))}")
|
| 260 |
+
print(f"INFO: Class token caps: {dict(sorted(class_token_caps.items()))}")
|
| 261 |
|
| 262 |
+
# Build balanced task list
|
| 263 |
tasks = []
|
| 264 |
+
for cid, mint_list in mints_by_class.items():
|
| 265 |
+
random.shuffle(mint_list)
|
| 266 |
+
cap = class_token_caps.get(cid, len(mint_list))
|
| 267 |
+
spt = class_multipliers.get(cid, 1)
|
| 268 |
+
# Override with CLI --samples_per_token if explicitly set > 1
|
| 269 |
+
if args.samples_per_token > 1:
|
| 270 |
+
spt = args.samples_per_token
|
| 271 |
+
for i, m in mint_list[:cap]:
|
| 272 |
+
mint_addr = m['mint_address']
|
| 273 |
+
if args.cache_mode == "context":
|
| 274 |
+
tasks.append((i, mint_addr, spt, str(output_dir)))
|
| 275 |
+
else:
|
| 276 |
+
tasks.append((i, mint_addr, str(output_dir)))
|
| 277 |
+
|
| 278 |
+
random.shuffle(tasks) # Shuffle tasks for even load distribution across workers
|
| 279 |
+
expected_files = sum(
|
| 280 |
+
class_multipliers.get(cid, 1) * min(class_token_caps.get(cid, len(ml)), len(ml))
|
| 281 |
+
for cid, ml in mints_by_class.items()
|
| 282 |
+
)
|
| 283 |
+
print(f"INFO: Total tasks: {len(tasks)} (expected ~{expected_files} output files, target ~{target_total})")
|
| 284 |
|
| 285 |
print(f"INFO: Starting to cache {len(tasks)} tokens...")
|
| 286 |
|
|
|
|
| 328 |
pass
|
| 329 |
|
| 330 |
with open(output_dir / "class_metadata.json", 'w') as f:
|
| 331 |
+
json.dump({
|
| 332 |
+
'file_class_map': file_class_map,
|
| 333 |
+
'class_distribution': {str(k): v for k, v in class_distribution.items()},
|
| 334 |
+
'cache_mode': args.cache_mode,
|
| 335 |
+
'num_workers': args.num_workers,
|
| 336 |
+
'horizons_seconds': args.horizons_seconds,
|
| 337 |
+
'quantiles': args.quantiles,
|
| 338 |
+
'class_multipliers': {str(k): v for k, v in class_multipliers.items()},
|
| 339 |
+
'class_token_caps': {str(k): v for k, v in class_token_caps.items()},
|
| 340 |
+
'target_total': target_total,
|
| 341 |
+
'target_per_class': target_per_class,
|
| 342 |
+
}, f, indent=2)
|
| 343 |
|
| 344 |
print(f"\n--- Done ---\nSuccess: {success_count}, Skipped: {skipped_count}, Errors: {error_count}\nFiles: {len(file_class_map)}\nLocation: {output_dir.resolve()}")
|
| 345 |
|