zirobtc commited on
Commit
6e3cdd3
·
1 Parent(s): 7d63a09

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. data/data_collator.py +2 -1
  2. data/data_loader.py +11 -0
  3. log.log +2 -2
  4. train.py +5 -0
data/data_collator.py CHANGED
@@ -712,7 +712,8 @@ class MemecoinCollator:
712
  'labels': torch.stack([item['labels'] for item in batch]) if batch and 'labels' in batch[0] else None,
713
  'labels_mask': torch.stack([item['labels_mask'] for item in batch]) if batch and 'labels_mask' in batch[0] else None,
714
  # Debug info
715
- 'token_addresses': [item.get('token_address', 'unknown') for item in batch]
 
716
  }
717
 
718
  # Filter out None values (e.g., if no labels provided)
 
712
  'labels': torch.stack([item['labels'] for item in batch]) if batch and 'labels' in batch[0] else None,
713
  'labels_mask': torch.stack([item['labels_mask'] for item in batch]) if batch and 'labels_mask' in batch[0] else None,
714
  # Debug info
715
+ 'token_addresses': [item.get('token_address', 'unknown') for item in batch],
716
+ 't_cutoffs': [item.get('t_cutoff', 'unknown') for item in batch]
717
  }
718
 
719
  # Filter out None values (e.g., if no labels provided)
data/data_loader.py CHANGED
@@ -1623,6 +1623,16 @@ class OracleDataset(Dataset):
1623
  else:
1624
  break
1625
 
 
 
 
 
 
 
 
 
 
 
1626
  label_values = []
1627
  mask_values = []
1628
 
@@ -1661,6 +1671,7 @@ class OracleDataset(Dataset):
1661
 
1662
  return {
1663
  'token_address': token_address, # For debugging
 
1664
  'event_sequence': event_sequence,
1665
  'wallets': wallet_data,
1666
  'tokens': all_token_data,
 
1623
  else:
1624
  break
1625
 
1626
+ # DEBUG: Log label computation details
1627
+ print(f" DEBUG LABELS: token={token_address[:12]}...")
1628
+ print(f" T_cutoff={T_cutoff.isoformat()}, cutoff_ts={cutoff_ts_val}")
1629
+ print(f" Successful trades count: {len(all_trades)}")
1630
+ print(f" current_price_idx={current_price_idx}, current_price={current_price}")
1631
+ print(f" last_trade_ts={last_trade_ts_val}, trades_after_cutoff={len(all_trades) - current_price_idx - 1}")
1632
+ if current_price_idx >= 0 and current_price_idx + 1 < len(all_trades):
1633
+ next_trade = all_trades[current_price_idx + 1]
1634
+ print(f" Next trade: ts={_timestamp_to_order_value(next_trade['timestamp'])}, price={next_trade.get('price_usd')}")
1635
+
1636
  label_values = []
1637
  mask_values = []
1638
 
 
1671
 
1672
  return {
1673
  'token_address': token_address, # For debugging
1674
+ 't_cutoff': T_cutoff.isoformat() if T_cutoff else None, # For debugging
1675
  'event_sequence': event_sequence,
1676
  'wallets': wallet_data,
1677
  'tokens': all_token_data,
log.log CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f1c2198c3ed6e249ddf7b7b017b99b2389e4611b8b0649c63d30c40c59e03ac1
3
- size 76001
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6c006d7598527d9f388edff81fcd301b87ad3698d315090426dec45751757798
3
+ size 109185
train.py CHANGED
@@ -156,6 +156,11 @@ def log_debug_batch_context(batch: Dict[str, Any], logger: logging.Logger, step:
156
  if token_addresses:
157
  logger.info(f"Token Address: {token_addresses[0]}")
158
 
 
 
 
 
 
159
  context_str = ", ".join(events)
160
  logger.info(f"Event Stream ({len(events)}): [{context_str}]")
161
 
 
156
  if token_addresses:
157
  logger.info(f"Token Address: {token_addresses[0]}")
158
 
159
+ # Log T_cutoff timestamp
160
+ t_cutoffs = batch.get('t_cutoffs', [])
161
+ if t_cutoffs:
162
+ logger.info(f"T_cutoff: {t_cutoffs[0]}")
163
+
164
  context_str = ", ".join(events)
165
  logger.info(f"Event Stream ({len(events)}): [{context_str}]")
166