Upload 3 files
Browse files- config_dcpg.json +22 -0
- dcpg_encoder.py +112 -257
- inference_dcpg.py +40 -0
config_dcpg.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "dcpg_encoder",
|
| 3 |
+
"architecture": "GAT",
|
| 4 |
+
"node_feat_dim": 19,
|
| 5 |
+
"hidden_dim": 32,
|
| 6 |
+
"embed_dim": 16,
|
| 7 |
+
"num_layers": 2,
|
| 8 |
+
"pooling": "attention",
|
| 9 |
+
"attention": "single_head",
|
| 10 |
+
"edge_weight_formula": "0.30*f_temporal + 0.30*f_semantic + 0.25*f_modality + 0.15*f_trust",
|
| 11 |
+
"input_sources": [
|
| 12 |
+
"DCPGAdapter.graph_summary",
|
| 13 |
+
"CRDTGraph.summary"
|
| 14 |
+
],
|
| 15 |
+
"output": {
|
| 16 |
+
"patient_embedding": 16,
|
| 17 |
+
"node_embeddings": "per_node",
|
| 18 |
+
"risk_score": "scalar_sigmoid"
|
| 19 |
+
},
|
| 20 |
+
"dependencies": [],
|
| 21 |
+
"framework": "pure_python"
|
| 22 |
+
}
|
dcpg_encoder.py
CHANGED
|
@@ -6,40 +6,34 @@ from dataclasses import dataclass, field
|
|
| 6 |
from typing import Any, Dict, List, Optional, Tuple
|
| 7 |
|
| 8 |
|
| 9 |
-
# ---------------------------------------------------------------------------
|
| 10 |
-
# Node feature extraction
|
| 11 |
-
# ---------------------------------------------------------------------------
|
| 12 |
-
|
| 13 |
MODALITY_INDEX = {
|
| 14 |
-
"text": 0,
|
| 15 |
-
"
|
| 16 |
-
"image_proxy": 2,
|
| 17 |
-
"waveform_proxy": 3,
|
| 18 |
-
"audio_proxy": 4,
|
| 19 |
-
"image_link": 5,
|
| 20 |
-
"audio_link": 6,
|
| 21 |
}
|
| 22 |
-
MODALITY_DIM = len(MODALITY_INDEX) + 1
|
| 23 |
|
| 24 |
PHI_TYPE_INDEX = {
|
| 25 |
-
"NAME_DATE_MRN_FACILITY": 0,
|
| 26 |
-
"
|
| 27 |
-
"FACE_IMAGE": 2,
|
| 28 |
-
"WAVEFORM_HEADER": 3,
|
| 29 |
-
"VOICE": 4,
|
| 30 |
-
"FACE_LINK": 5,
|
| 31 |
-
"VOICE_LINK": 6,
|
| 32 |
}
|
| 33 |
PHI_TYPE_DIM = len(PHI_TYPE_INDEX) + 1
|
| 34 |
|
| 35 |
-
NODE_SCALAR_DIM = 3
|
| 36 |
-
NODE_FEAT_DIM = MODALITY_DIM + PHI_TYPE_DIM + NODE_SCALAR_DIM #
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
def _one_hot(idx_map: Dict[str, int], key: str, dim: int) -> List[float]:
|
| 40 |
vec = [0.0] * dim
|
| 41 |
-
|
| 42 |
-
vec[i] = 1.0
|
| 43 |
return vec
|
| 44 |
|
| 45 |
|
|
@@ -51,30 +45,15 @@ def node_features(
|
|
| 51 |
pseudonym_version: int,
|
| 52 |
max_pv: int = 10,
|
| 53 |
) -> List[float]:
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
# ---------------------------------------------------------------------------
|
| 65 |
-
# Linear layer (no deps)
|
| 66 |
-
# ---------------------------------------------------------------------------
|
| 67 |
-
|
| 68 |
-
def _matmul(A: List[List[float]], B: List[List[float]]) -> List[List[float]]:
|
| 69 |
-
rows, mid, cols = len(A), len(B), len(B[0])
|
| 70 |
-
out = [[0.0] * cols for _ in range(rows)]
|
| 71 |
-
for i in range(rows):
|
| 72 |
-
for k in range(mid):
|
| 73 |
-
if A[i][k] == 0.0:
|
| 74 |
-
continue
|
| 75 |
-
for j in range(cols):
|
| 76 |
-
out[i][j] += A[i][k] * B[k][j]
|
| 77 |
-
return out
|
| 78 |
|
| 79 |
|
| 80 |
def _matvec(W: List[List[float]], x: List[float]) -> List[float]:
|
|
@@ -92,12 +71,8 @@ def _softmax(x: List[float]) -> List[float]:
|
|
| 92 |
return [v / s for v in e]
|
| 93 |
|
| 94 |
|
| 95 |
-
def _norm(x: List[float]) -> float:
|
| 96 |
-
return math.sqrt(sum(v * v for v in x)) or 1.0
|
| 97 |
-
|
| 98 |
-
|
| 99 |
def _normalize(x: List[float]) -> List[float]:
|
| 100 |
-
n =
|
| 101 |
return [v / n for v in x]
|
| 102 |
|
| 103 |
|
|
@@ -105,13 +80,24 @@ def _add(a: List[float], b: List[float]) -> List[float]:
|
|
| 105 |
return [a[i] + b[i] for i in range(len(a))]
|
| 106 |
|
| 107 |
|
| 108 |
-
def
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
@dataclass
|
| 117 |
class GATLayer:
|
|
@@ -138,7 +124,6 @@ class GATLayer:
|
|
| 138 |
n = len(node_feats)
|
| 139 |
h = [_relu(_matvec(self.W, x)) for x in node_feats]
|
| 140 |
|
| 141 |
-
# attention coefficients
|
| 142 |
e: Dict[Tuple[int, int], float] = {}
|
| 143 |
for (src, dst), w in zip(edge_index, edge_weights):
|
| 144 |
score = (
|
|
@@ -147,167 +132,23 @@ class GATLayer:
|
|
| 147 |
)
|
| 148 |
e[(src, dst)] = math.exp(score) * float(w)
|
| 149 |
|
| 150 |
-
# per-node normalization
|
| 151 |
norm_sum: List[float] = [0.0] * n
|
| 152 |
for (src, dst), v in e.items():
|
| 153 |
norm_sum[dst] += v
|
| 154 |
for (src, dst) in e:
|
| 155 |
-
|
| 156 |
-
e[(src, dst)] /= denom
|
| 157 |
|
| 158 |
-
# aggregate
|
| 159 |
out = [[0.0] * self.out_dim for _ in range(n)]
|
| 160 |
for (src, dst), alpha in e.items():
|
| 161 |
for k in range(self.out_dim):
|
| 162 |
out[dst][k] += alpha * h[src][k]
|
| 163 |
|
| 164 |
-
# residual add (project if needed)
|
| 165 |
for i in range(n):
|
| 166 |
out[i] = _add(out[i], h[i])
|
| 167 |
|
| 168 |
return out
|
| 169 |
|
| 170 |
|
| 171 |
-
def _xavier_init(rows: int, cols: int) -> List[List[float]]:
|
| 172 |
-
limit = math.sqrt(6.0 / (rows + cols))
|
| 173 |
-
import random
|
| 174 |
-
rng = random.Random(42)
|
| 175 |
-
return [
|
| 176 |
-
[rng.uniform(-limit, limit) for _ in range(cols)]
|
| 177 |
-
for _ in range(rows)
|
| 178 |
-
]
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
# ---------------------------------------------------------------------------
|
| 182 |
-
# Pooling
|
| 183 |
-
# ---------------------------------------------------------------------------
|
| 184 |
-
|
| 185 |
-
def mean_pool(node_embeds: List[List[float]]) -> List[float]:
|
| 186 |
-
if not node_embeds:
|
| 187 |
-
return []
|
| 188 |
-
dim = len(node_embeds[0])
|
| 189 |
-
out = [0.0] * dim
|
| 190 |
-
for h in node_embeds:
|
| 191 |
-
for k in range(dim):
|
| 192 |
-
out[k] += h[k]
|
| 193 |
-
return [v / len(node_embeds) for v in out]
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
def max_pool(node_embeds: List[List[float]]) -> List[float]:
|
| 197 |
-
if not node_embeds:
|
| 198 |
-
return []
|
| 199 |
-
dim = len(node_embeds[0])
|
| 200 |
-
out = [-1e9] * dim
|
| 201 |
-
for h in node_embeds:
|
| 202 |
-
for k in range(dim):
|
| 203 |
-
if h[k] > out[k]:
|
| 204 |
-
out[k] = h[k]
|
| 205 |
-
return out
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
def attention_pool(
|
| 209 |
-
node_embeds: List[List[float]],
|
| 210 |
-
risk_entropies: List[float],
|
| 211 |
-
) -> List[float]:
|
| 212 |
-
if not node_embeds:
|
| 213 |
-
return []
|
| 214 |
-
weights = _softmax(risk_entropies)
|
| 215 |
-
dim = len(node_embeds[0])
|
| 216 |
-
out = [0.0] * dim
|
| 217 |
-
for h, w in zip(node_embeds, weights):
|
| 218 |
-
for k in range(dim):
|
| 219 |
-
out[k] += w * h[k]
|
| 220 |
-
return out
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
# ---------------------------------------------------------------------------
|
| 224 |
-
# Encoder
|
| 225 |
-
# ---------------------------------------------------------------------------
|
| 226 |
-
|
| 227 |
-
HIDDEN_DIM = 32
|
| 228 |
-
EMBED_DIM = 16
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
@dataclass
|
| 232 |
-
class DCPGEncoder:
|
| 233 |
-
"""
|
| 234 |
-
Two-layer GAT encoder over a DCPG graph.
|
| 235 |
-
|
| 236 |
-
Input: graph_summary dict from DCPGAdapter.graph_summary()
|
| 237 |
-
or CRDTGraph.summary() enriched with node features
|
| 238 |
-
Output: patient_embedding (EMBED_DIM floats) + risk_score (float)
|
| 239 |
-
"""
|
| 240 |
-
layer1: GATLayer = field(default_factory=lambda: GATLayer(NODE_FEAT_DIM, HIDDEN_DIM))
|
| 241 |
-
layer2: GATLayer = field(default_factory=lambda: GATLayer(HIDDEN_DIM, EMBED_DIM))
|
| 242 |
-
risk_head: List[List[float]] = field(default_factory=lambda: _xavier_init(1, EMBED_DIM))
|
| 243 |
-
|
| 244 |
-
def encode(self, graph: "DCPGGraph") -> "EncoderOutput":
|
| 245 |
-
if not graph.nodes:
|
| 246 |
-
zero = [0.0] * EMBED_DIM
|
| 247 |
-
return EncoderOutput(
|
| 248 |
-
patient_embedding=zero,
|
| 249 |
-
node_embeddings=[],
|
| 250 |
-
risk_score=0.0,
|
| 251 |
-
node_ids=[],
|
| 252 |
-
)
|
| 253 |
-
|
| 254 |
-
feats = [n.feature_vec() for n in graph.nodes]
|
| 255 |
-
ei = graph.edge_index()
|
| 256 |
-
ew = graph.edge_weights()
|
| 257 |
-
|
| 258 |
-
h1 = self.layer1.forward(feats, ei, ew)
|
| 259 |
-
h2 = self.layer2.forward(h1, ei, ew)
|
| 260 |
-
|
| 261 |
-
risk_entropies = [n.risk_entropy for n in graph.nodes]
|
| 262 |
-
patient_emb = attention_pool(h2, risk_entropies)
|
| 263 |
-
patient_emb = _normalize(patient_emb)
|
| 264 |
-
|
| 265 |
-
risk_score = math.sigmoid_approx(
|
| 266 |
-
sum(self.risk_head[0][k] * patient_emb[k] for k in range(EMBED_DIM))
|
| 267 |
-
)
|
| 268 |
-
|
| 269 |
-
return EncoderOutput(
|
| 270 |
-
patient_embedding=patient_emb,
|
| 271 |
-
node_embeddings=[_normalize(h) for h in h2],
|
| 272 |
-
risk_score=round(risk_score, 4),
|
| 273 |
-
node_ids=[n.node_id for n in graph.nodes],
|
| 274 |
-
)
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
def _sigmoid(x: float) -> float:
|
| 278 |
-
if x >= 0:
|
| 279 |
-
return 1.0 / (1.0 + math.exp(-x))
|
| 280 |
-
e = math.exp(x)
|
| 281 |
-
return e / (1.0 + e)
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
# patch into math namespace for use above
|
| 285 |
-
math.sigmoid_approx = _sigmoid # type: ignore[attr-defined]
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
@dataclass
|
| 289 |
-
class EncoderOutput:
|
| 290 |
-
patient_embedding: List[float]
|
| 291 |
-
node_embeddings: List[List[float]]
|
| 292 |
-
risk_score: float
|
| 293 |
-
node_ids: List[str]
|
| 294 |
-
|
| 295 |
-
def to_dict(self) -> Dict[str, Any]:
|
| 296 |
-
return {
|
| 297 |
-
"patient_embedding": [round(v, 5) for v in self.patient_embedding],
|
| 298 |
-
"node_embeddings": {
|
| 299 |
-
nid: [round(v, 5) for v in emb]
|
| 300 |
-
for nid, emb in zip(self.node_ids, self.node_embeddings)
|
| 301 |
-
},
|
| 302 |
-
"risk_score": self.risk_score,
|
| 303 |
-
"embed_dim": len(self.patient_embedding),
|
| 304 |
-
}
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
# ---------------------------------------------------------------------------
|
| 308 |
-
# DCPGGraph — thin wrapper to consume DCPGAdapter.graph_summary() output
|
| 309 |
-
# ---------------------------------------------------------------------------
|
| 310 |
-
|
| 311 |
@dataclass
|
| 312 |
class DCPGGraphNode:
|
| 313 |
node_id: str
|
|
@@ -319,11 +160,8 @@ class DCPGGraphNode:
|
|
| 319 |
|
| 320 |
def feature_vec(self) -> List[float]:
|
| 321 |
return node_features(
|
| 322 |
-
self.modality,
|
| 323 |
-
self.
|
| 324 |
-
self.risk_entropy,
|
| 325 |
-
self.context_confidence,
|
| 326 |
-
self.pseudonym_version,
|
| 327 |
)
|
| 328 |
|
| 329 |
|
|
@@ -339,68 +177,99 @@ class DCPGGraph:
|
|
| 339 |
idx = self._node_index()
|
| 340 |
ei: List[Tuple[int, int]] = []
|
| 341 |
for e in self.edges:
|
| 342 |
-
s = idx.get(e["source"])
|
| 343 |
-
t = idx.get(e["target"])
|
| 344 |
if s is not None and t is not None:
|
| 345 |
-
ei
|
| 346 |
-
ei.append((t, s)) # undirected
|
| 347 |
return ei
|
| 348 |
|
| 349 |
def edge_weights(self) -> List[float]:
|
| 350 |
idx = self._node_index()
|
| 351 |
ew: List[float] = []
|
| 352 |
for e in self.edges:
|
| 353 |
-
s = idx.get(e["source"])
|
| 354 |
-
t = idx.get(e["target"])
|
| 355 |
if s is not None and t is not None:
|
| 356 |
w = float(e.get("weight", 1.0))
|
| 357 |
-
ew
|
| 358 |
return ew
|
| 359 |
|
| 360 |
@classmethod
|
| 361 |
def from_summary(cls, summary: Dict[str, Any]) -> "DCPGGraph":
|
| 362 |
nodes = [
|
| 363 |
DCPGGraphNode(
|
| 364 |
-
node_id=n["node_id"],
|
| 365 |
-
modality=n["modality"],
|
| 366 |
-
phi_type=n["phi_type"],
|
| 367 |
risk_entropy=float(n.get("risk_entropy", 0.0)),
|
| 368 |
context_confidence=float(n.get("context_confidence", 1.0)),
|
| 369 |
pseudonym_version=int(n.get("pseudonym_version", 0)),
|
| 370 |
)
|
| 371 |
for n in summary.get("nodes", [])
|
| 372 |
]
|
| 373 |
-
|
| 374 |
-
return cls(nodes=nodes, edges=edges)
|
| 375 |
|
| 376 |
@classmethod
|
| 377 |
-
def from_crdt_summary(
|
| 378 |
-
cls,
|
| 379 |
-
summary: Dict[str, Any],
|
| 380 |
-
provisional_risk: float = 0.0,
|
| 381 |
-
) -> "DCPGGraph":
|
| 382 |
nodes = []
|
| 383 |
for n in summary.get("nodes", []):
|
| 384 |
parts = str(n["node_id"]).split("::")
|
| 385 |
modality = parts[1] if len(parts) > 1 else "text"
|
| 386 |
-
nodes.append(
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
1.0, float(n.get("total_phi_units", 1)) / 10.0
|
| 394 |
-
),
|
| 395 |
-
pseudonym_version=int(n.get("pseudonym_version", 0)),
|
| 396 |
-
)
|
| 397 |
-
)
|
| 398 |
return cls(nodes=nodes, edges=[])
|
| 399 |
|
| 400 |
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
|
| 405 |
def encode_patient(
|
| 406 |
graph_summary: Dict[str, Any],
|
|
@@ -415,18 +284,11 @@ def encode_patient(
|
|
| 415 |
)
|
| 416 |
else:
|
| 417 |
g = DCPGGraph.from_summary(graph_summary)
|
| 418 |
-
|
| 419 |
-
return out.to_dict()
|
| 420 |
|
| 421 |
|
| 422 |
-
# ---------------------------------------------------------------------------
|
| 423 |
-
# Smoke test
|
| 424 |
-
# ---------------------------------------------------------------------------
|
| 425 |
-
|
| 426 |
if __name__ == "__main__":
|
| 427 |
summary = {
|
| 428 |
-
"node_count": 3,
|
| 429 |
-
"edge_count": 2,
|
| 430 |
"nodes": [
|
| 431 |
{"node_id": "p1::text::NAME_DATE_MRN_FACILITY", "modality": "text",
|
| 432 |
"phi_type": "NAME_DATE_MRN_FACILITY", "risk_entropy": 0.72,
|
|
@@ -440,17 +302,10 @@ if __name__ == "__main__":
|
|
| 440 |
],
|
| 441 |
"edges": [
|
| 442 |
{"source": "p1::text::NAME_DATE_MRN_FACILITY",
|
| 443 |
-
"target": "p1::asr::NAME_DATE_MRN",
|
| 444 |
-
"type": "co_occurrence", "weight": 0.71},
|
| 445 |
{"source": "p1::text::NAME_DATE_MRN_FACILITY",
|
| 446 |
-
"target": "p1::image_proxy::FACE_IMAGE",
|
| 447 |
-
"type": "cross_modal", "weight": 0.58},
|
| 448 |
],
|
| 449 |
-
"provisional_risk": 0.664,
|
| 450 |
}
|
| 451 |
-
|
| 452 |
result = encode_patient(summary)
|
| 453 |
print(json.dumps(result, indent=2))
|
| 454 |
-
print(f"\nrisk_score: {result['risk_score']}")
|
| 455 |
-
print(f"embed_dim: {result['embed_dim']}")
|
| 456 |
-
print(f"nodes encoded: {len(result['node_embeddings'])}")
|
|
|
|
| 6 |
from typing import Any, Dict, List, Optional, Tuple
|
| 7 |
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
MODALITY_INDEX = {
|
| 10 |
+
"text": 0, "asr": 1, "image_proxy": 2, "waveform_proxy": 3,
|
| 11 |
+
"audio_proxy": 4, "image_link": 5, "audio_link": 6,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
}
|
| 13 |
+
MODALITY_DIM = len(MODALITY_INDEX) + 1
|
| 14 |
|
| 15 |
PHI_TYPE_INDEX = {
|
| 16 |
+
"NAME_DATE_MRN_FACILITY": 0, "NAME_DATE_MRN": 1, "FACE_IMAGE": 2,
|
| 17 |
+
"WAVEFORM_HEADER": 3, "VOICE": 4, "FACE_LINK": 5, "VOICE_LINK": 6,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
}
|
| 19 |
PHI_TYPE_DIM = len(PHI_TYPE_INDEX) + 1
|
| 20 |
|
| 21 |
+
NODE_SCALAR_DIM = 3
|
| 22 |
+
NODE_FEAT_DIM = MODALITY_DIM + PHI_TYPE_DIM + NODE_SCALAR_DIM # 19
|
| 23 |
+
HIDDEN_DIM = 32
|
| 24 |
+
EMBED_DIM = 16
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _sigmoid(x: float) -> float:
|
| 28 |
+
if x >= 0:
|
| 29 |
+
return 1.0 / (1.0 + math.exp(-x))
|
| 30 |
+
e = math.exp(x)
|
| 31 |
+
return e / (1.0 + e)
|
| 32 |
|
| 33 |
|
| 34 |
def _one_hot(idx_map: Dict[str, int], key: str, dim: int) -> List[float]:
|
| 35 |
vec = [0.0] * dim
|
| 36 |
+
vec[idx_map.get(key, dim - 1)] = 1.0
|
|
|
|
| 37 |
return vec
|
| 38 |
|
| 39 |
|
|
|
|
| 45 |
pseudonym_version: int,
|
| 46 |
max_pv: int = 10,
|
| 47 |
) -> List[float]:
|
| 48 |
+
return (
|
| 49 |
+
_one_hot(MODALITY_INDEX, modality, MODALITY_DIM)
|
| 50 |
+
+ _one_hot(PHI_TYPE_INDEX, phi_type, PHI_TYPE_DIM)
|
| 51 |
+
+ [
|
| 52 |
+
float(max(0.0, min(1.0, risk_entropy))),
|
| 53 |
+
float(max(0.0, min(1.0, context_confidence))),
|
| 54 |
+
float(min(pseudonym_version, max_pv)) / float(max_pv),
|
| 55 |
+
]
|
| 56 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
|
| 59 |
def _matvec(W: List[List[float]], x: List[float]) -> List[float]:
|
|
|
|
| 71 |
return [v / s for v in e]
|
| 72 |
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
def _normalize(x: List[float]) -> List[float]:
|
| 75 |
+
n = math.sqrt(sum(v * v for v in x)) or 1.0
|
| 76 |
return [v / n for v in x]
|
| 77 |
|
| 78 |
|
|
|
|
| 80 |
return [a[i] + b[i] for i in range(len(a))]
|
| 81 |
|
| 82 |
|
| 83 |
+
def _xavier_init(rows: int, cols: int) -> List[List[float]]:
|
| 84 |
+
import random
|
| 85 |
+
limit = math.sqrt(6.0 / (rows + cols))
|
| 86 |
+
rng = random.Random(42)
|
| 87 |
+
return [[rng.uniform(-limit, limit) for _ in range(cols)] for _ in range(rows)]
|
| 88 |
|
| 89 |
|
| 90 |
+
def attention_pool(node_embeds: List[List[float]], risk_entropies: List[float]) -> List[float]:
|
| 91 |
+
if not node_embeds:
|
| 92 |
+
return []
|
| 93 |
+
weights = _softmax(risk_entropies)
|
| 94 |
+
dim = len(node_embeds[0])
|
| 95 |
+
out = [0.0] * dim
|
| 96 |
+
for h, w in zip(node_embeds, weights):
|
| 97 |
+
for k in range(dim):
|
| 98 |
+
out[k] += w * h[k]
|
| 99 |
+
return out
|
| 100 |
+
|
| 101 |
|
| 102 |
@dataclass
|
| 103 |
class GATLayer:
|
|
|
|
| 124 |
n = len(node_feats)
|
| 125 |
h = [_relu(_matvec(self.W, x)) for x in node_feats]
|
| 126 |
|
|
|
|
| 127 |
e: Dict[Tuple[int, int], float] = {}
|
| 128 |
for (src, dst), w in zip(edge_index, edge_weights):
|
| 129 |
score = (
|
|
|
|
| 132 |
)
|
| 133 |
e[(src, dst)] = math.exp(score) * float(w)
|
| 134 |
|
|
|
|
| 135 |
norm_sum: List[float] = [0.0] * n
|
| 136 |
for (src, dst), v in e.items():
|
| 137 |
norm_sum[dst] += v
|
| 138 |
for (src, dst) in e:
|
| 139 |
+
e[(src, dst)] /= norm_sum[dst] or 1.0
|
|
|
|
| 140 |
|
|
|
|
| 141 |
out = [[0.0] * self.out_dim for _ in range(n)]
|
| 142 |
for (src, dst), alpha in e.items():
|
| 143 |
for k in range(self.out_dim):
|
| 144 |
out[dst][k] += alpha * h[src][k]
|
| 145 |
|
|
|
|
| 146 |
for i in range(n):
|
| 147 |
out[i] = _add(out[i], h[i])
|
| 148 |
|
| 149 |
return out
|
| 150 |
|
| 151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
@dataclass
|
| 153 |
class DCPGGraphNode:
|
| 154 |
node_id: str
|
|
|
|
| 160 |
|
| 161 |
def feature_vec(self) -> List[float]:
|
| 162 |
return node_features(
|
| 163 |
+
self.modality, self.phi_type,
|
| 164 |
+
self.risk_entropy, self.context_confidence, self.pseudonym_version,
|
|
|
|
|
|
|
|
|
|
| 165 |
)
|
| 166 |
|
| 167 |
|
|
|
|
| 177 |
idx = self._node_index()
|
| 178 |
ei: List[Tuple[int, int]] = []
|
| 179 |
for e in self.edges:
|
| 180 |
+
s, t = idx.get(e["source"]), idx.get(e["target"])
|
|
|
|
| 181 |
if s is not None and t is not None:
|
| 182 |
+
ei += [(s, t), (t, s)]
|
|
|
|
| 183 |
return ei
|
| 184 |
|
| 185 |
def edge_weights(self) -> List[float]:
|
| 186 |
idx = self._node_index()
|
| 187 |
ew: List[float] = []
|
| 188 |
for e in self.edges:
|
| 189 |
+
s, t = idx.get(e["source"]), idx.get(e["target"])
|
|
|
|
| 190 |
if s is not None and t is not None:
|
| 191 |
w = float(e.get("weight", 1.0))
|
| 192 |
+
ew += [w, w]
|
| 193 |
return ew
|
| 194 |
|
| 195 |
@classmethod
|
| 196 |
def from_summary(cls, summary: Dict[str, Any]) -> "DCPGGraph":
|
| 197 |
nodes = [
|
| 198 |
DCPGGraphNode(
|
| 199 |
+
node_id=n["node_id"], modality=n["modality"], phi_type=n["phi_type"],
|
|
|
|
|
|
|
| 200 |
risk_entropy=float(n.get("risk_entropy", 0.0)),
|
| 201 |
context_confidence=float(n.get("context_confidence", 1.0)),
|
| 202 |
pseudonym_version=int(n.get("pseudonym_version", 0)),
|
| 203 |
)
|
| 204 |
for n in summary.get("nodes", [])
|
| 205 |
]
|
| 206 |
+
return cls(nodes=nodes, edges=summary.get("edges", []))
|
|
|
|
| 207 |
|
| 208 |
@classmethod
|
| 209 |
+
def from_crdt_summary(cls, summary: Dict[str, Any], provisional_risk: float = 0.0) -> "DCPGGraph":
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
nodes = []
|
| 211 |
for n in summary.get("nodes", []):
|
| 212 |
parts = str(n["node_id"]).split("::")
|
| 213 |
modality = parts[1] if len(parts) > 1 else "text"
|
| 214 |
+
nodes.append(DCPGGraphNode(
|
| 215 |
+
node_id=n["node_id"], modality=modality,
|
| 216 |
+
phi_type=modality.upper(),
|
| 217 |
+
risk_entropy=float(n.get("risk_entropy", provisional_risk)),
|
| 218 |
+
context_confidence=min(1.0, float(n.get("total_phi_units", 1)) / 10.0),
|
| 219 |
+
pseudonym_version=int(n.get("pseudonym_version", 0)),
|
| 220 |
+
))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
return cls(nodes=nodes, edges=[])
|
| 222 |
|
| 223 |
|
| 224 |
+
@dataclass
|
| 225 |
+
class EncoderOutput:
|
| 226 |
+
patient_embedding: List[float]
|
| 227 |
+
node_embeddings: List[List[float]]
|
| 228 |
+
risk_score: float
|
| 229 |
+
node_ids: List[str]
|
| 230 |
+
|
| 231 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 232 |
+
return {
|
| 233 |
+
"patient_embedding": [round(v, 5) for v in self.patient_embedding],
|
| 234 |
+
"node_embeddings": {
|
| 235 |
+
nid: [round(v, 5) for v in emb]
|
| 236 |
+
for nid, emb in zip(self.node_ids, self.node_embeddings)
|
| 237 |
+
},
|
| 238 |
+
"risk_score": self.risk_score,
|
| 239 |
+
"embed_dim": len(self.patient_embedding),
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
@dataclass
|
| 244 |
+
class DCPGEncoder:
|
| 245 |
+
layer1: GATLayer = field(default_factory=lambda: GATLayer(NODE_FEAT_DIM, HIDDEN_DIM))
|
| 246 |
+
layer2: GATLayer = field(default_factory=lambda: GATLayer(HIDDEN_DIM, EMBED_DIM))
|
| 247 |
+
risk_head: List[List[float]] = field(default_factory=lambda: _xavier_init(1, EMBED_DIM))
|
| 248 |
+
|
| 249 |
+
def encode(self, graph: DCPGGraph) -> EncoderOutput:
|
| 250 |
+
if not graph.nodes:
|
| 251 |
+
return EncoderOutput(
|
| 252 |
+
patient_embedding=[0.0] * EMBED_DIM,
|
| 253 |
+
node_embeddings=[], risk_score=0.0, node_ids=[],
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
feats = [n.feature_vec() for n in graph.nodes]
|
| 257 |
+
ei = graph.edge_index()
|
| 258 |
+
ew = graph.edge_weights()
|
| 259 |
+
|
| 260 |
+
h1 = self.layer1.forward(feats, ei, ew)
|
| 261 |
+
h2 = self.layer2.forward(h1, ei, ew)
|
| 262 |
+
|
| 263 |
+
patient_emb = _normalize(attention_pool(h2, [n.risk_entropy for n in graph.nodes]))
|
| 264 |
+
risk_score = _sigmoid(sum(self.risk_head[0][k] * patient_emb[k] for k in range(EMBED_DIM)))
|
| 265 |
+
|
| 266 |
+
return EncoderOutput(
|
| 267 |
+
patient_embedding=patient_emb,
|
| 268 |
+
node_embeddings=[_normalize(h) for h in h2],
|
| 269 |
+
risk_score=round(risk_score, 4),
|
| 270 |
+
node_ids=[n.node_id for n in graph.nodes],
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
|
| 274 |
def encode_patient(
|
| 275 |
graph_summary: Dict[str, Any],
|
|
|
|
| 284 |
)
|
| 285 |
else:
|
| 286 |
g = DCPGGraph.from_summary(graph_summary)
|
| 287 |
+
return enc.encode(g).to_dict()
|
|
|
|
| 288 |
|
| 289 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
if __name__ == "__main__":
|
| 291 |
summary = {
|
|
|
|
|
|
|
| 292 |
"nodes": [
|
| 293 |
{"node_id": "p1::text::NAME_DATE_MRN_FACILITY", "modality": "text",
|
| 294 |
"phi_type": "NAME_DATE_MRN_FACILITY", "risk_entropy": 0.72,
|
|
|
|
| 302 |
],
|
| 303 |
"edges": [
|
| 304 |
{"source": "p1::text::NAME_DATE_MRN_FACILITY",
|
| 305 |
+
"target": "p1::asr::NAME_DATE_MRN", "type": "co_occurrence", "weight": 0.71},
|
|
|
|
| 306 |
{"source": "p1::text::NAME_DATE_MRN_FACILITY",
|
| 307 |
+
"target": "p1::image_proxy::FACE_IMAGE", "type": "cross_modal", "weight": 0.58},
|
|
|
|
| 308 |
],
|
|
|
|
| 309 |
}
|
|
|
|
| 310 |
result = encode_patient(summary)
|
| 311 |
print(json.dumps(result, indent=2))
|
|
|
|
|
|
|
|
|
inference_dcpg.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
from dcpg_encoder import DCPGEncoder, encode_patient
|
| 7 |
+
|
| 8 |
+
_encoder = DCPGEncoder()
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def predict(graph_summary: dict, source: str = "dcpg") -> dict:
|
| 12 |
+
return encode_patient(graph_summary, encoder=_encoder, source=source)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def predict_batch(summaries: list, source: str = "dcpg") -> list:
|
| 16 |
+
return [predict(s, source=source) for s in summaries]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
if __name__ == "__main__":
|
| 20 |
+
if len(sys.argv) > 1:
|
| 21 |
+
with open(sys.argv[1]) as f:
|
| 22 |
+
data = json.load(f)
|
| 23 |
+
result = predict(data)
|
| 24 |
+
else:
|
| 25 |
+
result = predict({
|
| 26 |
+
"nodes": [
|
| 27 |
+
{"node_id": "p1::text::NAME_DATE_MRN_FACILITY", "modality": "text",
|
| 28 |
+
"phi_type": "NAME_DATE_MRN_FACILITY", "risk_entropy": 0.8,
|
| 29 |
+
"context_confidence": 0.9, "pseudonym_version": 2},
|
| 30 |
+
{"node_id": "p1::audio_proxy::VOICE", "modality": "audio_proxy",
|
| 31 |
+
"phi_type": "VOICE", "risk_entropy": 0.55,
|
| 32 |
+
"context_confidence": 0.6, "pseudonym_version": 1},
|
| 33 |
+
],
|
| 34 |
+
"edges": [
|
| 35 |
+
{"source": "p1::text::NAME_DATE_MRN_FACILITY",
|
| 36 |
+
"target": "p1::audio_proxy::VOICE",
|
| 37 |
+
"type": "cross_modal", "weight": 0.63},
|
| 38 |
+
],
|
| 39 |
+
})
|
| 40 |
+
print(json.dumps(result, indent=2))
|