zirobtc commited on
Commit
7d63a09
·
1 Parent(s): 5800f64

Upload folder using huggingface_hub

Browse files
Files changed (7) hide show
  1. README.md +67 -4
  2. data/data_collator.py +3 -1
  3. data/data_loader.py +194 -50
  4. data/ohlc_stats.npz +1 -1
  5. log.log +2 -2
  6. train.py +27 -21
  7. train.sh +1 -4
README.md CHANGED
@@ -36,7 +36,70 @@ Launch training with updated hyperparameters.
36
  ./train.sh
37
  ```
38
 
39
- ## TODOs
40
- * [ ] **Re-run Caching**: Since horizons changed, the existing cache (if any) is stale. Expected to run `pre_cache.sh`.
41
- * [ ] **Verify Inference**: Ensure `inference.py` handles the 20s latency constraints gracefully (e.g. timestamp checks).
42
- * [ ] **Model Architecture**: Confirm `8192` context length fits in VRAM with current model config (Attention implementation).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  ./train.sh
37
  ```
38
 
39
+ ## TODO: Future Enhancements
40
+
41
+ ### Multi-Task Quality Prediction Head
42
+ Add a secondary head (Head B) that predicts **token quality percentiles** alongside price returns:
43
+ - **Fees Percentile** — Predicted future fees relative to class median
44
+ - **Volume Percentile** — Predicted future volume relative to class median
45
+ - **Holders Percentile** — Predicted future holder count relative to class median
46
+
47
+ **Rationale:** The `analyze_distribution.py` script currently uses hard thresholds on future metrics to classify tokens as "Manipulated". This head would let the model **learn to predict** those quality metrics from current features, enabling scam detection at inference time without access to future data.
48
+
49
+ **Approach Options:**
50
+ 1. Single composite quality score (simpler)
51
+ 2. Three separate percentile predictions (more interpretable)
52
+ 3. Three binary classifications (fees_ok, volume_ok, holders_ok)
53
+
54
+ Data Sampling (Context Optimization)
55
+ Replace hardcoded H/B/H limits with a dynamic sampling strategy that maximizes the model's context window usage.
56
+
57
+ The Problem
58
+ Currently, the system triggers H/B/H logic based on a fixed 30k trade count and uses hardcoded limits (10k early, 15k recent). This mismatch with the model's max_seq_len (e.g., 8192) leads to inefficient data usage—either truncating valuable data arbitrarily or feeding too little when more could fit.
59
+
60
+ The Solution: Dynamic Context Filling
61
+ Implementation moves to
62
+ data_loader.py
63
+ (since cache contains full history).
64
+
65
+ Algorithm
66
+ Input: Full sorted list of events (Trades, Chart Segments, etc.) up to T_cutoff.
67
+ Check: if
68
+ len(events) <= max_seq_len
69
+ , use ALL events.
70
+ Split: If
71
+ len(events) > max_seq_len
72
+ :
73
+ Reserve space for special tokens (start/end/pad).
74
+ Calculate Budget: budget = max_seq_len - reserve (e.g., 8100).
75
+ Dynamic Split:
76
+ Head (Early): First budget / 2 events.
77
+ Tail (Recent): Last budget / 2 events.
78
+ Construct: [HEAD] ... [GAP_TOKEN] ... [TAIL].
79
+ Implementation Changes
80
+ [MODIFY]
81
+ data_loader.py
82
+ Remove Constants: Delete HBH_EARLY_EVENT_LIMIT, HBH_RECENT_EVENT_LIMIT.
83
+ Update
84
+ _generate_dataset_item
85
+ :
86
+ Accept max_seq_len.
87
+ Implement the split logic defined above before returning event_sequence.
88
+
89
+
90
+
91
+
92
+ Here explained easly:
93
+
94
+ We check all the final events if exeed the total context we have.
95
+ Then we filter out all the trade events and then check how many non aggregable events we have, for example a burn or a deployer trade etc...
96
+ Then we take the remaining from context exldued thosoe IMPORTANT events like i show above and we check how many snapshot will fit chart segment, holders snapshot, chain stats etc...
97
+ Then the remaining after snapshot and important non aggregable events we use them to make the H segments (high definition) and in the middle (Blurry) we keep just the snapshots.
98
+
99
+ This works because 90% of context is taken just by trades and transfers so they are the only thing to compress to help context
100
+
101
+ you dont need new tokens becuase there are already special tokens for it:
102
+ 'MIDDLE',
103
+ 'RECENT'
104
+
105
+ so when you switch to blurry <MIDDLE> and when you go back to high definition you use <RECENT>
data/data_collator.py CHANGED
@@ -710,7 +710,9 @@ class MemecoinCollator:
710
  'textual_event_data': textual_event_data_list, # RENAMED
