import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Dict, Any import models.vocabulary as vocab # For event type IDs class XPostEncoder(nn.Module): """ Encodes: , , """ def __init__(self, d_model: int, dtype: torch.dtype): super().__init__() # Input: Wallet (d_model) + Text (d_model) + Media (d_model) self.mlp = nn.Sequential( nn.Linear(d_model * 3, d_model * 2), nn.GELU(), nn.LayerNorm(d_model * 2), nn.Linear(d_model * 2, d_model) ).to(dtype) def forward(self, author_emb: torch.Tensor, text_emb: torch.Tensor, media_emb: torch.Tensor) -> torch.Tensor: combined = torch.cat([author_emb, text_emb, media_emb], dim=-1) return self.mlp(combined) class XRetweetEncoder(nn.Module): """ Encodes: , , , """ def __init__(self, d_model: int, dtype: torch.dtype): super().__init__() # Input: Retweeter (d_model) + Original Author (d_model) + Original Text (d_model) + Original Media (d_model) self.mlp = nn.Sequential( nn.Linear(d_model * 4, d_model * 2), nn.GELU(), nn.LayerNorm(d_model * 2), nn.Linear(d_model * 2, d_model) ).to(dtype) def forward(self, retweeter_emb: torch.Tensor, orig_author_emb: torch.Tensor, orig_text_emb: torch.Tensor, orig_media_emb: torch.Tensor) -> torch.Tensor: combined = torch.cat([retweeter_emb, orig_author_emb, orig_text_emb, orig_media_emb], dim=-1) return self.mlp(combined) class XReplyEncoder(nn.Module): """ Encodes: , , , """ def __init__(self, d_model: int, dtype: torch.dtype): super().__init__() # Input: Author (d_model) + Reply Text (d_model) + Reply Media (d_model) + Main Tweet Text (d_model) self.mlp = nn.Sequential( nn.Linear(d_model * 4, d_model * 2), nn.GELU(), nn.LayerNorm(d_model * 2), nn.Linear(d_model * 2, d_model) ).to(dtype) def forward(self, author_emb: torch.Tensor, text_emb: torch.Tensor, media_emb: torch.Tensor, main_tweet_emb: torch.Tensor) -> torch.Tensor: combined = torch.cat([author_emb, text_emb, media_emb, main_tweet_emb], dim=-1) return self.mlp(combined) class XQuoteTweetEncoder(nn.Module): """ Encodes: , , , , """ def __init__(self, d_model: int, dtype: torch.dtype): super().__init__() # Input: Quoter Wallet (d_model) + Quoter Text (d_model) + Orig Author (d_model) + Orig Text (d_model) + Orig Media (d_model) self.mlp = nn.Sequential( nn.Linear(d_model * 5, d_model * 2), nn.GELU(), nn.LayerNorm(d_model * 2), nn.Linear(d_model * 2, d_model) ).to(dtype) def forward(self, quoter_wallet_emb: torch.Tensor, quoter_text_emb: torch.Tensor, orig_author_emb: torch.Tensor, orig_text_emb: torch.Tensor, orig_media_emb: torch.Tensor) -> torch.Tensor: combined = torch.cat([quoter_wallet_emb, quoter_text_emb, orig_author_emb, orig_text_emb, orig_media_emb], dim=-1) return self.mlp(combined) class PumpReplyEncoder(nn.Module): """ Encodes: , """ def __init__(self, d_model: int, dtype: torch.dtype): super().__init__() # Input: User Wallet (d_model) + Reply Text (d_model) self.mlp = nn.Sequential( nn.Linear(d_model * 2, d_model * 2), nn.GELU(), nn.LayerNorm(d_model * 2), nn.Linear(d_model * 2, d_model) ).to(dtype) def forward(self, user_emb: torch.Tensor, text_emb: torch.Tensor) -> torch.Tensor: combined = torch.cat([user_emb, text_emb], dim=-1) return self.mlp(combined) # --- NEW: Encoders for other text-based events --- class DexProfileUpdatedEncoder(nn.Module): """ Encodes: <4_flags_projection>, , , , """ def __init__(self, d_model: int, dtype: torch.dtype): super().__init__() # Input: flags_proj (d_model) + 4x text_embeds (d_model) self.mlp = nn.Sequential( nn.Linear(d_model * 4, d_model * 2), # Corrected from 5 to 4, flags are separate nn.GELU(), nn.LayerNorm(d_model * 2), nn.Linear(d_model * 2, d_model) ).to(dtype) def forward(self, website_emb: torch.Tensor, twitter_emb: torch.Tensor, telegram_emb: torch.Tensor, description_emb: torch.Tensor) -> torch.Tensor: combined = torch.cat([website_emb, twitter_emb, telegram_emb, description_emb], dim=-1) return self.mlp(combined) class GlobalTrendingEncoder(nn.Module): """ Encodes: """ def __init__(self, d_model: int, dtype: torch.dtype): super().__init__() # Input: hashtag_emb (d_model) self.mlp = nn.Sequential( nn.Linear(d_model, d_model), nn.GELU(), nn.Linear(d_model, d_model) ).to(dtype) def forward(self, hashtag_emb: torch.Tensor) -> torch.Tensor: return self.mlp(hashtag_emb) class SocialEncoder(nn.Module): """ A single module to house all social event encoders. This simplifies instantiation in the main Oracle model. """ def __init__(self, d_model: int, dtype: torch.dtype): super().__init__() self.x_post_encoder = XPostEncoder(d_model, dtype) self.x_retweet_encoder = XRetweetEncoder(d_model, dtype) self.x_reply_encoder = XReplyEncoder(d_model, dtype) self.x_quote_tweet_encoder = XQuoteTweetEncoder(d_model, dtype) self.pump_reply_encoder = PumpReplyEncoder(d_model, dtype) # --- NEW: Add the other text-based encoders --- self.dex_profile_encoder = DexProfileUpdatedEncoder(d_model, dtype) self.global_trending_encoder = GlobalTrendingEncoder(d_model, dtype) # Store for convenience self.d_model = d_model self.dtype = dtype def forward(self, batch: Dict[str, Any], gathered_embeds: Dict[str, torch.Tensor]) -> torch.Tensor: """ REFACTORED: Processes all text-based events for the entire batch in a vectorized way. This replaces the inefficient loops in the main Oracle model. """ device = gathered_embeds['wallet'].device B, L, D = gathered_embeds['wallet'].shape final_embeds = torch.zeros(B, L, D, device=device, dtype=self.dtype) textual_event_indices = batch['textual_event_indices'] textual_event_data = batch.get('textual_event_data', []) precomputed_lookup = gathered_embeds['precomputed'] zero_emb = torch.zeros(self.d_model, device=device, dtype=self.dtype) # --- Create masks for each event type --- event_type_ids = batch['event_type_ids'] event_masks = { 'XPost': (event_type_ids == vocab.EVENT_TO_ID.get('XPost', -1)), 'XReply': (event_type_ids == vocab.EVENT_TO_ID.get('XReply', -1)), 'XRetweet': (event_type_ids == vocab.EVENT_TO_ID.get('XRetweet', -1)), 'XQuoteTweet': (event_type_ids == vocab.EVENT_TO_ID.get('XQuoteTweet', -1)), 'PumpReply': (event_type_ids == vocab.EVENT_TO_ID.get('PumpReply', -1)), 'DexProfile_Updated': (event_type_ids == vocab.EVENT_TO_ID.get('DexProfile_Updated', -1)), 'TikTok_Trending_Hashtag': (event_type_ids == vocab.EVENT_TO_ID.get('TikTok_Trending_Hashtag', -1)), 'XTrending_Hashtag': (event_type_ids == vocab.EVENT_TO_ID.get('XTrending_Hashtag', -1)), } # --- Gather all necessary pre-computed embeddings in one go --- # Flatten indices for efficient lookup, then reshape flat_indices = textual_event_indices.flatten() # Create a default event structure for padding indices (idx=0) default_event = {'event_type': 'PAD'} # Use 1-based index from collator, so textual_event_data[idx-1] raw_events_flat = [textual_event_data[idx-1] if idx > 0 else default_event for idx in flat_indices.tolist()] # Helper to gather embeddings for a specific key def gather_precomputed(key: str) -> torch.Tensor: indices = torch.tensor([e.get(key, 0) for e in raw_events_flat], device=device, dtype=torch.long) return F.embedding(indices, precomputed_lookup).view(B, L, -1) # --- Process each event type --- # XPost if event_masks['XPost'].any(): text_emb = gather_precomputed('text_emb_idx') media_emb = gather_precomputed('media_emb_idx') post_embeds = self.x_post_encoder(gathered_embeds['wallet'], text_emb, media_emb) final_embeds += post_embeds * event_masks['XPost'].unsqueeze(-1) # XReply if event_masks['XReply'].any(): text_emb = gather_precomputed('text_emb_idx') media_emb = gather_precomputed('media_emb_idx') main_tweet_emb = gather_precomputed('main_tweet_text_emb_idx') reply_embeds = self.x_reply_encoder(gathered_embeds['wallet'], text_emb, media_emb, main_tweet_emb) final_embeds += reply_embeds * event_masks['XReply'].unsqueeze(-1) # XRetweet if event_masks['XRetweet'].any(): orig_text_emb = gather_precomputed('original_post_text_emb_idx') orig_media_emb = gather_precomputed('original_post_media_emb_idx') retweet_embeds = self.x_retweet_encoder(gathered_embeds['wallet'], gathered_embeds['original_author'], orig_text_emb, orig_media_emb) final_embeds += retweet_embeds * event_masks['XRetweet'].unsqueeze(-1) # XQuoteTweet if event_masks['XQuoteTweet'].any(): quoter_text_emb = gather_precomputed('quoter_text_emb_idx') orig_text_emb = gather_precomputed('original_post_text_emb_idx') orig_media_emb = gather_precomputed('original_post_media_emb_idx') quote_embeds = self.x_quote_tweet_encoder(gathered_embeds['wallet'], quoter_text_emb, gathered_embeds['original_author'], orig_text_emb, orig_media_emb) final_embeds += quote_embeds * event_masks['XQuoteTweet'].unsqueeze(-1) # PumpReply if event_masks['PumpReply'].any(): text_emb = gather_precomputed('reply_text_emb_idx') pump_reply_embeds = self.pump_reply_encoder(gathered_embeds['wallet'], text_emb) final_embeds += pump_reply_embeds * event_masks['PumpReply'].unsqueeze(-1) # DexProfile_Updated if event_masks['DexProfile_Updated'].any(): website_emb = gather_precomputed('website_emb_idx') twitter_emb = gather_precomputed('twitter_link_emb_idx') telegram_emb = gather_precomputed('telegram_link_emb_idx') description_emb = gather_precomputed('description_emb_idx') profile_embeds = self.dex_profile_encoder(website_emb, twitter_emb, telegram_emb, description_emb) # Note: The flags are handled separately in the main model now, so we just add the text embeds final_embeds += profile_embeds * event_masks['DexProfile_Updated'].unsqueeze(-1) # Global Trending Hashtags trending_mask = event_masks['TikTok_Trending_Hashtag'] | event_masks['XTrending_Hashtag'] if trending_mask.any(): hashtag_emb = gather_precomputed('hashtag_name_emb_idx') trending_embeds = self.global_trending_encoder(hashtag_emb) final_embeds += trending_embeds * trending_mask.unsqueeze(-1) return final_embeds