Spaces:
Sleeping
Sleeping
| """모두의 빛길 — 학습 스크립트. | |
| 세 단계로 학습한다. | |
| Stage 1) 슬롯 추출기 fine-tune (KLUE/RoBERTa) | |
| Stage 2) GNN 자기지도 사전학습 (링크 예측) | |
| Stage 3) 추천 fine-tune (BPR + 멀티태스크) | |
| GPU 설정은 사용자가 별도 셀에서 모델을 .to(device)로 옮긴 뒤 호출한다. | |
| 이 모듈 내부에서는 device를 직접 지정하지 않는다. | |
| """ | |
| from __future__ import annotations | |
| from typing import Dict, List, Tuple, Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # ============================================================ | |
| # Stage 1) 슬롯 추출기 학습 | |
| # ============================================================ | |
| def train_slot_extractor( | |
| model, # RobertaSlotExtractor | |
| train_examples: List[Dict], # [{"text": ..., "labels": [...]}] | |
| epochs: int = 5, | |
| batch_size: int = 8, | |
| lr: float = 3e-5, | |
| ): | |
| """KLUE/RoBERTa fine-tuning. | |
| train_examples 형식 | |
| ------------------- | |
| [ | |
| { | |
| "text": "아이와 저녁 공연을 보고 안전하게 귀가할 수 있는 동선을 추천해줘", | |
| "labels": ["O", "B-USER_TYPE", "O", "B-TIME_WINDOW", "B-CULTURE_PREF", ...] | |
| }, | |
| ... | |
| ] | |
| """ | |
| from .slot_extractor import LABEL2ID | |
| optim = torch.optim.AdamW(model.parameters(), lr=lr) | |
| def encode(batch): | |
| texts = [b["text"] for b in batch] | |
| labels = [b["labels"] for b in batch] | |
| enc = model.tokenizer( | |
| texts, is_split_into_words=False, | |
| return_tensors="pt", padding=True, truncation=True, max_length=128, | |
| return_offsets_mapping=True, | |
| ) | |
| offsets = enc.pop("offset_mapping") | |
| # offset → 라벨 정렬 (BIO를 토큰 레벨로 펴는 작업) | |
| # 단순화: 라벨이 토큰 단위로 이미 펴져 있다고 가정 | |
| # 실제로는 char offset → token offset 매핑이 필요. 데이터 빌드 시 처리. | |
| label_ids = torch.zeros_like(enc["input_ids"]).long().fill_(-100) | |
| for i, lab in enumerate(labels): | |
| for j in range(min(len(lab), label_ids.size(1))): | |
| label_ids[i, j] = LABEL2ID.get(lab[j], 0) | |
| return enc, label_ids | |
| model.train() | |
| device = next(model.parameters()).device | |
| for ep in range(epochs): | |
| total = 0.0 | |
| for i in range(0, len(train_examples), batch_size): | |
| batch = train_examples[i:i + batch_size] | |
| enc, lab = encode(batch) | |
| for k in enc: | |
| enc[k] = enc[k].to(device) | |
| lab = lab.to(device) | |
| out = model(input_ids=enc["input_ids"], | |
| attention_mask=enc["attention_mask"], | |
| labels=lab) | |
| loss = out.loss | |
| optim.zero_grad() | |
| loss.backward() | |
| optim.step() | |
| total += loss.item() | |
| print(f"[Stage1] epoch {ep+1}/{epochs} loss={total / max(1, len(train_examples)//batch_size):.4f}") | |
| # ============================================================ | |
| # Stage 2) GNN 자기지도 사전학습 (링크 예측) | |
| # ============================================================ | |
| def train_gnn_link_pred( | |
| gnn: nn.Module, | |
| hetero_data, | |
| relations: List[Tuple[str, str, str]] = None, | |
| epochs: int = 50, | |
| lr: float = 1e-3, | |
| neg_ratio: int = 3, | |
| ): | |
| """엣지를 마스킹하고 복원하는 자기지도 학습. | |
| BPR 손실: sigmoid(z_pos · z_neg)을 최소화. | |
| """ | |
| if relations is None: | |
| relations = [ | |
| ("venue", "near", "transit"), | |
| ("venue", "has_amenity", "amenity"), | |
| ("venue", "hosts", "event"), | |
| ] | |
| optim = torch.optim.Adam(gnn.parameters(), lr=lr) | |
| gnn.train() | |
| device = next(gnn.parameters()).device | |
| x_dict = {k: v.to(device) for k, v in hetero_data.x_dict.items()} | |
| eidx = {k: v.to(device) for k, v in hetero_data.edge_index_dict.items()} | |
| for ep in range(epochs): | |
| out = gnn(x_dict, eidx) | |
| loss = torch.tensor(0.0, device=device) | |
| for rel in relations: | |
| if rel not in eidx: | |
| continue | |
| src_t, _, dst_t = rel | |
| ei = eidx[rel] | |
| if ei.size(1) == 0: | |
| continue | |
| pos_src = out[src_t][ei[0]] | |
| pos_dst = out[dst_t][ei[1]] | |
| pos_score = (pos_src * pos_dst).sum(-1) | |
| # 음성 샘플 | |
| num_dst = out[dst_t].size(0) | |
| neg_dst = torch.randint(0, num_dst, (ei.size(1) * neg_ratio,), device=device) | |
| neg_src = ei[0].repeat_interleave(neg_ratio) | |
| neg_score = (out[src_t][neg_src] * out[dst_t][neg_dst]).sum(-1) | |
| # BPR | |
| loss = loss + -F.logsigmoid(pos_score.repeat_interleave(neg_ratio) - neg_score).mean() | |
| optim.zero_grad() | |
| loss.backward() | |
| optim.step() | |
| if (ep + 1) % 10 == 0: | |
| print(f"[Stage2] epoch {ep+1}/{epochs} loss={loss.item():.4f}") | |
| # ============================================================ | |
| # Stage 3) 추천 fine-tune (BPR) | |
| # ============================================================ | |
| def train_recommender( | |
| gnn: nn.Module, | |
| venue_ranker, # VenueRanker | |
| hetero_data, | |
| triples: List[Tuple[torch.Tensor, int, int]], | |
| # (slot_vec, pos_venue_id, neg_venue_id) 형태 | |
| epochs: int = 30, | |
| lr: float = 1e-3, | |
| ): | |
| """슬롯 → venue ranking BPR 학습.""" | |
| params = list(gnn.parameters()) + list(venue_ranker.parameters()) | |
| optim = torch.optim.Adam(params, lr=lr) | |
| gnn.train() | |
| venue_ranker.train() | |
| device = next(gnn.parameters()).device | |
| x_dict = {k: v.to(device) for k, v in hetero_data.x_dict.items()} | |
| eidx = {k: v.to(device) for k, v in hetero_data.edge_index_dict.items()} | |
| for ep in range(epochs): | |
| total = 0.0 | |
| for slot_vec, pos_id, neg_id in triples: | |
| slot_vec = slot_vec.to(device).unsqueeze(0) | |
| x = gnn(x_dict, eidx) | |
| venue_emb = x["venue"] | |
| scores = venue_ranker(slot_vec, venue_emb).squeeze(0) | |
| pos_score = scores[pos_id] | |
| neg_score = scores[neg_id] | |
| loss = -F.logsigmoid(pos_score - neg_score) | |
| optim.zero_grad() | |
| loss.backward() | |
| optim.step() | |
| total += loss.item() | |
| if (ep + 1) % 5 == 0: | |
| print(f"[Stage3] epoch {ep+1}/{epochs} loss={total / max(1, len(triples)):.4f}") | |
| __all__ = [ | |
| "train_slot_extractor", | |
| "train_gnn_link_pred", | |
| "train_recommender", | |
| ] | |