711
  # Labels
712
  'labels': torch.stack([item['labels'] for item in batch]) if batch and 'labels' in batch[0] else None,
713
- 'labels_mask': torch.stack([item['labels_mask'] for item in batch]) if batch and 'labels_mask' in batch[0] else None
 
 
714
  }
715
 
716
  # Filter out None values (e.g., if no labels provided)
 
710
  'textual_event_data': textual_event_data_list, # RENAMED
711
  # Labels
712
  'labels': torch.stack([item['labels'] for item in batch]) if batch and 'labels' in batch[0] else None,
713
+ 'labels_mask': torch.stack([item['labels_mask'] for item in batch]) if batch and 'labels_mask' in batch[0] else None,
714
+ # Debug info
715
+ 'token_addresses': [item.get('token_address', 'unknown') for item in batch]
716
  }
717
 
718
  # Filter out None values (e.g., if no labels provided)
data/data_loader.py CHANGED
@@ -33,10 +33,25 @@ LARGE_TRANSFER_SUPPLY_PCT_THRESHOLD = 0.03 # 3% of supply
33
  SMART_WALLET_PNL_THRESHOLD = 3.0 # 300% PNL
34
  SMART_WALLET_USD_THRESHOLD = 20000.0
35
 
36
- # --- NEW: Hyperparameters for H/B/H Event Fetching ---
37
- EVENT_COUNT_THRESHOLD_FOR_HBH = 30000 # If total events > this, use H/B/H
38
- HBH_EARLY_EVENT_LIMIT = 10000
39
- HBH_RECENT_EVENT_LIMIT = 15000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  # --- NEW: OHLC Sequence Length Constant ---
42
  OHLC_SEQ_LEN = 300 # 4 minutes of chart
@@ -107,7 +122,10 @@ class OracleDataset(Dataset):
107
  t_cutoff_seconds: int = 60,
108
  cache_dir: Optional[Union[str, Path]] = None,
109
  start_date: Optional[datetime.datetime] = None,
110
- min_trade_usd: float = 0.0):
 
 
 
111
 
112
  # --- NEW: Create a persistent requests session for efficiency ---
113
  # Configure robust HTTP session
@@ -261,6 +279,90 @@ class OracleDataset(Dataset):
261
  denom = self.ohlc_price_std if abs(self.ohlc_price_std) > 1e-9 else 1.0
262
  return [(float(v) - self.ohlc_price_mean) / denom for v in values]
