"""모두의 빛길 — 이종 그래프 빌더. GraphData(노드 컬렉션) → PyG HeteroData(GNN 입력) 변환을 담당한다. 좌표 기반으로 거리 임계값 내 엣지를 자동 생성하고, 각 엣지에 거리/위험도/접근성 등 가중치 feature를 부여한다. PyG metadata는 GNN 모델 초기화에 그대로 전달된다. """ from __future__ import annotations import math from typing import List, Tuple, Dict import torch from torch_geometric.data import HeteroData from .data_schema import ( GraphData, VenueNode, EventNode, TransitNode, AmenityNode, HazardNode, ) # ============================================================ # 거리 계산 # ============================================================ def haversine_m(lat1: float, lng1: float, lat2: float, lng2: float) -> float: """간이 하버사인 거리(m). 광주 인근 단거리에 적합.""" R = 6_371_000.0 p1 = math.radians(lat1) p2 = math.radians(lat2) dp = math.radians(lat2 - lat1) dl = math.radians(lng2 - lng1) a = math.sin(dp / 2) ** 2 + math.cos(p1) * math.cos(p2) * math.sin(dl / 2) ** 2 return 2 * R * math.asin(math.sqrt(a)) # ============================================================ # 그래프 빌더 # ============================================================ class GraphBuilder: """GraphData → HeteroData. 파라미터 -------- venue_transit_radius_m : 시설-정류장 엣지를 만들 거리 한계 (기본 500m) venue_amenity_radius_m : 시설-편의시설 엣지 거리 한계 (기본 300m) transit_walk_radius_m : 정류장-정류장 도보 가능 거리 (기본 600m) transit_hazard_radius_m: 정류장-위험구간 엣지 거리 한계 (기본 200m) """ def __init__( self, venue_transit_radius_m: float = 500.0, venue_amenity_radius_m: float = 300.0, transit_walk_radius_m: float = 600.0, transit_hazard_radius_m: float = 200.0, ): self.r_vt = venue_transit_radius_m self.r_va = venue_amenity_radius_m self.r_tt = transit_walk_radius_m self.r_th = transit_hazard_radius_m # -------------------------------------------------------- # 메인 빌드 # -------------------------------------------------------- def build(self, data: GraphData) -> HeteroData: g = HeteroData() # 1) 노드 features g["venue"].x = self._stack_features([v.feature_vec() for v in data.venues]) g["event"].x = self._stack_features([e.feature_vec() for e in data.events]) g["transit"].x = self._stack_features([t.feature_vec() for t in data.transits]) g["amenity"].x = self._stack_features([a.feature_vec() for a in data.amenities]) g["hazard"].x = self._stack_features([h.feature_vec() for h in data.hazards]) # 2) (VENUE)-hosts->(EVENT) — 양방향 ei = self._venue_event_edges(data) g["venue", "hosts", "event"].edge_index = ei g["event", "hosted_by", "venue"].edge_index = ei[[1, 0]] # 3) (VENUE)<->(TRANSIT) near ei_vt, ew_vt = self._geo_edges( [(v.id, v.lat, v.lng) for v in data.venues], [(t.id, t.lat, t.lng) for t in data.transits], radius_m=self.r_vt, ) g["venue", "near", "transit"].edge_index = ei_vt g["venue", "near", "transit"].edge_attr = ew_vt g["transit", "near", "venue"].edge_index = ei_vt[[1, 0]] g["transit", "near", "venue"].edge_attr = ew_vt # 4) (VENUE)<->(AMENITY) has_amenity ei_va, ew_va = self._geo_edges( [(v.id, v.lat, v.lng) for v in data.venues], [(a.id, a.lat, a.lng) for a in data.amenities], radius_m=self.r_va, ) g["venue", "has_amenity", "amenity"].edge_index = ei_va g["venue", "has_amenity", "amenity"].edge_attr = ew_va g["amenity", "near_venue", "venue"].edge_index = ei_va[[1, 0]] g["amenity", "near_venue", "venue"].edge_attr = ew_va # 5) (TRANSIT)<->(TRANSIT) walkable (도보 가능 거리) ei_tt, ew_tt = self._geo_edges( [(t.id, t.lat, t.lng) for t in data.transits], [(t.id, t.lat, t.lng) for t in data.transits], radius_m=self.r_tt, exclude_self=True, ) g["transit", "walkable", "transit"].edge_index = ei_tt g["transit", "walkable", "transit"].edge_attr = ew_tt # 6) (TRANSIT)->(HAZARD) passes_hazard ei_th, ew_th = self._geo_edges( [(t.id, t.lat, t.lng) for t in data.transits], [(h.id, h.lat, h.lng) for h in data.hazards], radius_m=self.r_th, ) if ei_th.numel() > 0: severities = torch.tensor( [data.hazards[j].severity for j in ei_th[1].tolist()], dtype=torch.float, ).unsqueeze(-1) ew_th = torch.cat([ew_th, severities], dim=-1) # (E, 2): (거리, 위험도) g["transit", "passes_hazard", "hazard"].edge_index = ei_th g["transit", "passes_hazard", "hazard"].edge_attr = ew_th return g # -------------------------------------------------------- # 헬퍼 # -------------------------------------------------------- @staticmethod def _stack_features(feats: List[List[float]]) -> torch.Tensor: if not feats: return torch.zeros((0, 1), dtype=torch.float) return torch.tensor(feats, dtype=torch.float) @staticmethod def _venue_event_edges(data: GraphData) -> torch.Tensor: if not data.events: return torch.zeros((2, 0), dtype=torch.long) src, dst = [], [] for e in data.events: src.append(e.venue_id) dst.append(e.id) return torch.tensor([src, dst], dtype=torch.long) @staticmethod def _geo_edges( src: List[Tuple[int, float, float]], dst: List[Tuple[int, float, float]], radius_m: float, exclude_self: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """src와 dst 간 radius 이내인 엣지 + 정규화된 거리(0~1).""" edge_src, edge_dst, dists = [], [], [] for sid, slat, slng in src: for did, dlat, dlng in dst: if exclude_self and sid == did: continue d = haversine_m(slat, slng, dlat, dlng) if d <= radius_m: edge_src.append(sid) edge_dst.append(did) dists.append(d / radius_m) # 정규화 0~1 if not edge_src: return torch.zeros((2, 0), dtype=torch.long), torch.zeros((0, 1), dtype=torch.float) ei = torch.tensor([edge_src, edge_dst], dtype=torch.long) ew = torch.tensor(dists, dtype=torch.float).unsqueeze(-1) return ei, ew __all__ = ["GraphBuilder", "haversine_m"]