|
|
""" |
|
|
Simplified Query Auto-Completion Model |
|
|
Uses CNN+Transformer for prefix/candidate encoding (IE module) |
|
|
Optionally uses pretrained ByT5 embeddings |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers import T5EncoderModel |
|
|
|
|
|
|
|
|
class CNNLocalEncoder(nn.Module): |
|
|
"""Multi-scale CNN for local pattern extraction""" |
|
|
|
|
|
def __init__(self, embed_dim=128, num_filters=64, filter_sizes=[3, 4, 5]): |
|
|
super().__init__() |
|
|
self.convs = nn.ModuleList( |
|
|
[ |
|
|
nn.Conv1d(embed_dim, num_filters, fs, padding=fs // 2) |
|
|
for fs in filter_sizes |
|
|
] |
|
|
) |
|
|
self.layer_norm = nn.LayerNorm(num_filters * len(filter_sizes)) |
|
|
self._init_weights() |
|
|
|
|
|
def _init_weights(self): |
|
|
for conv in self.convs: |
|
|
nn.init.kaiming_normal_(conv.weight, mode="fan_out", nonlinearity="relu") |
|
|
nn.init.zeros_(conv.bias) |
|
|
|
|
|
def forward(self, x): |
|
|
x = x.transpose(1, 2) |
|
|
conv_outs = [F.relu(conv(x)) for conv in self.convs] |
|
|
pooled = [ |
|
|
( |
|
|
F.max_pool1d(out, out.size(2)).squeeze(2) |
|
|
if out.size(2) > 1 |
|
|
else out.squeeze(2) |
|
|
) |
|
|
for out in conv_outs |
|
|
] |
|
|
out = torch.cat(pooled, dim=1) |
|
|
return self.layer_norm(out) |
|
|
|
|
|
|
|
|
class PrefixEncoder(nn.Module): |
|
|
"""CNN + Transformer encoder for prefix""" |
|
|
|
|
|
def __init__(self, embed_dim=128, num_filters=64, num_heads=4, num_layers=2): |
|
|
super().__init__() |
|
|
self.cnn = CNNLocalEncoder(embed_dim, num_filters) |
|
|
cnn_out_dim = num_filters * 3 |
|
|
|
|
|
self.transformer = nn.TransformerEncoder( |
|
|
nn.TransformerEncoderLayer( |
|
|
d_model=cnn_out_dim, |
|
|
nhead=num_heads, |
|
|
dim_feedforward=cnn_out_dim * 4, |
|
|
dropout=0.1, |
|
|
batch_first=True, |
|
|
activation="gelu", |
|
|
layer_norm_eps=1e-6, |
|
|
norm_first=True, |
|
|
), |
|
|
num_layers=num_layers, |
|
|
) |
|
|
self.proj = nn.Linear(cnn_out_dim, embed_dim) |
|
|
self.layer_norm = nn.LayerNorm(embed_dim) |
|
|
self._init_weights() |
|
|
|
|
|
def _init_weights(self): |
|
|
nn.init.xavier_uniform_(self.proj.weight, gain=0.5) |
|
|
nn.init.zeros_(self.proj.bias) |
|
|
|
|
|
def forward(self, prefix_embed): |
|
|
cnn_out = self.cnn(prefix_embed).unsqueeze(1) |
|
|
transformer_out = self.transformer(cnn_out).squeeze(1) |
|
|
return self.layer_norm(self.proj(transformer_out)) |
|
|
|
|
|
|
|
|
class CandidateEncoder(nn.Module): |
|
|
"""Transformer encoder for candidate (no CNN)""" |
|
|
|
|
|
def __init__(self, embed_dim=128, num_heads=4, num_layers=2): |
|
|
super().__init__() |
|
|
self.transformer = nn.TransformerEncoder( |
|
|
nn.TransformerEncoderLayer( |
|
|
d_model=embed_dim, |
|
|
nhead=num_heads, |
|
|
dim_feedforward=embed_dim * 4, |
|
|
dropout=0.1, |
|
|
batch_first=True, |
|
|
activation="gelu", |
|
|
layer_norm_eps=1e-6, |
|
|
norm_first=True, |
|
|
), |
|
|
num_layers=num_layers, |
|
|
) |
|
|
self.layer_norm = nn.LayerNorm(embed_dim) |
|
|
|
|
|
def forward(self, candidate_embed): |
|
|
transformer_out = self.transformer(candidate_embed) |
|
|
pooled = torch.max(transformer_out, dim=1)[0] |
|
|
return self.layer_norm(pooled) |
|
|
|
|
|
|
|
|
class QueryCompletionModel(nn.Module): |
|
|
"""Query auto-completion: CNN+Transformer for prefix, Transformer for candidate""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocab_size=10000, |
|
|
embed_dim=128, |
|
|
num_filters=64, |
|
|
num_heads=4, |
|
|
num_transformer_layers=2, |
|
|
use_pretrained_embeddings=False, |
|
|
pretrained_model_name="google/byt5-small", |
|
|
): |
|
|
super().__init__() |
|
|
self.use_pretrained_embeddings = use_pretrained_embeddings |
|
|
|
|
|
if use_pretrained_embeddings: |
|
|
|
|
|
print(f"Loading pretrained embeddings from {pretrained_model_name}...") |
|
|
byt5_model = T5EncoderModel.from_pretrained(pretrained_model_name) |
|
|
pretrained_embed_dim = byt5_model.config.d_model |
|
|
|
|
|
|
|
|
self.shared_embedding = byt5_model.shared |
|
|
self.shared_embedding.requires_grad_(True) |
|
|
|
|
|
|
|
|
if pretrained_embed_dim != embed_dim: |
|
|
self.embed_proj = nn.Linear(pretrained_embed_dim, embed_dim) |
|
|
nn.init.xavier_uniform_(self.embed_proj.weight, gain=0.5) |
|
|
else: |
|
|
self.embed_proj = nn.Identity() |
|
|
|
|
|
print( |
|
|
f" ✓ Using pretrained embeddings: {pretrained_embed_dim}D → {embed_dim}D" |
|
|
) |
|
|
else: |
|
|
|
|
|
self.prefix_embedding = nn.Embedding(vocab_size, embed_dim) |
|
|
self.candidate_embedding = nn.Embedding(vocab_size, embed_dim) |
|
|
self._init_embeddings() |
|
|
|
|
|
self.prefix_encoder = PrefixEncoder( |
|
|
embed_dim, num_filters, num_heads, num_transformer_layers |
|
|
) |
|
|
self.candidate_encoder = CandidateEncoder( |
|
|
embed_dim, num_heads, num_transformer_layers |
|
|
) |
|
|
|
|
|
self.match_predictor = nn.Sequential( |
|
|
nn.LayerNorm(embed_dim * 2), |
|
|
nn.Linear(embed_dim * 2, embed_dim * 2), |
|
|
nn.GELU(), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(embed_dim * 2, embed_dim), |
|
|
nn.GELU(), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(embed_dim, 1), |
|
|
) |
|
|
self._init_predictor() |
|
|
|
|
|
def _init_embeddings(self): |
|
|
if not self.use_pretrained_embeddings: |
|
|
nn.init.normal_(self.prefix_embedding.weight, std=0.02) |
|
|
nn.init.normal_(self.candidate_embedding.weight, std=0.02) |
|
|
|
|
|
def _init_predictor(self): |
|
|
for module in self.match_predictor: |
|
|
if isinstance(module, nn.Linear): |
|
|
nn.init.xavier_uniform_(module.weight, gain=0.5) |
|
|
nn.init.zeros_(module.bias) |
|
|
|
|
|
def forward(self, prefix_ids, candidate_ids): |
|
|
if self.use_pretrained_embeddings: |
|
|
|
|
|
prefix_embed = self.embed_proj(self.shared_embedding(prefix_ids)) |
|
|
candidate_embed = self.embed_proj(self.shared_embedding(candidate_ids)) |
|
|
else: |
|
|
|
|
|
prefix_embed = self.prefix_embedding(prefix_ids) |
|
|
candidate_embed = self.candidate_embedding(candidate_ids) |
|
|
|
|
|
prefix_intention = self.prefix_encoder(prefix_embed) |
|
|
candidate_intention = self.candidate_encoder(candidate_embed) |
|
|
combined = torch.cat([prefix_intention, candidate_intention], dim=-1) |
|
|
logits = self.match_predictor(combined) |
|
|
return torch.sigmoid(logits) |
|
|
|