| """ |
| Ad-aware state encoder. |
| |
| State = concat( |
| GRU_hidden(ad_interaction_sequence), # 32-dim |
| Linear(user_features), # 16-dim |
| context_features, # 4-dim (sinusoidal hour + dow) |
| ) → 52-dim state vector |
| """ |
| import torch |
| import torch.nn as nn |
| import numpy as np |
|
|
|
|
| class AdStateEncoder(nn.Module): |
| OUTPUT_DIM = 52 |
|
|
| def __init__(self, n_ads: int, user_feat_dim: int = 21): |
| super().__init__() |
| self.n_ads = n_ads |
|
|
| |
| |
| self.ad_emb = nn.Embedding(n_ads + 1, 16, padding_idx=0) |
| self.outcome_emb = nn.Embedding(4, 4) |
| self.gru = nn.GRU(20, 32, batch_first=True) |
|
|
| |
| self.user_proj = nn.Sequential( |
| nn.Linear(user_feat_dim, 16), |
| nn.ReLU(), |
| ) |
|
|
| def forward( |
| self, |
| ad_seq: torch.Tensor, |
| outcome_seq: torch.Tensor, |
| user_feat: torch.Tensor, |
| ctx_feat: torch.Tensor, |
| ) -> torch.Tensor: |
| |
| x = torch.cat([self.ad_emb(ad_seq), self.outcome_emb(outcome_seq)], dim=-1) |
| _, h = self.gru(x) |
| gru_out = h.squeeze(0) |
|
|
| user_out = self.user_proj(user_feat) |
| state = torch.cat([gru_out, user_out, ctx_feat], dim=-1) |
| return state |
|
|
| @torch.no_grad() |
| def encode( |
| self, |
| history: list, |
| user_feat: np.ndarray, |
| ctx_feat: np.ndarray, |
| max_len: int = 10, |
| ) -> torch.Tensor: |
| """Single-sample encoding. Returns (1, 52) state tensor.""" |
| history = history[-max_len:] |
| pad = max_len - len(history) |
|
|
| ads = [0] * pad + [int(ad_id) + 1 for ad_id, *_ in history] |
| outcomes = [0] * pad |
|
|
| for _, _, clicked, converted in history: |
| if converted: |
| outcomes.append(2) |
| elif clicked: |
| outcomes.append(1) |
| else: |
| outcomes.append(3) |
|
|
| ad_t = torch.tensor([ads], dtype=torch.long) |
| out_t = torch.tensor([outcomes], dtype=torch.long) |
| u_t = torch.from_numpy(np.asarray(user_feat, dtype=np.float32).reshape(1, -1)) |
| ctx_t = torch.from_numpy(np.asarray(ctx_feat, dtype=np.float32).reshape(1, -1)) |
|
|
| return self.forward(ad_t, out_t, u_t, ctx_t) |
|
|