vkatg commited on
Commit
74f6e67
·
verified ·
1 Parent(s): 51a62e8

Upload 3 files

Browse files
Files changed (3) hide show
  1. config_dcpg.json +22 -0
  2. dcpg_encoder.py +112 -257
  3. inference_dcpg.py +40 -0
config_dcpg.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "dcpg_encoder",
3
+ "architecture": "GAT",
4
+ "node_feat_dim": 19,
5
+ "hidden_dim": 32,
6
+ "embed_dim": 16,
7
+ "num_layers": 2,
8
+ "pooling": "attention",
9
+ "attention": "single_head",
10
+ "edge_weight_formula": "0.30*f_temporal + 0.30*f_semantic + 0.25*f_modality + 0.15*f_trust",
11
+ "input_sources": [
12
+ "DCPGAdapter.graph_summary",
13
+ "CRDTGraph.summary"
14
+ ],
15
+ "output": {
16
+ "patient_embedding": 16,
17
+ "node_embeddings": "per_node",
18
+ "risk_score": "scalar_sigmoid"
19
+ },
20
+ "dependencies": [],
21
+ "framework": "pure_python"
22
+ }
dcpg_encoder.py CHANGED
@@ -6,40 +6,34 @@ from dataclasses import dataclass, field
6
  from typing import Any, Dict, List, Optional, Tuple
7
 
8
 
9
- # ---------------------------------------------------------------------------
10
- # Node feature extraction
11
- # ---------------------------------------------------------------------------
12
-
13
  MODALITY_INDEX = {
14
- "text": 0,
15
- "asr": 1,
16
- "image_proxy": 2,
17
- "waveform_proxy": 3,
18
- "audio_proxy": 4,
19
- "image_link": 5,
20
- "audio_link": 6,
21
  }
22
- MODALITY_DIM = len(MODALITY_INDEX) + 1 # +1 for unknown
23
 
24
  PHI_TYPE_INDEX = {
25
- "NAME_DATE_MRN_FACILITY": 0,
26
- "NAME_DATE_MRN": 1,
27
- "FACE_IMAGE": 2,
28
- "WAVEFORM_HEADER": 3,
29
- "VOICE": 4,
30
- "FACE_LINK": 5,
31
- "VOICE_LINK": 6,
32
  }
33
  PHI_TYPE_DIM = len(PHI_TYPE_INDEX) + 1
34
 
35
- NODE_SCALAR_DIM = 3 # risk_entropy, context_confidence, pseudonym_version_norm
36
- NODE_FEAT_DIM = MODALITY_DIM + PHI_TYPE_DIM + NODE_SCALAR_DIM # 18
 
 
 
 
 
 
 
 
 
37
 
38
 
39
  def _one_hot(idx_map: Dict[str, int], key: str, dim: int) -> List[float]:
40
  vec = [0.0] * dim
41
- i = idx_map.get(key, dim - 1)
42
- vec[i] = 1.0
43
  return vec
44
 
45
 
@@ -51,30 +45,15 @@ def node_features(
51
  pseudonym_version: int,
52
  max_pv: int = 10,
53
  ) -> List[float]:
54
- mod_oh = _one_hot(MODALITY_INDEX, modality, MODALITY_DIM)
55
- phi_oh = _one_hot(PHI_TYPE_INDEX, phi_type, PHI_TYPE_DIM)
56
- scalars = [
57
- float(max(0.0, min(1.0, risk_entropy))),
58
- float(max(0.0, min(1.0, context_confidence))),
59
- float(min(pseudonym_version, max_pv)) / float(max_pv),
60
- ]
61
- return mod_oh + phi_oh + scalars
62
-
63
-
64
- # ---------------------------------------------------------------------------
65
- # Linear layer (no deps)
66
- # ---------------------------------------------------------------------------
67
-
68
- def _matmul(A: List[List[float]], B: List[List[float]]) -> List[List[float]]:
69
- rows, mid, cols = len(A), len(B), len(B[0])
70
- out = [[0.0] * cols for _ in range(rows)]
71
- for i in range(rows):
72
- for k in range(mid):
73
- if A[i][k] == 0.0:
74
- continue
75
- for j in range(cols):
76
- out[i][j] += A[i][k] * B[k][j]
77
- return out
78
 
