File size: 2,767 Bytes
f19eb84
 
 
 
 
 
 
 
 
fb62752
 
f19eb84
 
fb62752
f19eb84
 
fb62752
f19eb84
fb62752
f19eb84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb62752
f19eb84
 
 
 
 
fb62752
 
f19eb84
 
 
 
 
 
 
 
fb62752
f19eb84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""
Ad-aware state encoder.

State = concat(
    GRU_hidden(ad_interaction_sequence),   # 32-dim
    Linear(user_features),                 # 16-dim
    context_features,                      #  4-dim (sinusoidal hour + dow)
)  →  52-dim state vector
"""
import torch
import torch.nn as nn
import numpy as np


class AdStateEncoder(nn.Module):
    OUTPUT_DIM = 52   # 32 + 16 + 4

    def __init__(self, n_ads: int, user_feat_dim: int = 21):
        super().__init__()
        self.n_ads = n_ads

        # Ad interaction sequence encoder
        # Each step: ad_emb(16) + outcome_emb(4) = 20-dim input
        self.ad_emb      = nn.Embedding(n_ads + 1, 16, padding_idx=0)  # +1 for pad
        self.outcome_emb = nn.Embedding(4, 4)   # 0=no_imp, 1=click, 2=convert, 3=skip
        self.gru         = nn.GRU(20, 32, batch_first=True)

        # User features projection
        self.user_proj = nn.Sequential(
            nn.Linear(user_feat_dim, 16),
            nn.ReLU(),
        )

    def forward(
        self,
        ad_seq:      torch.Tensor,   # (batch, seq_len)  LongTensor
        outcome_seq: torch.Tensor,   # (batch, seq_len)  LongTensor
        user_feat:   torch.Tensor,   # (batch, user_feat_dim)
        ctx_feat:    torch.Tensor,   # (batch, 4)
    ) -> torch.Tensor:
        # Sequence branch
        x = torch.cat([self.ad_emb(ad_seq), self.outcome_emb(outcome_seq)], dim=-1)
        _, h = self.gru(x)
        gru_out = h.squeeze(0)                   # (batch, 32)

        user_out = self.user_proj(user_feat)      # (batch, 16)
        state    = torch.cat([gru_out, user_out, ctx_feat], dim=-1)  # (batch, 52)
        return state

    @torch.no_grad()
    def encode(
        self,
        history:    list,           # [(ad_id, reward, clicked, converted), ...]
        user_feat:  np.ndarray,     # (21,)
        ctx_feat:   np.ndarray,     # (4,)
        max_len:    int = 10,
    ) -> torch.Tensor:
        """Single-sample encoding. Returns (1, 52) state tensor."""
        history = history[-max_len:]
        pad     = max_len - len(history)

        ads      = [0] * pad + [int(ad_id) + 1 for ad_id, *_ in history]
        outcomes = [0] * pad

        for _, _, clicked, converted in history:
            if converted:
                outcomes.append(2)
            elif clicked:
                outcomes.append(1)
            else:
                outcomes.append(3)

        ad_t  = torch.tensor([ads],     dtype=torch.long)
        out_t = torch.tensor([outcomes], dtype=torch.long)
        u_t   = torch.from_numpy(np.asarray(user_feat, dtype=np.float32).reshape(1, -1))
        ctx_t = torch.from_numpy(np.asarray(ctx_feat,  dtype=np.float32).reshape(1, -1))

        return self.forward(ad_t, out_t, u_t, ctx_t)  # (1, 52)