263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  def _compute_future_return_labels(self,
265
  anchor_price: Optional[float],
266
  anchor_timestamp: int,
@@ -830,9 +932,11 @@ class OracleDataset(Dataset):
830
  raw_data = torch.load(filepath, map_location='cpu', weights_only=False)
831
  except Exception as e:
832
  raise RuntimeError(f"ERROR: Could not load cached item {filepath}: {e}")
 
 
833
  else:
834
- # Online mode fallback
835
- raw_data = self.__cacheitem__(idx)
836
 
837
  if not raw_data:
838
  raise RuntimeError(f"No raw data loaded for index {idx}")
@@ -882,19 +986,31 @@ class OracleDataset(Dataset):
882
  preferred_horizon = horizons[1] if len(horizons) > 1 else min_label
883
 
884
  mint_ts_value = _timestamp_to_order_value(mint_timestamp)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
885
  trade_ts_values = [
886
  _timestamp_to_order_value(trade.get('timestamp'))
887
- for trade in raw_data.get('trades', [])
888
- if trade.get('timestamp') is not None
889
  ]
890
- if not trade_ts_values:
 
 
891
  return None
892
 
893
- # Cache guarantees min_trades=25, so we proceed assuming valid data.
894
- # But for safety in dynamic sampling:
895
- if not trade_ts_values:
896
- return None
897
-
898
  # Sort trades to find the 24th trade timestamp
899
  sorted_trades_ts = sorted(trade_ts_values)
900
 
@@ -1057,10 +1173,10 @@ class OracleDataset(Dataset):
1057
  max_horizon_seconds=self.max_cache_horizon_seconds,
1058
  include_wallet_data=False,
1059
  include_graph=False,
1060
- min_trades=25,
1061
  full_history=True, # Bypass H/B/H limits
1062
- prune_failed=True, # Drop failed trades
1063
- prune_transfers=True # Drop transfers (captured in snapshots)
1064
  )
1065
  if raw_data is None:
1066
  return None
@@ -1447,9 +1563,12 @@ class OracleDataset(Dataset):
1447
  cached_holders_list=cached_holders_list
1448
  )
1449
 
1450
- # 7. Finalize Sequence
1451
  event_sequence_entries.sort(key=lambda x: x[0])
1452
- event_sequence = [entry[1] for entry in event_sequence_entries]
 
 
 
1453
 
1454
  # 8. Compute Labels using future data
1455
  # Define horizons (e.g., [60, 120, ...])
@@ -1459,7 +1578,31 @@ class OracleDataset(Dataset):
1459
  # Note: future_trades_for_labels contains ALL trades (past & future relative to T_cutoff)
1460
  # We need to find the price at T_cutoff and at T_cutoff + h
1461
 
1462
- all_trades = future_trades_for_labels
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1463
  # Ensure sorted
1464
  all_trades.sort(key=lambda x: _timestamp_to_order_value(x['timestamp']))
1465
 
@@ -1483,40 +1626,41 @@ class OracleDataset(Dataset):
1483
  label_values = []
1484
  mask_values = []
1485
 
1486
- for h in horizons:
1487
- target_ts = cutoff_ts_val + h
1488
-
1489
- if target_ts > last_trade_ts_val:
1490
- # Horizon extends beyond known history
1491
- # We MASK this label. We do NOT guess 0.
1492
- label_values.append(0.0) # Dummy value
1493
- mask_values.append(0.0) # Mask = 0 (Ignore)
1494
- else:
1495
- # Find price at target_ts
1496
- # It is the last trade strictly before or at target_ts
1497
- future_price = current_price # Default to current if no trades found in window? Unlikely if checked range.
1498
-
1499
- # Check trades between current_idx and target
1500
- # Optimization: start search from current_price_idx
1501
- found_future = False
1502
- for j in range(current_price_idx, len(all_trades)):
1503
- t = all_trades[j]
1504
- t_ts = _timestamp_to_order_value(t['timestamp'])
1505
- if t_ts <= target_ts:
1506
- future_price = float(t['price_usd'])
1507
- found_future = True
1508
- else:
1509
- break # Optimization: surpassed target_ts
1510
 
1511
- if current_price > 0:
1512
- ret = (future_price - current_price) / current_price
 
 
 
1513
  else:
1514
- ret = 0.0
 
 
1515
 
1516
- label_values.append(ret)
1517
- mask_values.append(1.0) # Mask = 1 (Valid)
 
 
 
 
 
 
 
 
 
 
1518
 