79
 
80
  def _matvec(W: List[List[float]], x: List[float]) -> List[float]:
@@ -92,12 +71,8 @@ def _softmax(x: List[float]) -> List[float]:
92
  return [v / s for v in e]
93
 
94
 
95
- def _norm(x: List[float]) -> float:
96
- return math.sqrt(sum(v * v for v in x)) or 1.0
97
-
98
-
99
  def _normalize(x: List[float]) -> List[float]:
100
- n = _norm(x)
101
  return [v / n for v in x]
102
 
103
 
@@ -105,13 +80,24 @@ def _add(a: List[float], b: List[float]) -> List[float]:
105
  return [a[i] + b[i] for i in range(len(a))]
106
 
107
 
108
- def _scale(a: List[float], s: float) -> List[float]:
109
- return [v * s for v in a]
 
 
 
110
 
111
 
112
- # ---------------------------------------------------------------------------
113
- # GAT message passing (single attention head, numpy-free)
114
- # ---------------------------------------------------------------------------
 
 
 
 
 
 
 
 
115
 
116
  @dataclass
117
  class GATLayer:
@@ -138,7 +124,6 @@ class GATLayer:
138
  n = len(node_feats)
139
  h = [_relu(_matvec(self.W, x)) for x in node_feats]
140
 
141
- # attention coefficients
142
  e: Dict[Tuple[int, int], float] = {}
143
  for (src, dst), w in zip(edge_index, edge_weights):
144
  score = (
@@ -147,167 +132,23 @@ class GATLayer:
147
  )
148
  e[(src, dst)] = math.exp(score) * float(w)
149
 
150
- # per-node normalization
151
  norm_sum: List[float] = [0.0] * n
152
  for (src, dst), v in e.items():
153
  norm_sum[dst] += v
154
  for (src, dst) in e:
155
- denom = norm_sum[dst] or 1.0
156
- e[(src, dst)] /= denom
157
 
158
- # aggregate
159
  out = [[0.0] * self.out_dim for _ in range(n)]
160
  for (src, dst), alpha in e.items():
161
  for k in range(self.out_dim):
162
  out[dst][k] += alpha * h[src][k]
163
 
164
- # residual add (project if needed)
165
  for i in range(n):
166
  out[i] = _add(out[i], h[i])
167
 
168
  return out
169
 
170
 
