modu-lightway / src /train.py
Munsusu's picture
Commit message: Initial deploy — 모두의 빛길 v1.0
131589b verified
"""모두의 빛길 — 학습 스크립트.
세 단계로 학습한다.
Stage 1) 슬롯 추출기 fine-tune (KLUE/RoBERTa)
Stage 2) GNN 자기지도 사전학습 (링크 예측)
Stage 3) 추천 fine-tune (BPR + 멀티태스크)
GPU 설정은 사용자가 별도 셀에서 모델을 .to(device)로 옮긴 뒤 호출한다.
이 모듈 내부에서는 device를 직접 지정하지 않는다.
"""
from __future__ import annotations
from typing import Dict, List, Tuple, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
# ============================================================
# Stage 1) 슬롯 추출기 학습
# ============================================================
def train_slot_extractor(
model, # RobertaSlotExtractor
train_examples: List[Dict], # [{"text": ..., "labels": [...]}]
epochs: int = 5,
batch_size: int = 8,
lr: float = 3e-5,
):
"""KLUE/RoBERTa fine-tuning.
train_examples 형식
-------------------
[
{
"text": "아이와 저녁 공연을 보고 안전하게 귀가할 수 있는 동선을 추천해줘",
"labels": ["O", "B-USER_TYPE", "O", "B-TIME_WINDOW", "B-CULTURE_PREF", ...]
},
...
]
"""
from .slot_extractor import LABEL2ID
optim = torch.optim.AdamW(model.parameters(), lr=lr)
def encode(batch):
texts = [b["text"] for b in batch]
labels = [b["labels"] for b in batch]
enc = model.tokenizer(
texts, is_split_into_words=False,
return_tensors="pt", padding=True, truncation=True, max_length=128,
return_offsets_mapping=True,
)
offsets = enc.pop("offset_mapping")
# offset → 라벨 정렬 (BIO를 토큰 레벨로 펴는 작업)
# 단순화: 라벨이 토큰 단위로 이미 펴져 있다고 가정
# 실제로는 char offset → token offset 매핑이 필요. 데이터 빌드 시 처리.
label_ids = torch.zeros_like(enc["input_ids"]).long().fill_(-100)
for i, lab in enumerate(labels):
for j in range(min(len(lab), label_ids.size(1))):
label_ids[i, j] = LABEL2ID.get(lab[j], 0)
return enc, label_ids
model.train()
device = next(model.parameters()).device
for ep in range(epochs):
total = 0.0
for i in range(0, len(train_examples), batch_size):
batch = train_examples[i:i + batch_size]
enc, lab = encode(batch)
for k in enc:
enc[k] = enc[k].to(device)
lab = lab.to(device)
out = model(input_ids=enc["input_ids"],
attention_mask=enc["attention_mask"],
labels=lab)
loss = out.loss
optim.zero_grad()
loss.backward()
optim.step()
total += loss.item()
print(f"[Stage1] epoch {ep+1}/{epochs} loss={total / max(1, len(train_examples)//batch_size):.4f}")
# ============================================================
# Stage 2) GNN 자기지도 사전학습 (링크 예측)
# ============================================================
def train_gnn_link_pred(
gnn: nn.Module,
hetero_data,
relations: List[Tuple[str, str, str]] = None,
epochs: int = 50,
lr: float = 1e-3,
neg_ratio: int = 3,
):
"""엣지를 마스킹하고 복원하는 자기지도 학습.
BPR 손실: sigmoid(z_pos · z_neg)을 최소화.
"""
if relations is None:
relations = [
("venue", "near", "transit"),
("venue", "has_amenity", "amenity"),
("venue", "hosts", "event"),
]
optim = torch.optim.Adam(gnn.parameters(), lr=lr)
gnn.train()
device = next(gnn.parameters()).device
x_dict = {k: v.to(device) for k, v in hetero_data.x_dict.items()}
eidx = {k: v.to(device) for k, v in hetero_data.edge_index_dict.items()}
for ep in range(epochs):
out = gnn(x_dict, eidx)
loss = torch.tensor(0.0, device=device)
for rel in relations:
if rel not in eidx:
continue
src_t, _, dst_t = rel
ei = eidx[rel]
if ei.size(1) == 0:
continue
pos_src = out[src_t][ei[0]]
pos_dst = out[dst_t][ei[1]]
pos_score = (pos_src * pos_dst).sum(-1)
# 음성 샘플
num_dst = out[dst_t].size(0)
neg_dst = torch.randint(0, num_dst, (ei.size(1) * neg_ratio,), device=device)
neg_src = ei[0].repeat_interleave(neg_ratio)
neg_score = (out[src_t][neg_src] * out[dst_t][neg_dst]).sum(-1)
# BPR
loss = loss + -F.logsigmoid(pos_score.repeat_interleave(neg_ratio) - neg_score).mean()
optim.zero_grad()
loss.backward()
optim.step()
if (ep + 1) % 10 == 0:
print(f"[Stage2] epoch {ep+1}/{epochs} loss={loss.item():.4f}")
# ============================================================
# Stage 3) 추천 fine-tune (BPR)
# ============================================================
def train_recommender(
gnn: nn.Module,
venue_ranker, # VenueRanker
hetero_data,
triples: List[Tuple[torch.Tensor, int, int]],
# (slot_vec, pos_venue_id, neg_venue_id) 형태
epochs: int = 30,
lr: float = 1e-3,
):
"""슬롯 → venue ranking BPR 학습."""
params = list(gnn.parameters()) + list(venue_ranker.parameters())
optim = torch.optim.Adam(params, lr=lr)
gnn.train()
venue_ranker.train()
device = next(gnn.parameters()).device
x_dict = {k: v.to(device) for k, v in hetero_data.x_dict.items()}
eidx = {k: v.to(device) for k, v in hetero_data.edge_index_dict.items()}
for ep in range(epochs):
total = 0.0
for slot_vec, pos_id, neg_id in triples:
slot_vec = slot_vec.to(device).unsqueeze(0)
x = gnn(x_dict, eidx)
venue_emb = x["venue"]
scores = venue_ranker(slot_vec, venue_emb).squeeze(0)
pos_score = scores[pos_id]
neg_score = scores[neg_id]
loss = -F.logsigmoid(pos_score - neg_score)
optim.zero_grad()
loss.backward()
optim.step()
total += loss.item()
if (ep + 1) % 5 == 0:
print(f"[Stage3] epoch {ep+1}/{epochs} loss={total / max(1, len(triples)):.4f}")
__all__ = [
"train_slot_extractor",
"train_gnn_link_pred",
"train_recommender",
]