1519
  return {
 
1520
  'event_sequence': event_sequence,
1521
  'wallets': wallet_data,
1522
  'tokens': all_token_data,
 
33
  SMART_WALLET_PNL_THRESHOLD = 3.0 # 300% PNL
34
  SMART_WALLET_USD_THRESHOLD = 20000.0
35
 
36
+ # --- Event Categorization for Dynamic Sampling ---
37
+ # Events that are rare and should ALWAYS be kept
38
+ CRITICAL_EVENTS = {
39
+ 'Mint', 'Deployer_Trade', 'SmartWallet_Trade', 'LargeTrade', 'LargeTransfer',
40
+ 'TokenBurn', 'SupplyLock', 'PoolCreated', 'LiquidityChange', 'Migrated',
41
+ 'FeeCollected', 'TrendingToken', 'BoostedToken', 'XPost', 'XRetweet',
42
+ 'XReply', 'XQuoteTweet', 'PumpReply', 'DexBoost_Paid', 'DexProfile_Updated',
43
+ 'AlphaGroup_Call', 'Channel_Call', 'CexListing', 'TikTok_Trending_Hashtag',
44
+ 'XTrending_Hashtag'
45
+ }
46
+
47
+ # Periodic snapshots - kept for context continuity
48
+ SNAPSHOT_EVENTS = {
49
+ 'Chart_Segment', 'OnChain_Snapshot', 'HolderSnapshot',
50
+ 'ChainSnapshot', 'Lighthouse_Snapshot'
51
+ }
52
+
53
+ # High-volume events that can be compressed (Head/Tail)
54
+ COMPRESSIBLE_EVENTS = {'Trade', 'Transfer'}
55
 
56
  # --- NEW: OHLC Sequence Length Constant ---
57
  OHLC_SEQ_LEN = 300 # 4 minutes of chart
 
122
  t_cutoff_seconds: int = 60,
123
  cache_dir: Optional[Union[str, Path]] = None,
124
  start_date: Optional[datetime.datetime] = None,
125
+ min_trade_usd: float = 0.0,
126
+ max_seq_len: int = 8192):
127
+
128
+ self.max_seq_len = max_seq_len
129
 
130
  # --- NEW: Create a persistent requests session for efficiency ---
131
  # Configure robust HTTP session
 
279
  denom = self.ohlc_price_std if abs(self.ohlc_price_std) > 1e-9 else 1.0
280
  return [(float(v) - self.ohlc_price_mean) / denom for v in values]
281
 
