sin-qac-model / model.py
lv12's picture
Upload folder using huggingface_hub
924e4e0 verified
"""
Simplified Query Auto-Completion Model
Uses CNN+Transformer for prefix/candidate encoding (IE module)
Optionally uses pretrained ByT5 embeddings
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import T5EncoderModel
class CNNLocalEncoder(nn.Module):
"""Multi-scale CNN for local pattern extraction"""
def __init__(self, embed_dim=128, num_filters=64, filter_sizes=[3, 4, 5]):
super().__init__()
self.convs = nn.ModuleList(
[
nn.Conv1d(embed_dim, num_filters, fs, padding=fs // 2)
for fs in filter_sizes
]
)
self.layer_norm = nn.LayerNorm(num_filters * len(filter_sizes))
self._init_weights()
def _init_weights(self):
for conv in self.convs:
nn.init.kaiming_normal_(conv.weight, mode="fan_out", nonlinearity="relu")
nn.init.zeros_(conv.bias)
def forward(self, x):
x = x.transpose(1, 2)
conv_outs = [F.relu(conv(x)) for conv in self.convs]
pooled = [
(
F.max_pool1d(out, out.size(2)).squeeze(2)
if out.size(2) > 1
else out.squeeze(2)
)
for out in conv_outs
]
out = torch.cat(pooled, dim=1)
return self.layer_norm(out)
class PrefixEncoder(nn.Module):
"""CNN + Transformer encoder for prefix"""
def __init__(self, embed_dim=128, num_filters=64, num_heads=4, num_layers=2):
super().__init__()
self.cnn = CNNLocalEncoder(embed_dim, num_filters)
cnn_out_dim = num_filters * 3
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=cnn_out_dim,
nhead=num_heads,
dim_feedforward=cnn_out_dim * 4,
dropout=0.1,
batch_first=True,
activation="gelu",
layer_norm_eps=1e-6,
norm_first=True,
),
num_layers=num_layers,
)
self.proj = nn.Linear(cnn_out_dim, embed_dim)
self.layer_norm = nn.LayerNorm(embed_dim)
self._init_weights()
def _init_weights(self):
nn.init.xavier_uniform_(self.proj.weight, gain=0.5)
nn.init.zeros_(self.proj.bias)
def forward(self, prefix_embed):
cnn_out = self.cnn(prefix_embed).unsqueeze(1)
transformer_out = self.transformer(cnn_out).squeeze(1)
return self.layer_norm(self.proj(transformer_out))
class CandidateEncoder(nn.Module):
"""Transformer encoder for candidate (no CNN)"""
def __init__(self, embed_dim=128, num_heads=4, num_layers=2):
super().__init__()
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=num_heads,
dim_feedforward=embed_dim * 4,
dropout=0.1,
batch_first=True,
activation="gelu",
layer_norm_eps=1e-6,
norm_first=True,
),
num_layers=num_layers,
)
self.layer_norm = nn.LayerNorm(embed_dim)
def forward(self, candidate_embed):
transformer_out = self.transformer(candidate_embed)
pooled = torch.max(transformer_out, dim=1)[0]
return self.layer_norm(pooled)
class QueryCompletionModel(nn.Module):
"""Query auto-completion: CNN+Transformer for prefix, Transformer for candidate"""
def __init__(
self,
vocab_size=10000,
embed_dim=128,
num_filters=64,
num_heads=4,
num_transformer_layers=2,
use_pretrained_embeddings=False,
pretrained_model_name="google/byt5-small",
):
super().__init__()
self.use_pretrained_embeddings = use_pretrained_embeddings
if use_pretrained_embeddings:
# Load pretrained ByT5 and use its embeddings
print(f"Loading pretrained embeddings from {pretrained_model_name}...")
byt5_model = T5EncoderModel.from_pretrained(pretrained_model_name)
pretrained_embed_dim = byt5_model.config.d_model
# Share the pretrained embedding for both prefix and candidate
self.shared_embedding = byt5_model.shared
self.shared_embedding.requires_grad_(True) # Fine-tune embeddings
# Project to target embed_dim if different
if pretrained_embed_dim != embed_dim:
self.embed_proj = nn.Linear(pretrained_embed_dim, embed_dim)
nn.init.xavier_uniform_(self.embed_proj.weight, gain=0.5)
else:
self.embed_proj = nn.Identity()
print(
f" ✓ Using pretrained embeddings: {pretrained_embed_dim}D → {embed_dim}D"
)
else:
# Use separate learned embeddings (original behavior)
self.prefix_embedding = nn.Embedding(vocab_size, embed_dim)
self.candidate_embedding = nn.Embedding(vocab_size, embed_dim)
self._init_embeddings()
self.prefix_encoder = PrefixEncoder(
embed_dim, num_filters, num_heads, num_transformer_layers
)
self.candidate_encoder = CandidateEncoder(
embed_dim, num_heads, num_transformer_layers
)
self.match_predictor = nn.Sequential(
nn.LayerNorm(embed_dim * 2),
nn.Linear(embed_dim * 2, embed_dim * 2),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(embed_dim * 2, embed_dim),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(embed_dim, 1),
)
self._init_predictor()
def _init_embeddings(self):
if not self.use_pretrained_embeddings:
nn.init.normal_(self.prefix_embedding.weight, std=0.02)
nn.init.normal_(self.candidate_embedding.weight, std=0.02)
def _init_predictor(self):
for module in self.match_predictor:
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight, gain=0.5)
nn.init.zeros_(module.bias)
def forward(self, prefix_ids, candidate_ids):
if self.use_pretrained_embeddings:
# Use shared pretrained embeddings for both
prefix_embed = self.embed_proj(self.shared_embedding(prefix_ids))
candidate_embed = self.embed_proj(self.shared_embedding(candidate_ids))
else:
# Use separate learned embeddings
prefix_embed = self.prefix_embedding(prefix_ids)
candidate_embed = self.candidate_embedding(candidate_ids)
prefix_intention = self.prefix_encoder(prefix_embed)
candidate_intention = self.candidate_encoder(candidate_embed)
combined = torch.cat([prefix_intention, candidate_intention], dim=-1)
logits = self.match_predictor(combined)
return torch.sigmoid(logits)