171
- def _xavier_init(rows: int, cols: int) -> List[List[float]]:
172
- limit = math.sqrt(6.0 / (rows + cols))
173
- import random
174
- rng = random.Random(42)
175
- return [
176
- [rng.uniform(-limit, limit) for _ in range(cols)]
177
- for _ in range(rows)
178
- ]
179
-
180
-
181
- # ---------------------------------------------------------------------------
182
- # Pooling
183
- # ---------------------------------------------------------------------------
184
-
185
- def mean_pool(node_embeds: List[List[float]]) -> List[float]:
186
- if not node_embeds:
187
- return []
188
- dim = len(node_embeds[0])
189
- out = [0.0] * dim
190
- for h in node_embeds:
191
- for k in range(dim):
192
- out[k] += h[k]
193
- return [v / len(node_embeds) for v in out]
194
-
195
-
196
- def max_pool(node_embeds: List[List[float]]) -> List[float]:
197
- if not node_embeds:
198
- return []
199
- dim = len(node_embeds[0])
200
- out = [-1e9] * dim
201
- for h in node_embeds:
202
- for k in range(dim):
203
- if h[k] > out[k]:
204
- out[k] = h[k]
205
- return out
206
-
207
-
208
- def attention_pool(
209
- node_embeds: List[List[float]],
210
- risk_entropies: List[float],
211
- ) -> List[float]:
212
- if not node_embeds:
213
- return []
214
- weights = _softmax(risk_entropies)
215
- dim = len(node_embeds[0])
216
- out = [0.0] * dim
217
- for h, w in zip(node_embeds, weights):
218
- for k in range(dim):
219
- out[k] += w * h[k]
220
- return out
221
-
222
-
223
- # ---------------------------------------------------------------------------
224
- # Encoder
225
- # ---------------------------------------------------------------------------
226
-
227
- HIDDEN_DIM = 32
228
- EMBED_DIM = 16
229
-
230
-
231
- @dataclass
232
- class DCPGEncoder:
233
- """
234
- Two-layer GAT encoder over a DCPG graph.
235
-
236
- Input: graph_summary dict from DCPGAdapter.graph_summary()
237
- or CRDTGraph.summary() enriched with node features
238
- Output: patient_embedding (EMBED_DIM floats) + risk_score (float)
239
- """
240
- layer1: GATLayer = field(default_factory=lambda: GATLayer(NODE_FEAT_DIM, HIDDEN_DIM))
241
- layer2: GATLayer = field(default_factory=lambda: GATLayer(HIDDEN_DIM, EMBED_DIM))
242
- risk_head: List[List[float]] = field(default_factory=lambda: _xavier_init(1, EMBED_DIM))
243
-
244
- def encode(self, graph: "DCPGGraph") -> "EncoderOutput":
245
- if not graph.nodes:
246
- zero = [0.0] * EMBED_DIM
247
- return EncoderOutput(
248
- patient_embedding=zero,
249
- node_embeddings=[],
250
- risk_score=0.0,
251
- node_ids=[],
252
- )
253
-
254
- feats = [n.feature_vec() for n in graph.nodes]
255
- ei = graph.edge_index()
256
- ew = graph.edge_weights()
257
-
258
- h1 = self.layer1.forward(feats, ei, ew)
259
- h2 = self.layer2.forward(h1, ei, ew)
260
-
261
- risk_entropies = [n.risk_entropy for n in graph.nodes]
262
- patient_emb = attention_pool(h2, risk_entropies)
263
- patient_emb = _normalize(patient_emb)
264
-
265
- risk_score = math.sigmoid_approx(
266
- sum(self.risk_head[0][k] * patient_emb[k] for k in range(EMBED_DIM))
267
- )
268
-
269
- return EncoderOutput(
270
- patient_embedding=patient_emb,
271
- node_embeddings=[_normalize(h) for h in h2],
272
- risk_score=round(risk_score, 4),
273
- node_ids=[n.node_id for n in graph.nodes],
274
- )
275
-
276
-
277
- def _sigmoid(x: float) -> float:
278
- if x >= 0:
279
- return 1.0 / (1.0 + math.exp(-x))
280
- e = math.exp(x)
281
- return e / (1.0 + e)
282
-
283
-
284
- # patch into math namespace for use above
285
- math.sigmoid_approx = _sigmoid # type: ignore[attr-defined]
286
-
287
-
288
- @dataclass
289
- class EncoderOutput:
290
- patient_embedding: List[float]
291
- node_embeddings: List[List[float]]
292
- risk_score: float
293
- node_ids: List[str]
294
-
295
- def to_dict(self) -> Dict[str, Any]:
296
- return {
297
- "patient_embedding": [round(v, 5) for v in self.patient_embedding],
298
- "node_embeddings": {
299
- nid: [round(v, 5) for v in emb]
300
- for nid, emb in zip(self.node_ids, self.node_embeddings)
301
- },
302
- "risk_score": self.risk_score,
303
- "embed_dim": len(self.patient_embedding),
304
- }
305
-
306
-
307
- # ---------------------------------------------------------------------------
308
- # DCPGGraph — thin wrapper to consume DCPGAdapter.graph_summary() output
309
- # ---------------------------------------------------------------------------
310
-
311
  @dataclass
312
  class DCPGGraphNode:
313
  node_id: str
@@ -319,11 +160,8 @@ class DCPGGraphNode:
319
 
320
  def feature_vec(self) -> List[float]:
321
  return node_features(
322
- self.modality,
323
- self.phi_type,
324
- self.risk_entropy,
325
- self.context_confidence,
326
- self.pseudonym_version,
327
  )
328
 
329
 
@@ -339,68 +177,99 @@ class DCPGGraph:
339
  idx = self._node_index()
