Upload folder using huggingface_hub
Browse files- data/data_collator.py +2 -1
- data/data_loader.py +11 -0
- log.log +2 -2
- train.py +5 -0
data/data_collator.py
CHANGED
|
@@ -712,7 +712,8 @@ class MemecoinCollator:
|
|
| 712 |
'labels': torch.stack([item['labels'] for item in batch]) if batch and 'labels' in batch[0] else None,
|
| 713 |
'labels_mask': torch.stack([item['labels_mask'] for item in batch]) if batch and 'labels_mask' in batch[0] else None,
|
| 714 |
# Debug info
|
| 715 |
-
'token_addresses': [item.get('token_address', 'unknown') for item in batch]
|
|
|
|
| 716 |
}
|
| 717 |
|
| 718 |
# Filter out None values (e.g., if no labels provided)
|
|
|
|
| 712 |
'labels': torch.stack([item['labels'] for item in batch]) if batch and 'labels' in batch[0] else None,
|
| 713 |
'labels_mask': torch.stack([item['labels_mask'] for item in batch]) if batch and 'labels_mask' in batch[0] else None,
|
| 714 |
# Debug info
|
| 715 |
+
'token_addresses': [item.get('token_address', 'unknown') for item in batch],
|
| 716 |
+
't_cutoffs': [item.get('t_cutoff', 'unknown') for item in batch]
|
| 717 |
}
|
| 718 |
|
| 719 |
# Filter out None values (e.g., if no labels provided)
|
data/data_loader.py
CHANGED
|
@@ -1623,6 +1623,16 @@ class OracleDataset(Dataset):
|
|
| 1623 |
else:
|
| 1624 |
break
|
| 1625 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1626 |
label_values = []
|
| 1627 |
mask_values = []
|
| 1628 |
|
|
@@ -1661,6 +1671,7 @@ class OracleDataset(Dataset):
|
|
| 1661 |
|
| 1662 |
return {
|
| 1663 |
'token_address': token_address, # For debugging
|
|
|
|
| 1664 |
'event_sequence': event_sequence,
|
| 1665 |
'wallets': wallet_data,
|
| 1666 |
'tokens': all_token_data,
|
|
|
|
| 1623 |
else:
|
| 1624 |
break
|
| 1625 |
|
| 1626 |
+
# DEBUG: Log label computation details
|
| 1627 |
+
print(f" DEBUG LABELS: token={token_address[:12]}...")
|
| 1628 |
+
print(f" T_cutoff={T_cutoff.isoformat()}, cutoff_ts={cutoff_ts_val}")
|
| 1629 |
+
print(f" Successful trades count: {len(all_trades)}")
|
| 1630 |
+
print(f" current_price_idx={current_price_idx}, current_price={current_price}")
|
| 1631 |
+
print(f" last_trade_ts={last_trade_ts_val}, trades_after_cutoff={len(all_trades) - current_price_idx - 1}")
|
| 1632 |
+
if current_price_idx >= 0 and current_price_idx + 1 < len(all_trades):
|
| 1633 |
+
next_trade = all_trades[current_price_idx + 1]
|
| 1634 |
+
print(f" Next trade: ts={_timestamp_to_order_value(next_trade['timestamp'])}, price={next_trade.get('price_usd')}")
|
| 1635 |
+
|
| 1636 |
label_values = []
|
| 1637 |
mask_values = []
|
| 1638 |
|
|
|
|
| 1671 |
|
| 1672 |
return {
|
| 1673 |
'token_address': token_address, # For debugging
|
| 1674 |
+
't_cutoff': T_cutoff.isoformat() if T_cutoff else None, # For debugging
|
| 1675 |
'event_sequence': event_sequence,
|
| 1676 |
'wallets': wallet_data,
|
| 1677 |
'tokens': all_token_data,
|
log.log
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6c006d7598527d9f388edff81fcd301b87ad3698d315090426dec45751757798
|
| 3 |
+
size 109185
|
train.py
CHANGED
|
@@ -156,6 +156,11 @@ def log_debug_batch_context(batch: Dict[str, Any], logger: logging.Logger, step:
|
|
| 156 |
if token_addresses:
|
| 157 |
logger.info(f"Token Address: {token_addresses[0]}")
|
| 158 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
context_str = ", ".join(events)
|
| 160 |
logger.info(f"Event Stream ({len(events)}): [{context_str}]")
|
| 161 |
|
|
|
|
| 156 |
if token_addresses:
|
| 157 |
logger.info(f"Token Address: {token_addresses[0]}")
|
| 158 |
|
| 159 |
+
# Log T_cutoff timestamp
|
| 160 |
+
t_cutoffs = batch.get('t_cutoffs', [])
|
| 161 |
+
if t_cutoffs:
|
| 162 |
+
logger.info(f"T_cutoff: {t_cutoffs[0]}")
|
| 163 |
+
|
| 164 |
context_str = ", ".join(events)
|
| 165 |
logger.info(f"Event Stream ({len(events)}): [{context_str}]")
|
| 166 |
|