"""모두의 빛길 — 학습 스크립트. 세 단계로 학습한다. Stage 1) 슬롯 추출기 fine-tune (KLUE/RoBERTa) Stage 2) GNN 자기지도 사전학습 (링크 예측) Stage 3) 추천 fine-tune (BPR + 멀티태스크) GPU 설정은 사용자가 별도 셀에서 모델을 .to(device)로 옮긴 뒤 호출한다. 이 모듈 내부에서는 device를 직접 지정하지 않는다. """ from __future__ import annotations from typing import Dict, List, Tuple, Optional import torch import torch.nn as nn import torch.nn.functional as F # ============================================================ # Stage 1) 슬롯 추출기 학습 # ============================================================ def train_slot_extractor( model, # RobertaSlotExtractor train_examples: List[Dict], # [{"text": ..., "labels": [...]}] epochs: int = 5, batch_size: int = 8, lr: float = 3e-5, ): """KLUE/RoBERTa fine-tuning. train_examples 형식 ------------------- [ { "text": "아이와 저녁 공연을 보고 안전하게 귀가할 수 있는 동선을 추천해줘", "labels": ["O", "B-USER_TYPE", "O", "B-TIME_WINDOW", "B-CULTURE_PREF", ...] }, ... ] """ from .slot_extractor import LABEL2ID optim = torch.optim.AdamW(model.parameters(), lr=lr) def encode(batch): texts = [b["text"] for b in batch] labels = [b["labels"] for b in batch] enc = model.tokenizer( texts, is_split_into_words=False, return_tensors="pt", padding=True, truncation=True, max_length=128, return_offsets_mapping=True, ) offsets = enc.pop("offset_mapping") # offset → 라벨 정렬 (BIO를 토큰 레벨로 펴는 작업) # 단순화: 라벨이 토큰 단위로 이미 펴져 있다고 가정 # 실제로는 char offset → token offset 매핑이 필요. 데이터 빌드 시 처리. label_ids = torch.zeros_like(enc["input_ids"]).long().fill_(-100) for i, lab in enumerate(labels): for j in range(min(len(lab), label_ids.size(1))): label_ids[i, j] = LABEL2ID.get(lab[j], 0) return enc, label_ids model.train() device = next(model.parameters()).device for ep in range(epochs): total = 0.0 for i in range(0, len(train_examples), batch_size): batch = train_examples[i:i + batch_size] enc, lab = encode(batch) for k in enc: enc[k] = enc[k].to(device) lab = lab.to(device) out = model(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], labels=lab) loss = out.loss optim.zero_grad() loss.backward() optim.step() total += loss.item() print(f"[Stage1] epoch {ep+1}/{epochs} loss={total / max(1, len(train_examples)//batch_size):.4f}") # ============================================================ # Stage 2) GNN 자기지도 사전학습 (링크 예측) # ============================================================ def train_gnn_link_pred( gnn: nn.Module, hetero_data, relations: List[Tuple[str, str, str]] = None, epochs: int = 50, lr: float = 1e-3, neg_ratio: int = 3, ): """엣지를 마스킹하고 복원하는 자기지도 학습. BPR 손실: sigmoid(z_pos · z_neg)을 최소화. """ if relations is None: relations = [ ("venue", "near", "transit"), ("venue", "has_amenity", "amenity"), ("venue", "hosts", "event"), ] optim = torch.optim.Adam(gnn.parameters(), lr=lr) gnn.train() device = next(gnn.parameters()).device x_dict = {k: v.to(device) for k, v in hetero_data.x_dict.items()} eidx = {k: v.to(device) for k, v in hetero_data.edge_index_dict.items()} for ep in range(epochs): out = gnn(x_dict, eidx) loss = torch.tensor(0.0, device=device) for rel in relations: if rel not in eidx: continue src_t, _, dst_t = rel ei = eidx[rel] if ei.size(1) == 0: continue pos_src = out[src_t][ei[0]] pos_dst = out[dst_t][ei[1]] pos_score = (pos_src * pos_dst).sum(-1) # 음성 샘플 num_dst = out[dst_t].size(0) neg_dst = torch.randint(0, num_dst, (ei.size(1) * neg_ratio,), device=device) neg_src = ei[0].repeat_interleave(neg_ratio) neg_score = (out[src_t][neg_src] * out[dst_t][neg_dst]).sum(-1) # BPR loss = loss + -F.logsigmoid(pos_score.repeat_interleave(neg_ratio) - neg_score).mean() optim.zero_grad() loss.backward() optim.step() if (ep + 1) % 10 == 0: print(f"[Stage2] epoch {ep+1}/{epochs} loss={loss.item():.4f}") # ============================================================ # Stage 3) 추천 fine-tune (BPR) # ============================================================ def train_recommender( gnn: nn.Module, venue_ranker, # VenueRanker hetero_data, triples: List[Tuple[torch.Tensor, int, int]], # (slot_vec, pos_venue_id, neg_venue_id) 형태 epochs: int = 30, lr: float = 1e-3, ): """슬롯 → venue ranking BPR 학습.""" params = list(gnn.parameters()) + list(venue_ranker.parameters()) optim = torch.optim.Adam(params, lr=lr) gnn.train() venue_ranker.train() device = next(gnn.parameters()).device x_dict = {k: v.to(device) for k, v in hetero_data.x_dict.items()} eidx = {k: v.to(device) for k, v in hetero_data.edge_index_dict.items()} for ep in range(epochs): total = 0.0 for slot_vec, pos_id, neg_id in triples: slot_vec = slot_vec.to(device).unsqueeze(0) x = gnn(x_dict, eidx) venue_emb = x["venue"] scores = venue_ranker(slot_vec, venue_emb).squeeze(0) pos_score = scores[pos_id] neg_score = scores[neg_id] loss = -F.logsigmoid(pos_score - neg_score) optim.zero_grad() loss.backward() optim.step() total += loss.item() if (ep + 1) % 5 == 0: print(f"[Stage3] epoch {ep+1}/{epochs} loss={total / max(1, len(triples)):.4f}") __all__ = [ "train_slot_extractor", "train_gnn_link_pred", "train_recommender", ]