Safetensors
miniOneRec-kuairec / intent_classifier.py
hiiamkik's picture
Upload intent_classifier.py with huggingface_hub
30796e4 verified
"""
็”จๆˆทๆ„ๅ›พๅˆ†็ฑปๅ™จ
่พ“ๅ…ฅ: ็”จๆˆทๆœ€่ฟ‘ N ไธช็‚นๅ‡ป item ็š„ embedding ๅบๅˆ—
่พ“ๅ‡บ: ๆ„ๅ›พ็ฑปๅˆซ๏ผˆไปŽ KuaiRec item tag ไธญๆๅ–็š„ Top-K ็ฑปๅˆซ๏ผ‰
ไฝœ็”จ: ๅฌๅ›ž้˜ถๆฎตไฝœไธบ็ฑปๅˆซ bias๏ผŒ่กฅๅ…… mindset ๅ‘้‡
"""
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from typing import List, Dict, Tuple
from config import cfg
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# ๆจกๅž‹
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
class IntentClassifier(nn.Module):
def __init__(self, embed_dim: int = None, hidden_dim: int = 128, n_classes: int = 20):
super().__init__()
embed_dim = embed_dim or cfg.embed_dim
self.n_classes = n_classes
# GRU ็ผ–็ ๅކๅฒๅบๅˆ—
self.gru = nn.GRU(embed_dim, hidden_dim, batch_first=True, num_layers=1)
self.classifier = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, n_classes),
)
def forward(self, seq: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
"""
seq: (B, T, embed_dim)
lengths: (B,) ๅฎž้™…ๅบๅˆ—้•ฟๅบฆ
"""
packed = nn.utils.rnn.pack_padded_sequence(
seq, lengths.cpu(), batch_first=True, enforce_sorted=False
)
_, hidden = self.gru(packed)
hidden = hidden.squeeze(0) # (B, hidden_dim)
return self.classifier(hidden)
def predict(self, history_embs: np.ndarray) -> Tuple[int, np.ndarray]:
"""
history_embs: (T, embed_dim) ๆœ€่ฟ‘็‚นๅ‡ปๅบๅˆ—
่ฟ”ๅ›ž: (top_class_idx, probs)
"""
if len(history_embs) == 0:
probs = np.ones(self.n_classes) / self.n_classes
return 0, probs
with torch.no_grad():
seq = torch.tensor(history_embs[-20:], dtype=torch.float32).unsqueeze(0)
seq = seq.to(next(self.parameters()).device)
length = torch.tensor([seq.shape[1]])
logits = self.forward(seq, length)
probs = F.softmax(logits, dim=-1).squeeze(0).cpu().numpy()
return int(probs.argmax()), probs
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# ๆ ‡็ญพๆๅ–๏ผšไปŽ item tag ไธญๆๅ– Top-K ็ฑปๅˆซ
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def extract_categories(data, top_k: int = 20) -> Tuple[Dict[int, int], List[str]]:
"""
ไปŽ item ๆ–‡ๆœฌไธญๆๅ–้ซ˜้ข‘ tag ไฝœไธบ็ฑปๅˆซๆ ‡็ญพ
่ฟ”ๅ›ž: (iid -> category_id, category_names)
"""
from collections import Counter
tag_counter = Counter()
iid_tags: Dict[int, List[str]] = {}
for iid, text in data.id2text.items():
tags = [t.strip() for t in text.split() if len(t.strip()) > 1][:5]
iid_tags[iid] = tags
tag_counter.update(tags)
top_tags = [tag for tag, _ in tag_counter.most_common(top_k)]
tag2id = {t: i for i, t in enumerate(top_tags)}
iid2cat: Dict[int, int] = {}
for iid, tags in iid_tags.items():
for tag in tags:
if tag in tag2id:
iid2cat[iid] = tag2id[tag]
break
if iid not in iid2cat:
iid2cat[iid] = top_k - 1 # ๅ…ถไป–็ฑป
print(f"[IntentClassifier] Categories: {top_tags[:10]}...")
return iid2cat, top_tags
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# ่ฎญ็ปƒๆ•ฐๆฎ้›†
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
class IntentDataset(Dataset):
def __init__(self, seqs: List[np.ndarray], labels: List[int], max_len: int = 20):
self.seqs = seqs
self.labels = labels
self.max_len = max_len
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
seq = self.seqs[idx][-self.max_len:]
length = len(seq)
# padding
if length < self.max_len:
pad = np.zeros((self.max_len - length, seq.shape[1]), dtype=np.float32)
seq = np.vstack([seq, pad])
return (torch.tensor(seq, dtype=torch.float32),
torch.tensor(length, dtype=torch.long),
torch.tensor(self.labels[idx], dtype=torch.long))
def build_intent_data(data, item_embeddings: np.ndarray,
iid2cat: Dict[int, int],
max_samples: int = 100_000):
"""
ๆž„ๅปบๆ„ๅ›พๅˆ†็ฑป่ฎญ็ปƒๆ•ฐๆฎ๏ผš
ๅކๅฒๅบๅˆ— โ†’ ไธ‹ไธ€ไธช็‚นๅ‡ป item ็š„็ฑปๅˆซ
"""
seqs, labels = [], []
for uid, hist in data.user_histories.items():
hist = [iid for iid in hist if iid < len(item_embeddings)]
if len(hist) < 3:
continue
for t in range(2, len(hist)):
history_embs = np.array([item_embeddings[iid] for iid in hist[:t]])
next_cat = iid2cat.get(hist[t], len(iid2cat) - 1)
seqs.append(history_embs.astype(np.float32))
labels.append(next_cat)
if len(labels) >= max_samples:
break
if len(labels) >= max_samples:
break
print(f"[IntentClassifier] Training samples: {len(labels):,}")
return seqs, labels
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# ่ฎญ็ปƒ
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def train_intent_classifier(data, item_embeddings: np.ndarray,
n_classes: int = 20, epochs: int = 5,
batch_size: int = 512, lr: float = 1e-3) -> Tuple["IntentClassifier", List[str]]:
ckpt = f"{cfg.output_dir}/intent_classifier.pt"
iid2cat, category_names = extract_categories(data, top_k=n_classes)
model = IntentClassifier(n_classes=n_classes).to(cfg.device)
if os.path.exists(ckpt):
model.load_state_dict(torch.load(ckpt, map_location=cfg.device))
print(f"[IntentClassifier] Loaded checkpoint: {ckpt}")
return model, category_names
seqs, labels = build_intent_data(data, item_embeddings, iid2cat, max_samples=5_000)
n = len(labels)
idx = np.random.permutation(n)
split = int(n * 0.8)
train_idx, val_idx = idx[:split], idx[split:]
train_ds = IntentDataset([seqs[i] for i in train_idx], [labels[i] for i in train_idx])
val_ds = IntentDataset([seqs[i] for i in val_idx], [labels[i] for i in val_idx])
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0)
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
best_val_acc = 0.0
for epoch in range(1, epochs + 1):
model.train()
total_loss, correct, total = 0.0, 0, 0
for seq, length, y in train_dl:
seq, length, y = seq.to(cfg.device), length, y.to(cfg.device)
logits = model(seq, length)
loss = criterion(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item() * len(y)
correct += (logits.argmax(1) == y).sum().item()
total += len(y)
model.eval()
val_correct, val_total = 0, 0
with torch.no_grad():
for seq, length, y in val_dl:
seq, length, y = seq.to(cfg.device), length, y.to(cfg.device)
val_correct += (model(seq, length).argmax(1) == y).sum().item()
val_total += len(y)
val_acc = val_correct / val_total
print(f" Epoch {epoch}/{epochs} | loss={total_loss/total:.4f} "
f"| train_acc={correct/total:.3f} | val_acc={val_acc:.3f}")
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), ckpt)
model.load_state_dict(torch.load(ckpt, map_location=cfg.device))
print(f"[IntentClassifier] Best val_acc={best_val_acc:.3f}")
return model, category_names