from __future__ import annotations import math import json from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Tuple # --------------------------------------------------------------------------- # Node feature extraction # --------------------------------------------------------------------------- MODALITY_INDEX = { "text": 0, "asr": 1, "image_proxy": 2, "waveform_proxy": 3, "audio_proxy": 4, "image_link": 5, "audio_link": 6, } MODALITY_DIM = len(MODALITY_INDEX) + 1 # +1 for unknown PHI_TYPE_INDEX = { "NAME_DATE_MRN_FACILITY": 0, "NAME_DATE_MRN": 1, "FACE_IMAGE": 2, "WAVEFORM_HEADER": 3, "VOICE": 4, "FACE_LINK": 5, "VOICE_LINK": 6, } PHI_TYPE_DIM = len(PHI_TYPE_INDEX) + 1 NODE_SCALAR_DIM = 3 # risk_entropy, context_confidence, pseudonym_version_norm NODE_FEAT_DIM = MODALITY_DIM + PHI_TYPE_DIM + NODE_SCALAR_DIM # 18 def _one_hot(idx_map: Dict[str, int], key: str, dim: int) -> List[float]: vec = [0.0] * dim i = idx_map.get(key, dim - 1) vec[i] = 1.0 return vec def node_features( modality: str, phi_type: str, risk_entropy: float, context_confidence: float, pseudonym_version: int, max_pv: int = 10, ) -> List[float]: mod_oh = _one_hot(MODALITY_INDEX, modality, MODALITY_DIM) phi_oh = _one_hot(PHI_TYPE_INDEX, phi_type, PHI_TYPE_DIM) scalars = [ float(max(0.0, min(1.0, risk_entropy))), float(max(0.0, min(1.0, context_confidence))), float(min(pseudonym_version, max_pv)) / float(max_pv), ] return mod_oh + phi_oh + scalars # --------------------------------------------------------------------------- # Linear layer (no deps) # --------------------------------------------------------------------------- def _matmul(A: List[List[float]], B: List[List[float]]) -> List[List[float]]: rows, mid, cols = len(A), len(B), len(B[0]) out = [[0.0] * cols for _ in range(rows)] for i in range(rows): for k in range(mid): if A[i][k] == 0.0: continue for j in range(cols): out[i][j] += A[i][k] * B[k][j] return out def _matvec(W: List[List[float]], x: List[float]) -> List[float]: return [sum(W[i][j] * x[j] for j in range(len(x))) for i in range(len(W))] def _relu(x: List[float]) -> List[float]: return [max(0.0, v) for v in x] def _softmax(x: List[float]) -> List[float]: m = max(x) e = [math.exp(v - m) for v in x] s = sum(e) or 1.0 return [v / s for v in e] def _norm(x: List[float]) -> float: return math.sqrt(sum(v * v for v in x)) or 1.0 def _normalize(x: List[float]) -> List[float]: n = _norm(x) return [v / n for v in x] def _add(a: List[float], b: List[float]) -> List[float]: return [a[i] + b[i] for i in range(len(a))] def _scale(a: List[float], s: float) -> List[float]: return [v * s for v in a] # --------------------------------------------------------------------------- # GAT message passing (single attention head, numpy-free) # --------------------------------------------------------------------------- @dataclass class GATLayer: in_dim: int out_dim: int W: List[List[float]] = field(default_factory=list) a_src: List[float] = field(default_factory=list) a_dst: List[float] = field(default_factory=list) def __post_init__(self) -> None: if not self.W: self.W = _xavier_init(self.out_dim, self.in_dim) if not self.a_src: self.a_src = [1.0 / self.out_dim] * self.out_dim if not self.a_dst: self.a_dst = [1.0 / self.out_dim] * self.out_dim def forward( self, node_feats: List[List[float]], edge_index: List[Tuple[int, int]], edge_weights: List[float], ) -> List[List[float]]: n = len(node_feats) h = [_relu(_matvec(self.W, x)) for x in node_feats] # attention coefficients e: Dict[Tuple[int, int], float] = {} for (src, dst), w in zip(edge_index, edge_weights): score = ( sum(self.a_src[k] * h[src][k] for k in range(self.out_dim)) + sum(self.a_dst[k] * h[dst][k] for k in range(self.out_dim)) ) e[(src, dst)] = math.exp(score) * float(w) # per-node normalization norm_sum: List[float] = [0.0] * n for (src, dst), v in e.items(): norm_sum[dst] += v for (src, dst) in e: denom = norm_sum[dst] or 1.0 e[(src, dst)] /= denom # aggregate out = [[0.0] * self.out_dim for _ in range(n)] for (src, dst), alpha in e.items(): for k in range(self.out_dim): out[dst][k] += alpha * h[src][k] # residual add (project if needed) for i in range(n): out[i] = _add(out[i], h[i]) return out def _xavier_init(rows: int, cols: int) -> List[List[float]]: limit = math.sqrt(6.0 / (rows + cols)) import random rng = random.Random(42) return [ [rng.uniform(-limit, limit) for _ in range(cols)] for _ in range(rows) ] # --------------------------------------------------------------------------- # Pooling # --------------------------------------------------------------------------- def mean_pool(node_embeds: List[List[float]]) -> List[float]: if not node_embeds: return [] dim = len(node_embeds[0]) out = [0.0] * dim for h in node_embeds: for k in range(dim): out[k] += h[k] return [v / len(node_embeds) for v in out] def max_pool(node_embeds: List[List[float]]) -> List[float]: if not node_embeds: return [] dim = len(node_embeds[0]) out = [-1e9] * dim for h in node_embeds: for k in range(dim): if h[k] > out[k]: out[k] = h[k] return out def attention_pool( node_embeds: List[List[float]], risk_entropies: List[float], ) -> List[float]: if not node_embeds: return [] weights = _softmax(risk_entropies) dim = len(node_embeds[0]) out = [0.0] * dim for h, w in zip(node_embeds, weights): for k in range(dim): out[k] += w * h[k] return out # --------------------------------------------------------------------------- # Encoder # --------------------------------------------------------------------------- HIDDEN_DIM = 32 EMBED_DIM = 16 @dataclass class DCPGEncoder: """ Two-layer GAT encoder over a DCPG graph. Input: graph_summary dict from DCPGAdapter.graph_summary() or CRDTGraph.summary() enriched with node features Output: patient_embedding (EMBED_DIM floats) + risk_score (float) """ layer1: GATLayer = field(default_factory=lambda: GATLayer(NODE_FEAT_DIM, HIDDEN_DIM)) layer2: GATLayer = field(default_factory=lambda: GATLayer(HIDDEN_DIM, EMBED_DIM)) risk_head: List[List[float]] = field(default_factory=lambda: _xavier_init(1, EMBED_DIM)) def encode(self, graph: "DCPGGraph") -> "EncoderOutput": if not graph.nodes: zero = [0.0] * EMBED_DIM return EncoderOutput( patient_embedding=zero, node_embeddings=[], risk_score=0.0, node_ids=[], ) feats = [n.feature_vec() for n in graph.nodes] ei = graph.edge_index() ew = graph.edge_weights() h1 = self.layer1.forward(feats, ei, ew) h2 = self.layer2.forward(h1, ei, ew) risk_entropies = [n.risk_entropy for n in graph.nodes] patient_emb = attention_pool(h2, risk_entropies) patient_emb = _normalize(patient_emb) risk_score = math.sigmoid_approx( sum(self.risk_head[0][k] * patient_emb[k] for k in range(EMBED_DIM)) ) return EncoderOutput( patient_embedding=patient_emb, node_embeddings=[_normalize(h) for h in h2], risk_score=round(risk_score, 4), node_ids=[n.node_id for n in graph.nodes], ) def _sigmoid(x: float) -> float: if x >= 0: return 1.0 / (1.0 + math.exp(-x)) e = math.exp(x) return e / (1.0 + e) # patch into math namespace for use above math.sigmoid_approx = _sigmoid # type: ignore[attr-defined] @dataclass class EncoderOutput: patient_embedding: List[float] node_embeddings: List[List[float]] risk_score: float node_ids: List[str] def to_dict(self) -> Dict[str, Any]: return { "patient_embedding": [round(v, 5) for v in self.patient_embedding], "node_embeddings": { nid: [round(v, 5) for v in emb] for nid, emb in zip(self.node_ids, self.node_embeddings) }, "risk_score": self.risk_score, "embed_dim": len(self.patient_embedding), } # --------------------------------------------------------------------------- # DCPGGraph — thin wrapper to consume DCPGAdapter.graph_summary() output # --------------------------------------------------------------------------- @dataclass class DCPGGraphNode: node_id: str modality: str phi_type: str risk_entropy: float context_confidence: float pseudonym_version: int def feature_vec(self) -> List[float]: return node_features( self.modality, self.phi_type, self.risk_entropy, self.context_confidence, self.pseudonym_version, ) @dataclass class DCPGGraph: nodes: List[DCPGGraphNode] = field(default_factory=list) edges: List[Dict[str, Any]] = field(default_factory=list) def _node_index(self) -> Dict[str, int]: return {n.node_id: i for i, n in enumerate(self.nodes)} def edge_index(self) -> List[Tuple[int, int]]: idx = self._node_index() ei: List[Tuple[int, int]] = [] for e in self.edges: s = idx.get(e["source"]) t = idx.get(e["target"]) if s is not None and t is not None: ei.append((s, t)) ei.append((t, s)) # undirected return ei def edge_weights(self) -> List[float]: idx = self._node_index() ew: List[float] = [] for e in self.edges: s = idx.get(e["source"]) t = idx.get(e["target"]) if s is not None and t is not None: w = float(e.get("weight", 1.0)) ew.extend([w, w]) return ew @classmethod def from_summary(cls, summary: Dict[str, Any]) -> "DCPGGraph": nodes = [ DCPGGraphNode( node_id=n["node_id"], modality=n["modality"], phi_type=n["phi_type"], risk_entropy=float(n.get("risk_entropy", 0.0)), context_confidence=float(n.get("context_confidence", 1.0)), pseudonym_version=int(n.get("pseudonym_version", 0)), ) for n in summary.get("nodes", []) ] edges = summary.get("edges", []) return cls(nodes=nodes, edges=edges) @classmethod def from_crdt_summary( cls, summary: Dict[str, Any], provisional_risk: float = 0.0, ) -> "DCPGGraph": nodes = [] for n in summary.get("nodes", []): parts = str(n["node_id"]).split("::") modality = parts[1] if len(parts) > 1 else "text" nodes.append( DCPGGraphNode( node_id=n["node_id"], modality=modality, phi_type=modality.upper(), risk_entropy=float(n.get("risk_entropy", provisional_risk)), context_confidence=min( 1.0, float(n.get("total_phi_units", 1)) / 10.0 ), pseudonym_version=int(n.get("pseudonym_version", 0)), ) ) return cls(nodes=nodes, edges=[]) # --------------------------------------------------------------------------- # Inference helper # --------------------------------------------------------------------------- def encode_patient( graph_summary: Dict[str, Any], encoder: Optional[DCPGEncoder] = None, source: str = "dcpg", ) -> Dict[str, Any]: enc = encoder or DCPGEncoder() if source == "crdt": g = DCPGGraph.from_crdt_summary( graph_summary, provisional_risk=float(graph_summary.get("merged_risk_patient_1", 0.0)), ) else: g = DCPGGraph.from_summary(graph_summary) out = enc.encode(g) return out.to_dict() # --------------------------------------------------------------------------- # Smoke test # --------------------------------------------------------------------------- if __name__ == "__main__": summary = { "node_count": 3, "edge_count": 2, "nodes": [ {"node_id": "p1::text::NAME_DATE_MRN_FACILITY", "modality": "text", "phi_type": "NAME_DATE_MRN_FACILITY", "risk_entropy": 0.72, "context_confidence": 0.9, "pseudonym_version": 1}, {"node_id": "p1::asr::NAME_DATE_MRN", "modality": "asr", "phi_type": "NAME_DATE_MRN", "risk_entropy": 0.61, "context_confidence": 0.7, "pseudonym_version": 1}, {"node_id": "p1::image_proxy::FACE_IMAGE", "modality": "image_proxy", "phi_type": "FACE_IMAGE", "risk_entropy": 0.45, "context_confidence": 0.5, "pseudonym_version": 0}, ], "edges": [ {"source": "p1::text::NAME_DATE_MRN_FACILITY", "target": "p1::asr::NAME_DATE_MRN", "type": "co_occurrence", "weight": 0.71}, {"source": "p1::text::NAME_DATE_MRN_FACILITY", "target": "p1::image_proxy::FACE_IMAGE", "type": "cross_modal", "weight": 0.58}, ], "provisional_risk": 0.664, } result = encode_patient(summary) print(json.dumps(result, indent=2)) print(f"\nrisk_score: {result['risk_score']}") print(f"embed_dim: {result['embed_dim']}") print(f"nodes encoded: {len(result['node_embeddings'])}")