340
  ei: List[Tuple[int, int]] = []
341
  for e in self.edges:
342
- s = idx.get(e["source"])
343
- t = idx.get(e["target"])
344
  if s is not None and t is not None:
345
- ei.append((s, t))
346
- ei.append((t, s)) # undirected
347
  return ei
348
 
349
  def edge_weights(self) -> List[float]:
350
  idx = self._node_index()
351
  ew: List[float] = []
352
  for e in self.edges:
353
- s = idx.get(e["source"])
354
- t = idx.get(e["target"])
355
  if s is not None and t is not None:
356
  w = float(e.get("weight", 1.0))
357
- ew.extend([w, w])
358
  return ew
359
 
360
  @classmethod
361
  def from_summary(cls, summary: Dict[str, Any]) -> "DCPGGraph":
362
  nodes = [
363
  DCPGGraphNode(
364
- node_id=n["node_id"],
365
- modality=n["modality"],
366
- phi_type=n["phi_type"],
367
  risk_entropy=float(n.get("risk_entropy", 0.0)),
368
  context_confidence=float(n.get("context_confidence", 1.0)),
369
  pseudonym_version=int(n.get("pseudonym_version", 0)),
370
  )
371
  for n in summary.get("nodes", [])
372
  ]
373
- edges = summary.get("edges", [])
374
- return cls(nodes=nodes, edges=edges)
375
 
376
  @classmethod
377
- def from_crdt_summary(
378
- cls,
379
- summary: Dict[str, Any],
380
- provisional_risk: float = 0.0,
381
- ) -> "DCPGGraph":
382
  nodes = []
383
  for n in summary.get("nodes", []):
384
  parts = str(n["node_id"]).split("::")
385
  modality = parts[1] if len(parts) > 1 else "text"
386
- nodes.append(
387
- DCPGGraphNode(
388
- node_id=n["node_id"],
389
- modality=modality,
390
- phi_type=modality.upper(),
391
- risk_entropy=float(n.get("risk_entropy", provisional_risk)),
392
- context_confidence=min(
393
- 1.0, float(n.get("total_phi_units", 1)) / 10.0
394
- ),
395
- pseudonym_version=int(n.get("pseudonym_version", 0)),
396
- )
397
- )
398
  return cls(nodes=nodes, edges=[])
399
 
400
 
401
- # ---------------------------------------------------------------------------
402
- # Inference helper
403
- # ---------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
 
