oracle / models /SocialEncoders.py
zirobtc's picture
Upload folder using huggingface_hub
858826c
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict, Any
import models.vocabulary as vocab # For event type IDs
class XPostEncoder(nn.Module):
""" Encodes: <AuthorWalletEmbedding>, <PostTextEmbedding>, <MediaEmbedding> """
def __init__(self, d_model: int, dtype: torch.dtype):
super().__init__()
# Input: Wallet (d_model) + Text (d_model) + Media (d_model)
self.mlp = nn.Sequential(
nn.Linear(d_model * 3, d_model * 2),
nn.GELU(),
nn.LayerNorm(d_model * 2),
nn.Linear(d_model * 2, d_model)
).to(dtype)
def forward(self, author_emb: torch.Tensor, text_emb: torch.Tensor, media_emb: torch.Tensor) -> torch.Tensor:
combined = torch.cat([author_emb, text_emb, media_emb], dim=-1)
return self.mlp(combined)
class XRetweetEncoder(nn.Module):
""" Encodes: <RetweeterWalletEmbedding>, <OriginalAuthorWalletEmbedding>, <OriginalPostTextEmbedding>, <OriginalPostMediaEmbedding> """
def __init__(self, d_model: int, dtype: torch.dtype):
super().__init__()
# Input: Retweeter (d_model) + Original Author (d_model) + Original Text (d_model) + Original Media (d_model)
self.mlp = nn.Sequential(
nn.Linear(d_model * 4, d_model * 2),
nn.GELU(),
nn.LayerNorm(d_model * 2),
nn.Linear(d_model * 2, d_model)
).to(dtype)
def forward(self,
retweeter_emb: torch.Tensor,
orig_author_emb: torch.Tensor,
orig_text_emb: torch.Tensor,
orig_media_emb: torch.Tensor) -> torch.Tensor:
combined = torch.cat([retweeter_emb, orig_author_emb, orig_text_emb, orig_media_emb], dim=-1)
return self.mlp(combined)
class XReplyEncoder(nn.Module):
""" Encodes: <AuthorWalletEmbedding>, <PostTextEmbedding>, <MediaEmbedding>, <MainTweetEmbedding> """
def __init__(self, d_model: int, dtype: torch.dtype):
super().__init__()
# Input: Author (d_model) + Reply Text (d_model) + Reply Media (d_model) + Main Tweet Text (d_model)
self.mlp = nn.Sequential(
nn.Linear(d_model * 4, d_model * 2),
nn.GELU(),
nn.LayerNorm(d_model * 2),
nn.Linear(d_model * 2, d_model)
).to(dtype)
def forward(self,
author_emb: torch.Tensor,
text_emb: torch.Tensor,
media_emb: torch.Tensor,
main_tweet_emb: torch.Tensor) -> torch.Tensor:
combined = torch.cat([author_emb, text_emb, media_emb, main_tweet_emb], dim=-1)
return self.mlp(combined)
class XQuoteTweetEncoder(nn.Module):
""" Encodes: <QuoterWalletEmbedding>, <QuoterTextEmbedding>, <OriginalAuthorWalletEmbedding>, <OriginalPostTextEmbedding>, <OriginalPostMediaEmbedding> """
def __init__(self, d_model: int, dtype: torch.dtype):
super().__init__()
# Input: Quoter Wallet (d_model) + Quoter Text (d_model) + Orig Author (d_model) + Orig Text (d_model) + Orig Media (d_model)
self.mlp = nn.Sequential(
nn.Linear(d_model * 5, d_model * 2),
nn.GELU(),
nn.LayerNorm(d_model * 2),
nn.Linear(d_model * 2, d_model)
).to(dtype)
def forward(self,
quoter_wallet_emb: torch.Tensor,
quoter_text_emb: torch.Tensor,
orig_author_emb: torch.Tensor,
orig_text_emb: torch.Tensor,
orig_media_emb: torch.Tensor) -> torch.Tensor:
combined = torch.cat([quoter_wallet_emb, quoter_text_emb, orig_author_emb, orig_text_emb, orig_media_emb], dim=-1)
return self.mlp(combined)
class PumpReplyEncoder(nn.Module):
""" Encodes: <UserWalletEmbedding>, <ReplyTextEmbedding> """
def __init__(self, d_model: int, dtype: torch.dtype):
super().__init__()
# Input: User Wallet (d_model) + Reply Text (d_model)
self.mlp = nn.Sequential(
nn.Linear(d_model * 2, d_model * 2),
nn.GELU(),
nn.LayerNorm(d_model * 2),
nn.Linear(d_model * 2, d_model)
).to(dtype)
def forward(self, user_emb: torch.Tensor, text_emb: torch.Tensor) -> torch.Tensor:
combined = torch.cat([user_emb, text_emb], dim=-1)
return self.mlp(combined)
# --- NEW: Encoders for other text-based events ---
class DexProfileUpdatedEncoder(nn.Module):
""" Encodes: <4_flags_projection>, <website_emb>, <twitter_emb>, <telegram_emb>, <description_emb> """
def __init__(self, d_model: int, dtype: torch.dtype):
super().__init__()
# Input: flags_proj (d_model) + 4x text_embeds (d_model)
self.mlp = nn.Sequential(
nn.Linear(d_model * 4, d_model * 2), # Corrected from 5 to 4, flags are separate
nn.GELU(),
nn.LayerNorm(d_model * 2),
nn.Linear(d_model * 2, d_model)
).to(dtype)
def forward(self, website_emb: torch.Tensor, twitter_emb: torch.Tensor, telegram_emb: torch.Tensor, description_emb: torch.Tensor) -> torch.Tensor:
combined = torch.cat([website_emb, twitter_emb, telegram_emb, description_emb], dim=-1)
return self.mlp(combined)
class GlobalTrendingEncoder(nn.Module):
""" Encodes: <hashtag_emb> """
def __init__(self, d_model: int, dtype: torch.dtype):
super().__init__()
# Input: hashtag_emb (d_model)
self.mlp = nn.Sequential(
nn.Linear(d_model, d_model),
nn.GELU(),
nn.Linear(d_model, d_model)
).to(dtype)
def forward(self, hashtag_emb: torch.Tensor) -> torch.Tensor:
return self.mlp(hashtag_emb)
class SocialEncoder(nn.Module):
"""
A single module to house all social event encoders.
This simplifies instantiation in the main Oracle model.
"""
def __init__(self, d_model: int, dtype: torch.dtype):
super().__init__()
self.x_post_encoder = XPostEncoder(d_model, dtype)
self.x_retweet_encoder = XRetweetEncoder(d_model, dtype)
self.x_reply_encoder = XReplyEncoder(d_model, dtype)
self.x_quote_tweet_encoder = XQuoteTweetEncoder(d_model, dtype)
self.pump_reply_encoder = PumpReplyEncoder(d_model, dtype)
# --- NEW: Add the other text-based encoders ---
self.dex_profile_encoder = DexProfileUpdatedEncoder(d_model, dtype)
self.global_trending_encoder = GlobalTrendingEncoder(d_model, dtype)
# Store for convenience
self.d_model = d_model
self.dtype = dtype
def forward(self, batch: Dict[str, Any], gathered_embeds: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
REFACTORED: Processes all text-based events for the entire batch in a vectorized way.
This replaces the inefficient loops in the main Oracle model.
"""
device = gathered_embeds['wallet'].device
B, L, D = gathered_embeds['wallet'].shape
final_embeds = torch.zeros(B, L, D, device=device, dtype=self.dtype)
textual_event_indices = batch['textual_event_indices']
textual_event_data = batch.get('textual_event_data', [])
precomputed_lookup = gathered_embeds['precomputed']
zero_emb = torch.zeros(self.d_model, device=device, dtype=self.dtype)
# --- Create masks for each event type ---
event_type_ids = batch['event_type_ids']
event_masks = {
'XPost': (event_type_ids == vocab.EVENT_TO_ID.get('XPost', -1)),
'XReply': (event_type_ids == vocab.EVENT_TO_ID.get('XReply', -1)),
'XRetweet': (event_type_ids == vocab.EVENT_TO_ID.get('XRetweet', -1)),
'XQuoteTweet': (event_type_ids == vocab.EVENT_TO_ID.get('XQuoteTweet', -1)),
'PumpReply': (event_type_ids == vocab.EVENT_TO_ID.get('PumpReply', -1)),
'DexProfile_Updated': (event_type_ids == vocab.EVENT_TO_ID.get('DexProfile_Updated', -1)),
'TikTok_Trending_Hashtag': (event_type_ids == vocab.EVENT_TO_ID.get('TikTok_Trending_Hashtag', -1)),
'XTrending_Hashtag': (event_type_ids == vocab.EVENT_TO_ID.get('XTrending_Hashtag', -1)),
}
# --- Gather all necessary pre-computed embeddings in one go ---
# Flatten indices for efficient lookup, then reshape
flat_indices = textual_event_indices.flatten()
# Create a default event structure for padding indices (idx=0)
default_event = {'event_type': 'PAD'}
# Use 1-based index from collator, so textual_event_data[idx-1]
raw_events_flat = [textual_event_data[idx-1] if idx > 0 else default_event for idx in flat_indices.tolist()]
# Helper to gather embeddings for a specific key
def gather_precomputed(key: str) -> torch.Tensor:
indices = torch.tensor([e.get(key, 0) for e in raw_events_flat], device=device, dtype=torch.long)
return F.embedding(indices, precomputed_lookup).view(B, L, -1)
# --- Process each event type ---
# XPost
if event_masks['XPost'].any():
text_emb = gather_precomputed('text_emb_idx')
media_emb = gather_precomputed('media_emb_idx')
post_embeds = self.x_post_encoder(gathered_embeds['wallet'], text_emb, media_emb)
final_embeds += post_embeds * event_masks['XPost'].unsqueeze(-1)
# XReply
if event_masks['XReply'].any():
text_emb = gather_precomputed('text_emb_idx')
media_emb = gather_precomputed('media_emb_idx')
main_tweet_emb = gather_precomputed('main_tweet_text_emb_idx')
reply_embeds = self.x_reply_encoder(gathered_embeds['wallet'], text_emb, media_emb, main_tweet_emb)
final_embeds += reply_embeds * event_masks['XReply'].unsqueeze(-1)
# XRetweet
if event_masks['XRetweet'].any():
orig_text_emb = gather_precomputed('original_post_text_emb_idx')
orig_media_emb = gather_precomputed('original_post_media_emb_idx')
retweet_embeds = self.x_retweet_encoder(gathered_embeds['wallet'], gathered_embeds['original_author'], orig_text_emb, orig_media_emb)
final_embeds += retweet_embeds * event_masks['XRetweet'].unsqueeze(-1)
# XQuoteTweet
if event_masks['XQuoteTweet'].any():
quoter_text_emb = gather_precomputed('quoter_text_emb_idx')
orig_text_emb = gather_precomputed('original_post_text_emb_idx')
orig_media_emb = gather_precomputed('original_post_media_emb_idx')
quote_embeds = self.x_quote_tweet_encoder(gathered_embeds['wallet'], quoter_text_emb, gathered_embeds['original_author'], orig_text_emb, orig_media_emb)
final_embeds += quote_embeds * event_masks['XQuoteTweet'].unsqueeze(-1)
# PumpReply
if event_masks['PumpReply'].any():
text_emb = gather_precomputed('reply_text_emb_idx')
pump_reply_embeds = self.pump_reply_encoder(gathered_embeds['wallet'], text_emb)
final_embeds += pump_reply_embeds * event_masks['PumpReply'].unsqueeze(-1)
# DexProfile_Updated
if event_masks['DexProfile_Updated'].any():
website_emb = gather_precomputed('website_emb_idx')
twitter_emb = gather_precomputed('twitter_link_emb_idx')
telegram_emb = gather_precomputed('telegram_link_emb_idx')
description_emb = gather_precomputed('description_emb_idx')
profile_embeds = self.dex_profile_encoder(website_emb, twitter_emb, telegram_emb, description_emb)
# Note: The flags are handled separately in the main model now, so we just add the text embeds
final_embeds += profile_embeds * event_masks['DexProfile_Updated'].unsqueeze(-1)
# Global Trending Hashtags
trending_mask = event_masks['TikTok_Trending_Hashtag'] | event_masks['XTrending_Hashtag']
if trending_mask.any():
hashtag_emb = gather_precomputed('hashtag_name_emb_idx')
trending_embeds = self.global_trending_encoder(hashtag_emb)
final_embeds += trending_embeds * trending_mask.unsqueeze(-1)
return final_embeds