exposureguard-dcpg-encoder / dcpg_encoder.py
vkatg's picture
Upload 5 files
51a62e8 verified
from __future__ import annotations
import math
import json
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
# ---------------------------------------------------------------------------
# Node feature extraction
# ---------------------------------------------------------------------------
MODALITY_INDEX = {
"text": 0,
"asr": 1,
"image_proxy": 2,
"waveform_proxy": 3,
"audio_proxy": 4,
"image_link": 5,
"audio_link": 6,
}
MODALITY_DIM = len(MODALITY_INDEX) + 1 # +1 for unknown
PHI_TYPE_INDEX = {
"NAME_DATE_MRN_FACILITY": 0,
"NAME_DATE_MRN": 1,
"FACE_IMAGE": 2,
"WAVEFORM_HEADER": 3,
"VOICE": 4,
"FACE_LINK": 5,
"VOICE_LINK": 6,
}
PHI_TYPE_DIM = len(PHI_TYPE_INDEX) + 1
NODE_SCALAR_DIM = 3 # risk_entropy, context_confidence, pseudonym_version_norm
NODE_FEAT_DIM = MODALITY_DIM + PHI_TYPE_DIM + NODE_SCALAR_DIM # 18
def _one_hot(idx_map: Dict[str, int], key: str, dim: int) -> List[float]:
vec = [0.0] * dim
i = idx_map.get(key, dim - 1)
vec[i] = 1.0
return vec
def node_features(
modality: str,
phi_type: str,
risk_entropy: float,
context_confidence: float,
pseudonym_version: int,
max_pv: int = 10,
) -> List[float]:
mod_oh = _one_hot(MODALITY_INDEX, modality, MODALITY_DIM)
phi_oh = _one_hot(PHI_TYPE_INDEX, phi_type, PHI_TYPE_DIM)
scalars = [
float(max(0.0, min(1.0, risk_entropy))),
float(max(0.0, min(1.0, context_confidence))),
float(min(pseudonym_version, max_pv)) / float(max_pv),
]
return mod_oh + phi_oh + scalars
# ---------------------------------------------------------------------------
# Linear layer (no deps)
# ---------------------------------------------------------------------------
def _matmul(A: List[List[float]], B: List[List[float]]) -> List[List[float]]:
rows, mid, cols = len(A), len(B), len(B[0])
out = [[0.0] * cols for _ in range(rows)]
for i in range(rows):
for k in range(mid):
if A[i][k] == 0.0:
continue
for j in range(cols):
out[i][j] += A[i][k] * B[k][j]
return out
def _matvec(W: List[List[float]], x: List[float]) -> List[float]:
return [sum(W[i][j] * x[j] for j in range(len(x))) for i in range(len(W))]
def _relu(x: List[float]) -> List[float]:
return [max(0.0, v) for v in x]
def _softmax(x: List[float]) -> List[float]:
m = max(x)
e = [math.exp(v - m) for v in x]
s = sum(e) or 1.0
return [v / s for v in e]
def _norm(x: List[float]) -> float:
return math.sqrt(sum(v * v for v in x)) or 1.0
def _normalize(x: List[float]) -> List[float]:
n = _norm(x)
return [v / n for v in x]
def _add(a: List[float], b: List[float]) -> List[float]:
return [a[i] + b[i] for i in range(len(a))]
def _scale(a: List[float], s: float) -> List[float]:
return [v * s for v in a]
# ---------------------------------------------------------------------------
# GAT message passing (single attention head, numpy-free)
# ---------------------------------------------------------------------------
@dataclass
class GATLayer:
in_dim: int
out_dim: int
W: List[List[float]] = field(default_factory=list)
a_src: List[float] = field(default_factory=list)
a_dst: List[float] = field(default_factory=list)
def __post_init__(self) -> None:
if not self.W:
self.W = _xavier_init(self.out_dim, self.in_dim)
if not self.a_src:
self.a_src = [1.0 / self.out_dim] * self.out_dim
if not self.a_dst:
self.a_dst = [1.0 / self.out_dim] * self.out_dim
def forward(
self,
node_feats: List[List[float]],
edge_index: List[Tuple[int, int]],
edge_weights: List[float],
) -> List[List[float]]:
n = len(node_feats)
h = [_relu(_matvec(self.W, x)) for x in node_feats]
# attention coefficients
e: Dict[Tuple[int, int], float] = {}
for (src, dst), w in zip(edge_index, edge_weights):
score = (
sum(self.a_src[k] * h[src][k] for k in range(self.out_dim))
+ sum(self.a_dst[k] * h[dst][k] for k in range(self.out_dim))
)
e[(src, dst)] = math.exp(score) * float(w)
# per-node normalization
norm_sum: List[float] = [0.0] * n
for (src, dst), v in e.items():
norm_sum[dst] += v
for (src, dst) in e:
denom = norm_sum[dst] or 1.0
e[(src, dst)] /= denom
# aggregate
out = [[0.0] * self.out_dim for _ in range(n)]
for (src, dst), alpha in e.items():
for k in range(self.out_dim):
out[dst][k] += alpha * h[src][k]
# residual add (project if needed)
for i in range(n):
out[i] = _add(out[i], h[i])
return out
def _xavier_init(rows: int, cols: int) -> List[List[float]]:
limit = math.sqrt(6.0 / (rows + cols))
import random
rng = random.Random(42)
return [
[rng.uniform(-limit, limit) for _ in range(cols)]
for _ in range(rows)
]
# ---------------------------------------------------------------------------
# Pooling
# ---------------------------------------------------------------------------
def mean_pool(node_embeds: List[List[float]]) -> List[float]:
if not node_embeds:
return []
dim = len(node_embeds[0])
out = [0.0] * dim
for h in node_embeds:
for k in range(dim):
out[k] += h[k]
return [v / len(node_embeds) for v in out]
def max_pool(node_embeds: List[List[float]]) -> List[float]:
if not node_embeds:
return []
dim = len(node_embeds[0])
out = [-1e9] * dim
for h in node_embeds:
for k in range(dim):
if h[k] > out[k]:
out[k] = h[k]
return out
def attention_pool(
node_embeds: List[List[float]],
risk_entropies: List[float],
) -> List[float]:
if not node_embeds:
return []
weights = _softmax(risk_entropies)
dim = len(node_embeds[0])
out = [0.0] * dim
for h, w in zip(node_embeds, weights):
for k in range(dim):
out[k] += w * h[k]
return out
# ---------------------------------------------------------------------------
# Encoder
# ---------------------------------------------------------------------------
HIDDEN_DIM = 32
EMBED_DIM = 16
@dataclass
class DCPGEncoder:
"""
Two-layer GAT encoder over a DCPG graph.
Input: graph_summary dict from DCPGAdapter.graph_summary()
or CRDTGraph.summary() enriched with node features
Output: patient_embedding (EMBED_DIM floats) + risk_score (float)
"""
layer1: GATLayer = field(default_factory=lambda: GATLayer(NODE_FEAT_DIM, HIDDEN_DIM))
layer2: GATLayer = field(default_factory=lambda: GATLayer(HIDDEN_DIM, EMBED_DIM))
risk_head: List[List[float]] = field(default_factory=lambda: _xavier_init(1, EMBED_DIM))
def encode(self, graph: "DCPGGraph") -> "EncoderOutput":
if not graph.nodes:
zero = [0.0] * EMBED_DIM
return EncoderOutput(
patient_embedding=zero,
node_embeddings=[],
risk_score=0.0,
node_ids=[],
)
feats = [n.feature_vec() for n in graph.nodes]
ei = graph.edge_index()
ew = graph.edge_weights()
h1 = self.layer1.forward(feats, ei, ew)
h2 = self.layer2.forward(h1, ei, ew)
risk_entropies = [n.risk_entropy for n in graph.nodes]
patient_emb = attention_pool(h2, risk_entropies)
patient_emb = _normalize(patient_emb)
risk_score = math.sigmoid_approx(
sum(self.risk_head[0][k] * patient_emb[k] for k in range(EMBED_DIM))
)
return EncoderOutput(
patient_embedding=patient_emb,
node_embeddings=[_normalize(h) for h in h2],
risk_score=round(risk_score, 4),
node_ids=[n.node_id for n in graph.nodes],
)
def _sigmoid(x: float) -> float:
if x >= 0:
return 1.0 / (1.0 + math.exp(-x))
e = math.exp(x)
return e / (1.0 + e)
# patch into math namespace for use above
math.sigmoid_approx = _sigmoid # type: ignore[attr-defined]
@dataclass
class EncoderOutput:
patient_embedding: List[float]
node_embeddings: List[List[float]]
risk_score: float
node_ids: List[str]
def to_dict(self) -> Dict[str, Any]:
return {
"patient_embedding": [round(v, 5) for v in self.patient_embedding],
"node_embeddings": {
nid: [round(v, 5) for v in emb]
for nid, emb in zip(self.node_ids, self.node_embeddings)
},
"risk_score": self.risk_score,
"embed_dim": len(self.patient_embedding),
}
# ---------------------------------------------------------------------------
# DCPGGraph — thin wrapper to consume DCPGAdapter.graph_summary() output
# ---------------------------------------------------------------------------
@dataclass
class DCPGGraphNode:
node_id: str
modality: str
phi_type: str
risk_entropy: float
context_confidence: float
pseudonym_version: int
def feature_vec(self) -> List[float]:
return node_features(
self.modality,
self.phi_type,
self.risk_entropy,
self.context_confidence,
self.pseudonym_version,
)
@dataclass
class DCPGGraph:
nodes: List[DCPGGraphNode] = field(default_factory=list)
edges: List[Dict[str, Any]] = field(default_factory=list)
def _node_index(self) -> Dict[str, int]:
return {n.node_id: i for i, n in enumerate(self.nodes)}
def edge_index(self) -> List[Tuple[int, int]]:
idx = self._node_index()
ei: List[Tuple[int, int]] = []
for e in self.edges:
s = idx.get(e["source"])
t = idx.get(e["target"])
if s is not None and t is not None:
ei.append((s, t))
ei.append((t, s)) # undirected
return ei
def edge_weights(self) -> List[float]:
idx = self._node_index()
ew: List[float] = []
for e in self.edges:
s = idx.get(e["source"])
t = idx.get(e["target"])
if s is not None and t is not None:
w = float(e.get("weight", 1.0))
ew.extend([w, w])
return ew
@classmethod
def from_summary(cls, summary: Dict[str, Any]) -> "DCPGGraph":
nodes = [
DCPGGraphNode(
node_id=n["node_id"],
modality=n["modality"],
phi_type=n["phi_type"],
risk_entropy=float(n.get("risk_entropy", 0.0)),
context_confidence=float(n.get("context_confidence", 1.0)),
pseudonym_version=int(n.get("pseudonym_version", 0)),
)
for n in summary.get("nodes", [])
]
edges = summary.get("edges", [])
return cls(nodes=nodes, edges=edges)
@classmethod
def from_crdt_summary(
cls,
summary: Dict[str, Any],
provisional_risk: float = 0.0,
) -> "DCPGGraph":
nodes = []
for n in summary.get("nodes", []):
parts = str(n["node_id"]).split("::")
modality = parts[1] if len(parts) > 1 else "text"
nodes.append(
DCPGGraphNode(
node_id=n["node_id"],
modality=modality,
phi_type=modality.upper(),
risk_entropy=float(n.get("risk_entropy", provisional_risk)),
context_confidence=min(
1.0, float(n.get("total_phi_units", 1)) / 10.0
),
pseudonym_version=int(n.get("pseudonym_version", 0)),
)
)
return cls(nodes=nodes, edges=[])
# ---------------------------------------------------------------------------
# Inference helper
# ---------------------------------------------------------------------------
def encode_patient(
graph_summary: Dict[str, Any],
encoder: Optional[DCPGEncoder] = None,
source: str = "dcpg",
) -> Dict[str, Any]:
enc = encoder or DCPGEncoder()
if source == "crdt":
g = DCPGGraph.from_crdt_summary(
graph_summary,
provisional_risk=float(graph_summary.get("merged_risk_patient_1", 0.0)),
)
else:
g = DCPGGraph.from_summary(graph_summary)
out = enc.encode(g)
return out.to_dict()
# ---------------------------------------------------------------------------
# Smoke test
# ---------------------------------------------------------------------------
if __name__ == "__main__":
summary = {
"node_count": 3,
"edge_count": 2,
"nodes": [
{"node_id": "p1::text::NAME_DATE_MRN_FACILITY", "modality": "text",
"phi_type": "NAME_DATE_MRN_FACILITY", "risk_entropy": 0.72,
"context_confidence": 0.9, "pseudonym_version": 1},
{"node_id": "p1::asr::NAME_DATE_MRN", "modality": "asr",
"phi_type": "NAME_DATE_MRN", "risk_entropy": 0.61,
"context_confidence": 0.7, "pseudonym_version": 1},
{"node_id": "p1::image_proxy::FACE_IMAGE", "modality": "image_proxy",
"phi_type": "FACE_IMAGE", "risk_entropy": 0.45,
"context_confidence": 0.5, "pseudonym_version": 0},
],
"edges": [
{"source": "p1::text::NAME_DATE_MRN_FACILITY",
"target": "p1::asr::NAME_DATE_MRN",
"type": "co_occurrence", "weight": 0.71},
{"source": "p1::text::NAME_DATE_MRN_FACILITY",
"target": "p1::image_proxy::FACE_IMAGE",
"type": "cross_modal", "weight": 0.58},
],
"provisional_risk": 0.664,
}
result = encode_patient(summary)
print(json.dumps(result, indent=2))
print(f"\nrisk_score: {result['risk_score']}")
print(f"embed_dim: {result['embed_dim']}")
print(f"nodes encoded: {len(result['node_embeddings'])}")