405
  def encode_patient(
406
  graph_summary: Dict[str, Any],
@@ -415,18 +284,11 @@ def encode_patient(
415
  )
416
  else:
417
  g = DCPGGraph.from_summary(graph_summary)
418
- out = enc.encode(g)
419
- return out.to_dict()
420
 
421
 
422
- # ---------------------------------------------------------------------------
423
- # Smoke test
424
- # ---------------------------------------------------------------------------
425
-
426
  if __name__ == "__main__":
427
  summary = {
428
- "node_count": 3,
429
- "edge_count": 2,
430
  "nodes": [
431
  {"node_id": "p1::text::NAME_DATE_MRN_FACILITY", "modality": "text",
432
  "phi_type": "NAME_DATE_MRN_FACILITY", "risk_entropy": 0.72,
@@ -440,17 +302,10 @@ if __name__ == "__main__":
440
  ],
441
  "edges": [
442
  {"source": "p1::text::NAME_DATE_MRN_FACILITY",
443
- "target": "p1::asr::NAME_DATE_MRN",
444
- "type": "co_occurrence", "weight": 0.71},
445
  {"source": "p1::text::NAME_DATE_MRN_FACILITY",
446
- "target": "p1::image_proxy::FACE_IMAGE",
447
- "type": "cross_modal", "weight": 0.58},
448
  ],
449
- "provisional_risk": 0.664,
450
  }
451
-
452
  result = encode_patient(summary)
453
  print(json.dumps(result, indent=2))
454
- print(f"\nrisk_score: {result['risk_score']}")
455
- print(f"embed_dim: {result['embed_dim']}")
456
- print(f"nodes encoded: {len(result['node_embeddings'])}")
 
6
  from typing import Any, Dict, List, Optional, Tuple
7
 
8
 
 
 
 
 
9
  MODALITY_INDEX = {
10
+ "text": 0, "asr": 1, "image_proxy": 2, "waveform_proxy": 3,
11
+ "audio_proxy": 4, "image_link": 5, "audio_link": 6,
 
 
 
 
 
12
  }
13
+ MODALITY_DIM = len(MODALITY_INDEX) + 1
14
 
15
  PHI_TYPE_INDEX = {
16
+ "NAME_DATE_MRN_FACILITY": 0, "NAME_DATE_MRN": 1, "FACE_IMAGE": 2,
17
+ "WAVEFORM_HEADER": 3, "VOICE": 4, "FACE_LINK": 5, "VOICE_LINK": 6,
 
 
 
 
 
18
  }
19
  PHI_TYPE_DIM = len(PHI_TYPE_INDEX) + 1
20
 
21
+ NODE_SCALAR_DIM = 3
22
+ NODE_FEAT_DIM = MODALITY_DIM + PHI_TYPE_DIM + NODE_SCALAR_DIM # 19
23
+ HIDDEN_DIM = 32
24
+ EMBED_DIM = 16
25
+
26
+
27
+ def _sigmoid(x: float) -> float:
28
+ if x >= 0:
29
+ return 1.0 / (1.0 + math.exp(-x))
30
+ e = math.exp(x)
31
+ return e / (1.0 + e)
32
 
33
 
34
  def _one_hot(idx_map: Dict[str, int], key: str, dim: int) -> List[float]:
35
  vec = [0.0] * dim
36
+ vec[idx_map.get(key, dim - 1)] = 1.0
 
37
  return vec
38
 
39
 
 
45
  pseudonym_version: int,
46
  max_pv: int = 10,
47
  ) -> List[float]:
48
+ return (
49
+ _one_hot(MODALITY_INDEX, modality, MODALITY_DIM)
50
+ + _one_hot(PHI_TYPE_INDEX, phi_type, PHI_TYPE_DIM)
51
+ + [
52
+ float(max(0.0, min(1.0, risk_entropy))),
53
+ float(max(0.0, min(1.0, context_confidence))),
54
+ float(min(pseudonym_version, max_pv)) / float(max_pv),
55
+ ]
56
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
 
59
  def _matvec(W: List[List[float]], x: List[float]) -> List[float]:
 
71
  return [v / s for v in e]
72
 
73
 
 
 
 
 
74
  def _normalize(x: List[float]) -> List[float]:
75
+ n = math.sqrt(sum(v * v for v in x)) or 1.0
76
  return [v / n for v in x]
77
 
78
 
 
80
  return [a[i] + b[i] for i in range(len(a))]
81
 
82
 
83
+ def _xavier_init(rows: int, cols: int) -> List[List[float]]:
84
+ import random
85
+ limit = math.sqrt(6.0 / (rows + cols))
86
+ rng = random.Random(42)
87
+ return [[rng.uniform(-limit, limit) for _ in range(cols)] for _ in range(rows)]
88
 
89
 
90
+ def attention_pool(node_embeds: List[List[float]], risk_entropies: List[float]) -> List[float]:
91
+ if not node_embeds:
92
+ return []
93
+ weights = _softmax(risk_entropies)
94
+ dim = len(node_embeds[0])
95
+ out = [0.0] * dim
96
+ for h, w in zip(node_embeds, weights):
97
+ for k in range(dim):
98
+ out[k] += w * h[k]
99
+ return out
100
+
101
 
102
  @dataclass
103
  class GATLayer:
 
124
  n = len(node_feats)
125
  h = [_relu(_matvec(self.W, x)) for x in node_feats]
126
 
 
127
  e: Dict[Tuple[int, int], float] = {}
128
  for (src, dst), w in zip(edge_index, edge_weights):
129
  score = (
 
132
  )
133
  e[(src, dst)] = math.exp(score) * float(w)
134
 
 
135
  norm_sum: List[float] = [0.0] * n
136
  for (src, dst), v in e.items():
137
  norm_sum[dst] += v
138
  for (src, dst) in e:
139
+ e[(src, dst)] /= norm_sum[dst] or 1.0
 
140
 
 
141
  out = [[0.0] * self.out_dim for _ in range(n)]
142
  for (src, dst), alpha in e.items():
143
  for k in range(self.out_dim):
144
  out[dst][k] += alpha * h[src][k]
145
 
 
146
  for i in range(n):
147
  out[i] = _add(out[i], h[i])
148
 
149
  return out
150
 
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  @dataclass
153
  class DCPGGraphNode:
154
  node_id: str
 
160
 
161
  def feature_vec(self) -> List[float]:
162
  return node_features(
163
+ self.modality, self.phi_type,
164
+ self.risk_entropy, self.context_confidence, self.pseudonym_version,
 
 
 
165
  )
166
 
167
 
 
177
  idx = self._node_index()
178
  ei: List[Tuple[int, int]] = []
179
  for e in self.edges:
180
+ s, t = idx.get(e["source"]), idx.get(e["target"])
 
181
  if s is not None and t is not None:
182
+ ei += [(s, t), (t, s)]
 
183
  return ei
184
 
185
  def edge_weights(self) -> List[float]:
186
  idx = self._node_index()
187
  ew: List[float] = []
188
  for e in self.edges:
189
+ s, t = idx.get(e["source"]), idx.get(e["target"])
 
190
  if s is not None and t is not None:
191
  w = float(e.get("weight", 1.0))
192
+ ew += [w, w]
193
  return ew
194
 
195
  @classmethod
196
  def from_summary(cls, summary: Dict[str, Any]) -> "DCPGGraph":
197
  nodes = [
198
  DCPGGraphNode(
199
+ node_id=n["node_id"], modality=n["modality"], phi_type=n["phi_type"],
 
 
200
  risk_entropy=float(n.get("risk_entropy", 0.0)),
201
  context_confidence=float(n.get("context_confidence", 1.0)),
202
  pseudonym_version=int(n.get("pseudonym_version", 0)),
203
  )
204
  for n in summary.get("nodes", [])
205
  ]
206
+ return cls(nodes=nodes, edges=summary.get("edges", []))
 
207
 
208
  @classmethod
209
+ def from_crdt_summary(cls, summary: Dict[str, Any], provisional_risk: float = 0.0) -> "DCPGGraph":
 
 
 
 
210
  nodes = []
211
  for n in summary.get("nodes", []):
212
  parts = str(n["node_id"]).split("::")
213
  modality = parts[1] if len(parts) > 1 else "text"
214
+ nodes.append(DCPGGraphNode(
215
+ node_id=n["node_id"], modality=modality,
216
+ phi_type=modality.upper(),
217
+ risk_entropy=float(n.get("risk_entropy", provisional_risk)),
218
+ context_confidence=min(1.0, float(n.get("total_phi_units", 1)) / 10.0),
219
+ pseudonym_version=int(n.get("pseudonym_version", 0)),
220
+ ))
 
 
 
 
 
221
  return cls(nodes=nodes, edges=[])
222
 
223
 
224
+ @dataclass
225
+ class EncoderOutput:
226
+ patient_embedding: List[float]
227
+ node_embeddings: List[List[float]]
228
+ risk_score: float
229
+ node_ids: List[str]
230
+
231
+ def to_dict(self) -> Dict[str, Any]:
232
+ return {
233
+ "patient_embedding": [round(v, 5) for v in self.patient_embedding],
234
+ "node_embeddings": {
235
+ nid: [round(v, 5) for v in emb]
236
+ for nid, emb in zip(self.node_ids, self.node_embeddings)
237
+ },
238
+ "risk_score": self.risk_score,
239
+ "embed_dim": len(self.patient_embedding),
240
+ }
241
+
242
+
243
+ @dataclass
244
+ class DCPGEncoder:
245
+ layer1: GATLayer = field(default_factory=lambda: GATLayer(NODE_FEAT_DIM, HIDDEN_DIM))
246
+ layer2: GATLayer = field(default_factory=lambda: GATLayer(HIDDEN_DIM, EMBED_DIM))
247
+ risk_head: List[List[float]] = field(default_factory=lambda: _xavier_init(1, EMBED_DIM))
248
+
249
+ def encode(self, graph: DCPGGraph) -> EncoderOutput:
250
+ if not graph.nodes:
251
+ return EncoderOutput(
252
+ patient_embedding=[0.0] * EMBED_DIM,
253
+ node_embeddings=[], risk_score=0.0, node_ids=[],
254
+ )
255
+
256
+ feats = [n.feature_vec() for n in graph.nodes]
257
+ ei = graph.edge_index()
258
+ ew = graph.edge_weights()
259
+
260
+ h1 = self.layer1.forward(feats, ei, ew)
261
+ h2 = self.layer2.forward(h1, ei, ew)
262
+
263
+ patient_emb = _normalize(attention_pool(h2, [n.risk_entropy for n in graph.nodes]))
264
+ risk_score = _sigmoid(sum(self.risk_head[0][k] * patient_emb[k] for k in range(EMBED_DIM)))
265
+
266
+ return EncoderOutput(
267
+ patient_embedding=patient_emb,
268
+ node_embeddings=[_normalize(h) for h in h2],
269
+ risk_score=round(risk_score, 4),
270
+ node_ids=[n.node_id for n in graph.nodes],
271
+ )
272
+
273
 
274
  def encode_patient(
275
  graph_summary: Dict[str, Any],
 
284
  )
285
  else:
286
  g = DCPGGraph.from_summary(graph_summary)
287
+ return enc.encode(g).to_dict()
 
288
 
289
 
 
 
 
 
290
  if __name__ == "__main__":
291
  summary = {
 
 
292
  "nodes": [
293
  {"node_id": "p1::text::NAME_DATE_MRN_FACILITY", "modality": "text",
294
  "phi_type": "NAME_DATE_MRN_FACILITY", "risk_entropy": 0.72,
 
302
  ],
303
  "edges": [
304
  {"source": "p1::text::NAME_DATE_MRN_FACILITY",
305
+ "target": "p1::asr::NAME_DATE_MRN", "type": "co_occurrence", "weight": 0.71},
 
306
  {"source": "p1::text::NAME_DATE_MRN_FACILITY",
307
+ "target": "p1::image_proxy::FACE_IMAGE", "type": "cross_modal", "weight": 0.58},
 
308
  ],
 
309
  }
 
310
  result = encode_patient(summary)
311
  print(json.dumps(result, indent=2))
 
 
 
inference_dcpg.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import sys
5
+
6
+ from dcpg_encoder import DCPGEncoder, encode_patient
7
+
8
+ _encoder = DCPGEncoder()
9
+
10
+
11
+ def predict(graph_summary: dict, source: str = "dcpg") -> dict:
12
+ return encode_patient(graph_summary, encoder=_encoder, source=source)
13
+
14
+
15
+ def predict_batch(summaries: list, source: str = "dcpg") -> list:
16
+ return [predict(s, source=source) for s in summaries]
17
+
18
+
19
+ if __name__ == "__main__":
20
+ if len(sys.argv) > 1:
21
+ with open(sys.argv[1]) as f:
22
+ data = json.load(f)
23
+ result = predict(data)
24
+ else:
25
+ result = predict({
26
+ "nodes": [
27
+ {"node_id": "p1::text::NAME_DATE_MRN_FACILITY", "modality": "text",
28
+ "phi_type": "NAME_DATE_MRN_FACILITY", "risk_entropy": 0.8,
29
+ "context_confidence": 0.9, "pseudonym_version": 2},
30
+ {"node_id": "p1::audio_proxy::VOICE", "modality": "audio_proxy",
31
+ "phi_type": "VOICE", "risk_entropy": 0.55,
32
+ "context_confidence": 0.6, "pseudonym_version": 1},
33
+ ],
34
+ "edges": [
35
+ {"source": "p1::text::NAME_DATE_MRN_FACILITY",
36
+ "target": "p1::audio_proxy::VOICE",
37
+ "type": "cross_modal", "weight": 0.63},
38
+ ],
39
+ })
40
+ print(json.dumps(result, indent=2))