| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional, Dict, Any |
| import models.vocabulary as vocab |
|
|
| class XPostEncoder(nn.Module): |
| """ Encodes: <AuthorWalletEmbedding>, <PostTextEmbedding>, <MediaEmbedding> """ |
| def __init__(self, d_model: int, dtype: torch.dtype): |
| super().__init__() |
| |
| self.mlp = nn.Sequential( |
| nn.Linear(d_model * 3, d_model * 2), |
| nn.GELU(), |
| nn.LayerNorm(d_model * 2), |
| nn.Linear(d_model * 2, d_model) |
| ).to(dtype) |
|
|
| def forward(self, author_emb: torch.Tensor, text_emb: torch.Tensor, media_emb: torch.Tensor) -> torch.Tensor: |
| combined = torch.cat([author_emb, text_emb, media_emb], dim=-1) |
| return self.mlp(combined) |
|
|
| class XRetweetEncoder(nn.Module): |
| """ Encodes: <RetweeterWalletEmbedding>, <OriginalAuthorWalletEmbedding>, <OriginalPostTextEmbedding>, <OriginalPostMediaEmbedding> """ |
| def __init__(self, d_model: int, dtype: torch.dtype): |
| super().__init__() |
| |
| self.mlp = nn.Sequential( |
| nn.Linear(d_model * 4, d_model * 2), |
| nn.GELU(), |
| nn.LayerNorm(d_model * 2), |
| nn.Linear(d_model * 2, d_model) |
| ).to(dtype) |
|
|
| def forward(self, |
| retweeter_emb: torch.Tensor, |
| orig_author_emb: torch.Tensor, |
| orig_text_emb: torch.Tensor, |
| orig_media_emb: torch.Tensor) -> torch.Tensor: |
| combined = torch.cat([retweeter_emb, orig_author_emb, orig_text_emb, orig_media_emb], dim=-1) |
| return self.mlp(combined) |
|
|
| class XReplyEncoder(nn.Module): |
| """ Encodes: <AuthorWalletEmbedding>, <PostTextEmbedding>, <MediaEmbedding>, <MainTweetEmbedding> """ |
| def __init__(self, d_model: int, dtype: torch.dtype): |
| super().__init__() |
| |
| self.mlp = nn.Sequential( |
| nn.Linear(d_model * 4, d_model * 2), |
| nn.GELU(), |
| nn.LayerNorm(d_model * 2), |
| nn.Linear(d_model * 2, d_model) |
| ).to(dtype) |
|
|
| def forward(self, |
| author_emb: torch.Tensor, |
| text_emb: torch.Tensor, |
| media_emb: torch.Tensor, |
| main_tweet_emb: torch.Tensor) -> torch.Tensor: |
| combined = torch.cat([author_emb, text_emb, media_emb, main_tweet_emb], dim=-1) |
| return self.mlp(combined) |
|
|
| class XQuoteTweetEncoder(nn.Module): |
| """ Encodes: <QuoterWalletEmbedding>, <QuoterTextEmbedding>, <OriginalAuthorWalletEmbedding>, <OriginalPostTextEmbedding>, <OriginalPostMediaEmbedding> """ |
| def __init__(self, d_model: int, dtype: torch.dtype): |
| super().__init__() |
| |
| self.mlp = nn.Sequential( |
| nn.Linear(d_model * 5, d_model * 2), |
| nn.GELU(), |
| nn.LayerNorm(d_model * 2), |
| nn.Linear(d_model * 2, d_model) |
| ).to(dtype) |
|
|
| def forward(self, |
| quoter_wallet_emb: torch.Tensor, |
| quoter_text_emb: torch.Tensor, |
| orig_author_emb: torch.Tensor, |
| orig_text_emb: torch.Tensor, |
| orig_media_emb: torch.Tensor) -> torch.Tensor: |
| combined = torch.cat([quoter_wallet_emb, quoter_text_emb, orig_author_emb, orig_text_emb, orig_media_emb], dim=-1) |
| return self.mlp(combined) |
|
|
| class PumpReplyEncoder(nn.Module): |
| """ Encodes: <UserWalletEmbedding>, <ReplyTextEmbedding> """ |
| def __init__(self, d_model: int, dtype: torch.dtype): |
| super().__init__() |
| |
| self.mlp = nn.Sequential( |
| nn.Linear(d_model * 2, d_model * 2), |
| nn.GELU(), |
| nn.LayerNorm(d_model * 2), |
| nn.Linear(d_model * 2, d_model) |
| ).to(dtype) |
|
|
| def forward(self, user_emb: torch.Tensor, text_emb: torch.Tensor) -> torch.Tensor: |
| combined = torch.cat([user_emb, text_emb], dim=-1) |
| return self.mlp(combined) |
|
|
| |
| class DexProfileUpdatedEncoder(nn.Module): |
| """ Encodes: <4_flags_projection>, <website_emb>, <twitter_emb>, <telegram_emb>, <description_emb> """ |
| def __init__(self, d_model: int, dtype: torch.dtype): |
| super().__init__() |
| |
| self.mlp = nn.Sequential( |
| nn.Linear(d_model * 4, d_model * 2), |
| nn.GELU(), |
| nn.LayerNorm(d_model * 2), |
| nn.Linear(d_model * 2, d_model) |
| ).to(dtype) |
|
|
| def forward(self, website_emb: torch.Tensor, twitter_emb: torch.Tensor, telegram_emb: torch.Tensor, description_emb: torch.Tensor) -> torch.Tensor: |
| combined = torch.cat([website_emb, twitter_emb, telegram_emb, description_emb], dim=-1) |
| return self.mlp(combined) |
|
|
| class GlobalTrendingEncoder(nn.Module): |
| """ Encodes: <hashtag_emb> """ |
| def __init__(self, d_model: int, dtype: torch.dtype): |
| super().__init__() |
| |
| self.mlp = nn.Sequential( |
| nn.Linear(d_model, d_model), |
| nn.GELU(), |
| nn.Linear(d_model, d_model) |
| ).to(dtype) |
|
|
| def forward(self, hashtag_emb: torch.Tensor) -> torch.Tensor: |
| return self.mlp(hashtag_emb) |
| class SocialEncoder(nn.Module): |
| """ |
| A single module to house all social event encoders. |
| This simplifies instantiation in the main Oracle model. |
| """ |
| def __init__(self, d_model: int, dtype: torch.dtype): |
| super().__init__() |
| self.x_post_encoder = XPostEncoder(d_model, dtype) |
| self.x_retweet_encoder = XRetweetEncoder(d_model, dtype) |
| self.x_reply_encoder = XReplyEncoder(d_model, dtype) |
| self.x_quote_tweet_encoder = XQuoteTweetEncoder(d_model, dtype) |
| self.pump_reply_encoder = PumpReplyEncoder(d_model, dtype) |
| |
| self.dex_profile_encoder = DexProfileUpdatedEncoder(d_model, dtype) |
| self.global_trending_encoder = GlobalTrendingEncoder(d_model, dtype) |
|
|
| |
| self.d_model = d_model |
| self.dtype = dtype |
|
|
| def forward(self, batch: Dict[str, Any], gathered_embeds: Dict[str, torch.Tensor]) -> torch.Tensor: |
| """ |
| REFACTORED: Processes all text-based events for the entire batch in a vectorized way. |
| This replaces the inefficient loops in the main Oracle model. |
| """ |
| device = gathered_embeds['wallet'].device |
| B, L, D = gathered_embeds['wallet'].shape |
| final_embeds = torch.zeros(B, L, D, device=device, dtype=self.dtype) |
|
|
| textual_event_indices = batch['textual_event_indices'] |
| textual_event_data = batch.get('textual_event_data', []) |
| precomputed_lookup = gathered_embeds['precomputed'] |
| zero_emb = torch.zeros(self.d_model, device=device, dtype=self.dtype) |
|
|
| |
| event_type_ids = batch['event_type_ids'] |
| event_masks = { |
| 'XPost': (event_type_ids == vocab.EVENT_TO_ID.get('XPost', -1)), |
| 'XReply': (event_type_ids == vocab.EVENT_TO_ID.get('XReply', -1)), |
| 'XRetweet': (event_type_ids == vocab.EVENT_TO_ID.get('XRetweet', -1)), |
| 'XQuoteTweet': (event_type_ids == vocab.EVENT_TO_ID.get('XQuoteTweet', -1)), |
| 'PumpReply': (event_type_ids == vocab.EVENT_TO_ID.get('PumpReply', -1)), |
| 'DexProfile_Updated': (event_type_ids == vocab.EVENT_TO_ID.get('DexProfile_Updated', -1)), |
| 'TikTok_Trending_Hashtag': (event_type_ids == vocab.EVENT_TO_ID.get('TikTok_Trending_Hashtag', -1)), |
| 'XTrending_Hashtag': (event_type_ids == vocab.EVENT_TO_ID.get('XTrending_Hashtag', -1)), |
| } |
|
|
| |
| |
| flat_indices = textual_event_indices.flatten() |
| |
| default_event = {'event_type': 'PAD'} |
| |
| raw_events_flat = [textual_event_data[idx-1] if idx > 0 else default_event for idx in flat_indices.tolist()] |
|
|
| |
| def gather_precomputed(key: str) -> torch.Tensor: |
| indices = torch.tensor([e.get(key, 0) for e in raw_events_flat], device=device, dtype=torch.long) |
| return F.embedding(indices, precomputed_lookup).view(B, L, -1) |
|
|
| |
|
|
| |
| if event_masks['XPost'].any(): |
| text_emb = gather_precomputed('text_emb_idx') |
| media_emb = gather_precomputed('media_emb_idx') |
| post_embeds = self.x_post_encoder(gathered_embeds['wallet'], text_emb, media_emb) |
| final_embeds += post_embeds * event_masks['XPost'].unsqueeze(-1) |
|
|
| |
| if event_masks['XReply'].any(): |
| text_emb = gather_precomputed('text_emb_idx') |
| media_emb = gather_precomputed('media_emb_idx') |
| main_tweet_emb = gather_precomputed('main_tweet_text_emb_idx') |
| reply_embeds = self.x_reply_encoder(gathered_embeds['wallet'], text_emb, media_emb, main_tweet_emb) |
| final_embeds += reply_embeds * event_masks['XReply'].unsqueeze(-1) |
|
|
| |
| if event_masks['XRetweet'].any(): |
| orig_text_emb = gather_precomputed('original_post_text_emb_idx') |
| orig_media_emb = gather_precomputed('original_post_media_emb_idx') |
| retweet_embeds = self.x_retweet_encoder(gathered_embeds['wallet'], gathered_embeds['original_author'], orig_text_emb, orig_media_emb) |
| final_embeds += retweet_embeds * event_masks['XRetweet'].unsqueeze(-1) |
|
|
| |
| if event_masks['XQuoteTweet'].any(): |
| quoter_text_emb = gather_precomputed('quoter_text_emb_idx') |
| orig_text_emb = gather_precomputed('original_post_text_emb_idx') |
| orig_media_emb = gather_precomputed('original_post_media_emb_idx') |
| quote_embeds = self.x_quote_tweet_encoder(gathered_embeds['wallet'], quoter_text_emb, gathered_embeds['original_author'], orig_text_emb, orig_media_emb) |
| final_embeds += quote_embeds * event_masks['XQuoteTweet'].unsqueeze(-1) |
|
|
| |
| if event_masks['PumpReply'].any(): |
| text_emb = gather_precomputed('reply_text_emb_idx') |
| pump_reply_embeds = self.pump_reply_encoder(gathered_embeds['wallet'], text_emb) |
| final_embeds += pump_reply_embeds * event_masks['PumpReply'].unsqueeze(-1) |
|
|
| |
| if event_masks['DexProfile_Updated'].any(): |
| website_emb = gather_precomputed('website_emb_idx') |
| twitter_emb = gather_precomputed('twitter_link_emb_idx') |
| telegram_emb = gather_precomputed('telegram_link_emb_idx') |
| description_emb = gather_precomputed('description_emb_idx') |
| profile_embeds = self.dex_profile_encoder(website_emb, twitter_emb, telegram_emb, description_emb) |
| |
| final_embeds += profile_embeds * event_masks['DexProfile_Updated'].unsqueeze(-1) |
|
|
| |
| trending_mask = event_masks['TikTok_Trending_Hashtag'] | event_masks['XTrending_Hashtag'] |
| if trending_mask.any(): |
| hashtag_emb = gather_precomputed('hashtag_name_emb_idx') |
| trending_embeds = self.global_trending_encoder(hashtag_emb) |
| final_embeds += trending_embeds * trending_mask.unsqueeze(-1) |
|
|
| return final_embeds |