282
+ def _apply_dynamic_sampling(self, events: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
283
+ """
284
+ Applies dynamic context sampling to fit events within max_seq_len.
285
+
286
+ Priority:
287
+ 1. CRITICAL events (always kept)
288
+ 2. SNAPSHOT events (kept for continuity)
289
+ 3. COMPRESSIBLE events (Trade/Transfer) - split into Head/Tail with MIDDLE token
290
+
291
+ Uses existing 'MIDDLE' and 'RECENT' tokens to mark transitions.
292
+ """
293
+ if len(events) <= self.max_seq_len:
294
+ return events
295
+
296
+ # Categorize events by type
297
+ critical_events = [] # (original_idx, event)
298
+ snapshot_events = []
299
+ compressible_events = []
300
+
301
+ for idx, event in enumerate(events):
302
+ event_type = event.get('event_type', '')
303
+ if event_type in CRITICAL_EVENTS:
304
+ critical_events.append((idx, event))
305
+ elif event_type in SNAPSHOT_EVENTS:
306
+ snapshot_events.append((idx, event))
307
+ elif event_type in COMPRESSIBLE_EVENTS:
308
+ compressible_events.append((idx, event))
309
+ else:
310
+ # Unknown event types go to critical (safe default)
311
+ critical_events.append((idx, event))
312
+
313
+ # Calculate budget for compressible events
314
+ # Reserve 2 tokens for MIDDLE and RECENT markers
315
+ reserved_tokens = 2
316
+ fixed_count = len(critical_events) + len(snapshot_events) + reserved_tokens
317
+ budget_for_compressible = max(0, self.max_seq_len - fixed_count)
318
+
319
+ # If no budget for compressible, just return critical + snapshots
320
+ if budget_for_compressible == 0 or len(compressible_events) <= budget_for_compressible:
321
+ # All compressible fit, just return sorted
322
+ all_events = critical_events + snapshot_events + compressible_events
323
+ all_events.sort(key=lambda x: x[0])
324
+ return [e[1] for e in all_events]
325
+
326
+ # Apply Head/Tail split for compressible events
327
+ head_size = budget_for_compressible // 2
328
+ tail_size = budget_for_compressible - head_size
329
+
330
+ head_events = compressible_events[:head_size]
331
+ tail_events = compressible_events[-tail_size:] if tail_size > 0 else []
332
+
333
+ # Find the timestamp boundary for MIDDLE/RECENT markers
334
+ # MIDDLE goes after head, RECENT goes before tail
335
+ middle_marker_idx = head_events[-1][0] if head_events else 0
336
+ recent_marker_idx = tail_events[0][0] if tail_events else len(events)
337
+
338
+ # Create marker events
339
+ middle_marker = {
340
+ 'event_type': 'MIDDLE',
341
+ 'relative_ts': events[middle_marker_idx].get('relative_ts', 0) if middle_marker_idx < len(events) else 0,
342
+ 'is_marker': True
343
+ }
344
+ recent_marker = {
345
+ 'event_type': 'RECENT',
346
+ 'relative_ts': events[recent_marker_idx - 1].get('relative_ts', 0) if recent_marker_idx > 0 and recent_marker_idx <= len(events) else 0,
347
+ 'is_marker': True
348
+ }
349
+
350
+ # Combine all events with markers
351
+ # We need to maintain chronological order
352
+ all_indexed_events = critical_events + snapshot_events + head_events + tail_events
353
+
354
+ # Add markers with synthetic indices
355
+ middle_idx = middle_marker_idx + 0.5 # After last head event
356
+ recent_idx = recent_marker_idx - 0.5 # Before first tail event
357
+
358
+ all_indexed_events.append((middle_idx, middle_marker))
359
+ all_indexed_events.append((recent_idx, recent_marker))
360
+
361
+ # Sort by original index to maintain chronological order
362
+ all_indexed_events.sort(key=lambda x: x[0])
363
+
364
+ return [e[1] for e in all_indexed_events]
365
+
366
  def _compute_future_return_labels(self,
367
  anchor_price: Optional[float],
368
  anchor_timestamp: int,
 
932
  raw_data = torch.load(filepath, map_location='cpu', weights_only=False)
933
  except Exception as e:
934
  raise RuntimeError(f"ERROR: Could not load cached item {filepath}: {e}")
935
+ except Exception as e:
936
+ raise RuntimeError(f"ERROR: Could not load cached item {filepath}: {e}")
937
  else:
938
+ # Strict Offline Mode: No dynamic generation fallback
939
+ raise RuntimeError(f"Offline mode required. No cache directory provided or configured.")
940
 
941
  if not raw_data:
942
  raise RuntimeError(f"No raw data loaded for index {idx}")
 
986
  preferred_horizon = horizons[1] if len(horizons) > 1 else min_label
987
 
988
  mint_ts_value = _timestamp_to_order_value(mint_timestamp)
989
+
990
+ # ============================================================================
991
+ # CRITICAL: Use ONLY successful trades for T_cutoff sampling!
992
+ # ============================================================================
993
+ # Failed trades have invalid price_usd values and should not be used for:
994
+ # 1. Determining the valid T_cutoff range (trades[24] to trades[-1])
995
+ # 2. Computing price returns for labels
996
+ # The T_cutoff range must guarantee at least one successful trade after cutoff.
997
+ # ============================================================================
998
+ successful_trades = [
999
+ trade for trade in raw_data.get('trades', [])
1000
+ if trade.get('success', False)
1001
+ and trade.get('timestamp') is not None
1002
+ and float(trade.get('price_usd', 0) or 0) > 0
1003
+ ]
1004
+
1005
  trade_ts_values = [
1006
  _timestamp_to_order_value(trade.get('timestamp'))
1007
+ for trade in successful_trades
 
1008
  ]
1009
+
1010
+ if not trade_ts_values or len(trade_ts_values) < 25:
1011
+ # Not enough successful trades for valid sampling
1012
  return None
1013
 
 
 
 
 
 
1014
  # Sort trades to find the 24th trade timestamp
1015
  sorted_trades_ts = sorted(trade_ts_values)
1016
 
 
1173
  max_horizon_seconds=self.max_cache_horizon_seconds,
1174
  include_wallet_data=False,
1175
  include_graph=False,
1176
+ min_trades=24, # Enforce min trades for context
1177
  full_history=True, # Bypass H/B/H limits
1178
+ prune_failed=False, # Keep failed trades for realistic simulation
1179
+ prune_transfers=False # Keep transfers for snapshot reconstruction
1180
  )
1181
  if raw_data is None:
1182
  return None
 
1563
  cached_holders_list=cached_holders_list
1564
  )
