Spaces:
Running
Running
File size: 6,932 Bytes
131589b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 | """모두의 빛길 — 이종 그래프 빌더.
GraphData(노드 컬렉션) → PyG HeteroData(GNN 입력) 변환을 담당한다.
좌표 기반으로 거리 임계값 내 엣지를 자동 생성하고, 각 엣지에
거리/위험도/접근성 등 가중치 feature를 부여한다.
PyG metadata는 GNN 모델 초기화에 그대로 전달된다.
"""
from __future__ import annotations
import math
from typing import List, Tuple, Dict
import torch
from torch_geometric.data import HeteroData
from .data_schema import (
GraphData, VenueNode, EventNode, TransitNode, AmenityNode, HazardNode,
)
# ============================================================
# 거리 계산
# ============================================================
def haversine_m(lat1: float, lng1: float, lat2: float, lng2: float) -> float:
"""간이 하버사인 거리(m). 광주 인근 단거리에 적합."""
R = 6_371_000.0
p1 = math.radians(lat1)
p2 = math.radians(lat2)
dp = math.radians(lat2 - lat1)
dl = math.radians(lng2 - lng1)
a = math.sin(dp / 2) ** 2 + math.cos(p1) * math.cos(p2) * math.sin(dl / 2) ** 2
return 2 * R * math.asin(math.sqrt(a))
# ============================================================
# 그래프 빌더
# ============================================================
class GraphBuilder:
"""GraphData → HeteroData.
파라미터
--------
venue_transit_radius_m : 시설-정류장 엣지를 만들 거리 한계 (기본 500m)
venue_amenity_radius_m : 시설-편의시설 엣지 거리 한계 (기본 300m)
transit_walk_radius_m : 정류장-정류장 도보 가능 거리 (기본 600m)
transit_hazard_radius_m: 정류장-위험구간 엣지 거리 한계 (기본 200m)
"""
def __init__(
self,
venue_transit_radius_m: float = 500.0,
venue_amenity_radius_m: float = 300.0,
transit_walk_radius_m: float = 600.0,
transit_hazard_radius_m: float = 200.0,
):
self.r_vt = venue_transit_radius_m
self.r_va = venue_amenity_radius_m
self.r_tt = transit_walk_radius_m
self.r_th = transit_hazard_radius_m
# --------------------------------------------------------
# 메인 빌드
# --------------------------------------------------------
def build(self, data: GraphData) -> HeteroData:
g = HeteroData()
# 1) 노드 features
g["venue"].x = self._stack_features([v.feature_vec() for v in data.venues])
g["event"].x = self._stack_features([e.feature_vec() for e in data.events])
g["transit"].x = self._stack_features([t.feature_vec() for t in data.transits])
g["amenity"].x = self._stack_features([a.feature_vec() for a in data.amenities])
g["hazard"].x = self._stack_features([h.feature_vec() for h in data.hazards])
# 2) (VENUE)-hosts->(EVENT) — 양방향
ei = self._venue_event_edges(data)
g["venue", "hosts", "event"].edge_index = ei
g["event", "hosted_by", "venue"].edge_index = ei[[1, 0]]
# 3) (VENUE)<->(TRANSIT) near
ei_vt, ew_vt = self._geo_edges(
[(v.id, v.lat, v.lng) for v in data.venues],
[(t.id, t.lat, t.lng) for t in data.transits],
radius_m=self.r_vt,
)
g["venue", "near", "transit"].edge_index = ei_vt
g["venue", "near", "transit"].edge_attr = ew_vt
g["transit", "near", "venue"].edge_index = ei_vt[[1, 0]]
g["transit", "near", "venue"].edge_attr = ew_vt
# 4) (VENUE)<->(AMENITY) has_amenity
ei_va, ew_va = self._geo_edges(
[(v.id, v.lat, v.lng) for v in data.venues],
[(a.id, a.lat, a.lng) for a in data.amenities],
radius_m=self.r_va,
)
g["venue", "has_amenity", "amenity"].edge_index = ei_va
g["venue", "has_amenity", "amenity"].edge_attr = ew_va
g["amenity", "near_venue", "venue"].edge_index = ei_va[[1, 0]]
g["amenity", "near_venue", "venue"].edge_attr = ew_va
# 5) (TRANSIT)<->(TRANSIT) walkable (도보 가능 거리)
ei_tt, ew_tt = self._geo_edges(
[(t.id, t.lat, t.lng) for t in data.transits],
[(t.id, t.lat, t.lng) for t in data.transits],
radius_m=self.r_tt,
exclude_self=True,
)
g["transit", "walkable", "transit"].edge_index = ei_tt
g["transit", "walkable", "transit"].edge_attr = ew_tt
# 6) (TRANSIT)->(HAZARD) passes_hazard
ei_th, ew_th = self._geo_edges(
[(t.id, t.lat, t.lng) for t in data.transits],
[(h.id, h.lat, h.lng) for h in data.hazards],
radius_m=self.r_th,
)
if ei_th.numel() > 0:
severities = torch.tensor(
[data.hazards[j].severity for j in ei_th[1].tolist()],
dtype=torch.float,
).unsqueeze(-1)
ew_th = torch.cat([ew_th, severities], dim=-1) # (E, 2): (거리, 위험도)
g["transit", "passes_hazard", "hazard"].edge_index = ei_th
g["transit", "passes_hazard", "hazard"].edge_attr = ew_th
return g
# --------------------------------------------------------
# 헬퍼
# --------------------------------------------------------
@staticmethod
def _stack_features(feats: List[List[float]]) -> torch.Tensor:
if not feats:
return torch.zeros((0, 1), dtype=torch.float)
return torch.tensor(feats, dtype=torch.float)
@staticmethod
def _venue_event_edges(data: GraphData) -> torch.Tensor:
if not data.events:
return torch.zeros((2, 0), dtype=torch.long)
src, dst = [], []
for e in data.events:
src.append(e.venue_id)
dst.append(e.id)
return torch.tensor([src, dst], dtype=torch.long)
@staticmethod
def _geo_edges(
src: List[Tuple[int, float, float]],
dst: List[Tuple[int, float, float]],
radius_m: float,
exclude_self: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""src와 dst 간 radius 이내인 엣지 + 정규화된 거리(0~1)."""
edge_src, edge_dst, dists = [], [], []
for sid, slat, slng in src:
for did, dlat, dlng in dst:
if exclude_self and sid == did:
continue
d = haversine_m(slat, slng, dlat, dlng)
if d <= radius_m:
edge_src.append(sid)
edge_dst.append(did)
dists.append(d / radius_m) # 정규화 0~1
if not edge_src:
return torch.zeros((2, 0), dtype=torch.long), torch.zeros((0, 1), dtype=torch.float)
ei = torch.tensor([edge_src, edge_dst], dtype=torch.long)
ew = torch.tensor(dists, dtype=torch.float).unsqueeze(-1)
return ei, ew
__all__ = ["GraphBuilder", "haversine_m"]
|