Upload folder using huggingface_hub
Browse files- README.md +67 -4
- data/data_collator.py +3 -1
- data/data_loader.py +194 -50
- data/ohlc_stats.npz +1 -1
- log.log +2 -2
- train.py +27 -21
- train.sh +1 -4
README.md
CHANGED
|
@@ -36,7 +36,70 @@ Launch training with updated hyperparameters.
|
|
| 36 |
./train.sh
|
| 37 |
```
|
| 38 |
|
| 39 |
-
##
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
./train.sh
|
| 37 |
```
|
| 38 |
|
| 39 |
+
## TODO: Future Enhancements
|
| 40 |
+
|
| 41 |
+
### Multi-Task Quality Prediction Head
|
| 42 |
+
Add a secondary head (Head B) that predicts **token quality percentiles** alongside price returns:
|
| 43 |
+
- **Fees Percentile** — Predicted future fees relative to class median
|
| 44 |
+
- **Volume Percentile** — Predicted future volume relative to class median
|
| 45 |
+
- **Holders Percentile** — Predicted future holder count relative to class median
|
| 46 |
+
|
| 47 |
+
**Rationale:** The `analyze_distribution.py` script currently uses hard thresholds on future metrics to classify tokens as "Manipulated". This head would let the model **learn to predict** those quality metrics from current features, enabling scam detection at inference time without access to future data.
|
| 48 |
+
|
| 49 |
+
**Approach Options:**
|
| 50 |
+
1. Single composite quality score (simpler)
|
| 51 |
+
2. Three separate percentile predictions (more interpretable)
|
| 52 |
+
3. Three binary classifications (fees_ok, volume_ok, holders_ok)
|
| 53 |
+
|
| 54 |
+
Data Sampling (Context Optimization)
|
| 55 |
+
Replace hardcoded H/B/H limits with a dynamic sampling strategy that maximizes the model's context window usage.
|
| 56 |
+
|
| 57 |
+
The Problem
|
| 58 |
+
Currently, the system triggers H/B/H logic based on a fixed 30k trade count and uses hardcoded limits (10k early, 15k recent). This mismatch with the model's max_seq_len (e.g., 8192) leads to inefficient data usage—either truncating valuable data arbitrarily or feeding too little when more could fit.
|
| 59 |
+
|
| 60 |
+
The Solution: Dynamic Context Filling
|
| 61 |
+
Implementation moves to
|
| 62 |
+
data_loader.py
|
| 63 |
+
(since cache contains full history).
|
| 64 |
+
|
| 65 |
+
Algorithm
|
| 66 |
+
Input: Full sorted list of events (Trades, Chart Segments, etc.) up to T_cutoff.
|
| 67 |
+
Check: if
|
| 68 |
+
len(events) <= max_seq_len
|
| 69 |
+
, use ALL events.
|
| 70 |
+
Split: If
|
| 71 |
+
len(events) > max_seq_len
|
| 72 |
+
:
|
| 73 |
+
Reserve space for special tokens (start/end/pad).
|
| 74 |
+
Calculate Budget: budget = max_seq_len - reserve (e.g., 8100).
|
| 75 |
+
Dynamic Split:
|
| 76 |
+
Head (Early): First budget / 2 events.
|
| 77 |
+
Tail (Recent): Last budget / 2 events.
|
| 78 |
+
Construct: [HEAD] ... [GAP_TOKEN] ... [TAIL].
|
| 79 |
+
Implementation Changes
|
| 80 |
+
[MODIFY]
|
| 81 |
+
data_loader.py
|
| 82 |
+
Remove Constants: Delete HBH_EARLY_EVENT_LIMIT, HBH_RECENT_EVENT_LIMIT.
|
| 83 |
+
Update
|
| 84 |
+
_generate_dataset_item
|
| 85 |
+
:
|
| 86 |
+
Accept max_seq_len.
|
| 87 |
+
Implement the split logic defined above before returning event_sequence.
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
Here explained easly:
|
| 93 |
+
|
| 94 |
+
We check all the final events if exeed the total context we have.
|
| 95 |
+
Then we filter out all the trade events and then check how many non aggregable events we have, for example a burn or a deployer trade etc...
|
| 96 |
+
Then we take the remaining from context exldued thosoe IMPORTANT events like i show above and we check how many snapshot will fit chart segment, holders snapshot, chain stats etc...
|
| 97 |
+
Then the remaining after snapshot and important non aggregable events we use them to make the H segments (high definition) and in the middle (Blurry) we keep just the snapshots.
|
| 98 |
+
|
| 99 |
+
This works because 90% of context is taken just by trades and transfers so they are the only thing to compress to help context
|
| 100 |
+
|
| 101 |
+
you dont need new tokens becuase there are already special tokens for it:
|
| 102 |
+
'MIDDLE',
|
| 103 |
+
'RECENT'
|
| 104 |
+
|
| 105 |
+
so when you switch to blurry <MIDDLE> and when you go back to high definition you use <RECENT>
|
data/data_collator.py
CHANGED
|
@@ -710,7 +710,9 @@ class MemecoinCollator:
|
|
| 710 |
'textual_event_data': textual_event_data_list, # RENAMED
|
| 711 |
# Labels
|
| 712 |
'labels': torch.stack([item['labels'] for item in batch]) if batch and 'labels' in batch[0] else None,
|
| 713 |
-
'labels_mask': torch.stack([item['labels_mask'] for item in batch]) if batch and 'labels_mask' in batch[0] else None
|
|
|
|
|
|
|
| 714 |
}
|
| 715 |
|
| 716 |
# Filter out None values (e.g., if no labels provided)
|
|
|
|
| 710 |
'textual_event_data': textual_event_data_list, # RENAMED
|
| 711 |
# Labels
|
| 712 |
'labels': torch.stack([item['labels'] for item in batch]) if batch and 'labels' in batch[0] else None,
|
| 713 |
+
'labels_mask': torch.stack([item['labels_mask'] for item in batch]) if batch and 'labels_mask' in batch[0] else None,
|
| 714 |
+
# Debug info
|
| 715 |
+
'token_addresses': [item.get('token_address', 'unknown') for item in batch]
|
| 716 |
}
|
| 717 |
|
| 718 |
# Filter out None values (e.g., if no labels provided)
|
data/data_loader.py
CHANGED
|
@@ -33,10 +33,25 @@ LARGE_TRANSFER_SUPPLY_PCT_THRESHOLD = 0.03 # 3% of supply
|
|
| 33 |
SMART_WALLET_PNL_THRESHOLD = 3.0 # 300% PNL
|
| 34 |
SMART_WALLET_USD_THRESHOLD = 20000.0
|
| 35 |
|
| 36 |
-
# ---
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
# --- NEW: OHLC Sequence Length Constant ---
|
| 42 |
OHLC_SEQ_LEN = 300 # 4 minutes of chart
|
|
@@ -107,7 +122,10 @@ class OracleDataset(Dataset):
|
|
| 107 |
t_cutoff_seconds: int = 60,
|
| 108 |
cache_dir: Optional[Union[str, Path]] = None,
|
| 109 |
start_date: Optional[datetime.datetime] = None,
|
| 110 |
-
min_trade_usd: float = 0.0
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
# --- NEW: Create a persistent requests session for efficiency ---
|
| 113 |
# Configure robust HTTP session
|
|
@@ -261,6 +279,90 @@ class OracleDataset(Dataset):
|
|
| 261 |
denom = self.ohlc_price_std if abs(self.ohlc_price_std) > 1e-9 else 1.0
|
| 262 |
return [(float(v) - self.ohlc_price_mean) / denom for v in values]
|
| 263 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
def _compute_future_return_labels(self,
|
| 265 |
anchor_price: Optional[float],
|
| 266 |
anchor_timestamp: int,
|
|
@@ -830,9 +932,11 @@ class OracleDataset(Dataset):
|
|
| 830 |
raw_data = torch.load(filepath, map_location='cpu', weights_only=False)
|
| 831 |
except Exception as e:
|
| 832 |
raise RuntimeError(f"ERROR: Could not load cached item {filepath}: {e}")
|
|
|
|
|
|
|
| 833 |
else:
|
| 834 |
-
#
|
| 835 |
-
|
| 836 |
|
| 837 |
if not raw_data:
|
| 838 |
raise RuntimeError(f"No raw data loaded for index {idx}")
|
|
@@ -882,19 +986,31 @@ class OracleDataset(Dataset):
|
|
| 882 |
preferred_horizon = horizons[1] if len(horizons) > 1 else min_label
|
| 883 |
|
| 884 |
mint_ts_value = _timestamp_to_order_value(mint_timestamp)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 885 |
trade_ts_values = [
|
| 886 |
_timestamp_to_order_value(trade.get('timestamp'))
|
| 887 |
-
for trade in
|
| 888 |
-
if trade.get('timestamp') is not None
|
| 889 |
]
|
| 890 |
-
|
|
|
|
|
|
|
| 891 |
return None
|
| 892 |
|
| 893 |
-
# Cache guarantees min_trades=25, so we proceed assuming valid data.
|
| 894 |
-
# But for safety in dynamic sampling:
|
| 895 |
-
if not trade_ts_values:
|
| 896 |
-
return None
|
| 897 |
-
|
| 898 |
# Sort trades to find the 24th trade timestamp
|
| 899 |
sorted_trades_ts = sorted(trade_ts_values)
|
| 900 |
|
|
@@ -1057,10 +1173,10 @@ class OracleDataset(Dataset):
|
|
| 1057 |
max_horizon_seconds=self.max_cache_horizon_seconds,
|
| 1058 |
include_wallet_data=False,
|
| 1059 |
include_graph=False,
|
| 1060 |
-
min_trades=
|
| 1061 |
full_history=True, # Bypass H/B/H limits
|
| 1062 |
-
prune_failed=
|
| 1063 |
-
prune_transfers=
|
| 1064 |
)
|
| 1065 |
if raw_data is None:
|
| 1066 |
return None
|
|
@@ -1447,9 +1563,12 @@ class OracleDataset(Dataset):
|
|
| 1447 |
cached_holders_list=cached_holders_list
|
| 1448 |
)
|
| 1449 |
|
| 1450 |
-
# 7. Finalize Sequence
|
| 1451 |
event_sequence_entries.sort(key=lambda x: x[0])
|
| 1452 |
-
|
|
|
|
|
|
|
|
|
|
| 1453 |
|
| 1454 |
# 8. Compute Labels using future data
|
| 1455 |
# Define horizons (e.g., [60, 120, ...])
|
|
@@ -1459,7 +1578,31 @@ class OracleDataset(Dataset):
|
|
| 1459 |
# Note: future_trades_for_labels contains ALL trades (past & future relative to T_cutoff)
|
| 1460 |
# We need to find the price at T_cutoff and at T_cutoff + h
|
| 1461 |
|
| 1462 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1463 |
# Ensure sorted
|
| 1464 |
all_trades.sort(key=lambda x: _timestamp_to_order_value(x['timestamp']))
|
| 1465 |
|
|
@@ -1483,40 +1626,41 @@ class OracleDataset(Dataset):
|
|
| 1483 |
label_values = []
|
| 1484 |
mask_values = []
|
| 1485 |
|
| 1486 |
-
|
| 1487 |
-
|
| 1488 |
-
|
| 1489 |
-
|
| 1490 |
-
|
| 1491 |
-
|
| 1492 |
-
|
| 1493 |
-
|
| 1494 |
-
|
| 1495 |
-
# Find price at target_ts
|
| 1496 |
-
# It is the last trade strictly before or at target_ts
|
| 1497 |
-
future_price = current_price # Default to current if no trades found in window? Unlikely if checked range.
|
| 1498 |
-
|
| 1499 |
-
# Check trades between current_idx and target
|
| 1500 |
-
# Optimization: start search from current_price_idx
|
| 1501 |
-
found_future = False
|
| 1502 |
-
for j in range(current_price_idx, len(all_trades)):
|
| 1503 |
-
t = all_trades[j]
|
| 1504 |
-
t_ts = _timestamp_to_order_value(t['timestamp'])
|
| 1505 |
-
if t_ts <= target_ts:
|
| 1506 |
-
future_price = float(t['price_usd'])
|
| 1507 |
-
found_future = True
|
| 1508 |
-
else:
|
| 1509 |
-
break # Optimization: surpassed target_ts
|
| 1510 |
|
| 1511 |
-
if
|
| 1512 |
-
|
|
|
|
|
|
|
|
|
|
| 1513 |
else:
|
| 1514 |
-
|
|
|
|
|
|
|
| 1515 |
|
| 1516 |
-
|
| 1517 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1518 |
|
| 1519 |
return {
|
|
|
|
| 1520 |
'event_sequence': event_sequence,
|
| 1521 |
'wallets': wallet_data,
|
| 1522 |
'tokens': all_token_data,
|
|
|
|
| 33 |
SMART_WALLET_PNL_THRESHOLD = 3.0 # 300% PNL
|
| 34 |
SMART_WALLET_USD_THRESHOLD = 20000.0
|
| 35 |
|
| 36 |
+
# --- Event Categorization for Dynamic Sampling ---
|
| 37 |
+
# Events that are rare and should ALWAYS be kept
|
| 38 |
+
CRITICAL_EVENTS = {
|
| 39 |
+
'Mint', 'Deployer_Trade', 'SmartWallet_Trade', 'LargeTrade', 'LargeTransfer',
|
| 40 |
+
'TokenBurn', 'SupplyLock', 'PoolCreated', 'LiquidityChange', 'Migrated',
|
| 41 |
+
'FeeCollected', 'TrendingToken', 'BoostedToken', 'XPost', 'XRetweet',
|
| 42 |
+
'XReply', 'XQuoteTweet', 'PumpReply', 'DexBoost_Paid', 'DexProfile_Updated',
|
| 43 |
+
'AlphaGroup_Call', 'Channel_Call', 'CexListing', 'TikTok_Trending_Hashtag',
|
| 44 |
+
'XTrending_Hashtag'
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
# Periodic snapshots - kept for context continuity
|
| 48 |
+
SNAPSHOT_EVENTS = {
|
| 49 |
+
'Chart_Segment', 'OnChain_Snapshot', 'HolderSnapshot',
|
| 50 |
+
'ChainSnapshot', 'Lighthouse_Snapshot'
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
# High-volume events that can be compressed (Head/Tail)
|
| 54 |
+
COMPRESSIBLE_EVENTS = {'Trade', 'Transfer'}
|
| 55 |
|
| 56 |
# --- NEW: OHLC Sequence Length Constant ---
|
| 57 |
OHLC_SEQ_LEN = 300 # 4 minutes of chart
|
|
|
|
| 122 |
t_cutoff_seconds: int = 60,
|
| 123 |
cache_dir: Optional[Union[str, Path]] = None,
|
| 124 |
start_date: Optional[datetime.datetime] = None,
|
| 125 |
+
min_trade_usd: float = 0.0,
|
| 126 |
+
max_seq_len: int = 8192):
|
| 127 |
+
|
| 128 |
+
self.max_seq_len = max_seq_len
|
| 129 |
|
| 130 |
# --- NEW: Create a persistent requests session for efficiency ---
|
| 131 |
# Configure robust HTTP session
|
|
|
|
| 279 |
denom = self.ohlc_price_std if abs(self.ohlc_price_std) > 1e-9 else 1.0
|
| 280 |
return [(float(v) - self.ohlc_price_mean) / denom for v in values]
|
| 281 |
|
| 282 |
+
def _apply_dynamic_sampling(self, events: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 283 |
+
"""
|
| 284 |
+
Applies dynamic context sampling to fit events within max_seq_len.
|
| 285 |
+
|
| 286 |
+
Priority:
|
| 287 |
+
1. CRITICAL events (always kept)
|
| 288 |
+
2. SNAPSHOT events (kept for continuity)
|
| 289 |
+
3. COMPRESSIBLE events (Trade/Transfer) - split into Head/Tail with MIDDLE token
|
| 290 |
+
|
| 291 |
+
Uses existing 'MIDDLE' and 'RECENT' tokens to mark transitions.
|
| 292 |
+
"""
|
| 293 |
+
if len(events) <= self.max_seq_len:
|
| 294 |
+
return events
|
| 295 |
+
|
| 296 |
+
# Categorize events by type
|
| 297 |
+
critical_events = [] # (original_idx, event)
|
| 298 |
+
snapshot_events = []
|
| 299 |
+
compressible_events = []
|
| 300 |
+
|
| 301 |
+
for idx, event in enumerate(events):
|
| 302 |
+
event_type = event.get('event_type', '')
|
| 303 |
+
if event_type in CRITICAL_EVENTS:
|
| 304 |
+
critical_events.append((idx, event))
|
| 305 |
+
elif event_type in SNAPSHOT_EVENTS:
|
| 306 |
+
snapshot_events.append((idx, event))
|
| 307 |
+
elif event_type in COMPRESSIBLE_EVENTS:
|
| 308 |
+
compressible_events.append((idx, event))
|
| 309 |
+
else:
|
| 310 |
+
# Unknown event types go to critical (safe default)
|
| 311 |
+
critical_events.append((idx, event))
|
| 312 |
+
|
| 313 |
+
# Calculate budget for compressible events
|
| 314 |
+
# Reserve 2 tokens for MIDDLE and RECENT markers
|
| 315 |
+
reserved_tokens = 2
|
| 316 |
+
fixed_count = len(critical_events) + len(snapshot_events) + reserved_tokens
|
| 317 |
+
budget_for_compressible = max(0, self.max_seq_len - fixed_count)
|
| 318 |
+
|
| 319 |
+
# If no budget for compressible, just return critical + snapshots
|
| 320 |
+
if budget_for_compressible == 0 or len(compressible_events) <= budget_for_compressible:
|
| 321 |
+
# All compressible fit, just return sorted
|
| 322 |
+
all_events = critical_events + snapshot_events + compressible_events
|
| 323 |
+
all_events.sort(key=lambda x: x[0])
|
| 324 |
+
return [e[1] for e in all_events]
|
| 325 |
+
|
| 326 |
+
# Apply Head/Tail split for compressible events
|
| 327 |
+
head_size = budget_for_compressible // 2
|
| 328 |
+
tail_size = budget_for_compressible - head_size
|
| 329 |
+
|
| 330 |
+
head_events = compressible_events[:head_size]
|
| 331 |
+
tail_events = compressible_events[-tail_size:] if tail_size > 0 else []
|
| 332 |
+
|
| 333 |
+
# Find the timestamp boundary for MIDDLE/RECENT markers
|
| 334 |
+
# MIDDLE goes after head, RECENT goes before tail
|
| 335 |
+
middle_marker_idx = head_events[-1][0] if head_events else 0
|
| 336 |
+
recent_marker_idx = tail_events[0][0] if tail_events else len(events)
|
| 337 |
+
|
| 338 |
+
# Create marker events
|
| 339 |
+
middle_marker = {
|
| 340 |
+
'event_type': 'MIDDLE',
|
| 341 |
+
'relative_ts': events[middle_marker_idx].get('relative_ts', 0) if middle_marker_idx < len(events) else 0,
|
| 342 |
+
'is_marker': True
|
| 343 |
+
}
|
| 344 |
+
recent_marker = {
|
| 345 |
+
'event_type': 'RECENT',
|
| 346 |
+
'relative_ts': events[recent_marker_idx - 1].get('relative_ts', 0) if recent_marker_idx > 0 and recent_marker_idx <= len(events) else 0,
|
| 347 |
+
'is_marker': True
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
# Combine all events with markers
|
| 351 |
+
# We need to maintain chronological order
|
| 352 |
+
all_indexed_events = critical_events + snapshot_events + head_events + tail_events
|
| 353 |
+
|
| 354 |
+
# Add markers with synthetic indices
|
| 355 |
+
middle_idx = middle_marker_idx + 0.5 # After last head event
|
| 356 |
+
recent_idx = recent_marker_idx - 0.5 # Before first tail event
|
| 357 |
+
|
| 358 |
+
all_indexed_events.append((middle_idx, middle_marker))
|
| 359 |
+
all_indexed_events.append((recent_idx, recent_marker))
|
| 360 |
+
|
| 361 |
+
# Sort by original index to maintain chronological order
|
| 362 |
+
all_indexed_events.sort(key=lambda x: x[0])
|
| 363 |
+
|
| 364 |
+
return [e[1] for e in all_indexed_events]
|
| 365 |
+
|
| 366 |
def _compute_future_return_labels(self,
|
| 367 |
anchor_price: Optional[float],
|
| 368 |
anchor_timestamp: int,
|
|
|
|
| 932 |
raw_data = torch.load(filepath, map_location='cpu', weights_only=False)
|
| 933 |
except Exception as e:
|
| 934 |
raise RuntimeError(f"ERROR: Could not load cached item {filepath}: {e}")
|
| 935 |
+
except Exception as e:
|
| 936 |
+
raise RuntimeError(f"ERROR: Could not load cached item {filepath}: {e}")
|
| 937 |
else:
|
| 938 |
+
# Strict Offline Mode: No dynamic generation fallback
|
| 939 |
+
raise RuntimeError(f"Offline mode required. No cache directory provided or configured.")
|
| 940 |
|
| 941 |
if not raw_data:
|
| 942 |
raise RuntimeError(f"No raw data loaded for index {idx}")
|
|
|
|
| 986 |
preferred_horizon = horizons[1] if len(horizons) > 1 else min_label
|
| 987 |
|
| 988 |
mint_ts_value = _timestamp_to_order_value(mint_timestamp)
|
| 989 |
+
|
| 990 |
+
# ============================================================================
|
| 991 |
+
# CRITICAL: Use ONLY successful trades for T_cutoff sampling!
|
| 992 |
+
# ============================================================================
|
| 993 |
+
# Failed trades have invalid price_usd values and should not be used for:
|
| 994 |
+
# 1. Determining the valid T_cutoff range (trades[24] to trades[-1])
|
| 995 |
+
# 2. Computing price returns for labels
|
| 996 |
+
# The T_cutoff range must guarantee at least one successful trade after cutoff.
|
| 997 |
+
# ============================================================================
|
| 998 |
+
successful_trades = [
|
| 999 |
+
trade for trade in raw_data.get('trades', [])
|
| 1000 |
+
if trade.get('success', False)
|
| 1001 |
+
and trade.get('timestamp') is not None
|
| 1002 |
+
and float(trade.get('price_usd', 0) or 0) > 0
|
| 1003 |
+
]
|
| 1004 |
+
|
| 1005 |
trade_ts_values = [
|
| 1006 |
_timestamp_to_order_value(trade.get('timestamp'))
|
| 1007 |
+
for trade in successful_trades
|
|
|
|
| 1008 |
]
|
| 1009 |
+
|
| 1010 |
+
if not trade_ts_values or len(trade_ts_values) < 25:
|
| 1011 |
+
# Not enough successful trades for valid sampling
|
| 1012 |
return None
|
| 1013 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1014 |
# Sort trades to find the 24th trade timestamp
|
| 1015 |
sorted_trades_ts = sorted(trade_ts_values)
|
| 1016 |
|
|
|
|
| 1173 |
max_horizon_seconds=self.max_cache_horizon_seconds,
|
| 1174 |
include_wallet_data=False,
|
| 1175 |
include_graph=False,
|
| 1176 |
+
min_trades=24, # Enforce min trades for context
|
| 1177 |
full_history=True, # Bypass H/B/H limits
|
| 1178 |
+
prune_failed=False, # Keep failed trades for realistic simulation
|
| 1179 |
+
prune_transfers=False # Keep transfers for snapshot reconstruction
|
| 1180 |
)
|
| 1181 |
if raw_data is None:
|
| 1182 |
return None
|
|
|
|
| 1563 |
cached_holders_list=cached_holders_list
|
| 1564 |
)
|
| 1565 |
|
| 1566 |
+
# 7. Finalize Sequence with Dynamic Sampling
|
| 1567 |
event_sequence_entries.sort(key=lambda x: x[0])
|
| 1568 |
+
raw_event_sequence = [entry[1] for entry in event_sequence_entries]
|
| 1569 |
+
|
| 1570 |
+
# Apply dynamic context sampling if needed
|
| 1571 |
+
event_sequence = self._apply_dynamic_sampling(raw_event_sequence)
|
| 1572 |
|
| 1573 |
# 8. Compute Labels using future data
|
| 1574 |
# Define horizons (e.g., [60, 120, ...])
|
|
|
|
| 1578 |
# Note: future_trades_for_labels contains ALL trades (past & future relative to T_cutoff)
|
| 1579 |
# We need to find the price at T_cutoff and at T_cutoff + h
|
| 1580 |
|
| 1581 |
+
# ============================================================================
|
| 1582 |
+
# CRITICAL: Filter for successful trades with valid prices ONLY!
|
| 1583 |
+
# ============================================================================
|
| 1584 |
+
# Failed trades (success=False) often have price_usd=0 or invalid values.
|
| 1585 |
+
# Using these for label computation causes mathematically impossible returns
|
| 1586 |
+
# like -1.0 (price went to 0) or 0.0 (no price change despite trading).
|
| 1587 |
+
# ALWAYS filter by: success=True AND price_usd > 0
|
| 1588 |
+
# ============================================================================
|
| 1589 |
+
all_trades = [
|
| 1590 |
+
t for t in future_trades_for_labels
|
| 1591 |
+
if t.get('success', False) and float(t.get('price_usd', 0) or 0) > 0
|
| 1592 |
+
]
|
| 1593 |
+
|
| 1594 |
+
if not all_trades:
|
| 1595 |
+
# No valid trades for label computation
|
| 1596 |
+
return {
|
| 1597 |
+
'event_sequence': event_sequence,
|
| 1598 |
+
'wallets': wallet_data,
|
| 1599 |
+
'tokens': all_token_data,
|
| 1600 |
+
'graph_links': graph_links,
|
| 1601 |
+
'embedding_pooler': pooler,
|
| 1602 |
+
'labels': torch.zeros(len(self.horizons_seconds), dtype=torch.float32),
|
| 1603 |
+
'labels_mask': torch.zeros(len(self.horizons_seconds), dtype=torch.float32)
|
| 1604 |
+
}
|
| 1605 |
+
|
| 1606 |
# Ensure sorted
|
| 1607 |
all_trades.sort(key=lambda x: _timestamp_to_order_value(x['timestamp']))
|
| 1608 |
|
|
|
|
| 1626 |
label_values = []
|
| 1627 |
mask_values = []
|
| 1628 |
|
| 1629 |
+
# Edge case: no trades before cutoff means we have no anchor price
|
| 1630 |
+
if current_price_idx < 0 or current_price <= 0:
|
| 1631 |
+
# No valid anchor price - mask all labels
|
| 1632 |
+
for h in horizons:
|
| 1633 |
+
label_values.append(0.0)
|
| 1634 |
+
mask_values.append(0.0)
|
| 1635 |
+
else:
|
| 1636 |
+
for h in horizons:
|
| 1637 |
+
target_ts = cutoff_ts_val + h
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1638 |
|
| 1639 |
+
if target_ts > last_trade_ts_val:
|
| 1640 |
+
# Horizon extends beyond known history
|
| 1641 |
+
# We MASK this label. We do NOT guess 0.
|
| 1642 |
+
label_values.append(0.0) # Dummy value
|
| 1643 |
+
mask_values.append(0.0) # Mask = 0 (Ignore)
|
| 1644 |
else:
|
| 1645 |
+
# Find price at target_ts
|
| 1646 |
+
# Start searching AFTER current_price_idx to find the NEXT trade
|
| 1647 |
+
future_price = current_price
|
| 1648 |
|
| 1649 |
+
# Search from current_price_idx + 1 to find trades in the horizon window
|
| 1650 |
+
for j in range(current_price_idx + 1, len(all_trades)):
|
| 1651 |
+
t = all_trades[j]
|
| 1652 |
+
t_ts = _timestamp_to_order_value(t['timestamp'])
|
| 1653 |
+
if t_ts <= target_ts:
|
| 1654 |
+
future_price = float(t['price_usd'])
|
| 1655 |
+
else:
|
| 1656 |
+
break # Surpassed target_ts
|
| 1657 |
+
|
| 1658 |
+
ret = (future_price - current_price) / current_price
|
| 1659 |
+
label_values.append(ret)
|
| 1660 |
+
mask_values.append(1.0) # Mask = 1 (Valid)
|
| 1661 |
|
| 1662 |
return {
|
| 1663 |
+
'token_address': token_address, # For debugging
|
| 1664 |
'event_sequence': event_sequence,
|
| 1665 |
'wallets': wallet_data,
|
| 1666 |
'tokens': all_token_data,
|
data/ohlc_stats.npz
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1660
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e84cff0cfabf73d50f94c3f9a5cf9224e89c634db76982a4e3e5428c9df4ea91
|
| 3 |
size 1660
|
log.log
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f1c2198c3ed6e249ddf7b7b017b99b2389e4611b8b0649c63d30c40c59e03ac1
|
| 3 |
+
size 76001
|
train.py
CHANGED
|
@@ -150,10 +150,14 @@ def log_debug_batch_context(batch: Dict[str, Any], logger: logging.Logger, step:
|
|
| 150 |
events.append(name)
|
| 151 |
|
| 152 |
logger.info(f"\n--- [Step {step}] Batch Input Preview (Sample 0) ---")
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
|
| 158 |
# Show Labels
|
| 159 |
# Assuming flattened labels [H*Q]
|
|
@@ -190,11 +194,7 @@ def parse_args() -> argparse.Namespace:
|
|
| 190 |
parser.add_argument("--num_workers", type=int, default=0)
|
| 191 |
parser.add_argument("--pin_memory", dest="pin_memory", action="store_true", default=False)
|
| 192 |
parser.add_argument("--no-pin_memory", dest="pin_memory", action="store_false")
|
| 193 |
-
parser.add_argument("--
|
| 194 |
-
parser.add_argument("--clickhouse_port", type=int, default=9000)
|
| 195 |
-
parser.add_argument("--neo4j_uri", type=str, default="bolt://localhost:7687")
|
| 196 |
-
parser.add_argument("--neo4j_user", type=str, default=None)
|
| 197 |
-
parser.add_argument("--neo4j_password", type=str, default=None)
|
| 198 |
return parser.parse_args()
|
| 199 |
|
| 200 |
|
|
@@ -394,21 +394,27 @@ def main() -> None:
|
|
| 394 |
dirs = [d for d in checkpoint_dir.iterdir() if d.is_dir()]
|
| 395 |
if dirs:
|
| 396 |
# Sort by modification time or name to find latest
|
|
|
|
| 397 |
dirs.sort(key=lambda x: x.stat().st_mtime)
|
| 398 |
latest_checkpoint = dirs[-1]
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
else:
|
| 411 |
-
logger.info("No
|
| 412 |
|
| 413 |
# --- 7. Training Loop ---
|
| 414 |
total_steps = 0
|
|
|
|
| 150 |
events.append(name)
|
| 151 |
|
| 152 |
logger.info(f"\n--- [Step {step}] Batch Input Preview (Sample 0) ---")
|
| 153 |
+
|
| 154 |
+
# Log token address for manual verification
|
| 155 |
+
token_addresses = batch.get('token_addresses', [])
|
| 156 |
+
if token_addresses:
|
| 157 |
+
logger.info(f"Token Address: {token_addresses[0]}")
|
| 158 |
+
|
| 159 |
+
context_str = ", ".join(events)
|
| 160 |
+
logger.info(f"Event Stream ({len(events)}): [{context_str}]")
|
| 161 |
|
| 162 |
# Show Labels
|
| 163 |
# Assuming flattened labels [H*Q]
|
|
|
|
| 194 |
parser.add_argument("--num_workers", type=int, default=0)
|
| 195 |
parser.add_argument("--pin_memory", dest="pin_memory", action="store_true", default=False)
|
| 196 |
parser.add_argument("--no-pin_memory", dest="pin_memory", action="store_false")
|
| 197 |
+
parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Path to checkpoint or 'latest'")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
return parser.parse_args()
|
| 199 |
|
| 200 |
|
|
|
|
| 394 |
dirs = [d for d in checkpoint_dir.iterdir() if d.is_dir()]
|
| 395 |
if dirs:
|
| 396 |
# Sort by modification time or name to find latest
|
| 397 |
+
# Sort by modification time or name to find latest
|
| 398 |
dirs.sort(key=lambda x: x.stat().st_mtime)
|
| 399 |
latest_checkpoint = dirs[-1]
|
| 400 |
+
|
| 401 |
+
if args.resume_from_checkpoint:
|
| 402 |
+
if args.resume_from_checkpoint == "latest":
|
| 403 |
+
if latest_checkpoint:
|
| 404 |
+
logger.info(f"Resuming from latest checkpoint: {latest_checkpoint}")
|
| 405 |
+
accelerator.load_state(str(latest_checkpoint))
|
| 406 |
+
else:
|
| 407 |
+
logger.warning("Resume requested but no checkpoint found in dir. Starting fresh.")
|
| 408 |
+
else:
|
| 409 |
+
# Specific path
|
| 410 |
+
custom_ckpt = Path(args.resume_from_checkpoint)
|
| 411 |
+
if custom_ckpt.exists():
|
| 412 |
+
logger.info(f"Resuming from specific checkpoint: {custom_ckpt}")
|
| 413 |
+
accelerator.load_state(str(custom_ckpt))
|
| 414 |
+
else:
|
| 415 |
+
raise FileNotFoundError(f"Checkpoint not found at {custom_ckpt}")
|
| 416 |
else:
|
| 417 |
+
logger.info("No resume flag provided. Starting fresh.")
|
| 418 |
|
| 419 |
# --- 7. Training Loop ---
|
| 420 |
total_steps = 0
|
train.sh
CHANGED
|
@@ -15,7 +15,4 @@ accelerate launch train.py \
|
|
| 15 |
--horizons_seconds 60 180 300 600 1800 3600 7200 \
|
| 16 |
--quantiles 0.1 0.5 0.9 \
|
| 17 |
--ohlc_stats_path ./data/ohlc_stats.npz \
|
| 18 |
-
--num_workers 4
|
| 19 |
-
--clickhouse_host localhost \
|
| 20 |
-
--clickhouse_port 9000 \
|
| 21 |
-
--neo4j_uri bolt://localhost:7687
|
|
|
|
| 15 |
--horizons_seconds 60 180 300 600 1800 3600 7200 \
|
| 16 |
--quantiles 0.1 0.5 0.9 \
|
| 17 |
--ohlc_stats_path ./data/ohlc_stats.npz \
|
| 18 |
+
--num_workers 4
|
|
|
|
|
|
|
|
|