1565
 
1566
+ # 7. Finalize Sequence with Dynamic Sampling
1567
  event_sequence_entries.sort(key=lambda x: x[0])
1568
+ raw_event_sequence = [entry[1] for entry in event_sequence_entries]
1569
+
1570
+ # Apply dynamic context sampling if needed
1571
+ event_sequence = self._apply_dynamic_sampling(raw_event_sequence)
1572
 
1573
  # 8. Compute Labels using future data
1574
  # Define horizons (e.g., [60, 120, ...])
 
1578
  # Note: future_trades_for_labels contains ALL trades (past & future relative to T_cutoff)
1579
  # We need to find the price at T_cutoff and at T_cutoff + h
1580
 
1581
+ # ============================================================================
1582
+ # CRITICAL: Filter for successful trades with valid prices ONLY!
1583
+ # ============================================================================
1584
+ # Failed trades (success=False) often have price_usd=0 or invalid values.
1585
+ # Using these for label computation causes mathematically impossible returns
1586
+ # like -1.0 (price went to 0) or 0.0 (no price change despite trading).
1587
+ # ALWAYS filter by: success=True AND price_usd > 0
1588
+ # ============================================================================
1589
+ all_trades = [
1590
+ t for t in future_trades_for_labels
1591
+ if t.get('success', False) and float(t.get('price_usd', 0) or 0) > 0
1592
+ ]
1593
+
1594
+ if not all_trades:
1595
+ # No valid trades for label computation
1596
+ return {
1597
+ 'event_sequence': event_sequence,
1598
+ 'wallets': wallet_data,
1599
+ 'tokens': all_token_data,
1600
+ 'graph_links': graph_links,
1601
+ 'embedding_pooler': pooler,
1602
+ 'labels': torch.zeros(len(self.horizons_seconds), dtype=torch.float32),
1603
+ 'labels_mask': torch.zeros(len(self.horizons_seconds), dtype=torch.float32)
1604
+ }
1605
+
1606
  # Ensure sorted
1607
  all_trades.sort(key=lambda x: _timestamp_to_order_value(x['timestamp']))
1608
 
 
1626
  label_values = []
1627
  mask_values = []
1628
 
