| import torch | |
| import torch.nn as nn | |
| class VotePredictor(nn.Module): | |
| def __init__(self, text_dim=384, country_count=193, country_emb_dim=32, hidden_dim=256): | |
| super(VotePredictor, self).__init__() | |
| self.country_embedding = nn.Embedding(country_count, country_emb_dim) | |
| self.model = nn.Sequential( | |
| nn.Linear(text_dim + country_emb_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(hidden_dim, 1) | |
| ) | |
| def forward(self, text_vecs, country_ids): | |
| country_vecs = self.country_embedding(country_ids) | |
| x = torch.cat([text_vecs, country_vecs], dim=1) | |
| return self.model(x) | |