Spaces:
Sleeping
Sleeping
| """모두의 빛길 — 이종 그래프 빌더. | |
| GraphData(노드 컬렉션) → PyG HeteroData(GNN 입력) 변환을 담당한다. | |
| 좌표 기반으로 거리 임계값 내 엣지를 자동 생성하고, 각 엣지에 | |
| 거리/위험도/접근성 등 가중치 feature를 부여한다. | |
| PyG metadata는 GNN 모델 초기화에 그대로 전달된다. | |
| """ | |
| from __future__ import annotations | |
| import math | |
| from typing import List, Tuple, Dict | |
| import torch | |
| from torch_geometric.data import HeteroData | |
| from .data_schema import ( | |
| GraphData, VenueNode, EventNode, TransitNode, AmenityNode, HazardNode, | |
| ) | |
| # ============================================================ | |
| # 거리 계산 | |
| # ============================================================ | |
| def haversine_m(lat1: float, lng1: float, lat2: float, lng2: float) -> float: | |
| """간이 하버사인 거리(m). 광주 인근 단거리에 적합.""" | |
| R = 6_371_000.0 | |
| p1 = math.radians(lat1) | |
| p2 = math.radians(lat2) | |
| dp = math.radians(lat2 - lat1) | |
| dl = math.radians(lng2 - lng1) | |
| a = math.sin(dp / 2) ** 2 + math.cos(p1) * math.cos(p2) * math.sin(dl / 2) ** 2 | |
| return 2 * R * math.asin(math.sqrt(a)) | |
| # ============================================================ | |
| # 그래프 빌더 | |
| # ============================================================ | |
| class GraphBuilder: | |
| """GraphData → HeteroData. | |
| 파라미터 | |
| -------- | |
| venue_transit_radius_m : 시설-정류장 엣지를 만들 거리 한계 (기본 500m) | |
| venue_amenity_radius_m : 시설-편의시설 엣지 거리 한계 (기본 300m) | |
| transit_walk_radius_m : 정류장-정류장 도보 가능 거리 (기본 600m) | |
| transit_hazard_radius_m: 정류장-위험구간 엣지 거리 한계 (기본 200m) | |
| """ | |
| def __init__( | |
| self, | |
| venue_transit_radius_m: float = 500.0, | |
| venue_amenity_radius_m: float = 300.0, | |
| transit_walk_radius_m: float = 600.0, | |
| transit_hazard_radius_m: float = 200.0, | |
| ): | |
| self.r_vt = venue_transit_radius_m | |
| self.r_va = venue_amenity_radius_m | |
| self.r_tt = transit_walk_radius_m | |
| self.r_th = transit_hazard_radius_m | |
| # -------------------------------------------------------- | |
| # 메인 빌드 | |
| # -------------------------------------------------------- | |
| def build(self, data: GraphData) -> HeteroData: | |
| g = HeteroData() | |
| # 1) 노드 features | |
| g["venue"].x = self._stack_features([v.feature_vec() for v in data.venues]) | |
| g["event"].x = self._stack_features([e.feature_vec() for e in data.events]) | |
| g["transit"].x = self._stack_features([t.feature_vec() for t in data.transits]) | |
| g["amenity"].x = self._stack_features([a.feature_vec() for a in data.amenities]) | |
| g["hazard"].x = self._stack_features([h.feature_vec() for h in data.hazards]) | |
| # 2) (VENUE)-hosts->(EVENT) — 양방향 | |
| ei = self._venue_event_edges(data) | |
| g["venue", "hosts", "event"].edge_index = ei | |
| g["event", "hosted_by", "venue"].edge_index = ei[[1, 0]] | |
| # 3) (VENUE)<->(TRANSIT) near | |
| ei_vt, ew_vt = self._geo_edges( | |
| [(v.id, v.lat, v.lng) for v in data.venues], | |
| [(t.id, t.lat, t.lng) for t in data.transits], | |
| radius_m=self.r_vt, | |
| ) | |
| g["venue", "near", "transit"].edge_index = ei_vt | |
| g["venue", "near", "transit"].edge_attr = ew_vt | |
| g["transit", "near", "venue"].edge_index = ei_vt[[1, 0]] | |
| g["transit", "near", "venue"].edge_attr = ew_vt | |
| # 4) (VENUE)<->(AMENITY) has_amenity | |
| ei_va, ew_va = self._geo_edges( | |
| [(v.id, v.lat, v.lng) for v in data.venues], | |
| [(a.id, a.lat, a.lng) for a in data.amenities], | |
| radius_m=self.r_va, | |
| ) | |
| g["venue", "has_amenity", "amenity"].edge_index = ei_va | |
| g["venue", "has_amenity", "amenity"].edge_attr = ew_va | |
| g["amenity", "near_venue", "venue"].edge_index = ei_va[[1, 0]] | |
| g["amenity", "near_venue", "venue"].edge_attr = ew_va | |
| # 5) (TRANSIT)<->(TRANSIT) walkable (도보 가능 거리) | |
| ei_tt, ew_tt = self._geo_edges( | |
| [(t.id, t.lat, t.lng) for t in data.transits], | |
| [(t.id, t.lat, t.lng) for t in data.transits], | |
| radius_m=self.r_tt, | |
| exclude_self=True, | |
| ) | |
| g["transit", "walkable", "transit"].edge_index = ei_tt | |
| g["transit", "walkable", "transit"].edge_attr = ew_tt | |
| # 6) (TRANSIT)->(HAZARD) passes_hazard | |
| ei_th, ew_th = self._geo_edges( | |
| [(t.id, t.lat, t.lng) for t in data.transits], | |
| [(h.id, h.lat, h.lng) for h in data.hazards], | |
| radius_m=self.r_th, | |
| ) | |
| if ei_th.numel() > 0: | |
| severities = torch.tensor( | |
| [data.hazards[j].severity for j in ei_th[1].tolist()], | |
| dtype=torch.float, | |
| ).unsqueeze(-1) | |
| ew_th = torch.cat([ew_th, severities], dim=-1) # (E, 2): (거리, 위험도) | |
| g["transit", "passes_hazard", "hazard"].edge_index = ei_th | |
| g["transit", "passes_hazard", "hazard"].edge_attr = ew_th | |
| return g | |
| # -------------------------------------------------------- | |
| # 헬퍼 | |
| # -------------------------------------------------------- | |
| def _stack_features(feats: List[List[float]]) -> torch.Tensor: | |
| if not feats: | |
| return torch.zeros((0, 1), dtype=torch.float) | |
| return torch.tensor(feats, dtype=torch.float) | |
| def _venue_event_edges(data: GraphData) -> torch.Tensor: | |
| if not data.events: | |
| return torch.zeros((2, 0), dtype=torch.long) | |
| src, dst = [], [] | |
| for e in data.events: | |
| src.append(e.venue_id) | |
| dst.append(e.id) | |
| return torch.tensor([src, dst], dtype=torch.long) | |
| def _geo_edges( | |
| src: List[Tuple[int, float, float]], | |
| dst: List[Tuple[int, float, float]], | |
| radius_m: float, | |
| exclude_self: bool = False, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """src와 dst 간 radius 이내인 엣지 + 정규화된 거리(0~1).""" | |
| edge_src, edge_dst, dists = [], [], [] | |
| for sid, slat, slng in src: | |
| for did, dlat, dlng in dst: | |
| if exclude_self and sid == did: | |
| continue | |
| d = haversine_m(slat, slng, dlat, dlng) | |
| if d <= radius_m: | |
| edge_src.append(sid) | |
| edge_dst.append(did) | |
| dists.append(d / radius_m) # 정규화 0~1 | |
| if not edge_src: | |
| return torch.zeros((2, 0), dtype=torch.long), torch.zeros((0, 1), dtype=torch.float) | |
| ei = torch.tensor([edge_src, edge_dst], dtype=torch.long) | |
| ew = torch.tensor(dists, dtype=torch.float).unsqueeze(-1) | |
| return ei, ew | |
| __all__ = ["GraphBuilder", "haversine_m"] | |