1629
+ # Edge case: no trades before cutoff means we have no anchor price
1630
+ if current_price_idx < 0 or current_price <= 0:
1631
+ # No valid anchor price - mask all labels
1632
+ for h in horizons:
1633
+ label_values.append(0.0)
1634
+ mask_values.append(0.0)
1635
+ else:
1636
+ for h in horizons:
1637
+ target_ts = cutoff_ts_val + h
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1638
 
1639
+ if target_ts > last_trade_ts_val:
1640
+ # Horizon extends beyond known history
1641
+ # We MASK this label. We do NOT guess 0.
1642
+ label_values.append(0.0) # Dummy value
1643
+ mask_values.append(0.0) # Mask = 0 (Ignore)
1644
  else:
1645
+ # Find price at target_ts
1646
+ # Start searching AFTER current_price_idx to find the NEXT trade
1647
+ future_price = current_price
1648
 
1649
+ # Search from current_price_idx + 1 to find trades in the horizon window
1650
+ for j in range(current_price_idx + 1, len(all_trades)):
1651
+ t = all_trades[j]
1652
+ t_ts = _timestamp_to_order_value(t['timestamp'])
1653
+ if t_ts <= target_ts:
1654
+ future_price = float(t['price_usd'])
1655
+ else:
1656
+ break # Surpassed target_ts
1657
+
1658
+ ret = (future_price - current_price) / current_price
1659
+ label_values.append(ret)
1660
+ mask_values.append(1.0) # Mask = 1 (Valid)
1661
 
