| """ |
| ็จๆทๆๅพๅ็ฑปๅจ |
| ่พๅ
ฅ: ็จๆทๆ่ฟ N ไธช็นๅป item ็ embedding ๅบๅ |
| ่พๅบ: ๆๅพ็ฑปๅซ๏ผไป KuaiRec item tag ไธญๆๅ็ Top-K ็ฑปๅซ๏ผ |
| |
| ไฝ็จ: ๅฌๅ้ถๆฎตไฝไธบ็ฑปๅซ bias๏ผ่กฅๅ
mindset ๅ้ |
| """ |
| import os |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
| from typing import List, Dict, Tuple |
|
|
| from config import cfg |
|
|
|
|
| |
| |
| |
| class IntentClassifier(nn.Module): |
| def __init__(self, embed_dim: int = None, hidden_dim: int = 128, n_classes: int = 20): |
| super().__init__() |
| embed_dim = embed_dim or cfg.embed_dim |
| self.n_classes = n_classes |
| |
| self.gru = nn.GRU(embed_dim, hidden_dim, batch_first=True, num_layers=1) |
| self.classifier = nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim // 2), |
| nn.ReLU(), |
| nn.Linear(hidden_dim // 2, n_classes), |
| ) |
|
|
| def forward(self, seq: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor: |
| """ |
| seq: (B, T, embed_dim) |
| lengths: (B,) ๅฎ้
ๅบๅ้ฟๅบฆ |
| """ |
| packed = nn.utils.rnn.pack_padded_sequence( |
| seq, lengths.cpu(), batch_first=True, enforce_sorted=False |
| ) |
| _, hidden = self.gru(packed) |
| hidden = hidden.squeeze(0) |
| return self.classifier(hidden) |
|
|
| def predict(self, history_embs: np.ndarray) -> Tuple[int, np.ndarray]: |
| """ |
| history_embs: (T, embed_dim) ๆ่ฟ็นๅปๅบๅ |
| ่ฟๅ: (top_class_idx, probs) |
| """ |
| if len(history_embs) == 0: |
| probs = np.ones(self.n_classes) / self.n_classes |
| return 0, probs |
| with torch.no_grad(): |
| seq = torch.tensor(history_embs[-20:], dtype=torch.float32).unsqueeze(0) |
| seq = seq.to(next(self.parameters()).device) |
| length = torch.tensor([seq.shape[1]]) |
| logits = self.forward(seq, length) |
| probs = F.softmax(logits, dim=-1).squeeze(0).cpu().numpy() |
| return int(probs.argmax()), probs |
|
|
|
|
| |
| |
| |
| def extract_categories(data, top_k: int = 20) -> Tuple[Dict[int, int], List[str]]: |
| """ |
| ไป item ๆๆฌไธญๆๅ้ซ้ข tag ไฝไธบ็ฑปๅซๆ ็ญพ |
| ่ฟๅ: (iid -> category_id, category_names) |
| """ |
| from collections import Counter |
| tag_counter = Counter() |
| iid_tags: Dict[int, List[str]] = {} |
|
|
| for iid, text in data.id2text.items(): |
| tags = [t.strip() for t in text.split() if len(t.strip()) > 1][:5] |
| iid_tags[iid] = tags |
| tag_counter.update(tags) |
|
|
| top_tags = [tag for tag, _ in tag_counter.most_common(top_k)] |
| tag2id = {t: i for i, t in enumerate(top_tags)} |
|
|
| iid2cat: Dict[int, int] = {} |
| for iid, tags in iid_tags.items(): |
| for tag in tags: |
| if tag in tag2id: |
| iid2cat[iid] = tag2id[tag] |
| break |
| if iid not in iid2cat: |
| iid2cat[iid] = top_k - 1 |
|
|
| print(f"[IntentClassifier] Categories: {top_tags[:10]}...") |
| return iid2cat, top_tags |
|
|
|
|
| |
| |
| |
| class IntentDataset(Dataset): |
| def __init__(self, seqs: List[np.ndarray], labels: List[int], max_len: int = 20): |
| self.seqs = seqs |
| self.labels = labels |
| self.max_len = max_len |
|
|
| def __len__(self): |
| return len(self.labels) |
|
|
| def __getitem__(self, idx): |
| seq = self.seqs[idx][-self.max_len:] |
| length = len(seq) |
| |
| if length < self.max_len: |
| pad = np.zeros((self.max_len - length, seq.shape[1]), dtype=np.float32) |
| seq = np.vstack([seq, pad]) |
| return (torch.tensor(seq, dtype=torch.float32), |
| torch.tensor(length, dtype=torch.long), |
| torch.tensor(self.labels[idx], dtype=torch.long)) |
|
|
|
|
| def build_intent_data(data, item_embeddings: np.ndarray, |
| iid2cat: Dict[int, int], |
| max_samples: int = 100_000): |
| """ |
| ๆๅปบๆๅพๅ็ฑป่ฎญ็ปๆฐๆฎ๏ผ |
| ๅๅฒๅบๅ โ ไธไธไธช็นๅป item ็็ฑปๅซ |
| """ |
| seqs, labels = [], [] |
| for uid, hist in data.user_histories.items(): |
| hist = [iid for iid in hist if iid < len(item_embeddings)] |
| if len(hist) < 3: |
| continue |
| for t in range(2, len(hist)): |
| history_embs = np.array([item_embeddings[iid] for iid in hist[:t]]) |
| next_cat = iid2cat.get(hist[t], len(iid2cat) - 1) |
| seqs.append(history_embs.astype(np.float32)) |
| labels.append(next_cat) |
| if len(labels) >= max_samples: |
| break |
| if len(labels) >= max_samples: |
| break |
| print(f"[IntentClassifier] Training samples: {len(labels):,}") |
| return seqs, labels |
|
|
|
|
| |
| |
| |
| def train_intent_classifier(data, item_embeddings: np.ndarray, |
| n_classes: int = 20, epochs: int = 5, |
| batch_size: int = 512, lr: float = 1e-3) -> Tuple["IntentClassifier", List[str]]: |
| ckpt = f"{cfg.output_dir}/intent_classifier.pt" |
| iid2cat, category_names = extract_categories(data, top_k=n_classes) |
| model = IntentClassifier(n_classes=n_classes).to(cfg.device) |
|
|
| if os.path.exists(ckpt): |
| model.load_state_dict(torch.load(ckpt, map_location=cfg.device)) |
| print(f"[IntentClassifier] Loaded checkpoint: {ckpt}") |
| return model, category_names |
|
|
| seqs, labels = build_intent_data(data, item_embeddings, iid2cat, max_samples=5_000) |
|
|
| n = len(labels) |
| idx = np.random.permutation(n) |
| split = int(n * 0.8) |
| train_idx, val_idx = idx[:split], idx[split:] |
|
|
| train_ds = IntentDataset([seqs[i] for i in train_idx], [labels[i] for i in train_idx]) |
| val_ds = IntentDataset([seqs[i] for i in val_idx], [labels[i] for i in val_idx]) |
| train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0) |
| val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0) |
|
|
| optimizer = torch.optim.Adam(model.parameters(), lr=lr) |
| criterion = nn.CrossEntropyLoss() |
| best_val_acc = 0.0 |
|
|
| for epoch in range(1, epochs + 1): |
| model.train() |
| total_loss, correct, total = 0.0, 0, 0 |
| for seq, length, y in train_dl: |
| seq, length, y = seq.to(cfg.device), length, y.to(cfg.device) |
| logits = model(seq, length) |
| loss = criterion(logits, y) |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| total_loss += loss.item() * len(y) |
| correct += (logits.argmax(1) == y).sum().item() |
| total += len(y) |
|
|
| model.eval() |
| val_correct, val_total = 0, 0 |
| with torch.no_grad(): |
| for seq, length, y in val_dl: |
| seq, length, y = seq.to(cfg.device), length, y.to(cfg.device) |
| val_correct += (model(seq, length).argmax(1) == y).sum().item() |
| val_total += len(y) |
| val_acc = val_correct / val_total |
|
|
| print(f" Epoch {epoch}/{epochs} | loss={total_loss/total:.4f} " |
| f"| train_acc={correct/total:.3f} | val_acc={val_acc:.3f}") |
|
|
| if val_acc > best_val_acc: |
| best_val_acc = val_acc |
| torch.save(model.state_dict(), ckpt) |
|
|
| model.load_state_dict(torch.load(ckpt, map_location=cfg.device)) |
| print(f"[IntentClassifier] Best val_acc={best_val_acc:.3f}") |
| return model, category_names |
|
|