File size: 6,996 Bytes
924e4e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
"""
Simplified Query Auto-Completion Model
Uses CNN+Transformer for prefix/candidate encoding (IE module)
Optionally uses pretrained ByT5 embeddings
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import T5EncoderModel


class CNNLocalEncoder(nn.Module):
    """Multi-scale CNN for local pattern extraction"""

    def __init__(self, embed_dim=128, num_filters=64, filter_sizes=[3, 4, 5]):
        super().__init__()
        self.convs = nn.ModuleList(
            [
                nn.Conv1d(embed_dim, num_filters, fs, padding=fs // 2)
                for fs in filter_sizes
            ]
        )
        self.layer_norm = nn.LayerNorm(num_filters * len(filter_sizes))
        self._init_weights()

    def _init_weights(self):
        for conv in self.convs:
            nn.init.kaiming_normal_(conv.weight, mode="fan_out", nonlinearity="relu")
            nn.init.zeros_(conv.bias)

    def forward(self, x):
        x = x.transpose(1, 2)
        conv_outs = [F.relu(conv(x)) for conv in self.convs]
        pooled = [
            (
                F.max_pool1d(out, out.size(2)).squeeze(2)
                if out.size(2) > 1
                else out.squeeze(2)
            )
            for out in conv_outs
        ]
        out = torch.cat(pooled, dim=1)
        return self.layer_norm(out)


class PrefixEncoder(nn.Module):
    """CNN + Transformer encoder for prefix"""

    def __init__(self, embed_dim=128, num_filters=64, num_heads=4, num_layers=2):
        super().__init__()
        self.cnn = CNNLocalEncoder(embed_dim, num_filters)
        cnn_out_dim = num_filters * 3

        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=cnn_out_dim,
                nhead=num_heads,
                dim_feedforward=cnn_out_dim * 4,
                dropout=0.1,
                batch_first=True,
                activation="gelu",
                layer_norm_eps=1e-6,
                norm_first=True,
            ),
            num_layers=num_layers,
        )
        self.proj = nn.Linear(cnn_out_dim, embed_dim)
        self.layer_norm = nn.LayerNorm(embed_dim)
        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.proj.weight, gain=0.5)
        nn.init.zeros_(self.proj.bias)

    def forward(self, prefix_embed):
        cnn_out = self.cnn(prefix_embed).unsqueeze(1)
        transformer_out = self.transformer(cnn_out).squeeze(1)
        return self.layer_norm(self.proj(transformer_out))


class CandidateEncoder(nn.Module):
    """Transformer encoder for candidate (no CNN)"""

    def __init__(self, embed_dim=128, num_heads=4, num_layers=2):
        super().__init__()
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=embed_dim,
                nhead=num_heads,
                dim_feedforward=embed_dim * 4,
                dropout=0.1,
                batch_first=True,
                activation="gelu",
                layer_norm_eps=1e-6,
                norm_first=True,
            ),
            num_layers=num_layers,
        )
        self.layer_norm = nn.LayerNorm(embed_dim)

    def forward(self, candidate_embed):
        transformer_out = self.transformer(candidate_embed)
        pooled = torch.max(transformer_out, dim=1)[0]
        return self.layer_norm(pooled)


class QueryCompletionModel(nn.Module):
    """Query auto-completion: CNN+Transformer for prefix, Transformer for candidate"""

    def __init__(
        self,
        vocab_size=10000,
        embed_dim=128,
        num_filters=64,
        num_heads=4,
        num_transformer_layers=2,
        use_pretrained_embeddings=False,
        pretrained_model_name="google/byt5-small",
    ):
        super().__init__()
        self.use_pretrained_embeddings = use_pretrained_embeddings

        if use_pretrained_embeddings:
            # Load pretrained ByT5 and use its embeddings
            print(f"Loading pretrained embeddings from {pretrained_model_name}...")
            byt5_model = T5EncoderModel.from_pretrained(pretrained_model_name)
            pretrained_embed_dim = byt5_model.config.d_model

            # Share the pretrained embedding for both prefix and candidate
            self.shared_embedding = byt5_model.shared
            self.shared_embedding.requires_grad_(True)  # Fine-tune embeddings

            # Project to target embed_dim if different
            if pretrained_embed_dim != embed_dim:
                self.embed_proj = nn.Linear(pretrained_embed_dim, embed_dim)
                nn.init.xavier_uniform_(self.embed_proj.weight, gain=0.5)
            else:
                self.embed_proj = nn.Identity()

            print(
                f"  ✓ Using pretrained embeddings: {pretrained_embed_dim}D → {embed_dim}D"
            )
        else:
            # Use separate learned embeddings (original behavior)
            self.prefix_embedding = nn.Embedding(vocab_size, embed_dim)
            self.candidate_embedding = nn.Embedding(vocab_size, embed_dim)
            self._init_embeddings()

        self.prefix_encoder = PrefixEncoder(
            embed_dim, num_filters, num_heads, num_transformer_layers
        )
        self.candidate_encoder = CandidateEncoder(
            embed_dim, num_heads, num_transformer_layers
        )

        self.match_predictor = nn.Sequential(
            nn.LayerNorm(embed_dim * 2),
            nn.Linear(embed_dim * 2, embed_dim * 2),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(embed_dim * 2, embed_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(embed_dim, 1),
        )
        self._init_predictor()

    def _init_embeddings(self):
        if not self.use_pretrained_embeddings:
            nn.init.normal_(self.prefix_embedding.weight, std=0.02)
            nn.init.normal_(self.candidate_embedding.weight, std=0.02)

    def _init_predictor(self):
        for module in self.match_predictor:
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight, gain=0.5)
                nn.init.zeros_(module.bias)

    def forward(self, prefix_ids, candidate_ids):
        if self.use_pretrained_embeddings:
            # Use shared pretrained embeddings for both
            prefix_embed = self.embed_proj(self.shared_embedding(prefix_ids))
            candidate_embed = self.embed_proj(self.shared_embedding(candidate_ids))
        else:
            # Use separate learned embeddings
            prefix_embed = self.prefix_embedding(prefix_ids)
            candidate_embed = self.candidate_embedding(candidate_ids)

        prefix_intention = self.prefix_encoder(prefix_embed)
        candidate_intention = self.candidate_encoder(candidate_embed)
        combined = torch.cat([prefix_intention, candidate_intention], dim=-1)
        logits = self.match_predictor(combined)
        return torch.sigmoid(logits)