RL-Recommendation-System / src /models /state_encoder.py
mnoorchenar's picture
Update 2026-03-23 09:33:03
f19eb84
"""
Ad-aware state encoder.
State = concat(
GRU_hidden(ad_interaction_sequence), # 32-dim
Linear(user_features), # 16-dim
context_features, # 4-dim (sinusoidal hour + dow)
) → 52-dim state vector
"""
import torch
import torch.nn as nn
import numpy as np
class AdStateEncoder(nn.Module):
OUTPUT_DIM = 52 # 32 + 16 + 4
def __init__(self, n_ads: int, user_feat_dim: int = 21):
super().__init__()
self.n_ads = n_ads
# Ad interaction sequence encoder
# Each step: ad_emb(16) + outcome_emb(4) = 20-dim input
self.ad_emb = nn.Embedding(n_ads + 1, 16, padding_idx=0) # +1 for pad
self.outcome_emb = nn.Embedding(4, 4) # 0=no_imp, 1=click, 2=convert, 3=skip
self.gru = nn.GRU(20, 32, batch_first=True)
# User features projection
self.user_proj = nn.Sequential(
nn.Linear(user_feat_dim, 16),
nn.ReLU(),
)
def forward(
self,
ad_seq: torch.Tensor, # (batch, seq_len) LongTensor
outcome_seq: torch.Tensor, # (batch, seq_len) LongTensor
user_feat: torch.Tensor, # (batch, user_feat_dim)
ctx_feat: torch.Tensor, # (batch, 4)
) -> torch.Tensor:
# Sequence branch
x = torch.cat([self.ad_emb(ad_seq), self.outcome_emb(outcome_seq)], dim=-1)
_, h = self.gru(x)
gru_out = h.squeeze(0) # (batch, 32)
user_out = self.user_proj(user_feat) # (batch, 16)
state = torch.cat([gru_out, user_out, ctx_feat], dim=-1) # (batch, 52)
return state
@torch.no_grad()
def encode(
self,
history: list, # [(ad_id, reward, clicked, converted), ...]
user_feat: np.ndarray, # (21,)
ctx_feat: np.ndarray, # (4,)
max_len: int = 10,
) -> torch.Tensor:
"""Single-sample encoding. Returns (1, 52) state tensor."""
history = history[-max_len:]
pad = max_len - len(history)
ads = [0] * pad + [int(ad_id) + 1 for ad_id, *_ in history]
outcomes = [0] * pad
for _, _, clicked, converted in history:
if converted:
outcomes.append(2)
elif clicked:
outcomes.append(1)
else:
outcomes.append(3)
ad_t = torch.tensor([ads], dtype=torch.long)
out_t = torch.tensor([outcomes], dtype=torch.long)
u_t = torch.from_numpy(np.asarray(user_feat, dtype=np.float32).reshape(1, -1))
ctx_t = torch.from_numpy(np.asarray(ctx_feat, dtype=np.float32).reshape(1, -1))
return self.forward(ad_t, out_t, u_t, ctx_t) # (1, 52)