Safetensors
File size: 8,599 Bytes
30796e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
"""
็”จๆˆทๆ„ๅ›พๅˆ†็ฑปๅ™จ
่พ“ๅ…ฅ: ็”จๆˆทๆœ€่ฟ‘ N ไธช็‚นๅ‡ป item ็š„ embedding ๅบๅˆ—
่พ“ๅ‡บ: ๆ„ๅ›พ็ฑปๅˆซ๏ผˆไปŽ KuaiRec item tag ไธญๆๅ–็š„ Top-K ็ฑปๅˆซ๏ผ‰

ไฝœ็”จ: ๅฌๅ›ž้˜ถๆฎตไฝœไธบ็ฑปๅˆซ bias๏ผŒ่กฅๅ…… mindset ๅ‘้‡
"""
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from typing import List, Dict, Tuple

from config import cfg


# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# ๆจกๅž‹
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
class IntentClassifier(nn.Module):
    def __init__(self, embed_dim: int = None, hidden_dim: int = 128, n_classes: int = 20):
        super().__init__()
        embed_dim = embed_dim or cfg.embed_dim
        self.n_classes = n_classes
        # GRU ็ผ–็ ๅކๅฒๅบๅˆ—
        self.gru = nn.GRU(embed_dim, hidden_dim, batch_first=True, num_layers=1)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, n_classes),
        )

    def forward(self, seq: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
        """
        seq: (B, T, embed_dim)
        lengths: (B,) ๅฎž้™…ๅบๅˆ—้•ฟๅบฆ
        """
        packed = nn.utils.rnn.pack_padded_sequence(
            seq, lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        _, hidden = self.gru(packed)
        hidden = hidden.squeeze(0)  # (B, hidden_dim)
        return self.classifier(hidden)

    def predict(self, history_embs: np.ndarray) -> Tuple[int, np.ndarray]:
        """
        history_embs: (T, embed_dim) ๆœ€่ฟ‘็‚นๅ‡ปๅบๅˆ—
        ่ฟ”ๅ›ž: (top_class_idx, probs)
        """
        if len(history_embs) == 0:
            probs = np.ones(self.n_classes) / self.n_classes
            return 0, probs
        with torch.no_grad():
            seq = torch.tensor(history_embs[-20:], dtype=torch.float32).unsqueeze(0)
            seq = seq.to(next(self.parameters()).device)
            length = torch.tensor([seq.shape[1]])
            logits = self.forward(seq, length)
            probs = F.softmax(logits, dim=-1).squeeze(0).cpu().numpy()
        return int(probs.argmax()), probs


# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# ๆ ‡็ญพๆๅ–๏ผšไปŽ item tag ไธญๆๅ– Top-K ็ฑปๅˆซ
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def extract_categories(data, top_k: int = 20) -> Tuple[Dict[int, int], List[str]]:
    """
    ไปŽ item ๆ–‡ๆœฌไธญๆๅ–้ซ˜้ข‘ tag ไฝœไธบ็ฑปๅˆซๆ ‡็ญพ
    ่ฟ”ๅ›ž: (iid -> category_id, category_names)
    """
    from collections import Counter
    tag_counter = Counter()
    iid_tags: Dict[int, List[str]] = {}

    for iid, text in data.id2text.items():
        tags = [t.strip() for t in text.split() if len(t.strip()) > 1][:5]
        iid_tags[iid] = tags
        tag_counter.update(tags)

    top_tags = [tag for tag, _ in tag_counter.most_common(top_k)]
    tag2id = {t: i for i, t in enumerate(top_tags)}

    iid2cat: Dict[int, int] = {}
    for iid, tags in iid_tags.items():
        for tag in tags:
            if tag in tag2id:
                iid2cat[iid] = tag2id[tag]
                break
        if iid not in iid2cat:
            iid2cat[iid] = top_k - 1  # ๅ…ถไป–็ฑป

    print(f"[IntentClassifier] Categories: {top_tags[:10]}...")
    return iid2cat, top_tags


# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# ่ฎญ็ปƒๆ•ฐๆฎ้›†
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
class IntentDataset(Dataset):
    def __init__(self, seqs: List[np.ndarray], labels: List[int], max_len: int = 20):
        self.seqs = seqs
        self.labels = labels
        self.max_len = max_len

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        seq = self.seqs[idx][-self.max_len:]
        length = len(seq)
        # padding
        if length < self.max_len:
            pad = np.zeros((self.max_len - length, seq.shape[1]), dtype=np.float32)
            seq = np.vstack([seq, pad])
        return (torch.tensor(seq, dtype=torch.float32),
                torch.tensor(length, dtype=torch.long),
                torch.tensor(self.labels[idx], dtype=torch.long))


def build_intent_data(data, item_embeddings: np.ndarray,
                      iid2cat: Dict[int, int],
                      max_samples: int = 100_000):
    """
    ๆž„ๅปบๆ„ๅ›พๅˆ†็ฑป่ฎญ็ปƒๆ•ฐๆฎ๏ผš
    ๅކๅฒๅบๅˆ— โ†’ ไธ‹ไธ€ไธช็‚นๅ‡ป item ็š„็ฑปๅˆซ
    """
    seqs, labels = [], []
    for uid, hist in data.user_histories.items():
        hist = [iid for iid in hist if iid < len(item_embeddings)]
        if len(hist) < 3:
            continue
        for t in range(2, len(hist)):
            history_embs = np.array([item_embeddings[iid] for iid in hist[:t]])
            next_cat = iid2cat.get(hist[t], len(iid2cat) - 1)
            seqs.append(history_embs.astype(np.float32))
            labels.append(next_cat)
            if len(labels) >= max_samples:
                break
        if len(labels) >= max_samples:
            break
    print(f"[IntentClassifier] Training samples: {len(labels):,}")
    return seqs, labels


# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# ่ฎญ็ปƒ
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def train_intent_classifier(data, item_embeddings: np.ndarray,
                             n_classes: int = 20, epochs: int = 5,
                             batch_size: int = 512, lr: float = 1e-3) -> Tuple["IntentClassifier", List[str]]:
    ckpt = f"{cfg.output_dir}/intent_classifier.pt"
    iid2cat, category_names = extract_categories(data, top_k=n_classes)
    model = IntentClassifier(n_classes=n_classes).to(cfg.device)

    if os.path.exists(ckpt):
        model.load_state_dict(torch.load(ckpt, map_location=cfg.device))
        print(f"[IntentClassifier] Loaded checkpoint: {ckpt}")
        return model, category_names

    seqs, labels = build_intent_data(data, item_embeddings, iid2cat, max_samples=5_000)

    n = len(labels)
    idx = np.random.permutation(n)
    split = int(n * 0.8)
    train_idx, val_idx = idx[:split], idx[split:]

    train_ds = IntentDataset([seqs[i] for i in train_idx], [labels[i] for i in train_idx])
    val_ds   = IntentDataset([seqs[i] for i in val_idx],   [labels[i] for i in val_idx])
    train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  num_workers=0)
    val_dl   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, num_workers=0)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    best_val_acc = 0.0

    for epoch in range(1, epochs + 1):
        model.train()
        total_loss, correct, total = 0.0, 0, 0
        for seq, length, y in train_dl:
            seq, length, y = seq.to(cfg.device), length, y.to(cfg.device)
            logits = model(seq, length)
            loss = criterion(logits, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * len(y)
            correct += (logits.argmax(1) == y).sum().item()
            total += len(y)

        model.eval()
        val_correct, val_total = 0, 0
        with torch.no_grad():
            for seq, length, y in val_dl:
                seq, length, y = seq.to(cfg.device), length, y.to(cfg.device)
                val_correct += (model(seq, length).argmax(1) == y).sum().item()
                val_total += len(y)
        val_acc = val_correct / val_total

        print(f"  Epoch {epoch}/{epochs} | loss={total_loss/total:.4f} "
              f"| train_acc={correct/total:.3f} | val_acc={val_acc:.3f}")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), ckpt)

    model.load_state_dict(torch.load(ckpt, map_location=cfg.device))
    print(f"[IntentClassifier] Best val_acc={best_val_acc:.3f}")
    return model, category_names