modu-lightway / src /graph_builder.py
Munsusu's picture
Commit message: Initial deploy — 모두의 빛길 v1.0
131589b verified
"""모두의 빛길 — 이종 그래프 빌더.
GraphData(노드 컬렉션) → PyG HeteroData(GNN 입력) 변환을 담당한다.
좌표 기반으로 거리 임계값 내 엣지를 자동 생성하고, 각 엣지에
거리/위험도/접근성 등 가중치 feature를 부여한다.
PyG metadata는 GNN 모델 초기화에 그대로 전달된다.
"""
from __future__ import annotations
import math
from typing import List, Tuple, Dict
import torch
from torch_geometric.data import HeteroData
from .data_schema import (
GraphData, VenueNode, EventNode, TransitNode, AmenityNode, HazardNode,
)
# ============================================================
# 거리 계산
# ============================================================
def haversine_m(lat1: float, lng1: float, lat2: float, lng2: float) -> float:
"""간이 하버사인 거리(m). 광주 인근 단거리에 적합."""
R = 6_371_000.0
p1 = math.radians(lat1)
p2 = math.radians(lat2)
dp = math.radians(lat2 - lat1)
dl = math.radians(lng2 - lng1)
a = math.sin(dp / 2) ** 2 + math.cos(p1) * math.cos(p2) * math.sin(dl / 2) ** 2
return 2 * R * math.asin(math.sqrt(a))
# ============================================================
# 그래프 빌더
# ============================================================
class GraphBuilder:
"""GraphData → HeteroData.
파라미터
--------
venue_transit_radius_m : 시설-정류장 엣지를 만들 거리 한계 (기본 500m)
venue_amenity_radius_m : 시설-편의시설 엣지 거리 한계 (기본 300m)
transit_walk_radius_m : 정류장-정류장 도보 가능 거리 (기본 600m)
transit_hazard_radius_m: 정류장-위험구간 엣지 거리 한계 (기본 200m)
"""
def __init__(
self,
venue_transit_radius_m: float = 500.0,
venue_amenity_radius_m: float = 300.0,
transit_walk_radius_m: float = 600.0,
transit_hazard_radius_m: float = 200.0,
):
self.r_vt = venue_transit_radius_m
self.r_va = venue_amenity_radius_m
self.r_tt = transit_walk_radius_m
self.r_th = transit_hazard_radius_m
# --------------------------------------------------------
# 메인 빌드
# --------------------------------------------------------
def build(self, data: GraphData) -> HeteroData:
g = HeteroData()
# 1) 노드 features
g["venue"].x = self._stack_features([v.feature_vec() for v in data.venues])
g["event"].x = self._stack_features([e.feature_vec() for e in data.events])
g["transit"].x = self._stack_features([t.feature_vec() for t in data.transits])
g["amenity"].x = self._stack_features([a.feature_vec() for a in data.amenities])
g["hazard"].x = self._stack_features([h.feature_vec() for h in data.hazards])
# 2) (VENUE)-hosts->(EVENT) — 양방향
ei = self._venue_event_edges(data)
g["venue", "hosts", "event"].edge_index = ei
g["event", "hosted_by", "venue"].edge_index = ei[[1, 0]]
# 3) (VENUE)<->(TRANSIT) near
ei_vt, ew_vt = self._geo_edges(
[(v.id, v.lat, v.lng) for v in data.venues],
[(t.id, t.lat, t.lng) for t in data.transits],
radius_m=self.r_vt,
)
g["venue", "near", "transit"].edge_index = ei_vt
g["venue", "near", "transit"].edge_attr = ew_vt
g["transit", "near", "venue"].edge_index = ei_vt[[1, 0]]
g["transit", "near", "venue"].edge_attr = ew_vt
# 4) (VENUE)<->(AMENITY) has_amenity
ei_va, ew_va = self._geo_edges(
[(v.id, v.lat, v.lng) for v in data.venues],
[(a.id, a.lat, a.lng) for a in data.amenities],
radius_m=self.r_va,
)
g["venue", "has_amenity", "amenity"].edge_index = ei_va
g["venue", "has_amenity", "amenity"].edge_attr = ew_va
g["amenity", "near_venue", "venue"].edge_index = ei_va[[1, 0]]
g["amenity", "near_venue", "venue"].edge_attr = ew_va
# 5) (TRANSIT)<->(TRANSIT) walkable (도보 가능 거리)
ei_tt, ew_tt = self._geo_edges(
[(t.id, t.lat, t.lng) for t in data.transits],
[(t.id, t.lat, t.lng) for t in data.transits],
radius_m=self.r_tt,
exclude_self=True,
)
g["transit", "walkable", "transit"].edge_index = ei_tt
g["transit", "walkable", "transit"].edge_attr = ew_tt
# 6) (TRANSIT)->(HAZARD) passes_hazard
ei_th, ew_th = self._geo_edges(
[(t.id, t.lat, t.lng) for t in data.transits],
[(h.id, h.lat, h.lng) for h in data.hazards],
radius_m=self.r_th,
)
if ei_th.numel() > 0:
severities = torch.tensor(
[data.hazards[j].severity for j in ei_th[1].tolist()],
dtype=torch.float,
).unsqueeze(-1)
ew_th = torch.cat([ew_th, severities], dim=-1) # (E, 2): (거리, 위험도)
g["transit", "passes_hazard", "hazard"].edge_index = ei_th
g["transit", "passes_hazard", "hazard"].edge_attr = ew_th
return g
# --------------------------------------------------------
# 헬퍼
# --------------------------------------------------------
@staticmethod
def _stack_features(feats: List[List[float]]) -> torch.Tensor:
if not feats:
return torch.zeros((0, 1), dtype=torch.float)
return torch.tensor(feats, dtype=torch.float)
@staticmethod
def _venue_event_edges(data: GraphData) -> torch.Tensor:
if not data.events:
return torch.zeros((2, 0), dtype=torch.long)
src, dst = [], []
for e in data.events:
src.append(e.venue_id)
dst.append(e.id)
return torch.tensor([src, dst], dtype=torch.long)
@staticmethod
def _geo_edges(
src: List[Tuple[int, float, float]],
dst: List[Tuple[int, float, float]],
radius_m: float,
exclude_self: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""src와 dst 간 radius 이내인 엣지 + 정규화된 거리(0~1)."""
edge_src, edge_dst, dists = [], [], []
for sid, slat, slng in src:
for did, dlat, dlng in dst:
if exclude_self and sid == did:
continue
d = haversine_m(slat, slng, dlat, dlng)
if d <= radius_m:
edge_src.append(sid)
edge_dst.append(did)
dists.append(d / radius_m) # 정규화 0~1
if not edge_src:
return torch.zeros((2, 0), dtype=torch.long), torch.zeros((0, 1), dtype=torch.float)
ei = torch.tensor([edge_src, edge_dst], dtype=torch.long)
ew = torch.tensor(dists, dtype=torch.float).unsqueeze(-1)
return ei, ew
__all__ = ["GraphBuilder", "haversine_m"]