1662
  return {
1663
+ 'token_address': token_address, # For debugging
1664
  'event_sequence': event_sequence,
1665
  'wallets': wallet_data,
1666
  'tokens': all_token_data,
data/ohlc_stats.npz CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:847193fc90f4b0313f515ea38a24fd073be09188cfc4764c5dce3f658d4dc117
3
  size 1660
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e84cff0cfabf73d50f94c3f9a5cf9224e89c634db76982a4e3e5428c9df4ea91
3
  size 1660
log.log CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:10917f8ad8d8962a8c05a46f2b24dcb1180b23665d0767ea5c65c63d9ec09c92
3
- size 314966
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1c2198c3ed6e249ddf7b7b017b99b2389e4611b8b0649c63d30c40c59e03ac1
3
+ size 76001
train.py CHANGED
@@ -150,10 +150,14 @@ def log_debug_batch_context(batch: Dict[str, Any], logger: logging.Logger, step:
150
  events.append(name)
151
 
152
  logger.info(f"\n--- [Step {step}] Batch Input Preview (Sample 0) ---")
153
- # Show a slice of events (e.g. last 50)
154
- tail_len = 50
155
- context_str = ", ".join(events[-tail_len:])
156
- logger.info(f"Event Stream (Last {tail_len} of {len(events)}): [{context_str}]")
 
 
 
 
157
 
158
  # Show Labels
159
  # Assuming flattened labels [H*Q]
@@ -190,11 +194,7 @@ def parse_args() -> argparse.Namespace:
190
  parser.add_argument("--num_workers", type=int, default=0)
191
  parser.add_argument("--pin_memory", dest="pin_memory", action="store_true", default=False)
192
  parser.add_argument("--no-pin_memory", dest="pin_memory", action="store_false")
193
- parser.add_argument("--clickhouse_host", type=str, default="localhost")
194
- parser.add_argument("--clickhouse_port", type=int, default=9000)
195
- parser.add_argument("--neo4j_uri", type=str, default="bolt://localhost:7687")
196
- parser.add_argument("--neo4j_user", type=str, default=None)
197
- parser.add_argument("--neo4j_password", type=str, default=None)
198
  return parser.parse_args()
199
 
200
 
@@ -394,21 +394,27 @@ def main() -> None:
394
  dirs = [d for d in checkpoint_dir.iterdir() if d.is_dir()]
395
  if dirs:
396
  # Sort by modification time or name to find latest
 
397
  dirs.sort(key=lambda x: x.stat().st_mtime)
398
  latest_checkpoint = dirs[-1]
399
- logger.info(f"Found checkpoint: {latest_checkpoint}. Resuming training...")
400
- accelerator.load_state(str(latest_checkpoint))
401
-
402
- # Try to infer epoch/step from folder name or saved state if custom tracking
403
- # Accelerate restores DataLoader state, so we mainly need to know where we are for logging
404
- # Assuming standard naming or just relying on DataLoader restore.
405
- # Simple approach: Just trust Accelerate/DataLoader to skip.
406
- # If you need precise epoch/step recovery for logging display:
407
- # You could save a metadata.json inside the checkpoint folder.
408
-
409
- logger.info("Checkpoint loaded. DataLoader state restored.")
 
 
 
 
 
410
  else:
411
- logger.info("No checkpoint found. Starting fresh.")
412
 
413
  # --- 7. Training Loop ---
414
  total_steps = 0
 
150
  events.append(name)
151
 
152
  logger.info(f"\n--- [Step {step}] Batch Input Preview (Sample 0) ---")
153
+
154
+ # Log token address for manual verification
155
+ token_addresses = batch.get('token_addresses', [])
156
+ if token_addresses:
157
+ logger.info(f"Token Address: {token_addresses[0]}")
158
+
159
+ context_str = ", ".join(events)
160
+ logger.info(f"Event Stream ({len(events)}): [{context_str}]")
161
 
162
  # Show Labels
163
  # Assuming flattened labels [H*Q]
 
194
  parser.add_argument("--num_workers", type=int, default=0)
195
  parser.add_argument("--pin_memory", dest="pin_memory", action="store_true", default=False)
196
  parser.add_argument("--no-pin_memory", dest="pin_memory", action="store_false")
197
+ parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Path to checkpoint or 'latest'")
 
 
 
 
198
  return parser.parse_args()
199
 
200
 
 
394
  dirs = [d for d in checkpoint_dir.iterdir() if d.is_dir()]
395
  if dirs:
396
  # Sort by modification time or name to find latest
397
+ # Sort by modification time or name to find latest
398
  dirs.sort(key=lambda x: x.stat().st_mtime)
399
  latest_checkpoint = dirs[-1]
400
+
401
+ if args.resume_from_checkpoint:
402
+ if args.resume_from_checkpoint == "latest":
403
+ if latest_checkpoint:
404
+ logger.info(f"Resuming from latest checkpoint: {latest_checkpoint}")
405
+ accelerator.load_state(str(latest_checkpoint))
406
+ else:
407
+ logger.warning("Resume requested but no checkpoint found in dir. Starting fresh.")
408
+ else:
409
+ # Specific path
410
+ custom_ckpt = Path(args.resume_from_checkpoint)
411
+ if custom_ckpt.exists():
412
+ logger.info(f"Resuming from specific checkpoint: {custom_ckpt}")
413
+ accelerator.load_state(str(custom_ckpt))
414
+ else:
415
+ raise FileNotFoundError(f"Checkpoint not found at {custom_ckpt}")
416
  else:
417
+ logger.info("No resume flag provided. Starting fresh.")
418
 
419
  # --- 7. Training Loop ---
420
  total_steps = 0
train.sh CHANGED
@@ -15,7 +15,4 @@ accelerate launch train.py \
15
  --horizons_seconds 60 180 300 600 1800 3600 7200 \
16
  --quantiles 0.1 0.5 0.9 \
17
  --ohlc_stats_path ./data/ohlc_stats.npz \
18
- --num_workers 4 \
19
- --clickhouse_host localhost \
20
- --clickhouse_port 9000 \
21
- --neo4j_uri bolt://localhost:7687
 
15
  --horizons_seconds 60 180 300 600 1800 3600 7200 \
16
  --quantiles 0.1 0.5 0.9 \
17
  --ohlc_stats_path ./data/ohlc_stats.npz \
18
+ --num_workers 4