Mayank Keoliya commited on
Commit
cb9d9d1
·
1 Parent(s): c2708da

Deploy with bundled camel_inference and LFS for demo data

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.dat filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -1,4 +1,3 @@
1
- camel_inference/
2
  checkpoints/
3
  __pycache__/
4
  *.pyc
 
 
1
  checkpoints/
2
  __pycache__/
3
  *.pyc
camel_inference/.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MacOS
2
+ .DS_Store
3
+
4
+ # Python
5
+ /.env
6
+ __pycache__
7
+
8
+ # Ignore model checkpoints
9
+ checkpoints/
10
+ *.pt
11
+
12
+ *.egg-info
camel_inference/README.md ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CAMEL Inference
2
+
3
+ Inference-only repository for running CAMEL ECG-language checkpoints.
4
+
5
+ Only `run_camel.py` is intended as a public entrypoint. Modules under `src/camel/` are internal implementation details and may change.
6
+
7
+ ## Repository Layout
8
+
9
+ - `run_camel.py`: public inference CLI
10
+ - `src/camel/`: internal model, tokenizer, ECG packing, and loading utilities
11
+ - `checkpoints/`: local adapter/checkpoint files
12
+
13
+ ## Requirements
14
+
15
+ - Python 3.10+
16
+ - CUDA-enabled PyTorch recommended for practical inference latency
17
+
18
+ ## Install
19
+
20
+ ```bash
21
+ conda create -n camel python=3.10 -y
22
+ conda activate camel
23
+ pip install -e .
24
+ ```
25
+
26
+ ## Checkpoints
27
+
28
+ Checkpoints must be downloaded from huggingface `CAMEL-ECG/CAMEL` or with the repository script:
29
+
30
+ ```bash
31
+ bash scripts/download_checkpoints.sh
32
+ ```
33
+
34
+ ## Usage
35
+
36
+ * CAMEL is available in three modes:
37
+ - `base`
38
+ - `ecgbench`
39
+ - `forecast`
40
+
41
+ ```bash
42
+ python run_camel.py \
43
+ --mode forecast \
44
+ --text "Forecast cardiac rhythm for the next 5 minutes." \
45
+ --ecgs demo/08704_hr \
46
+ --device cuda:0
47
+ ```
48
+
49
+ ```bash
50
+ python run_camel.py \
51
+ --mode base \
52
+ --text "Compare the two ECG waveforms." \
53
+ --ecgs demo/12585_hr demo/12646_hr \
54
+ --device cuda:0
55
+ ```
56
+
57
+ * Optionally, you can set start, end, and leads with `--ecgs-config`.
58
+
59
+ ```bash
60
+ python run_camel.py \
61
+ --mode forecast \
62
+ --text "Forecast cardiac rhythm for the next 5 minutes." \
63
+ --ecgs demo/08704_hr \
64
+ --ecg-configs "start:0;end:5;use_leads:I,II" \
65
+ --device cuda:0
66
+ ```
67
+
68
+ * Using `--text` and `--ecgs` defaults to text followed by the ecg in order.
69
+ For arbitrary text/ECG interleaving use `--json`.
70
+ ```bash
71
+ python run_camel.py --mode base --json demo/example_prompt.json --device cuda:0
72
+ ```
73
+
74
+ * Sampling flags:
75
+ - `--temperature`
76
+ - `--top-k`
77
+ - `--top-p`
78
+ - `--min-p`
79
+ - `--max-new-tokens`
80
+
81
+ Implementation notes:
82
+ - ECG loading is currently implemented for WFDB-format inputs. To support additional formats, extend `src/read_ecg.py`.
camel_inference/demo/08704_hr.dat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ade700b2d1a1e3a0889d4135ee179a33b7dc2eaf76d6702bc9bc64f6cef4a3f4
3
+ size 120000
camel_inference/demo/08704_hr.hea ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 08704_hr 12 500 5000
2
+ 08704_hr.dat 16 1000.0(0)/mV 16 0 -50 41247 0 I
3
+ 08704_hr.dat 16 1000.0(0)/mV 16 0 -20 64030 0 II
4
+ 08704_hr.dat 16 1000.0(0)/mV 16 0 30 22786 0 III
5
+ 08704_hr.dat 16 1000.0(0)/mV 16 0 35 13135 0 aVR
6
+ 08704_hr.dat 16 1000.0(0)/mV 16 0 -40 42425 0 aVL
7
+ 08704_hr.dat 16 1000.0(0)/mV 16 0 5 11132 0 aVF
8
+ 08704_hr.dat 16 1000.0(0)/mV 16 0 210 35283 0 V1
9
+ 08704_hr.dat 16 1000.0(0)/mV 16 0 -605 1875 0 V2
10
+ 08704_hr.dat 16 1000.0(0)/mV 16 0 -360 20664 0 V3
11
+ 08704_hr.dat 16 1000.0(0)/mV 16 0 -255 26244 0 V4
12
+ 08704_hr.dat 16 1000.0(0)/mV 16 0 -230 25112 0 V5
13
+ 08704_hr.dat 16 1000.0(0)/mV 16 0 60 30065 0 V6
camel_inference/demo/12585_hr.dat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1dbc29ba7d9857a606c23d1a661dee034285f471908206e5428c9b6b7c8ff467
3
+ size 120000
camel_inference/demo/12585_hr.hea ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 12585_hr 12 500 5000
2
+ 12585_hr.dat 16 1000.0(0)/mV 16 0 540 5223 0 I
3
+ 12585_hr.dat 16 1000.0(0)/mV 16 0 40 11911 0 II
4
+ 12585_hr.dat 16 1000.0(0)/mV 16 0 -500 6676 0 III
5
+ 12585_hr.dat 16 1000.0(0)/mV 16 0 -290 57071 0 aVR
6
+ 12585_hr.dat 16 1000.0(0)/mV 16 0 520 65066 0 aVL
7
+ 12585_hr.dat 16 1000.0(0)/mV 16 0 -230 9479 0 aVF
8
+ 12585_hr.dat 16 1000.0(0)/mV 16 0 -130 26077 0 V1
9
+ 12585_hr.dat 16 1000.0(0)/mV 16 0 -145 27523 0 V2
10
+ 12585_hr.dat 16 1000.0(0)/mV 16 0 -155 35476 0 V3
11
+ 12585_hr.dat 16 1000.0(0)/mV 16 0 -350 16663 0 V4
12
+ 12585_hr.dat 16 1000.0(0)/mV 16 0 -545 60445 0 V5
13
+ 12585_hr.dat 16 1000.0(0)/mV 16 0 -105 58137 0 V6
camel_inference/demo/12646_hr.dat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7cabc7399daa14ebf97cc705b9da09aeb3daee12939ca797d145237cdfcf4ed8
3
+ size 120000
camel_inference/demo/12646_hr.hea ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 12646_hr 12 500 5000
2
+ 12646_hr.dat 16 1000.0(0)/mV 16 0 -260 38739 0 I
3
+ 12646_hr.dat 16 1000.0(0)/mV 16 0 -465 47305 0 II
4
+ 12646_hr.dat 16 1000.0(0)/mV 16 0 -205 8550 0 III
5
+ 12646_hr.dat 16 1000.0(0)/mV 16 0 363 22259 0 aVR
6
+ 12646_hr.dat 16 1000.0(0)/mV 16 0 -27 47829 0 aVL
7
+ 12646_hr.dat 16 1000.0(0)/mV 16 0 -335 60490 0 aVF
8
+ 12646_hr.dat 16 1000.0(0)/mV 16 0 220 55705 0 V1
9
+ 12646_hr.dat 16 1000.0(0)/mV 16 0 120 51889 0 V2
10
+ 12646_hr.dat 16 1000.0(0)/mV 16 0 -390 40903 0 V3
11
+ 12646_hr.dat 16 1000.0(0)/mV 16 0 -660 39373 0 V4
12
+ 12646_hr.dat 16 1000.0(0)/mV 16 0 -770 43235 0 V5
13
+ 12646_hr.dat 16 1000.0(0)/mV 16 0 135 65290 0 V6
camel_inference/demo/example_prompt.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ [
2
+ { "type": "text", "text": "You are given two ECG waveforms of a same patient from two different time points."},
3
+ { "type": "text", "text": "This is the first ECG."},
4
+ { "type": "ecg", "ecg": "demo/12585_hr" },
5
+ { "type": "text", "text": "This is the first ECG." },
6
+ { "type": "ecg", "ecg": "demo/12646_hr" },
7
+ { "type": "text", "text": "Has non-diagnostic t abnormalities been resolved in the recent tracing compared to the previous one?"}
8
+ ]
camel_inference/pyproject.toml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [tool.setuptools]
6
+ py-modules = ["run_camel"]
7
+
8
+ [tool.setuptools.packages.find]
9
+ where = ["src"]
10
+
11
+ [tool.setuptools.package-data]
12
+ camel = ["model_registry.yaml"]
13
+
14
+ [project]
15
+ name = "camel-inference"
16
+ version = "0.1.0"
17
+ description = "Inference-only CLI for CAMEL ECG-language checkpoints"
18
+ readme = "README.md"
19
+ requires-python = ">=3.9"
20
+ authors = [
21
+ { name = "CAMEL contributors" }
22
+ ]
23
+ dependencies = [
24
+ "numpy",
25
+ "scipy",
26
+ "pyyaml",
27
+ "torch",
28
+ "transformers",
29
+ "peft",
30
+ "accelerate",
31
+ "sentencepiece",
32
+ "protobuf"
33
+ ]
34
+
35
+ [project.scripts]
36
+ camel-infer = "run_camel:main"
camel_inference/run_camel.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from camel.camel_model import CAMEL
3
+
4
+ def main():
5
+ parser = argparse.ArgumentParser(description="CAMEL")
6
+ parser.add_argument("--mode", type=str, choices=['forecast', 'base', 'ecgbench'], default='base')
7
+ parser.add_argument("--device", type=str, default='cuda:0')
8
+ parser.add_argument("--json", type=str, default=None)
9
+ parser.add_argument("--text", type=str, default=None)
10
+ parser.add_argument("--ecgs", type=str, default=None, nargs='+')
11
+ parser.add_argument("--ecg-configs", type=str, default=None, nargs='+')
12
+ parser.add_argument("--temperature", type=float, default=0.0)
13
+ parser.add_argument(
14
+ "--top-k",
15
+ dest="top_k",
16
+ type=int,
17
+ default=64,
18
+ help="Top-k sampling cutoff (set <=0 to disable).",
19
+ )
20
+ parser.add_argument(
21
+ "--top-p",
22
+ dest="top_p",
23
+ type=float,
24
+ default=0.95,
25
+ help="Nucleus sampling cumulative probability cutoff.",
26
+ )
27
+ parser.add_argument(
28
+ "--min-p",
29
+ dest="min_p",
30
+ type=float,
31
+ default=0.0,
32
+ help="Minimum per-token probability threshold applied after temperature scaling.",
33
+ )
34
+ parser.add_argument(
35
+ "--max-new-tokens",
36
+ type=int,
37
+ default=512,
38
+ help="Maximum number of tokens to generate per sample.",
39
+ )
40
+ args = parser.parse_args()
41
+
42
+ model = CAMEL(mode=args.mode, device=args.device)
43
+ output, prompt = model.run(args)
44
+
45
+ print(f'Prompt: {prompt}')
46
+ print(f'Prediction: {output}')
47
+
48
+ if __name__ == "__main__":
49
+ main()
camel_inference/scripts/download_checkpoints.sh ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ echo "Installing huggingface_hub if needed..."
5
+ python3 -m pip install -q --user huggingface_hub
6
+
7
+ echo "Downloading CAMEL checkpoints from Hugging Face..."
8
+ mkdir -p checkpoints
9
+
10
+ python3 - <<'PY'
11
+ import os, shutil
12
+ from huggingface_hub import hf_hub_download
13
+
14
+ repo = "CAMEL-ECG/CAMEL"
15
+ files = [
16
+ "camel_base.pt",
17
+ "camel_ecginstruct.pt",
18
+ "camel_forecast.pt"
19
+ ]
20
+
21
+ os.makedirs("checkpoints", exist_ok=True)
22
+
23
+ for f in files:
24
+ print(f"Downloading {f}...")
25
+ src = hf_hub_download(
26
+ repo_id=repo,
27
+ filename=f,
28
+ repo_type="model"
29
+ )
30
+ dst = os.path.join("checkpoints", f)
31
+ shutil.copy2(src, dst)
32
+ print(f"Saved to {dst}")
33
+
34
+ print("All checkpoints downloaded.")
35
+ PY
36
+
37
+ echo "Done."
camel_inference/src/camel/__init__.py ADDED
File without changes
camel_inference/src/camel/assertions.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Assertions and summaries for ECG language-model wrappers and their LoRA adapters.
3
+ """
4
+ from __future__ import annotations
5
+
6
+ from typing import Any, Dict, Iterable, List, Mapping, Set
7
+ import functools
8
+ import os
9
+ import torch
10
+
11
+ from camel.ecg_attention_masks import ECGBlockLayout
12
+
13
+ _ASSERTIONS_ENABLED = os.getenv("ASSERTIONS") == "1"
14
+
15
+ def _skip_if_assertions_disabled(func):
16
+ """Decorator that no-ops assertion helpers when ASSERTIONS env var is not set."""
17
+ @functools.wraps(func)
18
+ def wrapper(*args, **kwargs):
19
+ if not _ASSERTIONS_ENABLED:
20
+ return None
21
+ return func(*args, **kwargs)
22
+
23
+ return wrapper
24
+
25
+ @_skip_if_assertions_disabled
26
+ def assert_tensor_dtype(tensor: torch.Tensor, *, expected: torch.dtype, context: str) -> None:
27
+ """Verify tensor dtype matches expectation."""
28
+ if tensor.dtype != expected:
29
+ raise AssertionError(f"{context}: expected dtype {expected}, got {tensor.dtype}")
30
+
31
+ @_skip_if_assertions_disabled
32
+ def assert_ecg_blocks_consistent(
33
+ *,
34
+ turn_parts: Iterable[Iterable[Dict[str, Any]]],
35
+ ecg_blocks: Iterable[Dict[str, Any]],
36
+ ) -> None:
37
+ """
38
+ Validate that structured turn parts contain the expected ECG markers per block.
39
+ """
40
+ blocks_list = list(ecg_blocks)
41
+ expected_counts: Dict[int, Dict[str, int]] = {}
42
+ per_lead_special_counts: Dict[int, Dict[str, int]] = {}
43
+ per_lead_secs: Dict[int, Dict[str, Set[int]]] = {}
44
+ global_special_count: Dict[int, int] = {}
45
+
46
+ for idx, blk in enumerate(blocks_list):
47
+ leads = [str(ld) for ld in blk.get("lead_names", [])]
48
+ segs = [int(n) for n in blk.get("segments_per_lead", [])]
49
+ expected_counts[idx] = {ld: int(n) for ld, n in zip(leads, segs)}
50
+ per_lead_special_counts[idx] = {ld: 0 for ld in leads}
51
+ per_lead_secs[idx] = {ld: set() for ld in leads}
52
+ global_special_count[idx] = 0
53
+
54
+ for turn in turn_parts:
55
+ for part in turn:
56
+ kind = part.get("kind")
57
+ if kind == "text":
58
+ continue
59
+ block_idx = part.get("block_index")
60
+ if block_idx is None or int(block_idx) not in expected_counts:
61
+ raise AssertionError("ECG part references unknown block_index.")
62
+ block_idx = int(block_idx)
63
+ allowed_leads = set(expected_counts[block_idx].keys())
64
+ if kind == "special":
65
+ lead = part.get("lead")
66
+ if lead:
67
+ if lead not in allowed_leads:
68
+ raise AssertionError(f"Special token references unknown lead '{lead}'.")
69
+ per_lead_special_counts[block_idx][lead] = per_lead_special_counts[block_idx].get(lead, 0) + 1
70
+ else:
71
+ global_special_count[block_idx] = global_special_count.get(block_idx, 0) + 1
72
+ token_literal = part.get("token")
73
+ if not isinstance(token_literal, str) or len(token_literal) == 0:
74
+ raise AssertionError("Special turn part lacks a string token literal.")
75
+ continue
76
+ if kind == "ecg":
77
+ lead = part.get("lead")
78
+ if lead not in allowed_leads:
79
+ raise AssertionError(f"ECG segment references unknown lead '{lead}'.")
80
+ sec_val = part.get("sec")
81
+ expected_total = expected_counts[block_idx].get(lead, 0)
82
+ if expected_total <= 0:
83
+ raise AssertionError(
84
+ f"Lead '{lead}' has non-positive declared segments ({expected_total}) but ECG markers are present."
85
+ )
86
+ try:
87
+ sec = int(sec_val)
88
+ except Exception as exc: # noqa: BLE001
89
+ raise AssertionError(f"ECG segment for lead '{lead}' has non-integer sec {sec_val!r}.") from exc
90
+ if sec < 1 or sec > expected_total:
91
+ raise AssertionError(
92
+ f"ECG segment for lead '{lead}' has second={sec}, expected within [1,{expected_total}]."
93
+ )
94
+ if sec in per_lead_secs[block_idx][lead]:
95
+ raise AssertionError(f"Duplicate ECG segment marker for lead '{lead}' second {sec}.")
96
+ per_lead_secs[block_idx][lead].add(sec)
97
+ continue
98
+ raise AssertionError(f"Unknown turn_parts kind '{kind}'.")
99
+
100
+ for block_idx, expected in expected_counts.items():
101
+ if global_special_count.get(block_idx, 0) != 2:
102
+ raise AssertionError(
103
+ f"Expected exactly two global ECG markers for block {block_idx}; "
104
+ f"found {global_special_count.get(block_idx, 0)}."
105
+ )
106
+ for lead, expected_total in expected.items():
107
+ expected_specials = per_lead_special_counts[block_idx].get(lead, 0)
108
+ if expected_specials != 2:
109
+ raise AssertionError(
110
+ f"Lead '{lead}' has {expected_specials} special markers; expected start and end (2 total)."
111
+ )
112
+ seen_secs = per_lead_secs[block_idx].get(lead, set())
113
+ if expected_total != len(seen_secs):
114
+ missing = sorted(set(range(1, expected_total + 1)) - seen_secs)
115
+ raise AssertionError(
116
+ f"Lead '{lead}' missing ECG segment markers for seconds {missing} (expected {expected_total})."
117
+ )
118
+
119
+
120
+ # ---------------- Trainer batch validation helpers -----------------------------------------------
121
+
122
+ @_skip_if_assertions_disabled
123
+ def assert_prefix_split_complete(*, offset: int, total_prefix_rows: int) -> None:
124
+ """Validate that prefix splitting consumed all rows."""
125
+ if offset != total_prefix_rows:
126
+ raise RuntimeError(
127
+ f"Prefix split mismatch: consumed {offset} rows but have {total_prefix_rows}"
128
+ )
129
+
130
+
131
+ @_skip_if_assertions_disabled
132
+ def assert_prefix_matches_segments(
133
+ *,
134
+ prefix_rows: int,
135
+ segments_per_lead: Iterable[int],
136
+ lead_names: Iterable[str],
137
+ sample_index: int,
138
+ block_index: int,
139
+ ) -> None:
140
+ """Validate that prefix row count matches sum of segments_per_lead."""
141
+ total_segments = sum(int(n) for n in segments_per_lead)
142
+ if prefix_rows != total_segments:
143
+ raise RuntimeError(
144
+ f"Sample {sample_index} block {block_index}: Prefix rows ({prefix_rows}) "
145
+ f"!= sum(segments_per_lead) ({total_segments}). "
146
+ f"lead_names={list(lead_names)} segments_per_lead={list(segments_per_lead)}"
147
+ )
148
+
149
+ @_skip_if_assertions_disabled
150
+ def assert_ecg_part_bounds(
151
+ *,
152
+ lead: str,
153
+ sec: int,
154
+ lead_to_offset: Mapping[str, int],
155
+ declared_segments: Mapping[str, int],
156
+ total_prefix_rows: int,
157
+ sample_index: int,
158
+ block_index: int,
159
+ ) -> None:
160
+ """Validate ECG part (lead, sec) falls within expected bounds."""
161
+ if lead not in declared_segments:
162
+ raise RuntimeError(f"Unknown lead {lead} in parts for sample {sample_index} block {block_index}")
163
+
164
+ nseg = int(declared_segments[lead])
165
+ if not (1 <= sec <= nseg):
166
+ raise RuntimeError(
167
+ f"sec out of range for lead {lead}: got {sec}, expected 1..{nseg}"
168
+ )
169
+
170
+ base = lead_to_offset[lead]
171
+ start = base
172
+ end = base + nseg # exclusive
173
+ row_idx = start + (sec - 1)
174
+
175
+ # Check both global and per-lead bounds
176
+ if not (0 <= row_idx < total_prefix_rows):
177
+ raise RuntimeError(
178
+ f"Bad (lead,sec)=({lead},{sec}) for sample {sample_index} block {block_index}: "
179
+ f"row_idx {row_idx} not in [0,{total_prefix_rows})"
180
+ )
181
+ if not (start <= row_idx < end):
182
+ raise RuntimeError(
183
+ f"(lead,sec)=({lead},{sec}) maps outside this lead block "
184
+ f"[{start},{end}) (row_idx={row_idx}) for sample {sample_index}"
185
+ )
186
+
187
+
188
+ @_skip_if_assertions_disabled
189
+ def assert_layout_specials_complete(
190
+ *,
191
+ block_layout: ECGBlockLayout,
192
+ lead_names: Iterable[str],
193
+ ) -> None:
194
+ """Validate that layout has complete and ordered special token markers.
195
+
196
+ For each declared lead:
197
+ - Both start and end must be present (or both absent)
198
+ - If present, start < end
199
+
200
+ For global markers:
201
+ - Both start and end must be present (or both absent)
202
+ - If present, start < end
203
+ """
204
+ # Check per-lead specials
205
+ for ld in lead_names:
206
+ s = block_layout.lead_start_idx.get(ld)
207
+ e = block_layout.lead_end_idx.get(ld)
208
+ if (s is None) != (e is None):
209
+ raise RuntimeError(f"Lead {ld} missing start/end special (s={s}, e={e})")
210
+ if s is not None and not (s < e):
211
+ raise RuntimeError(f"Lead {ld} specials out of order: start={s}, end={e}")
212
+
213
+ # Check global specials
214
+ if (block_layout.global_start_idx is None) != (block_layout.global_end_idx is None):
215
+ raise RuntimeError("Global start/end special mismatch")
216
+ if block_layout.global_start_idx is not None and block_layout.global_end_idx is not None:
217
+ if not (block_layout.global_start_idx < block_layout.global_end_idx):
218
+ raise RuntimeError(
219
+ f"Global specials out of order: start={block_layout.global_start_idx} "
220
+ f"end={block_layout.global_end_idx}"
221
+ )
222
+
223
+ # ---------------- Wrapper embedding validations --------------------------------------------------
224
+
225
+ @_skip_if_assertions_disabled
226
+ def assert_wrapper_embed_length(
227
+ *,
228
+ embeddings: torch.Tensor,
229
+ ids: List[int],
230
+ context: str,
231
+ ) -> None:
232
+ """Ensure embedding sequence length matches token count exactly.
233
+
234
+ This is a critical invariant check that ensures the 1:1 mapping between
235
+ input token IDs and output embeddings is preserved.
236
+
237
+ Args:
238
+ embeddings: Output embedding tensor (must be at least 1-D)
239
+ ids: Input token ID list
240
+ context: Description of where this check is being performed
241
+
242
+ Raises:
243
+ RuntimeError: if ids is not a list, embeddings are not at least 1-D,
244
+ or if the embedding length doesn't match the token count
245
+ """
246
+ if not isinstance(ids, list):
247
+ raise RuntimeError(f"{context}: ids must be a python list of ints")
248
+ if embeddings.dim() < 1:
249
+ raise RuntimeError(f"{context}: embeddings must be at least 1-D, got shape {tuple(embeddings.shape)}")
250
+ if embeddings.size(0) != len(ids):
251
+ raise RuntimeError(f"{context}: embed length {embeddings.size(0)} != token count {len(ids)}")
252
+
253
+ @_skip_if_assertions_disabled
254
+ def assert_rest_length_nonnegative(*, rest_length: int) -> None:
255
+ """Validate that rest token length is non-negative.
256
+
257
+ This should never happen in correct code, but catching it early helps
258
+ identify bugs in label construction logic.
259
+
260
+ Args:
261
+ rest_length: Length of ids_rest list
262
+
263
+ Raises:
264
+ ValueError: if rest_length is negative
265
+ """
266
+ if rest_length < 0:
267
+ raise ValueError("ids_rest length is negative (internal error).")
268
+
269
+
270
+ # ---------------- Utility assertion helpers ------------------------------------------------------
271
+
272
+ @_skip_if_assertions_disabled
273
+ def assert_sorted_non_overlapping_spans(spans: List[tuple[int, int]], length: int, ctx: str) -> None:
274
+ """Validate that spans are sorted, non-overlapping, and within bounds."""
275
+ prev_end = -1
276
+ for i, (s, e) in enumerate(spans):
277
+ if not (0 <= s <= e <= length):
278
+ raise AssertionError(f"{ctx}: span {i}={(s,e)} out of bounds for length {length}")
279
+ if s < prev_end:
280
+ raise AssertionError(f"{ctx}: spans overlap or not sorted at {i-1},{i}: prev_end={prev_end}, curr_start={s}")
281
+ prev_end = e
282
+
283
+
284
+ @_skip_if_assertions_disabled
285
+ def assert_equal_int(a: int, b: int, msg: str) -> None:
286
+ """Assert two integers are equal with a descriptive message."""
287
+ if int(a) != int(b):
288
+ raise AssertionError(f"{msg}: {a} != {b}")
289
+
290
+
291
+ @_skip_if_assertions_disabled
292
+ def assert_positive_int(n: int, msg: str) -> None:
293
+ """Assert an integer is positive (> 0)."""
294
+ if int(n) <= 0:
295
+ raise AssertionError(f"{msg}: expected > 0, got {n}")
296
+
297
+
298
+ # ---------------- Schema and catalog validations -------------------------------------------------
299
+
300
+ @_skip_if_assertions_disabled
301
+ def assert_ecg_catalog_valid(catalog: Any, schema: Any) -> None:
302
+ """Validate ECG special token catalog for uniqueness and mapping consistency.
303
+
304
+ Checks:
305
+ - All tokens are unique
306
+ - Every canonical lead has entries in lead_to_indices and lead_to_tokens
307
+ - Token-to-index mappings are consistent across all structures
308
+ - Global markers (start/end) are present in the catalog
309
+ """
310
+ # Uniqueness
311
+ if len(set(catalog.tokens)) != len(catalog.tokens):
312
+ raise AssertionError("ECG special tokens contain duplicates")
313
+
314
+ # Per-lead mappings
315
+ for lead in schema.ecg.canonical_leads:
316
+ if lead not in catalog.lead_to_indices or lead not in catalog.lead_to_tokens:
317
+ raise AssertionError(f"Missing lead in catalog: {lead}")
318
+ for kind in ("start", "end"):
319
+ tok = catalog.lead_to_tokens[lead][kind]
320
+ idx = catalog.lead_to_indices[lead][kind]
321
+ if catalog.tokens[idx] != tok:
322
+ raise AssertionError(f"Catalog mismatch for {lead}:{kind}: tokens[idx] != tok")
323
+ if catalog.token_to_index.get(tok, None) != idx:
324
+ raise AssertionError(f"token_to_index mismatch for {lead}:{kind}")
325
+
326
+ # Global markers
327
+ for tok in (schema.ecg.global_start, schema.ecg.global_end):
328
+ if tok not in catalog.token_to_index:
329
+ raise AssertionError(f"Global ECG token missing from catalog: {tok}")
330
+
331
+
332
+ # ---------------- Conversation and role validations ----------------------------------------------
333
+
334
+ @_skip_if_assertions_disabled
335
+ def assert_normalized_role_canonical(role: str, schema: Any) -> None:
336
+ """Ensure normalized role matches one of the canonical prompt roles."""
337
+ if role not in (schema.prompt.user_role, schema.prompt.model_role):
338
+ raise AssertionError(f"Normalized role '{role}' did not resolve to a canonical prompt role")
339
+
340
+ # ---------------- Tokenization and span validations ----------------------------------------------
341
+
342
+ @_skip_if_assertions_disabled
343
+ def assert_tokenization_cursor_matches(cursor: int, ids_length: int) -> None:
344
+ """Ensure cursor tracking matches actual text_ids length."""
345
+ if cursor != ids_length:
346
+ raise AssertionError(f"cursor ({cursor}) != len(text_ids) ({ids_length})")
347
+
348
+
349
+ @_skip_if_assertions_disabled
350
+ def assert_model_spans_valid(model_spans: List[tuple[int, int]], ids_length: int) -> None:
351
+ """Validate model spans are sorted, non-overlapping, and at least one exists."""
352
+ assert_sorted_non_overlapping_spans(model_spans, ids_length, ctx="model_spans_in_text")
353
+ if len(model_spans) == 0:
354
+ raise AssertionError("No model spans found in text ids")
355
+
356
+
357
+ @_skip_if_assertions_disabled
358
+ def assert_eos_appended(ids: List[int], tokenizer: Any, require_eos: bool) -> None:
359
+ """Validate EOS token was appended if required."""
360
+ if require_eos and tokenizer.eos_token_id is not None:
361
+ if not ids or ids[-1] != tokenizer.eos_token_id:
362
+ raise AssertionError("Required EOS was not appended at the end of text_ids")
363
+
364
+
365
+ # ---------------- u0 parts structure validations -------------------------------------------------
366
+
367
+ @_skip_if_assertions_disabled
368
+ def assert_turn_parts_structure_valid(
369
+ parts: List[Dict[str, Any]],
370
+ ecg_blocks: List[Dict[str, Any]],
371
+ schema: Any,
372
+ catalog: Any,
373
+ ) -> None:
374
+ """Validate the complete structure of turn parts for all blocks present."""
375
+ block_indices = sorted({int(p.get("block_index")) for p in parts if p.get("block_index") is not None})
376
+ for block_idx in block_indices:
377
+ if block_idx < 0 or block_idx >= len(ecg_blocks):
378
+ raise AssertionError(f"Unknown block_index {block_idx} in turn parts")
379
+ blk = ecg_blocks[block_idx]
380
+ leads_present = [str(ld) for ld in blk.get("lead_names", [])]
381
+ segments_per_lead = [int(n) for n in blk.get("segments_per_lead", [])]
382
+
383
+ special_tokens = [p.get("token") for p in parts if p.get("kind") == "special" and p.get("block_index") == block_idx]
384
+ if schema.ecg.global_start not in special_tokens:
385
+ raise AssertionError(f"Missing global_start special in block {block_idx}")
386
+ if schema.ecg.global_end not in special_tokens:
387
+ raise AssertionError(f"Missing global_end special in block {block_idx}")
388
+
389
+ idx_global_start = next(
390
+ (i for i, p in enumerate(parts)
391
+ if p.get("block_index") == block_idx and p.get("kind") == "special"
392
+ and p.get("token") == schema.ecg.global_start),
393
+ None
394
+ )
395
+ idx_global_end = next(
396
+ (i for i, p in enumerate(parts)
397
+ if p.get("block_index") == block_idx and p.get("kind") == "special"
398
+ and p.get("token") == schema.ecg.global_end),
399
+ None
400
+ )
401
+ if idx_global_start is None or idx_global_end is None or not (idx_global_start < idx_global_end):
402
+ raise AssertionError(f"Block {block_idx}: missing or misordered global start/end specials")
403
+
404
+ for lead, nseg in zip(leads_present, segments_per_lead):
405
+ assert_positive_int(nseg, f"segments_per_lead for {lead}")
406
+ idx_start = next(
407
+ (i for i, p in enumerate(parts)
408
+ if p.get("block_index") == block_idx and p.get("kind") == "special" and p.get("lead") == lead
409
+ and p.get("token") == catalog.lead_to_tokens[lead]["start"]),
410
+ None
411
+ )
412
+ idx_end = next(
413
+ (i for i, p in enumerate(parts)
414
+ if p.get("block_index") == block_idx and p.get("kind") == "special" and p.get("lead") == lead
415
+ and p.get("token") == catalog.lead_to_tokens[lead]["end"]),
416
+ None
417
+ )
418
+ if idx_start is None or idx_end is None or not (idx_start < idx_end):
419
+ raise AssertionError(f"Lead {lead}: missing or misordered start/end specials in block {block_idx}")
420
+ secs = [p["sec"] for p in parts[idx_start+1:idx_end] if p.get("kind") == "ecg" and p.get("lead") == lead]
421
+ if secs != list(range(1, int(nseg) + 1)):
422
+ raise AssertionError(f"Lead {lead}: ECG seconds sequence invalid: {secs} vs 1..{nseg}")
423
+ if parts[idx_start]["token_index"] != catalog.lead_to_indices[lead]["start"]:
424
+ raise AssertionError(f"Lead {lead}: start token_index mismatch")
425
+ if parts[idx_end]["token_index"] != catalog.lead_to_indices[lead]["end"]:
426
+ raise AssertionError(f"Lead {lead}: end token_index mismatch")
427
+
428
+
429
+ @_skip_if_assertions_disabled
430
+ def assert_turn_content_ends_with_eot(text_block: str, end_of_turn: str) -> None:
431
+ """Ensure turn content ends with the provided end-of-turn suffix."""
432
+ if not text_block.endswith(end_of_turn):
433
+ raise AssertionError("Turn content must end with end_of_turn suffix")
434
+
435
+
436
+ # ---------------- Per-sample packing validations -------------------------------------------------
437
+
438
+ @_skip_if_assertions_disabled
439
+ def assert_leads_canonical_and_ordered(leads_present: List[str], canonical_leads: tuple) -> None:
440
+ """Validate all leads are canonical (order is explicit and not enforced here)."""
441
+ lead_list = list(leads_present)
442
+ if any(ld not in canonical_leads for ld in lead_list):
443
+ raise AssertionError(f"Non-canonical lead found in leads_present: {lead_list}")
444
+ if len(set(lead_list)) != len(lead_list):
445
+ raise AssertionError(f"Duplicate lead detected in leads_present: {lead_list}")
446
+
447
+
448
+ @_skip_if_assertions_disabled
449
+ def assert_waveform_shapes_valid(
450
+ leads_present: List[str],
451
+ segments_per_lead: List[int],
452
+ waveform_segments: Dict[str, Any],
453
+ ) -> None:
454
+ """Validate waveform tensor shapes and segment counts.
455
+
456
+ Each waveform must be [T, 256] where T matches segments_per_lead.
457
+ """
458
+ for ld, nseg in zip(leads_present, segments_per_lead):
459
+ assert_positive_int(nseg, f"segments_per_lead[{ld}]")
460
+ wf = waveform_segments[ld]
461
+ if wf.ndim != 2 or wf.shape[1] != 256:
462
+ raise AssertionError(f"Waveform for {ld} must be [T,256], got {tuple(wf.shape)}")
463
+ assert_equal_int(wf.shape[0], nseg, f"Waveform seconds vs segments_per_lead for {ld}")
464
+
465
+
466
+ __all__ = [
467
+ "assertions_active",
468
+ "capture_adapter_snapshot",
469
+ "assert_wrapper_adapter_requires_grad",
470
+ "assert_wrapper_optimizer_coverage",
471
+ "assert_adapter_gradients",
472
+ "assert_adapter_updates",
473
+ "assert_trainable_param_sync",
474
+ "assert_tensor_dtype",
475
+ "assert_only_llava_proj_trainable",
476
+ "summarize_trainables_llava_lora",
477
+ "assert_language_lora_only",
478
+ "assert_single_bos_eos",
479
+ "assert_ecg_layout_valid",
480
+ "assert_ecg_mask_against_layout",
481
+ "assert_single_block_mask_matches_reference",
482
+ "assert_additive_mask_padding",
483
+ "assert_nonempty_waveform_segments",
484
+ "assert_prefix_split_complete",
485
+ "assert_prefix_matches_segments",
486
+ "assert_ids_are_lists",
487
+ "assert_embedding_length_matches_tokens",
488
+ "assert_ecg_part_bounds",
489
+ "assert_layout_specials_complete",
490
+ "assert_labels_match_spans",
491
+ "assert_wrapper_embed_length",
492
+ "assert_rest_length_nonnegative",
493
+ "assert_sorted_non_overlapping_spans",
494
+ "assert_equal_int",
495
+ "assert_positive_int",
496
+ "assert_ecg_catalog_valid",
497
+ "assert_normalized_role_canonical",
498
+ "assert_rest_blocks_valid",
499
+ "assert_tokenization_cursor_matches",
500
+ "assert_model_spans_valid",
501
+ "assert_turn_parts_consistent",
502
+ "assert_ecg_blocks_consistent",
503
+ "assert_eos_appended",
504
+ "assert_turn_parts_structure_valid",
505
+ "assert_turn_content_ends_with_eot",
506
+ "assert_leads_canonical_and_ordered",
507
+ "assert_waveform_shapes_valid",
508
+ "assert_collate_item_valid",
509
+ ]
camel_inference/src/camel/camel_model.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Compatibility wrapper around inference.KardiaLM."""
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import Any, Optional
7
+ import json
8
+ import torch
9
+
10
+ from camel.inference import KardiaLM
11
+ from camel.process_ecg import get_waveform
12
+
13
+ class CAMEL:
14
+ def __init__(
15
+ self,
16
+ device: torch.device,
17
+ mode: str,
18
+ model_config_name: str = 'medgemma-4b-it',
19
+ conv_ckpt: Optional[str] = None,
20
+ no_lora: bool = False,
21
+ mask_strategy: str = 'semantic',
22
+ **model_args,
23
+ ) -> None:
24
+ default_top_k = model_args.pop("default_top_k", 64)
25
+ default_top_p = float(model_args.pop("default_top_p", 0.95))
26
+ default_min_p = float(model_args.pop("default_min_p", 0.0))
27
+
28
+ # Initialize model
29
+ if mode == 'base':
30
+ ckpt = 'checkpoints/camel_base.pt'
31
+ elif mode == 'ecgbench':
32
+ ckpt = 'checkpoints/camel_ecginstruct.pt'
33
+ elif mode == 'forecast':
34
+ ckpt = 'checkpoints/camel_forecast.pt'
35
+
36
+ self.session = KardiaLM(
37
+ model_registry_path=None,
38
+ hf_model_id_override=None,
39
+ model_config_name=model_config_name,
40
+ adapter_ckpt=ckpt,
41
+ conv_ckpt=conv_ckpt,
42
+ no_lora=no_lora,
43
+ default_max_new_tokens=int(model_args.pop("default_max_new_tokens", 1000)),
44
+ default_temperature=float(model_args.pop("default_temperature", 1.0)),
45
+ default_top_k=None if default_top_k is None else int(default_top_k),
46
+ default_top_p=default_top_p,
47
+ default_min_p=default_min_p,
48
+ mask_strategy=mask_strategy,
49
+ device=device
50
+ )
51
+ self.prompt_tokens = self.session.packing_schema.prompt
52
+ self.device = device
53
+
54
+ def run(self, args):
55
+ if args.json is None and (args.ecgs is None):
56
+ raise ValueError("Either one of --json or --ecgs should be non-empty.")
57
+
58
+ if args.json is None:
59
+ text = args.text or ''
60
+ raw_context = [{'type': 'text', 'text': text}]
61
+ for ecg in args.ecgs:
62
+ raw_context.append({'type': 'ecg', 'ecg': ecg})
63
+ else:
64
+ try:
65
+ with open(args.json, "r") as f:
66
+ raw_context = json.load(f)
67
+ except:
68
+ raise ValueError(f'Failed during reading json: {args.json}')
69
+
70
+ content = self._build_content(raw_content=raw_context, ecg_configs=args.ecg_configs)
71
+ generate_kwargs = dict(
72
+ content=content,
73
+ max_new_tokens=args.max_new_tokens,
74
+ temperature=args.temperature,
75
+ top_k=args.top_k,
76
+ top_p=args.top_p,
77
+ min_p=args.min_p,
78
+ )
79
+ generated_text = self.generate(**generate_kwargs)
80
+ return generated_text
81
+
82
+ def generate(
83
+ self,
84
+ content: list[dict[str, Any]],
85
+ max_new_tokens: int = 1000,
86
+ temperature: float = 1.0,
87
+ top_k: Optional[int] = 64,
88
+ top_p: float = 0.95,
89
+ min_p: float = 0.0,
90
+ ) -> str:
91
+ text, prompt_preview = self.session.chat(
92
+ conversation=content,
93
+ max_new_tokens=max_new_tokens,
94
+ temperature=temperature,
95
+ top_k=top_k,
96
+ top_p=top_p,
97
+ min_p=min_p,
98
+ )
99
+ return text, prompt_preview
100
+
101
+ def _parse_ecg_config(self, ecg_configs: Optional[list[str]], n_ecgs:int) -> tuple:
102
+ def _parse_single_config(config):
103
+ start_ind, end_ind, leads = None, None, None
104
+ for field in config.split(";"):
105
+ field = field.strip()
106
+ if not field:
107
+ continue
108
+ if ":" not in field:
109
+ raise ValueError(f"Invalid field: {field}. Expected key:value.")
110
+
111
+ key, value = field.split(":", 1)
112
+ key = key.strip().lower()
113
+ value = value.strip()
114
+
115
+ if key == "start":
116
+ start_ind = int(value)
117
+ elif key == "end":
118
+ end_ind = int(value)
119
+ elif key in ("use_leads", "leads"):
120
+ leads = [x.strip() for x in value.split(",") if x.strip()]
121
+ else:
122
+ print(f"Ignoring the unknown key: {key}")
123
+ return start_ind, end_ind, leads
124
+
125
+ if ecg_configs is None:
126
+ output = [None] * n_ecgs
127
+ return output, output, output
128
+
129
+ n_configs = len(ecg_configs)
130
+ if n_configs!= 1 and n_configs != n_ecgs:
131
+ raise ValueError(f'Found {n_configs} ECG configs for {n_ecgs} ECG inputs. The number of config should be 1 or match the number of ECGs.')
132
+
133
+ start_inds, end_inds, leads = [], [], []
134
+ for config in ecg_configs:
135
+ start_ind, end_ind, lead = _parse_single_config(config)
136
+ print(f'ECG Config: {start_ind}, {end_ind}, {lead}')
137
+ start_inds.append(start_ind)
138
+ end_inds.append(end_ind)
139
+ leads.append(lead)
140
+
141
+ if n_configs == 1 and n_ecgs > 1:
142
+ start_inds = start_inds * n_ecgs
143
+ end_inds = end_inds * n_ecgs
144
+ leads = leads * n_ecgs
145
+
146
+ return start_inds, end_inds, leads
147
+
148
+
149
+ def _build_content(self, *, raw_content: list[dict[str, str]], ecg_configs: Optional[list[str]]) -> list[dict[str, Any]]:
150
+ n_ecgs = sum([True for c in raw_content if c['type'] == 'ecg'])
151
+ starts, ends, leads = self._parse_ecg_config(ecg_configs, n_ecgs)
152
+ ecg_ind = 0
153
+
154
+ content: list[dict[str, Any]] = []
155
+ for c in raw_content:
156
+ if c['type'] == 'text':
157
+ content.append({"type": "text", "text": c['text']})
158
+ elif c['type'] == 'ecg':
159
+ waveform = get_waveform(ecg_path=c['ecg'], start_sec=starts[ecg_ind], end_sec=ends[ecg_ind], leads=leads[ecg_ind], device=self.device)
160
+ ecg_ind += 1
161
+ content.append({"type": "ecg", "waveform_segments": waveform})
162
+
163
+ conversation = [{"from": self.prompt_tokens.user_role, "content": content}]
164
+ return conversation
165
+
166
+ __all__ = ["CAMEL"]
camel_inference/src/camel/checkpoint_utils.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Checkpoint- and state-management utilities shared by ECG training scripts.
3
+ These helpers were extracted from train_ecg_text.py to keep the training entrypoint
4
+ focused on orchestration while preserving original behaviour.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ from typing import Any, Dict, Optional, Tuple
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.distributed.checkpoint.state_dict import (
12
+ set_model_state_dict,
13
+ StateDictOptions,
14
+ )
15
+ from torch.distributed.tensor import DTensor
16
+ from peft import (
17
+ LoraConfig,
18
+ TaskType,
19
+ set_peft_model_state_dict,
20
+ )
21
+
22
+ from camel.training_setup import is_main_process
23
+
24
+ def _module_has_dtensor_params(mod: nn.Module) -> bool:
25
+ """
26
+ Return True if any parameter tensor underlying the module is a DTensor.
27
+ Under FSDP2 it is typically parameter.data that carries the DTensor type.
28
+ """
29
+ for param in mod.parameters(recurse=True):
30
+ if isinstance(getattr(param, "data", None), DTensor):
31
+ return True
32
+ return False
33
+
34
+ def _extract_projector_name(payload: Dict[str, Any]) -> Optional[str]:
35
+ """Return the stored projector name if present."""
36
+ name = payload.get("projector_name")
37
+ if isinstance(name, str) and name:
38
+ return name
39
+ extra = payload.get("extra")
40
+ if isinstance(extra, dict):
41
+ extra_name = extra.get("projector_name")
42
+ if isinstance(extra_name, str) and extra_name:
43
+ return extra_name
44
+ return None
45
+
46
+ def peek_projector_name(path: str) -> Optional[str]:
47
+ """Load a checkpoint just far enough to read the projector name metadata."""
48
+ if path is None:
49
+ return None
50
+ payload = torch.load(path, map_location="cpu")
51
+ if not isinstance(payload, dict):
52
+ raise RuntimeError(f"Checkpoint {path} must be a dict to inspect projector metadata.")
53
+ return _extract_projector_name(payload)
54
+
55
+ def extract_lora_config_from_checkpoints(
56
+ resume_ckpt_path: Optional[str],
57
+ load_llava_from: Optional[str],
58
+ ) -> Optional[Dict[str, Any]]:
59
+ """Extract the LoRA configuration embedded in checkpoints (resume path preferred)."""
60
+ def _load_config(path: Optional[str]) -> Optional[Dict[str, Any]]:
61
+ if not path:
62
+ return None
63
+ try:
64
+ payload = torch.load(path, map_location="cpu")
65
+ except Exception as exc:
66
+ raise RuntimeError(
67
+ f"Failed to load checkpoint '{path}' while extracting LoRA config"
68
+ ) from exc
69
+ lora_payload = payload.get("lora")
70
+ if isinstance(lora_payload, dict) and isinstance(lora_payload.get("config"), dict):
71
+ cfg = dict(lora_payload["config"])
72
+ if "use_dora" in cfg:
73
+ cfg["use_dora"] = bool(cfg["use_dora"])
74
+ return cfg
75
+ return None
76
+ cfg = _load_config(resume_ckpt_path)
77
+ if cfg is not None:
78
+ return cfg
79
+ return _load_config(load_llava_from)
80
+
81
+ def load_llava_and_lora(
82
+ wrapper: nn.Module,
83
+ model: nn.Module,
84
+ ckpt_path: str,
85
+ *,
86
+ expect_lora: bool,
87
+ load_lora: bool = True,
88
+ missing_lora_ok: bool = False,
89
+ ) -> Tuple[Dict[str, Any], nn.Module, Optional[LoraConfig]]:
90
+ """
91
+ Load llava_proj (mandatory), optional conv encoder weights, ECG special-token
92
+ embeddings, and LoRA adapters from a checkpoint.
93
+ """
94
+ payload = torch.load(ckpt_path, map_location="cpu")
95
+ if not isinstance(payload, dict):
96
+ raise RuntimeError(f"Checkpoint {ckpt_path} must be a dict, got {type(payload).__name__}")
97
+ extra_payload = payload.get("extra") or {}
98
+ if not isinstance(extra_payload, dict):
99
+ raise RuntimeError(
100
+ f"Checkpoint {ckpt_path} has non-dict extra payload of type {type(extra_payload).__name__}"
101
+ )
102
+ ckpt_projector = _extract_projector_name(payload)
103
+ wrapper_projector = getattr(wrapper, "projector_name", None)
104
+ if ckpt_projector is not None and wrapper_projector is not None and wrapper_projector != ckpt_projector:
105
+ raise RuntimeError(
106
+ f"Checkpoint {ckpt_path} projector '{ckpt_projector}' does not match wrapper projector '{wrapper_projector}'."
107
+ )
108
+ llava_sd = payload.get("llava_proj")
109
+ if not isinstance(llava_sd, dict):
110
+ raise RuntimeError(f"Checkpoint {ckpt_path} missing llava_proj state_dict.")
111
+ if is_main_process():
112
+ print(f"[load-llava] Loading llava_proj weights from {ckpt_path}", flush=True)
113
+ any_llava_dt = _module_has_dtensor_params(wrapper.llava_proj)
114
+ if any_llava_dt:
115
+ set_model_state_dict(
116
+ model=wrapper.llava_proj,
117
+ model_state_dict=llava_sd,
118
+ options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True),
119
+ )
120
+ else:
121
+ wrapper.llava_proj.load_state_dict(llava_sd, strict=True)
122
+ conv_sd = payload.get("conv")
123
+ conv_expected = bool(extra_payload.get("conv_trainable"))
124
+ if conv_sd is None:
125
+ if conv_expected:
126
+ raise RuntimeError(
127
+ f"Checkpoint {ckpt_path} indicates conv_trainable=True but conv weights are missing."
128
+ )
129
+ else:
130
+ if not isinstance(conv_sd, dict):
131
+ raise RuntimeError(
132
+ f"Checkpoint {ckpt_path} conv payload must be a state_dict, got {type(conv_sd).__name__}"
133
+ )
134
+ if is_main_process():
135
+ print(f"[load-llava] Loading conv encoder weights from {ckpt_path}", flush=True)
136
+ any_conv_dt = _module_has_dtensor_params(wrapper.enc)
137
+ if any_conv_dt:
138
+ set_model_state_dict(
139
+ model=wrapper.enc,
140
+ model_state_dict=conv_sd,
141
+ options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True),
142
+ )
143
+ else:
144
+ wrapper.enc.load_state_dict(conv_sd, strict=True)
145
+ ecg_special_sd = payload.get("ecg_special")
146
+ if not isinstance(ecg_special_sd, dict):
147
+ raise RuntimeError(f"Checkpoint {ckpt_path} missing ECG special-token embedding state.")
148
+ if is_main_process():
149
+ print(f"[load-llava] Loading ECG special-token embedding from {ckpt_path}", flush=True)
150
+ any_special_dt = _module_has_dtensor_params(wrapper.ecg_special_embed)
151
+ if any_special_dt:
152
+ set_model_state_dict(
153
+ model=wrapper.ecg_special_embed,
154
+ model_state_dict=ecg_special_sd,
155
+ options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True),
156
+ )
157
+ else:
158
+ wrapper.ecg_special_embed.load_state_dict(ecg_special_sd, strict=True)
159
+ lora_payload = payload.get("lora")
160
+ loaded_lora = False
161
+ created_cfg: Optional[LoraConfig] = None
162
+ if load_lora and lora_payload is not None:
163
+ if not isinstance(lora_payload, dict):
164
+ raise RuntimeError(
165
+ f"Checkpoint {ckpt_path} has non-dict LoRA payload of type {type(lora_payload).__name__}"
166
+ )
167
+ lora_state = lora_payload.get("state_dict")
168
+ if not isinstance(lora_state, dict):
169
+ raise RuntimeError(f"Checkpoint {ckpt_path} LoRA payload missing state_dict.")
170
+ if is_main_process():
171
+ print(f"[load-llava] Loading LoRA adapters from {ckpt_path}", flush=True)
172
+ set_peft_model_state_dict(model, lora_state)
173
+ loaded_lora = True
174
+ cfg_dict = lora_payload.get("config")
175
+ if isinstance(cfg_dict, dict):
176
+ cfg_args = dict(cfg_dict)
177
+ task_type_raw = cfg_args.get("task_type", TaskType.CAUSAL_LM)
178
+ if not isinstance(task_type_raw, TaskType):
179
+ try:
180
+ task_type_raw = TaskType(task_type_raw)
181
+ except Exception:
182
+ task_type_raw = TaskType.CAUSAL_LM
183
+ cfg_args["task_type"] = task_type_raw
184
+ if "lora_dropout" in cfg_args:
185
+ try:
186
+ cfg_args["lora_dropout"] = float(cfg_args["lora_dropout"])
187
+ except Exception:
188
+ cfg_args["lora_dropout"] = 0.0
189
+ if "r" in cfg_args:
190
+ try:
191
+ cfg_args["r"] = int(cfg_args["r"])
192
+ except Exception:
193
+ cfg_args["r"] = 0
194
+ if "lora_alpha" in cfg_args:
195
+ try:
196
+ cfg_args["lora_alpha"] = int(cfg_args["lora_alpha"])
197
+ except Exception:
198
+ cfg_args["lora_alpha"] = 0
199
+ if "target_modules" in cfg_args and cfg_args["target_modules"] is not None:
200
+ cfg_args["target_modules"] = list(cfg_args["target_modules"])
201
+ cfg_args.setdefault("bias", "none")
202
+ cfg_args.setdefault("inference_mode", False)
203
+ cfg_args["use_dora"] = bool(cfg_args.get("use_dora", False))
204
+ try:
205
+ created_cfg = LoraConfig(**cfg_args)
206
+ except Exception:
207
+ created_cfg = None
208
+ if expect_lora and load_lora and not loaded_lora:
209
+ if missing_lora_ok:
210
+ if is_main_process():
211
+ print(
212
+ f"[load-llava] Warning: checkpoint {ckpt_path} contains no LoRA adapters; "
213
+ "continuing with the currently-initialized adapters.",
214
+ flush=True,
215
+ )
216
+ else:
217
+ raise RuntimeError(
218
+ f"[load-llava] Warning: expected LoRA adapters in {ckpt_path} but none were loaded."
219
+ )
220
+ return extra_payload, model, created_cfg
221
+
222
+ def update_wrapper_language_model(wrapper: nn.Module, model: nn.Module) -> None:
223
+ """Ensure the wrapper references the latest language-model instance."""
224
+ wrapper.language_model = model
225
+
226
+ __all__ = [
227
+ "find_latest_step_checkpoint",
228
+ "extract_lora_config_from_checkpoints",
229
+ "dump_lora_state_fsdp_safe",
230
+ "prepare_optimizer_state_payload",
231
+ "load_llava_and_lora",
232
+ "update_wrapper_language_model",
233
+ "ensure_no_dtensor",
234
+ "peek_projector_name",
235
+ ]
camel_inference/src/camel/ecg_attention_masks.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ECG-aware attention mask utilities with pluggable strategy support."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Dict, List, Optional, Protocol
7
+ import torch
8
+
9
+ @dataclass
10
+ class ECGBlockLayout:
11
+ """Layout describing a single ECG block inside the assembled sequence."""
12
+
13
+ start_idx: Optional[int]
14
+ end_idx_exclusive: Optional[int]
15
+ global_start_idx: Optional[int] = None
16
+ global_end_idx: Optional[int] = None
17
+ lead_start_idx: Dict[str, int] = field(default_factory=dict)
18
+ lead_end_idx: Dict[str, int] = field(default_factory=dict)
19
+ signal_pos_by_lead: Dict[str, List[int]] = field(default_factory=dict)
20
+ time_to_signal_idxs: Dict[int, List[int]] = field(default_factory=dict)
21
+ special_idxs_sorted: List[int] = field(default_factory=list)
22
+ signal_pos_list: List[int] = field(default_factory=list)
23
+ declared_segments_per_lead: Dict[str, int] = field(default_factory=dict)
24
+ conv_idxs: List[int] = field(default_factory=list)
25
+
26
+
27
+ @dataclass
28
+ class ECGSequenceLayout:
29
+ """Compact description of the assembled token layout for one training sample."""
30
+
31
+ seq_len: int
32
+ text_idxs: List[int] = field(default_factory=list)
33
+ blocks: List[ECGBlockLayout] = field(default_factory=list)
34
+
35
+
36
+ def _as_tensor(indices: List[int], device: torch.device) -> torch.Tensor:
37
+ if not indices:
38
+ return torch.empty(0, dtype=torch.long, device=device)
39
+ return torch.tensor(sorted(indices), dtype=torch.long, device=device)
40
+
41
+
42
+ def apply_single_block_semantic_mask_(
43
+ allowed: torch.Tensor,
44
+ block_layout: ECGBlockLayout,
45
+ *,
46
+ visible_prefix_len: int,
47
+ key_limit_exclusive: int,
48
+ apply_header_causal: bool = True,
49
+ ) -> None:
50
+ """
51
+ In-place semantic mask update for a single ECG block. Mirrors the historical
52
+ single-block logic, but operates on a provided boolean mask.
53
+ """
54
+ L = int(allowed.size(0))
55
+ device = allowed.device
56
+ header = (
57
+ torch.arange(int(visible_prefix_len), dtype=torch.long, device=device)
58
+ if int(visible_prefix_len) > 0
59
+ else torch.empty(0, dtype=torch.long, device=device)
60
+ )
61
+
62
+ specials_list = block_layout.special_idxs_sorted or []
63
+ if not specials_list:
64
+ specials_list = sorted(
65
+ ([block_layout.global_start_idx] if block_layout.global_start_idx is not None else [])
66
+ + list(block_layout.lead_start_idx.values())
67
+ + list(block_layout.lead_end_idx.values())
68
+ + ([block_layout.global_end_idx] if block_layout.global_end_idx is not None else [])
69
+ )
70
+ block_layout.special_idxs_sorted = specials_list
71
+ if key_limit_exclusive is not None:
72
+ specials_list = [i for i in specials_list if int(i) < int(key_limit_exclusive)]
73
+ specials = _as_tensor([int(i) for i in specials_list], device)
74
+
75
+ signals_list = block_layout.signal_pos_list or []
76
+ if not signals_list:
77
+ signal_all: List[int] = []
78
+ for lst in block_layout.signal_pos_by_lead.values():
79
+ signal_all.extend(lst)
80
+ signals_list = sorted(signal_all)
81
+ block_layout.signal_pos_list = signals_list
82
+ if key_limit_exclusive is not None:
83
+ signals_list = [i for i in signals_list if int(i) < int(key_limit_exclusive)]
84
+ signals = _as_tensor([int(i) for i in signals_list], device)
85
+
86
+ lead_starts = _as_tensor(list(block_layout.lead_start_idx.values()), device)
87
+ lead_ends = _as_tensor(list(block_layout.lead_end_idx.values()), device)
88
+
89
+ if apply_header_causal and header.numel():
90
+ allowed[header[:, None], header[None, :]] = header[:, None] >= header[None, :]
91
+
92
+ gs = block_layout.global_start_idx
93
+ if gs is not None and header.numel():
94
+ allowed[int(gs), header] = True
95
+
96
+ rows_before = []
97
+ if lead_starts.numel():
98
+ rows_before.append(lead_starts)
99
+ if signals.numel():
100
+ rows_before.append(signals)
101
+ if lead_ends.numel():
102
+ rows_before.append(lead_ends)
103
+ rows_before_t = (
104
+ torch.cat(rows_before, dim=0) if rows_before else torch.empty(0, dtype=torch.long, device=device)
105
+ )
106
+
107
+ if rows_before_t.numel():
108
+ if header.numel():
109
+ allowed[rows_before_t[:, None], header[None, :]] = True
110
+ if specials.numel():
111
+ allowed[rows_before_t[:, None], specials[None, :]] = specials[None, :] < rows_before_t[:, None]
112
+
113
+ ttsi: Dict[int, List[int]] = block_layout.time_to_signal_idxs
114
+ if signals.numel() and ttsi:
115
+ pos_min_time: Dict[int, int] = {}
116
+ pos_to_time: Dict[int, int] = {}
117
+ for t, idxs in ttsi.items():
118
+ for p in idxs:
119
+ pos_to_time[p] = t
120
+ prev = pos_min_time.get(p)
121
+ if prev is None or t < prev:
122
+ pos_min_time[p] = t
123
+
124
+ u_pos_list = sorted(pos_min_time.keys())
125
+ if u_pos_list:
126
+ u_pos = torch.tensor(u_pos_list, dtype=torch.long, device=device)
127
+ u_time = torch.tensor([pos_min_time[p] for p in u_pos_list], dtype=torch.long, device=device)
128
+ q_time = torch.tensor([pos_to_time.get(p, 0) for p in signals_list], dtype=torch.long, device=device)
129
+ allowed[signals[:, None], u_pos[None, :]] = (u_time[None, :] <= q_time[:, None])
130
+
131
+ for lead, eidx in block_layout.lead_end_idx.items():
132
+ lead_sigs = block_layout.signal_pos_by_lead.get(lead, [])
133
+ if lead_sigs:
134
+ allowed[int(eidx), torch.tensor(lead_sigs, dtype=torch.long, device=device)] = True
135
+
136
+ ge = block_layout.global_end_idx
137
+ if ge is not None:
138
+ gei = int(ge)
139
+ if header.numel():
140
+ allowed[gei, header] = True
141
+ if specials.numel():
142
+ allowed[gei, specials] = True
143
+ if signals.numel():
144
+ allowed[gei, signals] = True
145
+
146
+ conv = _as_tensor(block_layout.conv_idxs, device)
147
+ if conv.numel():
148
+ allowed[conv[:, None], conv[None, :]] = conv[:, None] >= conv[None, :]
149
+ if header.numel():
150
+ allowed[conv[:, None], header[None, :]] = True
151
+ if specials.numel():
152
+ allowed[conv[:, None], specials[None, :]] = True
153
+ if signals.numel():
154
+ allowed[conv[:, None], signals[None, :]] = True
155
+ cols = torch.arange(L, device=device)
156
+ conv_rows = allowed[conv, :]
157
+ conv_rows &= (cols.unsqueeze(0) <= conv.unsqueeze(1))
158
+ allowed[conv, :] = conv_rows
159
+
160
+ if specials.numel():
161
+ allowed[specials, specials] = True
162
+
163
+ if key_limit_exclusive is not None and int(key_limit_exclusive) < L:
164
+ block_rows_list = list(specials_list) + list(signals_list)
165
+ if block_rows_list:
166
+ block_rows = _as_tensor(block_rows_list, device)
167
+ allowed[block_rows, int(key_limit_exclusive):] = False
168
+
169
+ @dataclass
170
+ class MaskBuildResult:
171
+ """Container for per-sample mask artifacts produced by a strategy."""
172
+
173
+ additive: torch.Tensor
174
+ boolean: Optional[torch.Tensor] = None
175
+
176
+
177
+ class ECGMaskStrategy(Protocol):
178
+ """Protocol for strategies that build and update per-sample attention masks."""
179
+
180
+ name: str
181
+
182
+ def build(
183
+ self,
184
+ layout: ECGSequenceLayout,
185
+ *,
186
+ device: torch.device,
187
+ dtype: torch.dtype,
188
+ ) -> MaskBuildResult:
189
+ ...
190
+
191
+ def update_for_generated_token(
192
+ self,
193
+ layout: ECGSequenceLayout,
194
+ *,
195
+ device: torch.device,
196
+ dtype: torch.dtype,
197
+ previous: MaskBuildResult,
198
+ ) -> MaskBuildResult:
199
+ ...
200
+
201
+
202
+ class SemanticMaskStrategy:
203
+ """Default strategy reproducing the historical ECG-aware attention mask."""
204
+
205
+ name = "semantic"
206
+
207
+ def build(
208
+ self,
209
+ layout: ECGSequenceLayout,
210
+ *,
211
+ device: torch.device,
212
+ dtype: torch.dtype,
213
+ ) -> MaskBuildResult:
214
+ L = int(layout.seq_len)
215
+ if L <= 0:
216
+ raise ValueError("Semantic mask requires a positive sequence length")
217
+ allowed = torch.tril(torch.ones((L, L), dtype=torch.bool, device=device))
218
+ multi_block = len(layout.blocks) > 1
219
+ for block in layout.blocks:
220
+ if block.start_idx is None or block.end_idx_exclusive is None:
221
+ continue
222
+ block_rows: List[int] = []
223
+ specials_list = block.special_idxs_sorted or []
224
+ if not specials_list:
225
+ specials_list = sorted(
226
+ ([block.global_start_idx] if block.global_start_idx is not None else [])
227
+ + list(block.lead_start_idx.values())
228
+ + list(block.lead_end_idx.values())
229
+ + ([block.global_end_idx] if block.global_end_idx is not None else [])
230
+ )
231
+ block.special_idxs_sorted = specials_list
232
+ signals_list = block.signal_pos_list or []
233
+ if not signals_list:
234
+ signal_all: List[int] = []
235
+ for lst in block.signal_pos_by_lead.values():
236
+ signal_all.extend(lst)
237
+ signals_list = sorted(signal_all)
238
+ block.signal_pos_list = signals_list
239
+ if specials_list:
240
+ block_rows.extend(int(i) for i in specials_list)
241
+ if signals_list:
242
+ block_rows.extend(int(i) for i in signals_list)
243
+ if block_rows:
244
+ rows = torch.tensor(sorted(set(block_rows)), dtype=torch.long, device=device)
245
+ allowed[rows, :] = False
246
+ apply_single_block_semantic_mask_(
247
+ allowed,
248
+ block,
249
+ visible_prefix_len=int(block.start_idx),
250
+ key_limit_exclusive=int(block.end_idx_exclusive),
251
+ apply_header_causal=not multi_block,
252
+ )
253
+ additive = self._boolean_to_additive(allowed, device=device, dtype=dtype)
254
+ return MaskBuildResult(additive=additive, boolean=allowed)
255
+
256
+ def update_for_generated_token(
257
+ self,
258
+ layout: ECGSequenceLayout,
259
+ *,
260
+ device: torch.device,
261
+ dtype: torch.dtype,
262
+ previous: MaskBuildResult,
263
+ ) -> MaskBuildResult:
264
+ allowed = previous.boolean
265
+ if allowed is None:
266
+ return self.build(layout, device=device, dtype=dtype)
267
+ prev_len = int(allowed.size(0))
268
+ new_allowed = torch.zeros((prev_len + 1, prev_len + 1), dtype=torch.bool, device=device)
269
+ new_allowed[:prev_len, :prev_len] = allowed
270
+ new_allowed[prev_len, : prev_len + 1] = True
271
+ additive = self._boolean_to_additive(new_allowed, device=device, dtype=dtype)
272
+ return MaskBuildResult(additive=additive, boolean=new_allowed)
273
+
274
+
275
+ @classmethod
276
+ def _build_boolean_mask(cls, layout: "ECGSequenceLayout", device: torch.device) -> torch.Tensor:
277
+ if len(layout.blocks) != 1:
278
+ raise ValueError("Single-block mask builder requires exactly one ECG block")
279
+ L = int(layout.seq_len)
280
+ allowed = torch.zeros((L, L), dtype=torch.bool, device=device)
281
+ block = layout.blocks[0]
282
+ prefix_len = int(block.start_idx or 0)
283
+ end_idx = int(block.end_idx_exclusive or L)
284
+ conv_idxs = [idx for idx in layout.text_idxs if int(idx) >= end_idx]
285
+ block_ref = ECGBlockLayout(
286
+ start_idx=block.start_idx,
287
+ end_idx_exclusive=block.end_idx_exclusive,
288
+ global_start_idx=block.global_start_idx,
289
+ global_end_idx=block.global_end_idx,
290
+ lead_start_idx=dict(block.lead_start_idx),
291
+ lead_end_idx=dict(block.lead_end_idx),
292
+ signal_pos_by_lead=dict(block.signal_pos_by_lead),
293
+ time_to_signal_idxs=dict(block.time_to_signal_idxs),
294
+ special_idxs_sorted=list(block.special_idxs_sorted),
295
+ signal_pos_list=list(block.signal_pos_list),
296
+ declared_segments_per_lead=dict(block.declared_segments_per_lead),
297
+ conv_idxs=conv_idxs,
298
+ )
299
+ apply_single_block_semantic_mask_(
300
+ allowed,
301
+ block_ref,
302
+ visible_prefix_len=prefix_len,
303
+ key_limit_exclusive=end_idx,
304
+ )
305
+ return allowed
306
+
307
+ @staticmethod
308
+ def _boolean_to_additive(allowed: torch.Tensor, *, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
309
+ additive = torch.zeros(allowed.shape, dtype=dtype, device=device)
310
+ additive.masked_fill_(~allowed, float("-inf"))
311
+ return additive
312
+
313
+
314
+ DEFAULT_MASK_STRATEGY = SemanticMaskStrategy()
315
+
316
+
317
+ MASK_STRATEGY_REGISTRY: Dict[str, ECGMaskStrategy] = {
318
+ DEFAULT_MASK_STRATEGY.name: DEFAULT_MASK_STRATEGY,
319
+ }
320
+
321
+
322
+ def get_mask_strategy(name: Optional[str]) -> ECGMaskStrategy:
323
+ """Resolve a registered mask strategy by name (case-insensitive)."""
324
+ if name is None:
325
+ return DEFAULT_MASK_STRATEGY
326
+ key = str(name).lower()
327
+ try:
328
+ return MASK_STRATEGY_REGISTRY[key]
329
+ except KeyError as exc:
330
+ known = ", ".join(sorted(MASK_STRATEGY_REGISTRY))
331
+ raise ValueError(f"Unknown ECG mask strategy '{name}'. Known strategies: {known}") from exc
332
+
333
+
334
+ __all__ = [
335
+ "ECGBlockLayout",
336
+ "ECGSequenceLayout",
337
+ "MaskBuildResult",
338
+ "ECGMaskStrategy",
339
+ "SemanticMaskStrategy",
340
+ "DEFAULT_MASK_STRATEGY",
341
+ "MASK_STRATEGY_REGISTRY",
342
+ "get_mask_strategy",
343
+ ]
camel_inference/src/camel/ecg_gemma_model.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ecg_gemma_model.py
2
+ from __future__ import annotations
3
+
4
+ from typing import Optional
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch import Tensor
8
+ from transformers import AutoModelForCausalLM
9
+
10
+ from camel.ecg_model_wrapper import ECGLanguageModelWrapper
11
+
12
+ class ECGGemmaPrefix(ECGLanguageModelWrapper):
13
+ """
14
+ Frozen: Gemma-IT (language model), 1D conv signal encoder (loaded from disk)
15
+ Trainable: llava_proj (Linear 64 -> Gemma hidden size)
16
+ Optionally trainable: conv encoder when explicitly requested
17
+
18
+ This wrapper turns per-second ECG windows into single "pseudo-token" rows
19
+ that are interleaved into user turns at embedding time.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ gemma: AutoModelForCausalLM,
25
+ enc: nn.Module,
26
+ hidden_size: int,
27
+ num_ecg_special_tokens: int,
28
+ dtype: Optional[torch.dtype] = torch.bfloat16,
29
+ enc_out_dim: int = 64, # from the specified conv stack: 4 channels * 16 length = 64
30
+ freeze_encoder: bool = True,
31
+ inference: bool = False,
32
+ projector_name: str = "linear",
33
+ ):
34
+ super().__init__(
35
+ language_model=gemma,
36
+ conv_encoder=enc,
37
+ hidden_size=hidden_size,
38
+ num_ecg_special_tokens=num_ecg_special_tokens,
39
+ dtype=dtype,
40
+ enc_out_dim=enc_out_dim,
41
+ freeze_encoder=freeze_encoder,
42
+ inference=inference,
43
+ projector_name=projector_name,
44
+ )
45
+ # Use language_model consistently; no Gemma alias is set.
46
+
47
+ def forward_language_model(
48
+ self,
49
+ inputs_embeds: Tensor,
50
+ attention_mask: Tensor,
51
+ labels: Optional[Tensor],
52
+ output_hidden_states = False,
53
+ ):
54
+ # Ensure inputs live on the device of the LM input embeddings when sharded
55
+ embedder_fn = getattr(self.language_model, "get_input_embeddings", None)
56
+ if callable(embedder_fn):
57
+ try:
58
+ embed_module = embedder_fn()
59
+ except Exception as exc:
60
+ raise RuntimeError("Failed to obtain language-model input embeddings for device inference") from exc
61
+ if not hasattr(embed_module, "weight"):
62
+ raise RuntimeError("Input embedding module lacks a weight parameter; cannot infer device")
63
+ dev0 = embed_module.weight.device
64
+ else:
65
+ params_iter = self.language_model.parameters()
66
+ try:
67
+ first_param = next(params_iter)
68
+ except StopIteration as exc:
69
+ raise RuntimeError("Language model exposes no parameters to infer device placement") from exc
70
+ dev0 = first_param.device
71
+
72
+ if inputs_embeds.device != dev0:
73
+ inputs_embeds = inputs_embeds.to(dev0)
74
+ if attention_mask.device != dev0:
75
+ attention_mask = attention_mask.to(dev0)
76
+ if labels is not None and labels.device != dev0:
77
+ labels = labels.to(dev0)
78
+
79
+ return self.language_model(
80
+ inputs_embeds=inputs_embeds,
81
+ attention_mask=attention_mask,
82
+ labels=labels,
83
+ use_cache=False,
84
+ output_hidden_states=output_hidden_states,
85
+ )
86
+
87
+
88
+ # FSDP2 helpers removed (not used)
89
+
90
+
91
+ __all__ = ["ECGGemmaPrefix"]
camel_inference/src/camel/ecg_model_wrapper.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Abstract ECG-language model adapter with shared conv→projection logic.
3
+
4
+ Subclasses can override only the pieces that differ (e.g., prompt format or
5
+ language-model forwarding) while inheriting the common ECG prefix handling.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import Dict, Iterable, List, Mapping, Optional, Tuple
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch import Tensor
14
+
15
+ from camel.assertions import assert_wrapper_embed_length, assert_rest_length_nonnegative, assert_tensor_dtype
16
+ from camel.projectors import build_projector
17
+
18
+ class ECGNonFiniteInputError(RuntimeError):
19
+ """Raised when the conv encoder input contains NaN or Inf values."""
20
+
21
+ def __init__(self, sample_idx: int, lead: Optional[str] = None) -> None:
22
+ self.sample_idx = int(sample_idx)
23
+ self.lead = lead
24
+ lead_part = f", lead={lead}" if lead is not None else ""
25
+ super().__init__(f"Non-finite waveform detected (sample_idx={self.sample_idx}{lead_part})")
26
+
27
+
28
+ class ECGLanguageModelWrapper(nn.Module):
29
+ """
30
+ Turns per-lead ECG waveforms into prefix embeddings consumable by a language model.
31
+ Stores the ECG encoder, trainable adapter, and the target language model.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ *,
37
+ language_model: nn.Module,
38
+ conv_encoder: nn.Module,
39
+ hidden_size: int,
40
+ num_ecg_special_tokens: int,
41
+ dtype: Optional[torch.dtype] = torch.bfloat16,
42
+ enc_out_dim: int = 64,
43
+ freeze_encoder: bool = True,
44
+ inference: bool = False,
45
+ projector_name: str = "linear",
46
+ ) -> None:
47
+ super().__init__()
48
+
49
+ if int(num_ecg_special_tokens) <= 0:
50
+ raise ValueError("num_ecg_special_tokens must be positive")
51
+
52
+ # Keep LM in train mode (actual freezing handled by caller).
53
+ self.language_model = language_model.train()
54
+
55
+ # Conv encoder may be frozen or trainable depending on configuration.
56
+ self.enc = conv_encoder
57
+ if freeze_encoder:
58
+ self.enc = self.enc.eval()
59
+ for p in self.enc.parameters():
60
+ p.requires_grad = False
61
+ else:
62
+ self.enc = self.enc.train()
63
+ for p in self.enc.parameters():
64
+ p.requires_grad = True
65
+
66
+ self.hidden_size = int(hidden_size)
67
+ self.dtype = dtype or torch.bfloat16
68
+ self.num_ecg_special_tokens = int(num_ecg_special_tokens)
69
+ self.inference = bool(inference)
70
+ self.projector_name = str(projector_name or "linear")
71
+
72
+ # Trainable adapter: conv output → LM hidden size (kept in fp32 for stability).
73
+ self.llava_proj = build_projector(self.projector_name, int(enc_out_dim), self.hidden_size)
74
+ if self.inference:
75
+ self.llava_proj.to(dtype=torch.float32)
76
+
77
+ self.ecg_special_embed = nn.Embedding(self.num_ecg_special_tokens, self.hidden_size)
78
+ self.ecg_special_embed.to(dtype=self.dtype)
79
+ if self.inference:
80
+ self.enc.to(dtype=torch.float32)
81
+ self._grad_ckpt_enabled = self._detect_grad_ckpt_state()
82
+
83
+ def _projector_param_dtype(self) -> torch.dtype:
84
+ """Return dtype of the first projector parameter (defaults to fp32)."""
85
+ first_param = next(self.llava_proj.parameters(), None)
86
+ if first_param is None:
87
+ return torch.float32
88
+ return first_param.dtype
89
+
90
+ # ---- Gradient checkpointing --------------------------------------------------------------
91
+
92
+ def _detect_grad_ckpt_state(self) -> bool:
93
+ if hasattr(self.language_model, "is_gradient_checkpointing"):
94
+ try:
95
+ return bool(self.language_model.is_gradient_checkpointing)
96
+ except Exception:
97
+ return False
98
+ if hasattr(self.language_model, "gradient_checkpointing"):
99
+ try:
100
+ return bool(self.language_model.gradient_checkpointing)
101
+ except Exception:
102
+ return False
103
+ return False
104
+
105
+ def set_gradient_checkpointing(self, enabled: bool) -> bool:
106
+ enabled = bool(enabled)
107
+ current = self._detect_grad_ckpt_state()
108
+ if current == enabled:
109
+ self._grad_ckpt_enabled = enabled
110
+ return False
111
+ if enabled:
112
+ if hasattr(self.language_model, "gradient_checkpointing_enable"):
113
+ self.language_model.gradient_checkpointing_enable()
114
+ if hasattr(getattr(self.language_model, "config", None), "use_cache"):
115
+ self.language_model.config.use_cache = False
116
+ else:
117
+ if hasattr(self.language_model, "gradient_checkpointing_disable"):
118
+ self.language_model.gradient_checkpointing_disable()
119
+ elif hasattr(self.language_model, "gradient_checkpointing"):
120
+ try:
121
+ self.language_model.gradient_checkpointing = False
122
+ except Exception:
123
+ pass
124
+ if hasattr(getattr(self.language_model, "config", None), "use_cache"):
125
+ self.language_model.config.use_cache = True
126
+ self._grad_ckpt_enabled = enabled
127
+ return True
128
+
129
+ def enable_gradient_checkpointing(self) -> None:
130
+ self.set_gradient_checkpointing(True)
131
+
132
+ def disable_gradient_checkpointing(self) -> None:
133
+ self.set_gradient_checkpointing(False)
134
+
135
+ def is_gradient_checkpointing_enabled(self) -> bool:
136
+ return bool(self._grad_ckpt_enabled)
137
+
138
+ # ---- Token helpers -----------------------------------------------------------------------
139
+
140
+ def tokens_to_embeds(self, input_embedder: nn.Embedding, ids: List[int], device: torch.device) -> Tensor:
141
+ ids_t = torch.tensor(ids, dtype=torch.long, device=device)
142
+ embeddings = input_embedder(ids_t)
143
+ embeddings = embeddings.to(dtype=self.dtype)
144
+ # Defensive: enforce 1:1 mapping between ids and embeddings
145
+ assert_wrapper_embed_length(embeddings=embeddings, ids=ids, context="tokens_to_embeds")
146
+ return embeddings
147
+
148
+ def ecg_special_tokens_to_embeds(self, indices: torch.Tensor | List[int], device: torch.device) -> Tensor:
149
+ if torch.is_tensor(indices):
150
+ idx = indices.to(device=device, dtype=torch.long)
151
+ else:
152
+ idx = torch.tensor(indices, dtype=torch.long, device=device)
153
+ embeds = self.ecg_special_embed(idx)
154
+ return embeds.to(dtype=self.dtype)
155
+
156
+ # ---- ECG prefix encoding -----------------------------------------------------------------
157
+
158
+ def ecg_prefix(
159
+ self,
160
+ waveform_segments: Dict[str, Tensor],
161
+ device: torch.device,
162
+ lead_order: Optional[List[str]] = None,
163
+ ) -> Tensor:
164
+ """Encode a single sample's ECG prefix in a deterministic lead order.
165
+
166
+ Args:
167
+ waveform_segments: Mapping lead -> [T,256] windows.
168
+ device: Target device for encoder input.
169
+ lead_order: Optional explicit order of leads to iterate. If provided,
170
+ segments are concatenated in this exact order; otherwise relies on
171
+ insertion order of the mapping.
172
+ """
173
+ seqs: List[Tensor] = []
174
+ if lead_order is None:
175
+ items = list(waveform_segments.items())
176
+ else:
177
+ items = [(ld, waveform_segments[ld]) for ld in lead_order if ld in waveform_segments]
178
+ # Validate presence when an explicit order is provided
179
+ missing = [ld for ld in lead_order if ld not in waveform_segments]
180
+ if missing:
181
+ raise ValueError(f"Missing leads in waveform_segments for requested order: {missing}")
182
+ for lead, seg in items:
183
+ seg = torch.as_tensor(seg)
184
+ if seg.ndim != 2 or seg.shape[1] != 256:
185
+ raise ValueError(f"Waveform for lead {lead} must be [T,256], got {seg.shape}")
186
+ seqs.append(seg)
187
+ x = torch.cat(seqs, dim=0)
188
+ x = x.to(device=device, dtype=torch.float32)
189
+ x = x.unsqueeze(1) # [P, 1, 256]
190
+
191
+ enc_trainable = any(p.requires_grad for p in self.enc.parameters())
192
+ ctx = torch.enable_grad() if enc_trainable else torch.no_grad()
193
+ with ctx:
194
+ z = self.enc(x) # [P, C, L]
195
+ if self.inference:
196
+ z = z.to(dtype=torch.float32)
197
+ assert_tensor_dtype(z, expected=torch.float32, context="conv encoder output (single)")
198
+ self.ensure_finite(z, "conv encoder output")
199
+ z = z.flatten(1) # [P, 64] for conv stack
200
+
201
+ proj_dtype = self._projector_param_dtype()
202
+ if z.dtype != proj_dtype:
203
+ z = z.to(dtype=proj_dtype)
204
+ y = self.llava_proj(z)
205
+ if self.inference:
206
+ y = y.to(dtype=torch.float32)
207
+ assert_tensor_dtype(y, expected=torch.float32, context="llava_proj output (single)")
208
+ return y.to(dtype=self.dtype)
209
+
210
+ def ecg_prefix_batch(
211
+ self,
212
+ waveform_segments_batch: List[Dict[str, Tensor]],
213
+ device: torch.device,
214
+ lead_orders: Optional[List[List[str]]] = None,
215
+ ) -> Tuple[Tensor, List[int]]:
216
+ """Encode a batch of ECG prefixes with explicit per-sample lead orders.
217
+
218
+ Args:
219
+ waveform_segments_batch: List of dicts; each maps lead -> [T,256].
220
+ device: Target device for encoder input.
221
+ lead_orders: Optional list of lead-order lists, one per sample. If
222
+ provided, each sample's segments are concatenated in that exact
223
+ order; otherwise relies on insertion order of each mapping.
224
+ """
225
+ all_seqs: List[Tensor] = []
226
+ prefix_lengths: List[int] = []
227
+
228
+ for i, wv_dict in enumerate(waveform_segments_batch):
229
+ seqs: List[Tensor] = []
230
+ leads_for_sample: List[str] = []
231
+ if lead_orders is not None and i < len(lead_orders) and lead_orders[i] is not None:
232
+ order = lead_orders[i]
233
+ missing = [ld for ld in order if ld not in wv_dict]
234
+ if missing:
235
+ raise ValueError(f"Missing leads for sample {i}: {missing}")
236
+ items = [(ld, wv_dict[ld]) for ld in order]
237
+ else:
238
+ items = list(wv_dict.items())
239
+ for lead, seg in items:
240
+ seg = torch.as_tensor(seg)
241
+ if seg.ndim != 2 or seg.shape[1] != 256:
242
+ raise ValueError(f"Waveform for lead {lead} must be [T,256], got {seg.shape}")
243
+ seqs.append(seg)
244
+ leads_for_sample.append(str(lead))
245
+ sample_segments = torch.cat(seqs, dim=0)
246
+ if not torch.isfinite(sample_segments).all().item():
247
+ bad_lead = None
248
+ for lead_name, seg_tensor in zip(leads_for_sample, seqs):
249
+ if not torch.isfinite(seg_tensor).all().item():
250
+ bad_lead = lead_name
251
+ break
252
+ raise ECGNonFiniteInputError(sample_idx=i, lead=bad_lead)
253
+ all_seqs.append(sample_segments)
254
+ prefix_lengths.append(sample_segments.size(0))
255
+
256
+ x = torch.cat(all_seqs, dim=0)
257
+ x = x.to(device=device, dtype=torch.float32).unsqueeze(1)
258
+
259
+ enc_trainable = any(p.requires_grad for p in self.enc.parameters())
260
+ ctx = torch.enable_grad() if enc_trainable else torch.no_grad()
261
+ with ctx:
262
+ z = self.enc(x)
263
+ if self.inference:
264
+ z = z.to(dtype=torch.float32)
265
+ assert_tensor_dtype(z, expected=torch.float32, context="conv encoder output (batch)")
266
+ self.ensure_finite(z, "conv encoder output (batch)")
267
+ z = z.flatten(1)
268
+
269
+ proj_dtype = self._projector_param_dtype()
270
+ if z.dtype != proj_dtype:
271
+ z = z.to(dtype=proj_dtype)
272
+ y = self.llava_proj(z)
273
+ if self.inference:
274
+ y = y.to(dtype=torch.float32)
275
+ assert_tensor_dtype(y, expected=torch.float32, context="llava_proj output (batch)")
276
+ return y.to(dtype=self.dtype), prefix_lengths
277
+
278
+ # ---- Language-model forward --------------------------------------------------------------
279
+
280
+ def forward_language_model(
281
+ self,
282
+ inputs_embeds: Tensor,
283
+ attention_mask: Tensor,
284
+ labels: Optional[Tensor],
285
+ ):
286
+ """
287
+ Default HF-style forward call; subclasses can override for custom behavior.
288
+ """
289
+ embedder_fn = getattr(self.language_model, "get_input_embeddings", None)
290
+ if callable(embedder_fn):
291
+ embed_module = embedder_fn()
292
+ if not hasattr(embed_module, "weight"):
293
+ raise RuntimeError("Language model embeddings missing weight parameter; cannot infer device.")
294
+ target_device = embed_module.weight.device
295
+ else:
296
+ try:
297
+ first_param = next(self.language_model.parameters())
298
+ except StopIteration as exc:
299
+ raise RuntimeError("Language model exposes no parameters to infer device placement.") from exc
300
+ target_device = first_param.device
301
+
302
+ if inputs_embeds.device != target_device:
303
+ inputs_embeds = inputs_embeds.to(target_device)
304
+ if attention_mask.device != target_device:
305
+ attention_mask = attention_mask.to(target_device)
306
+ if labels is not None and labels.device != target_device:
307
+ labels = labels.to(target_device)
308
+
309
+ return self.language_model(
310
+ inputs_embeds=inputs_embeds,
311
+ attention_mask=attention_mask,
312
+ labels=labels,
313
+ use_cache=False,
314
+ )
315
+
316
+ # ---- Label helpers ----------------------------------------------------------------------
317
+
318
+ @staticmethod
319
+ def build_labels_from_lengths(
320
+ *,
321
+ ids_rest: List[int],
322
+ model_spans_in_rest: List[Tuple[int, int]],
323
+ total_len: int,
324
+ offset_rest: int,
325
+ ) -> Tensor:
326
+ """Build label tensor using explicit offsets to satisfy any packing schema.
327
+
328
+ Args:
329
+ ids_rest: Token ids for the supervised span of the prompt.
330
+ model_spans_in_rest: List of (start, end) spans (relative to ids_rest) to supervise.
331
+ total_len: Total sequence length of the assembled prompt.
332
+ offset_rest: Absolute position in the sequence where `ids_rest` begins.
333
+
334
+ Returns:
335
+ Tensor of shape (total_len,) with supervised ids placed according to spans; all other
336
+ positions are filled with -100.
337
+
338
+ Raises:
339
+ ValueError: if `offset_rest` is invalid or spans fall outside the provided bounds.
340
+ """
341
+ labels = torch.full((total_len,), fill_value=-100, dtype=torch.long)
342
+
343
+ offset_rest = int(offset_rest)
344
+ if offset_rest < 0 or offset_rest > total_len:
345
+ raise ValueError(f"Invalid rest offset {offset_rest} for sequence length {total_len}.")
346
+ rest_len = len(ids_rest)
347
+ if offset_rest + rest_len > total_len:
348
+ raise ValueError(
349
+ f"Rest tokens (len={rest_len}) exceed total_len {total_len} with offset {offset_rest}."
350
+ )
351
+ # Defensive: no negative rest length (should not happen, but keep invariant tight)
352
+ assert_rest_length_nonnegative(rest_length=rest_len)
353
+
354
+ for (s, e) in model_spans_in_rest:
355
+ if not (0 <= s <= e <= rest_len):
356
+ raise ValueError(
357
+ f"Model span {(s, e)} is out of bounds for ids_rest length {rest_len}."
358
+ )
359
+ s_abs = offset_rest + s
360
+ e_abs = offset_rest + e
361
+ if not (0 <= s_abs <= e_abs <= total_len):
362
+ raise ValueError(
363
+ f"Model span {(s, e)} with rest offset {offset_rest} is out of bounds for length {total_len}."
364
+ )
365
+ labels[s_abs:e_abs] = torch.tensor(ids_rest[s:e], dtype=torch.long)
366
+ return labels
367
+
368
+ # ---- Trainable summaries -----------------------------------------------------------------
369
+
370
+ def summarize_trainables(self) -> Mapping[str, int]:
371
+ def _count(params: Iterable[nn.Parameter]) -> int:
372
+ return sum(int(p.numel()) for p in params if p.requires_grad)
373
+
374
+ llava_train = _count(self.llava_proj.parameters())
375
+ ecg_special_train = _count(self.ecg_special_embed.parameters())
376
+ enc_train = _count(self.enc.parameters())
377
+ lm_train = _count(self.language_model.parameters())
378
+
379
+ total = llava_train + ecg_special_train + enc_train + lm_train
380
+ return {
381
+ "llava_proj_trainable": llava_train,
382
+ "ecg_special_trainable": ecg_special_train,
383
+ "enc_trainable": enc_train,
384
+ "lm_trainable": lm_train,
385
+ "total_trainable": total,
386
+ }
387
+
388
+ # ---- Utility ---------------------------------------------------------------------------
389
+
390
+ @staticmethod
391
+ def ensure_finite(tensor: Tensor, context: str) -> None:
392
+ if not torch.isfinite(tensor).all():
393
+ xin_nan = torch.isnan(tensor).any().item()
394
+ raise RuntimeError(f"Encountered non-finite values in {context} (input_has_nan={bool(xin_nan)}).")
camel_inference/src/camel/ecg_text_packing.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from collections import OrderedDict
3
+ from dataclasses import dataclass
4
+ from typing import List, Dict, Tuple, Optional, Any
5
+ from transformers import PreTrainedTokenizer
6
+
7
+ from camel.assertions import (
8
+ assert_ecg_catalog_valid,
9
+ assert_normalized_role_canonical,
10
+ assert_turn_parts_structure_valid,
11
+ assert_turn_content_ends_with_eot,
12
+ assert_leads_canonical_and_ordered,
13
+ assert_waveform_shapes_valid,
14
+ )
15
+ from camel.prompt_renderers import turn_wrappers
16
+
17
+ # NOTE: Local BOS/EOS assertions are implemented at the bottom of this file.
18
+ @dataclass(frozen=True)
19
+ class PromptTokens:
20
+ start_of_turn: str
21
+ end_of_turn: str
22
+ user_role: str
23
+ model_role: str
24
+ require_bos: bool = True
25
+ require_eos: bool = True
26
+ allow_multiple_eos: bool = False
27
+
28
+
29
+ @dataclass(frozen=True)
30
+ class ConversationRules:
31
+ format_id: str
32
+ user_role_aliases: Tuple[str, ...] = ("human", "user")
33
+ model_role_aliases: Tuple[str, ...] = ("gpt", "assistant")
34
+ strip_image_from_roles: Tuple[str, ...] = ("human",)
35
+ merge_system_with_first_user: bool = True
36
+
37
+
38
+ @dataclass(frozen=True)
39
+ class ECGTokenSchema:
40
+ global_start: str
41
+ global_end: str
42
+ lead_start_template: str
43
+ lead_end_template: str
44
+ canonical_leads: Tuple[str, ...]
45
+
46
+
47
+ @dataclass(frozen=True)
48
+ class PackingSchema:
49
+ prompt: PromptTokens
50
+ conversation: ConversationRules
51
+ ecg: ECGTokenSchema
52
+
53
+
54
+ @dataclass(frozen=True)
55
+ class ECGSpecialTokenCatalog:
56
+ tokens: Tuple[str, ...]
57
+ lead_to_indices: Dict[str, Dict[str, int]]
58
+ lead_to_tokens: Dict[str, Dict[str, str]]
59
+ token_to_index: Dict[str, int]
60
+
61
+
62
+ def _render_lead_template(template: str, lead: str) -> str:
63
+ return template.format(
64
+ lead=lead,
65
+ lead_lower=lead.lower(),
66
+ lead_upper=lead.upper(),
67
+ )
68
+
69
+ _ECG_TOKEN_CACHE: Dict[PackingSchema, ECGSpecialTokenCatalog] = {}
70
+
71
+ def get_ecg_special_token_catalog(schema: PackingSchema) -> ECGSpecialTokenCatalog:
72
+ cached = _ECG_TOKEN_CACHE.get(schema)
73
+ if cached is not None:
74
+ return cached
75
+
76
+ tokens: List[str] = []
77
+ lead_to_indices: Dict[str, Dict[str, int]] = {}
78
+ lead_to_tokens: Dict[str, Dict[str, str]] = {}
79
+
80
+ tokens.append(schema.ecg.global_start)
81
+ tokens.append(schema.ecg.global_end)
82
+
83
+ for lead in schema.ecg.canonical_leads:
84
+ start_token = _render_lead_template(schema.ecg.lead_start_template, lead)
85
+ end_token = _render_lead_template(schema.ecg.lead_end_template, lead)
86
+ start_idx = len(tokens)
87
+ tokens.append(start_token)
88
+ end_idx = len(tokens)
89
+ tokens.append(end_token)
90
+ lead_to_indices[lead] = {"start": start_idx, "end": end_idx}
91
+ lead_to_tokens[lead] = {"start": start_token, "end": end_token}
92
+
93
+ catalog = ECGSpecialTokenCatalog(
94
+ tokens=tuple(tokens),
95
+ lead_to_indices=lead_to_indices,
96
+ lead_to_tokens=lead_to_tokens,
97
+ token_to_index={tok: idx for idx, tok in enumerate(tokens)},
98
+ )
99
+
100
+ # Validate catalog consistency
101
+ assert_ecg_catalog_valid(catalog, schema)
102
+
103
+ _ECG_TOKEN_CACHE[schema] = catalog
104
+ return catalog
105
+
106
+
107
+ # ---- Conversation normalization + validation ----------------------------------------------------
108
+
109
+ def canonical_leads(schema: PackingSchema) -> List[str]:
110
+ return list(schema.ecg.canonical_leads)
111
+
112
+
113
+ def _strip_image_tag(text: str) -> str:
114
+ """
115
+ Remove any <image> placeholder without gluing surrounding words.
116
+ Replace the token and surrounding whitespace with a single space and
117
+ normalize repeated spaces.
118
+ """
119
+ cleaned = re.sub(r"\s*<image>\s*", " ", text)
120
+ cleaned = re.sub(r"[ \t]{2,}", " ", cleaned)
121
+ return cleaned.strip()
122
+
123
+
124
+ def _normalize_role(role_value: Any, schema: PackingSchema) -> str:
125
+ role = (role_value or "").strip()
126
+ if not role:
127
+ raise ValueError("Conversation turn is missing a role identifier.")
128
+ if role.lower() == "system":
129
+ return "system"
130
+ if role.lower() == "developer":
131
+ return "developer"
132
+ needles = [schema.prompt.user_role] + list(schema.conversation.user_role_aliases)
133
+ if role.lower() in (needle.lower() for needle in needles):
134
+ out = schema.prompt.user_role
135
+ assert_normalized_role_canonical(out, schema)
136
+ return out
137
+ needles = [schema.prompt.model_role] + list(schema.conversation.model_role_aliases)
138
+ if role.lower() in (needle.lower() for needle in needles):
139
+ out = schema.prompt.model_role
140
+ assert_normalized_role_canonical(out, schema)
141
+ return out
142
+ raise ValueError(f"Unknown conversation role '{role_value}' for schema '{schema.conversation.format_id}'.")
143
+
144
+ def _normalize_conversation(
145
+ convo: List[Dict[str, Any]],
146
+ schema: PackingSchema,
147
+ system_text: Optional[str],
148
+ developer_text: Optional[str],
149
+ ) -> List[Dict[str, Any]]:
150
+ if not convo:
151
+ raise ValueError("Conversation must contain at least one turn.")
152
+ has_system_turn = False
153
+ has_developer_turn = False
154
+ for turn in convo:
155
+ if not isinstance(turn, dict):
156
+ continue
157
+ role_val = turn.get("from")
158
+ if role_val is None and "role" in turn:
159
+ role_val = turn.get("role")
160
+ role_lower = str(role_val or "").strip().lower()
161
+ if role_lower == "system":
162
+ has_system_turn = True
163
+ elif role_lower == "developer":
164
+ has_developer_turn = True
165
+ normalized: List[Dict[str, Any]] = []
166
+ if system_text and system_text.strip() and not has_system_turn:
167
+ sys = system_text.strip()
168
+ normalized.append({
169
+ "role": "system",
170
+ "content": [{"type": "text", "text": sys}],
171
+ })
172
+ if developer_text and developer_text.strip() and not has_developer_turn:
173
+ dev = developer_text.strip()
174
+ normalized.append({
175
+ "role": "developer",
176
+ "content": [{"type": "text", "text": dev}],
177
+ })
178
+ for idx, turn in enumerate(convo):
179
+ role_val = turn.get("from")
180
+ if role_val is None:
181
+ role_val = turn.get("role")
182
+ role = _normalize_role(role_val, schema)
183
+ content = turn.get("content")
184
+ if not isinstance(content, list):
185
+ raise ValueError(f"Turn {idx} content must be a list of items.")
186
+ normalized.append({"role": role, "content": content})
187
+
188
+ if schema.conversation.merge_system_with_first_user:
189
+ system_items: List[Dict[str, Any]] = []
190
+ out: List[Dict[str, Any]] = []
191
+ first_user_idx: Optional[int] = None
192
+ for turn in normalized:
193
+ if turn["role"] == "system":
194
+ system_items.extend(list(turn["content"]))
195
+ continue
196
+ if turn["role"] == "developer":
197
+ raise ValueError("Developer turn present but merge_system_with_first_user is true.")
198
+ if first_user_idx is None and turn["role"] == schema.prompt.user_role:
199
+ first_user_idx = len(out)
200
+ out.append(turn)
201
+
202
+ if system_items:
203
+ if first_user_idx is None:
204
+ raise ValueError("System turn present but no user turn to merge into.")
205
+ user_turn = out[first_user_idx]
206
+ user_turn["content"] = list(system_items) + list(user_turn["content"])
207
+
208
+ if not out:
209
+ raise ValueError("Conversation must contain at least one non-system turn.")
210
+ if out[0]["role"] != schema.prompt.user_role:
211
+ raise ValueError("Conversation must start with a user/human turn.")
212
+ return out
213
+
214
+ out = list(normalized)
215
+ if not out:
216
+ raise ValueError("Conversation must contain at least one turn.")
217
+ system_turns = [t for t in out if t["role"] == "system"]
218
+ developer_turns = [t for t in out if t["role"] == "developer"]
219
+ non_preamble_turns = [t for t in out if t["role"] not in ("system", "developer")]
220
+ if not non_preamble_turns:
221
+ raise ValueError("Conversation must contain at least one non-system/developer turn.")
222
+ if non_preamble_turns[0]["role"] != schema.prompt.user_role:
223
+ raise ValueError("Conversation must start with a user/human turn.")
224
+ return system_turns + developer_turns + non_preamble_turns
225
+
226
+
227
+ def _maybe_strip_content(text: str, canonical_role: str, schema: PackingSchema) -> str:
228
+ target_roles: set = set()
229
+ for r in schema.conversation.strip_image_from_roles:
230
+ try:
231
+ target_roles.add(_normalize_role(r, schema).lower())
232
+ except ValueError:
233
+ target_roles.add(str(r).lower())
234
+ if canonical_role.lower() in target_roles:
235
+ return _strip_image_tag(text)
236
+ return text
237
+
238
+ # ---- Tokenize & mark assistant spans (exclude control tokens from loss) --------------------------
239
+
240
+ # ---- Build structured turn parts ---------------------------------------------------------------
241
+
242
+ def build_structured_turn_parts(
243
+ *,
244
+ content: List[Dict[str, Any]],
245
+ canonical_role: str,
246
+ schema: PackingSchema,
247
+ ecg_blocks: List[Dict[str, Any]],
248
+ sampling_rate: Optional[float],
249
+ turn_suffix: Optional[str] = None,
250
+ ) -> Tuple[str, List[Dict[str, Any]]]:
251
+ prompt_tokens = schema.prompt
252
+ catalog = get_ecg_special_token_catalog(schema)
253
+ parts: List[Dict[str, Any]] = []
254
+ text_segments: List[str] = []
255
+
256
+ def _append_text(txt: str) -> None:
257
+ if not txt:
258
+ return
259
+ if parts and parts[-1].get("kind") == "text":
260
+ parts[-1]["text"] += txt
261
+ else:
262
+ parts.append({"kind": "text", "text": txt})
263
+ text_segments.append(txt)
264
+
265
+ for item in content:
266
+ if not isinstance(item, dict):
267
+ raise ValueError("Conversation content items must be dicts.")
268
+ item_type = item.get("type")
269
+ if item_type == "text":
270
+ text_val = item.get("text", "")
271
+ if not isinstance(text_val, str):
272
+ raise ValueError("Text content item must have a string 'text' field.")
273
+ cleaned = _maybe_strip_content(text_val, canonical_role, schema)
274
+ _append_text(cleaned)
275
+ continue
276
+ if item_type == "ecg":
277
+ waveform_segments = item.get("waveform_segments")
278
+ if not isinstance(waveform_segments, dict):
279
+ raise ValueError("ECG content item missing waveform_segments mapping.")
280
+ item_rate = item.get("sampling_rate")
281
+ if item_rate is not None and sampling_rate is not None and float(item_rate) != float(sampling_rate):
282
+ raise ValueError(
283
+ f"ECG item sampling_rate {item_rate} does not match sample sampling_rate {sampling_rate}."
284
+ )
285
+ lead_names = [str(ld) for ld in waveform_segments.keys()]
286
+ if not lead_names:
287
+ raise ValueError("ECG content item has no leads.")
288
+ segments_per_lead = [int(waveform_segments[ld].shape[0]) for ld in lead_names]
289
+ assert_leads_canonical_and_ordered(lead_names, schema.ecg.canonical_leads)
290
+ assert_waveform_shapes_valid(lead_names, segments_per_lead, waveform_segments)
291
+ block_index = len(ecg_blocks)
292
+ ecg_blocks.append({
293
+ "lead_names": lead_names,
294
+ "segments_per_lead": segments_per_lead,
295
+ "waveform_segments": OrderedDict((ld, waveform_segments[ld]) for ld in lead_names),
296
+ })
297
+
298
+ parts.append({
299
+ "kind": "special",
300
+ "token": schema.ecg.global_start,
301
+ "token_index": catalog.token_to_index[schema.ecg.global_start],
302
+ "block_index": block_index,
303
+ })
304
+ text_segments.append(schema.ecg.global_start)
305
+
306
+ for lead, nseg in zip(lead_names, segments_per_lead):
307
+ lead_tokens = catalog.lead_to_indices[lead]
308
+ parts.append({
309
+ "kind": "special",
310
+ "token": catalog.lead_to_tokens[lead]["start"],
311
+ "token_index": lead_tokens["start"],
312
+ "lead": lead,
313
+ "block_index": block_index,
314
+ })
315
+ text_segments.append(catalog.lead_to_tokens[lead]["start"])
316
+ for sec in range(1, int(nseg) + 1):
317
+ parts.append({
318
+ "kind": "ecg",
319
+ "lead": lead,
320
+ "sec": sec,
321
+ "block_index": block_index,
322
+ })
323
+ parts.append({
324
+ "kind": "special",
325
+ "token": catalog.lead_to_tokens[lead]["end"],
326
+ "token_index": lead_tokens["end"],
327
+ "lead": lead,
328
+ "block_index": block_index,
329
+ })
330
+ text_segments.append(catalog.lead_to_tokens[lead]["end"])
331
+
332
+ parts.append({
333
+ "kind": "special",
334
+ "token": schema.ecg.global_end,
335
+ "token_index": catalog.token_to_index[schema.ecg.global_end],
336
+ "block_index": block_index,
337
+ })
338
+ text_segments.append(schema.ecg.global_end)
339
+ continue
340
+ raise ValueError(f"Unknown content item type '{item_type}'.")
341
+
342
+ turn_content = "".join(text_segments)
343
+ if turn_suffix is None:
344
+ _, suffix = turn_wrappers(schema, canonical_role)
345
+ else:
346
+ suffix = turn_suffix
347
+ turn_text_block = turn_content + suffix
348
+
349
+ assert_turn_parts_structure_valid(parts, ecg_blocks, schema, catalog)
350
+ assert_turn_content_ends_with_eot(turn_text_block, suffix)
351
+
352
+ return turn_text_block, parts
353
+
354
+
355
+ def build_text_only_turn_parts(
356
+ *,
357
+ content: List[Dict[str, Any]],
358
+ canonical_role: str,
359
+ schema: PackingSchema,
360
+ turn_suffix: Optional[str] = None,
361
+ ) -> Tuple[str, List[Dict[str, Any]]]:
362
+ prompt_tokens = schema.prompt
363
+ parts: List[Dict[str, Any]] = []
364
+ text_segments: List[str] = []
365
+ needs_channel_header = (
366
+ schema.conversation.format_id == "harmony_chat_v1"
367
+ and canonical_role == schema.prompt.model_role
368
+ )
369
+ # Track raw content for diagnostics
370
+ raw_content_debug: List[Dict[str, Any]] = []
371
+
372
+ def _append_text(txt: str) -> None:
373
+ if not txt:
374
+ return
375
+ if parts and parts[-1].get("kind") == "text":
376
+ parts[-1]["text"] += txt
377
+ else:
378
+ parts.append({"kind": "text", "text": txt})
379
+ text_segments.append(txt)
380
+
381
+ for item in content:
382
+ if not isinstance(item, dict):
383
+ raise ValueError("Conversation content items must be dicts.")
384
+ item_type = item.get("type")
385
+ if item_type != "text":
386
+ raise ValueError("Model turns cannot contain ECG content items.")
387
+ text_val = item.get("text", "")
388
+ if not isinstance(text_val, str):
389
+ raise ValueError("Text content item must have a string 'text' field.")
390
+ if needs_channel_header and text_val.lstrip().startswith("<|channel|>"):
391
+ raise ValueError("Assistant content must not include harmony channel headers.")
392
+ cleaned = _maybe_strip_content(text_val, canonical_role, schema)
393
+ raw_content_debug.append({
394
+ "raw_text": repr(text_val[:200]) + ("..." if len(text_val) > 200 else ""),
395
+ "raw_len": len(text_val),
396
+ "cleaned_text": repr(cleaned[:200]) + ("..." if len(cleaned) > 200 else ""),
397
+ "cleaned_len": len(cleaned),
398
+ })
399
+ _append_text(cleaned)
400
+
401
+ turn_content = "".join(text_segments)
402
+ if not turn_content:
403
+ # Build detailed diagnostic message
404
+ diag_lines = [
405
+ "Model turn is empty after preprocessing.",
406
+ f" role: {canonical_role}",
407
+ f" num_content_items: {len(content)}",
408
+ ]
409
+ for i, dbg in enumerate(raw_content_debug):
410
+ diag_lines.append(f" item[{i}]: raw_len={dbg['raw_len']}, cleaned_len={dbg['cleaned_len']}")
411
+ diag_lines.append(f" raw: {dbg['raw_text']}")
412
+ diag_lines.append(f" cleaned: {dbg['cleaned_text']}")
413
+ raise ValueError("\n".join(diag_lines))
414
+ if needs_channel_header:
415
+ channel_header = "<|channel|>final<|message|>"
416
+ parts.insert(0, {"kind": "text", "text": channel_header})
417
+ text_segments.insert(0, channel_header)
418
+ turn_content = "".join(text_segments)
419
+ if turn_suffix is None:
420
+ _, suffix = turn_wrappers(schema, canonical_role)
421
+ else:
422
+ suffix = turn_suffix
423
+ turn_text_block = turn_content + suffix
424
+ assert_turn_content_ends_with_eot(turn_text_block, suffix)
425
+ return turn_text_block, parts
426
+
427
+
428
+ def annotate_turn_parts_with_ids(
429
+ turn_parts: List[List[Dict[str, Any]]],
430
+ tokenizer: PreTrainedTokenizer,
431
+ ) -> List[List[Dict[str, Any]]]:
432
+ """Attach token ids to text parts so the trainer can skip per-step tokenization."""
433
+ for parts in turn_parts:
434
+ for part in parts:
435
+ if part.get("kind") == "text":
436
+ txt = part.get("text", "")
437
+ part["ids"] = tokenizer.encode(txt, add_special_tokens=False) if txt else []
438
+ return turn_parts
439
+
440
+ def assert_single_bos_eos(
441
+ text_ids: List[int],
442
+ tok: PreTrainedTokenizer,
443
+ *,
444
+ require_bos_at_start: bool,
445
+ require_single_terminal_eos: bool,
446
+ allow_multiple_eos: bool = False,
447
+ ) -> None:
448
+ """Validate BOS/EOS placement under various schema-specific policies."""
449
+ if require_bos_at_start:
450
+ if tok.bos_token_id is None:
451
+ raise AssertionError("BOS token required but tokenizer has none")
452
+ bos_pos = [i for i, t in enumerate(text_ids) if t == tok.bos_token_id]
453
+ if len(bos_pos) != 1 or bos_pos[0] != 0:
454
+ raise AssertionError(f"BOS placement invalid: positions={bos_pos}")
455
+ else:
456
+ if tok.bos_token_id is not None:
457
+ bos_pos = [i for i, t in enumerate(text_ids) if t == tok.bos_token_id]
458
+ if len(bos_pos) > 1 or (bos_pos and bos_pos[0] != 0):
459
+ raise AssertionError(f"Unexpected BOS placement: positions={bos_pos}")
460
+
461
+ # EOS checks
462
+ if tok.eos_token_id is not None:
463
+ eos_pos = [i for i, t in enumerate(text_ids) if t == tok.eos_token_id]
464
+ if allow_multiple_eos:
465
+ if require_single_terminal_eos:
466
+ # Allow multiple EOS (e.g., ChatML-style per turn), but require the last EOS at sequence end.
467
+ if len(eos_pos) == 0 or eos_pos[-1] != (len(text_ids) - 1):
468
+ raise AssertionError(f"EOS bad: positions={eos_pos}")
469
+ else:
470
+ # Allow any count anywhere; no terminal EOS required (matches Qwen3 training practice).
471
+ pass
472
+ else:
473
+ if require_single_terminal_eos:
474
+ # Require exactly one EOS, and it must be terminal.
475
+ if len(eos_pos) != 1 or eos_pos[0] != (len(text_ids) - 1):
476
+ raise AssertionError(f"EOS bad: positions={eos_pos}")
477
+ else:
478
+ # Allow 0 or 1, but if present it must be terminal.
479
+ if len(eos_pos) > 1 or (eos_pos and eos_pos[0] != (len(text_ids) - 1)):
480
+ raise AssertionError(f"EOS bad: positions={eos_pos}")
481
+
482
+
483
+ def assert_struct_bos_eos(
484
+ token_struct: Dict[str, List[int]],
485
+ tok: PreTrainedTokenizer,
486
+ *,
487
+ require_bos_at_start: bool,
488
+ require_single_terminal_eos: bool,
489
+ allow_multiple_eos: bool = False,
490
+ ) -> None:
491
+ """Wrapper that validates concatenated ids per schema policy."""
492
+ text_ids = token_struct["text_ids"]
493
+ assert_single_bos_eos(
494
+ text_ids,
495
+ tok,
496
+ require_bos_at_start=require_bos_at_start,
497
+ require_single_terminal_eos=require_single_terminal_eos,
498
+ allow_multiple_eos=allow_multiple_eos,
499
+ )
camel_inference/src/camel/inference.py ADDED
@@ -0,0 +1,846 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # inference.py — ECGText inference (model loading + generation helpers)
3
+ import os
4
+ import re
5
+ from dataclasses import dataclass
6
+ from typing import Dict, List, Tuple, Optional, Any
7
+ from collections import OrderedDict
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ from transformers import AutoModelForCausalLM
12
+ from peft import LoraConfig
13
+
14
+ # Local imports
15
+ from camel.model_introspect import resolve_hidden_size as _resolve_hidden_size
16
+ from camel.model_registry import load_registry
17
+ from camel.training_setup import initialize_tokenizer, build_packing_schema, register_ecg_special_tokens
18
+ from camel.model_init import build_wrapper, attach_lora, build_conv_encoder
19
+ from camel.ecg_text_packing import (
20
+ _normalize_conversation,
21
+ annotate_turn_parts_with_ids,
22
+ build_structured_turn_parts,
23
+ build_text_only_turn_parts,
24
+ get_ecg_special_token_catalog,
25
+ )
26
+ from camel.prompt_renderers import render_prompt_and_spans, turn_wrappers, assistant_generation_prefix
27
+ from camel.ecg_attention_masks import (
28
+ ECGBlockLayout,
29
+ ECGSequenceLayout,
30
+ MaskBuildResult,
31
+ ECGMaskStrategy,
32
+ get_mask_strategy,
33
+ )
34
+ from camel.assertions import (
35
+ assert_ecg_blocks_consistent,
36
+ assert_ecg_part_bounds,
37
+ assert_layout_specials_complete,
38
+ assert_prefix_matches_segments,
39
+ assert_prefix_split_complete,
40
+ )
41
+ from camel.checkpoint_utils import (
42
+ load_llava_and_lora,
43
+ update_wrapper_language_model,
44
+ extract_lora_config_from_checkpoints,
45
+ peek_projector_name,
46
+ )
47
+
48
+ # ------------------------------
49
+ # Device & conv builder
50
+ # ------------------------------
51
+
52
+ def _device(device=None) -> torch.device:
53
+ if device:
54
+ return device
55
+ if torch.cuda.is_available():
56
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
57
+ return torch.device(f"cuda:{local_rank}")
58
+ return torch.device("cpu")
59
+
60
+
61
+ _HARMONY_CHANNEL_RE = re.compile(r"<\|channel\|>(.*?)<\|message\|>", re.DOTALL)
62
+ _HARMONY_DELIM_RE = re.compile(r"<\|end\|>|<\|return\|>|<\|call\|>|<\|start\|>")
63
+
64
+
65
+ def _extract_harmony_messages(text: str) -> List[Tuple[str, str]]:
66
+ matches = list(_HARMONY_CHANNEL_RE.finditer(text))
67
+ if not matches:
68
+ raise ValueError("No harmony channel headers found in model output.")
69
+ out: List[Tuple[str, str]] = []
70
+ for match in matches:
71
+ channel_raw = match.group(1).strip()
72
+ channel = channel_raw.split()[0] if channel_raw else ""
73
+ if not channel:
74
+ raise ValueError("Harmony channel header is empty.")
75
+ start = match.end()
76
+ end_match = _HARMONY_DELIM_RE.search(text, start)
77
+ end = end_match.start() if end_match else len(text)
78
+ out.append((channel, text[start:end]))
79
+ return out
80
+
81
+
82
+ def _checkpoint_has_conv(ckpt_path: Optional[str]) -> bool:
83
+ if not ckpt_path:
84
+ return False
85
+ payload = torch.load(ckpt_path, map_location="cpu")
86
+ if not isinstance(payload, dict):
87
+ raise RuntimeError(f"Checkpoint {ckpt_path} must be a dict to inspect conv metadata.")
88
+ return isinstance(payload.get("conv"), dict)
89
+
90
+
91
+ # ------------------------------
92
+ # Prompt building & stopping
93
+ # ------------------------------
94
+
95
+ @dataclass
96
+ class PromptContext:
97
+ """Container describing the prepared prompt state for autoregressive generation."""
98
+
99
+ inputs_embeds: torch.Tensor
100
+ layout: ECGSequenceLayout
101
+ prompt_preview: str
102
+ stop_ids: List[int]
103
+ input_embedder: nn.Embedding
104
+ mask_strategy: ECGMaskStrategy
105
+ mask_result: MaskBuildResult
106
+
107
+
108
+ def _sanitize_segments(tensor: torch.Tensor) -> torch.Tensor:
109
+ """Detach → float32 → replace NaN/Inf so downstream encoders stay numerically stable."""
110
+ out = tensor.detach().cpu().to(dtype=torch.float32)
111
+ return torch.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0)
112
+
113
+
114
+ def _sample_next_token(
115
+ logits: torch.Tensor,
116
+ *,
117
+ temperature: float,
118
+ top_k: Optional[int],
119
+ top_p: float,
120
+ min_p: float,
121
+ ) -> torch.Tensor:
122
+ """
123
+ Draw the next token id given final-step logits and sampling parameters.
124
+
125
+ Uses greedy decoding when temperature <= 0; otherwise applies temperature
126
+ scaling, optional nucleus sampling, and multinomial sampling.
127
+ """
128
+ if logits.ndim != 1:
129
+ raise ValueError(f"Expected 1D logits, got shape {tuple(logits.shape)}")
130
+
131
+ if temperature <= 0.0:
132
+ return torch.argmax(logits, dim=-1)
133
+
134
+ scaled = logits / max(temperature, 1e-5)
135
+ probs = torch.softmax(scaled, dim=-1)
136
+
137
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True)
138
+ keep_mask = torch.ones_like(sorted_probs, dtype=torch.bool)
139
+
140
+ if top_k is not None and top_k > 0:
141
+ top_k = min(int(top_k), sorted_probs.numel())
142
+ top_k_mask = torch.zeros_like(sorted_probs, dtype=torch.bool)
143
+ top_k_mask[:top_k] = True
144
+ keep_mask &= top_k_mask
145
+
146
+ if 0.0 < top_p < 1.0:
147
+ cumulative = torch.cumsum(sorted_probs, dim=-1)
148
+ cutoff_mask = (cumulative - sorted_probs) < top_p
149
+ cutoff_mask[0] = True # always keep the highest-prob token
150
+ keep_mask &= cutoff_mask
151
+
152
+ if min_p is not None and min_p > 0.0:
153
+ keep_mask &= sorted_probs >= float(min_p)
154
+
155
+ filtered_probs = sorted_probs[keep_mask]
156
+ filtered_indices = sorted_indices[keep_mask]
157
+ if filtered_probs.numel() == 0:
158
+ filtered_probs = sorted_probs[:1]
159
+ filtered_indices = sorted_indices[:1]
160
+
161
+ prob_sum = filtered_probs.sum()
162
+ if not torch.isfinite(prob_sum) or prob_sum <= 0:
163
+ return sorted_indices[0]
164
+ normalized = filtered_probs / prob_sum
165
+ next_idx = torch.multinomial(normalized, num_samples=1, replacement=False)
166
+ return filtered_indices[next_idx].squeeze(0)
167
+
168
+
169
+ class KardiaLM:
170
+ """High-level chat interface around an ECG language model."""
171
+
172
+ def __init__(
173
+ self,
174
+ *,
175
+ model_registry_path: Optional[str],
176
+ model_config_name: str,
177
+ hf_model_id_override: Optional[str],
178
+ adapter_ckpt: str,
179
+ conv_ckpt: Optional[str] = None,
180
+ no_lora: bool = False,
181
+ use_dora: bool = False,
182
+ default_max_new_tokens: int = 1000,
183
+ default_temperature: float = 1.0,
184
+ default_top_k: Optional[int] = 64,
185
+ default_top_p: float = 0.95,
186
+ default_min_p: float = 0.0,
187
+ mask_strategy: Optional[str] = None,
188
+ device: Optional[torch.device] = None,
189
+ ) -> None:
190
+ registry = load_registry(registry_path=model_registry_path)
191
+ model_cfg = registry.get(model_config_name)
192
+
193
+ self.model_cfg = model_cfg
194
+ self.hf_model_id = hf_model_id_override or model_cfg.hf_id
195
+ self.packing_schema = build_packing_schema(self.hf_model_id)
196
+ self.tokenizer_cfg = model_cfg.tokenizer_config()
197
+ self.arch_cfg = model_cfg.architecture_config()
198
+ self.system_text = None
199
+ self.developer_text = None
200
+ if self.packing_schema.conversation.format_id == "harmony_chat_v1":
201
+ self.system_text = model_cfg.required_prompt_text("system_prompt")
202
+ self.developer_text = model_cfg.required_prompt_text("developer_prompt")
203
+ if not self.system_text.strip():
204
+ raise RuntimeError("System prompt text for harmony format must be non-empty.")
205
+ if not self.developer_text.strip():
206
+ raise RuntimeError("Developer prompt text for harmony format must be non-empty.")
207
+
208
+ self.device = _device(device)
209
+ self.dtype = torch.bfloat16
210
+ self.mask_strategy: ECGMaskStrategy = get_mask_strategy(mask_strategy)
211
+ self.expect_dora = bool(use_dora)
212
+
213
+ tok = initialize_tokenizer(
214
+ self.hf_model_id,
215
+ trust_remote_code=True,
216
+ use_fast=self.tokenizer_cfg.use_fast,
217
+ add_prefix_space=self.tokenizer_cfg.add_prefix_space,
218
+ )
219
+ self.tokenizer = tok
220
+
221
+ catalog = get_ecg_special_token_catalog(self.packing_schema)
222
+ self.ecg_special_token_id_map = register_ecg_special_tokens(tok, catalog)
223
+
224
+ pad_strategy = self.tokenizer_cfg.pad_token_strategy.lower()
225
+ if pad_strategy == "eos":
226
+ if tok.eos_token is None:
227
+ raise RuntimeError(
228
+ f"Tokenizer for model '{model_cfg.name}' lacks an EOS token required for pad_token_strategy='eos'."
229
+ )
230
+ tok.pad_token = tok.eos_token
231
+ elif pad_strategy not in ("existing", "keep"):
232
+ raise RuntimeError(f"Unsupported pad_token_strategy '{self.tokenizer_cfg.pad_token_strategy}'.")
233
+
234
+ if self.tokenizer_cfg.require_bos and tok.bos_token is None:
235
+ raise RuntimeError(f"Tokenizer for model '{model_cfg.name}' is missing a BOS token.")
236
+ if self.tokenizer_cfg.require_eos and tok.eos_token is None:
237
+ raise RuntimeError(f"Tokenizer for model '{model_cfg.name}' is missing an EOS token.")
238
+
239
+ attn_impl = self.arch_cfg.attn_implementation or "flash_attention_2"
240
+ try:
241
+ model = AutoModelForCausalLM.from_pretrained(
242
+ self.hf_model_id,
243
+ torch_dtype=self.dtype,
244
+ trust_remote_code=True,
245
+ attn_implementation=attn_impl,
246
+ device_map=None,
247
+ ).to(self.device)
248
+ except Exception:
249
+ model = AutoModelForCausalLM.from_pretrained(
250
+ self.hf_model_id,
251
+ torch_dtype=self.dtype,
252
+ trust_remote_code=True,
253
+ attn_implementation="eager",
254
+ device_map=None,
255
+ ).to(self.device)
256
+ if model.get_input_embeddings().weight.shape[0] != len(tok):
257
+ model.resize_token_embeddings(len(tok))
258
+ for p in model.parameters():
259
+ p.requires_grad = False
260
+ model.eval()
261
+ if hasattr(model, "gradient_checkpointing_disable"):
262
+ model.gradient_checkpointing_disable()
263
+ if hasattr(model.config, "use_cache"):
264
+ model.config.use_cache = True
265
+ self.model = model
266
+
267
+ adapter_ckpt_path = os.path.expanduser(adapter_ckpt)
268
+ adapter_has_conv = _checkpoint_has_conv(adapter_ckpt_path)
269
+ if not adapter_has_conv and not conv_ckpt:
270
+ raise RuntimeError(
271
+ "Adapter checkpoint lacks conv weights; supply --conv_ckpt to match training."
272
+ )
273
+ lora_cfg_dict = extract_lora_config_from_checkpoints(adapter_ckpt_path, None)
274
+ active_lora_cfg: Optional[LoraConfig] = None
275
+ if self.expect_dora and no_lora:
276
+ raise RuntimeError("--use-dora cannot be combined with --no-lora since no adapters would be loaded.")
277
+ if lora_cfg_dict and not no_lora:
278
+ cfg_use_dora = bool(lora_cfg_dict.get("use_dora", False))
279
+ if cfg_use_dora and not self.expect_dora:
280
+ raise RuntimeError(
281
+ "Checkpoint adapters were trained with DoRA; re-run inference with --use-dora to load them."
282
+ )
283
+ if self.expect_dora and not cfg_use_dora:
284
+ raise RuntimeError(
285
+ "Checkpoint adapters were trained without DoRA; omit --use-dora or use a checkpoint with DoRA."
286
+ )
287
+ model, active_lora_cfg = attach_lora(model, lora_cfg_dict, self.device)
288
+ model.eval()
289
+ elif no_lora and lora_cfg_dict:
290
+ print("[LoRA] --no-lora set; skipping LoRA adapters from checkpoint.", flush=True)
291
+ elif self.expect_dora:
292
+ raise RuntimeError("--use-dora was provided, but no LoRA/DoRA adapters were found in the checkpoint.")
293
+
294
+ conv = build_conv_encoder(
295
+ conv_ckpt_path=None if adapter_has_conv else conv_ckpt,
296
+ device=self.device,
297
+ unfreeze=False,
298
+ )
299
+ conv.eval()
300
+ for p in conv.parameters():
301
+ p.requires_grad = False
302
+ self.conv_encoder = conv
303
+
304
+ hidden_size = _resolve_hidden_size(model, self.arch_cfg.hidden_size_attrs)
305
+ wrapper_cls = model_cfg.resolve_wrapper_class()
306
+ enc_out_dim = self.arch_cfg.conv_out_dim if getattr(self.arch_cfg, "conv_out_dim", None) is not None else 64
307
+ projector_name = peek_projector_name(adapter_ckpt_path) or "linear"
308
+ wrapper = build_wrapper(
309
+ wrapper_cls=wrapper_cls,
310
+ language_model=model,
311
+ conv_encoder=conv,
312
+ hidden_size=hidden_size,
313
+ num_ecg_special_tokens=len(catalog.tokens),
314
+ dtype=self.dtype,
315
+ enc_out_dim=int(enc_out_dim),
316
+ freeze_encoder=True,
317
+ inference=True,
318
+ projector_name=projector_name,
319
+ )
320
+ self.projector_name = projector_name
321
+ self.wrapper = wrapper
322
+
323
+ _extra_payload, model, inferred_lora_cfg = load_llava_and_lora(
324
+ wrapper,
325
+ model,
326
+ adapter_ckpt_path,
327
+ expect_lora=(active_lora_cfg is not None),
328
+ load_lora=not no_lora,
329
+ )
330
+ update_wrapper_language_model(wrapper, model)
331
+ if active_lora_cfg is None and inferred_lora_cfg is not None and not no_lora:
332
+ active_lora_cfg = inferred_lora_cfg
333
+ model.eval()
334
+ for p in model.parameters():
335
+ p.requires_grad = False
336
+
337
+ inp_emb = model.get_input_embeddings().weight
338
+ inp_dev = inp_emb.device
339
+ target_dtype = inp_emb.dtype
340
+ wrapper.llava_proj.to(device=inp_dev, dtype=torch.float32)
341
+ wrapper.enc.to(device=inp_dev, dtype=torch.float32)
342
+ wrapper.ecg_special_embed.to(device=inp_dev, dtype=target_dtype)
343
+ llava_param = next(wrapper.llava_proj.parameters(), None)
344
+ if llava_param is None:
345
+ raise AssertionError("llava_proj unexpectedly has no parameters.")
346
+ if llava_param.device != inp_dev:
347
+ raise AssertionError(f"llava_proj on {llava_param.device}, expected {inp_dev}")
348
+ if llava_param.dtype != torch.float32:
349
+ raise AssertionError(f"llava_proj dtype {llava_param.dtype}, expected torch.float32")
350
+ conv_param = next(wrapper.enc.parameters(), None)
351
+ if conv_param is None:
352
+ raise AssertionError("Convolutional encoder unexpectedly has no parameters.")
353
+ if conv_param.device != inp_dev:
354
+ raise AssertionError(f"Conv encoder on {conv_param.device}, expected {inp_dev}")
355
+ if conv_param.dtype != torch.float32:
356
+ raise AssertionError(f"Conv encoder dtype {conv_param.dtype}, expected torch.float32")
357
+ assert next(wrapper.ecg_special_embed.parameters()).device == inp_dev, (
358
+ f"ecg_special_embed on {next(wrapper.ecg_special_embed.parameters()).device}, expected {inp_dev}"
359
+ )
360
+ try:
361
+ wrapper.language_model.eval()
362
+ except Exception:
363
+ pass
364
+
365
+ self.default_max_new_tokens = int(default_max_new_tokens)
366
+ self.default_temperature = float(default_temperature)
367
+ self.default_top_k = int(default_top_k) if default_top_k is not None else None
368
+ self.default_top_p = float(default_top_p)
369
+ self.default_min_p = float(default_min_p)
370
+
371
+ def chat(
372
+ self,
373
+ *,
374
+ conversation: List[Dict[str, Any]],
375
+ max_new_tokens: Optional[int] = None,
376
+ temperature: Optional[float] = None,
377
+ top_k: Optional[int] = None,
378
+ top_p: Optional[float] = None,
379
+ min_p: Optional[float] = None,
380
+ harmony_output: Optional[str] = None,
381
+ ) -> Tuple[str, str]:
382
+ """Generate a response for a structured multi-turn conversation."""
383
+ context = self._prepare_prompt_context(conversation=conversation)
384
+
385
+ max_new_tokens = int(max_new_tokens if max_new_tokens is not None else self.default_max_new_tokens)
386
+ temperature = float(temperature if temperature is not None else self.default_temperature)
387
+ resolved_top_k = int(top_k) if top_k is not None else (self.default_top_k if self.default_top_k is not None else None)
388
+ top_p = float(top_p if top_p is not None else self.default_top_p)
389
+ min_p = float(min_p if min_p is not None else self.default_min_p)
390
+
391
+ token_ids = self._autoregressive_generate(
392
+ context=context,
393
+ max_new_tokens=max_new_tokens,
394
+ temperature=temperature,
395
+ top_k=resolved_top_k,
396
+ top_p=top_p,
397
+ min_p=min_p,
398
+ )
399
+ text = self.tokenizer.decode(token_ids, skip_special_tokens=False)
400
+ if self.packing_schema.conversation.format_id == "harmony_chat_v1":
401
+ mode = harmony_output if harmony_output is not None else "all"
402
+ if mode != "raw":
403
+ messages = _extract_harmony_messages(text)
404
+ if mode == "all":
405
+ text = "\n".join(msg for _, msg in messages)
406
+ elif mode == "final":
407
+ finals = [msg for channel, msg in messages if channel == "final"]
408
+ if not finals:
409
+ raise ValueError("No final channel output found in harmony response.")
410
+ text = finals[-1]
411
+ else:
412
+ raise ValueError(f"Unknown harmony_output '{mode}'.")
413
+ return text, context.prompt_preview
414
+
415
+ def _to_waveform_tensor(self, value: Any) -> torch.Tensor:
416
+ if isinstance(value, torch.Tensor):
417
+ tensor = value.detach().cpu()
418
+ elif isinstance(value, np.ndarray):
419
+ tensor = torch.from_numpy(np.asarray(value))
420
+ else:
421
+ tensor = torch.tensor(value, dtype=torch.float32)
422
+ tensor = tensor.to(dtype=torch.float32)
423
+ if tensor.ndim == 1:
424
+ if tensor.numel() != 256:
425
+ raise ValueError("Expected a 256-sample vector for a single lead second.")
426
+ tensor = tensor.view(1, 256)
427
+ elif tensor.ndim == 2:
428
+ if tensor.size(-1) != 256:
429
+ raise ValueError("Waveform segments must have length 256 along the last dimension.")
430
+ else:
431
+ raise ValueError("Waveform tensor must be rank 1 or 2 with 256-sample segments.")
432
+ return tensor.contiguous()
433
+
434
+ def _prepare_prompt_context(
435
+ self,
436
+ *,
437
+ conversation: List[Dict[str, Any]],
438
+ ) -> PromptContext:
439
+ tok = self.tokenizer
440
+ wrapper = self.wrapper
441
+ packing_schema = self.packing_schema
442
+ device = self.device
443
+
444
+ prompt_tokens = packing_schema.prompt
445
+ if not isinstance(conversation, list) or not conversation:
446
+ raise ValueError("conversation must be a non-empty list of turns.")
447
+
448
+ conv_input: List[Dict[str, Any]] = []
449
+ for turn in conversation:
450
+ if not isinstance(turn, dict):
451
+ raise ValueError("Conversation turns must be dicts.")
452
+ if "from" not in turn and "role" in turn:
453
+ turn = dict(turn)
454
+ turn["from"] = turn.get("role")
455
+ conv_input.append(turn)
456
+
457
+ turns = _normalize_conversation(conv_input, packing_schema, self.system_text, self.developer_text)
458
+ if turns[-1]["role"] != prompt_tokens.user_role:
459
+ raise ValueError("Conversation must end with a user turn to generate.")
460
+
461
+ def _sanitize_content(content: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
462
+ sanitized: List[Dict[str, Any]] = []
463
+ for item in content:
464
+ if not isinstance(item, dict):
465
+ raise ValueError("Conversation content items must be dicts.")
466
+ item_type = item.get("type")
467
+ if item_type == "ecg":
468
+ waveform = item.get("waveform_segments")
469
+ if not isinstance(waveform, dict):
470
+ raise ValueError("ECG content item missing waveform_segments mapping.")
471
+ wf_out: "OrderedDict[str, torch.Tensor]" = OrderedDict()
472
+ for ld, value in waveform.items():
473
+ wf_out[str(ld)] = _sanitize_segments(self._to_waveform_tensor(value))
474
+ new_item = dict(item)
475
+ new_item["waveform_segments"] = wf_out
476
+ sanitized.append(new_item)
477
+ continue
478
+ if item_type == "text":
479
+ text_val = item.get("text")
480
+ if not isinstance(text_val, str):
481
+ raise ValueError("Text content item must have a string 'text' field.")
482
+ if "<image>" in text_val:
483
+ raise ValueError("Conversation text must not contain <image> in inference mode.")
484
+ sanitized.append(item)
485
+ return sanitized
486
+
487
+ ecg_blocks: List[Dict[str, Any]] = []
488
+ turn_parts: List[List[Dict[str, Any]]] = []
489
+ token_turns: List[Dict[str, str]] = []
490
+
491
+ for turn in turns:
492
+ role = turn["role"]
493
+ content = _sanitize_content(turn["content"])
494
+ if role == prompt_tokens.model_role:
495
+ turn_text_block, content_parts = build_text_only_turn_parts(
496
+ content=content,
497
+ canonical_role=role,
498
+ schema=packing_schema,
499
+ )
500
+ else:
501
+ turn_text_block, content_parts = build_structured_turn_parts(
502
+ content=content,
503
+ canonical_role=role,
504
+ schema=packing_schema,
505
+ ecg_blocks=ecg_blocks,
506
+ sampling_rate=None,
507
+ )
508
+ prefix, suffix = turn_wrappers(packing_schema, role)
509
+ parts = [{"kind": "text", "text": prefix}]
510
+ parts.extend(content_parts)
511
+ parts.append({"kind": "text", "text": suffix})
512
+ turn_parts.append(parts)
513
+ token_turns.append({"role": role, "text_block": turn_text_block})
514
+
515
+ if not ecg_blocks:
516
+ raise ValueError("No ECG blocks found in conversation.")
517
+
518
+ turn_parts = annotate_turn_parts_with_ids(turn_parts, tok)
519
+ assert_ecg_blocks_consistent(turn_parts=turn_parts, ecg_blocks=ecg_blocks)
520
+
521
+ token_struct = render_prompt_and_spans(tok, token_turns, schema=packing_schema)
522
+ text_ids = list(token_struct["text_ids"])
523
+ if (
524
+ prompt_tokens.require_eos
525
+ and tok.eos_token_id is not None
526
+ and text_ids
527
+ and text_ids[-1] == tok.eos_token_id
528
+ ):
529
+ text_ids = text_ids[:-1]
530
+ text_preview = token_struct.get("text_preview", "")
531
+
532
+ model_prefix = assistant_generation_prefix(packing_schema)
533
+ ids_model_prefix = tok.encode(model_prefix, add_special_tokens=False)
534
+ text_ids.extend(ids_model_prefix)
535
+ prompt_preview = text_preview + model_prefix
536
+
537
+ model_prefix_parts = [{"kind": "text", "text": model_prefix, "ids": ids_model_prefix}]
538
+ all_parts = list(turn_parts) + [model_prefix_parts]
539
+
540
+ flat_blocks = [blk["waveform_segments"] for blk in ecg_blocks]
541
+ lead_orders = [blk["lead_names"] for blk in ecg_blocks]
542
+ prefix_all, prefix_lens = wrapper.ecg_prefix_batch(
543
+ flat_blocks,
544
+ device=device,
545
+ lead_orders=lead_orders,
546
+ )
547
+ prefixes: List[torch.Tensor] = []
548
+ offset = 0
549
+ for n in prefix_lens:
550
+ prefixes.append(prefix_all[offset:offset + int(n)])
551
+ offset += int(n)
552
+ assert_prefix_split_complete(offset=offset, total_prefix_rows=int(prefix_all.size(0)))
553
+
554
+ block_layouts: List[ECGBlockLayout] = []
555
+ lead_offsets: List[Dict[str, int]] = []
556
+ lead_special_counts: List[Dict[str, int]] = []
557
+ for blk_idx, blk in enumerate(ecg_blocks):
558
+ lead_names = [str(ld) for ld in blk.get("lead_names", [])]
559
+ segs_per_lead = [int(n) for n in blk.get("segments_per_lead", [])]
560
+ prefix_rows = prefixes[blk_idx].size(0) if blk_idx < len(prefixes) else 0
561
+ assert_prefix_matches_segments(
562
+ prefix_rows=prefix_rows,
563
+ segments_per_lead=segs_per_lead,
564
+ lead_names=lead_names,
565
+ sample_index=0,
566
+ block_index=blk_idx,
567
+ )
568
+ lead_to_offset: Dict[str, int] = {}
569
+ c = 0
570
+ for ld, nseg in zip(lead_names, segs_per_lead):
571
+ lead_to_offset[ld] = c
572
+ c += int(nseg)
573
+ lead_offsets.append(lead_to_offset)
574
+ block_layouts.append(ECGBlockLayout(
575
+ start_idx=None,
576
+ end_idx_exclusive=None,
577
+ global_start_idx=None,
578
+ global_end_idx=None,
579
+ lead_start_idx={},
580
+ lead_end_idx={},
581
+ signal_pos_by_lead={ld: [None] * int(nseg) for ld, nseg in zip(lead_names, segs_per_lead)},
582
+ time_to_signal_idxs={},
583
+ declared_segments_per_lead={ld: int(nseg) for ld, nseg in zip(lead_names, segs_per_lead)},
584
+ ))
585
+ lead_special_counts.append({})
586
+
587
+ special_indices: List[int] = [
588
+ int(part["token_index"])
589
+ for turn in all_parts
590
+ for part in turn
591
+ if part.get("kind") == "special"
592
+ ]
593
+ if special_indices:
594
+ special_idx_tensor = torch.tensor(special_indices, dtype=torch.long, device=device)
595
+ special_embeds = wrapper.ecg_special_tokens_to_embeds(special_idx_tensor, device=device)
596
+ else:
597
+ special_embeds = torch.empty((0, wrapper.hidden_size), dtype=wrapper.dtype, device=device)
598
+
599
+ input_embedder = wrapper.language_model.get_input_embeddings()
600
+ if text_ids:
601
+ E_text_all = wrapper.tokens_to_embeds(input_embedder, text_ids, device=device)
602
+ else:
603
+ E_text_all = torch.empty((0, wrapper.hidden_size), dtype=wrapper.dtype, device=device)
604
+
605
+ text_cursor = 0
606
+ empty_text = E_text_all[:0]
607
+ chunks: List[torch.Tensor] = []
608
+ layout = ECGSequenceLayout(seq_len=0, text_idxs=[], blocks=block_layouts)
609
+
610
+ def _take_text(count: int) -> torch.Tensor:
611
+ nonlocal text_cursor
612
+ if count <= 0:
613
+ return empty_text
614
+ end = text_cursor + count
615
+ if end > E_text_all.size(0):
616
+ raise RuntimeError("Text embedding cursor exceeded available embeddings")
617
+ out = E_text_all[text_cursor:end]
618
+ text_cursor = end
619
+ return out
620
+
621
+ def _record_text(count: int, cursor: int) -> None:
622
+ for i in range(count):
623
+ layout.text_idxs.append(cursor + i)
624
+
625
+ cursor = 0
626
+ special_cursor = 0
627
+
628
+ if (
629
+ text_ids
630
+ and prompt_tokens.require_bos
631
+ and tok.bos_token_id is not None
632
+ and text_ids[0] == tok.bos_token_id
633
+ ):
634
+ E_bos = _take_text(1)
635
+ chunks.append(E_bos)
636
+ _record_text(1, cursor)
637
+ cursor += 1
638
+
639
+ for turn in all_parts:
640
+ for part in turn:
641
+ kind = part.get("kind")
642
+ if kind == "text":
643
+ ids_chunk = part.get("ids")
644
+ if ids_chunk is None:
645
+ txt = part.get("text", "")
646
+ ids_chunk = tok.encode(txt, add_special_tokens=False) if txt else []
647
+ if ids_chunk:
648
+ if ids_chunk != text_ids[text_cursor:text_cursor + len(ids_chunk)]:
649
+ raise RuntimeError("Special token id does not match text_ids cursor.")
650
+ E_chunk = _take_text(len(ids_chunk))
651
+ chunks.append(E_chunk)
652
+ _record_text(len(ids_chunk), cursor)
653
+ cursor += len(ids_chunk)
654
+ continue
655
+ if kind == "special":
656
+ if special_cursor >= special_embeds.size(0):
657
+ raise RuntimeError("Special-token cursor exceeded embeddings.")
658
+ tok_idx = int(part.get("token_index", -1))
659
+ expected_id = self.ecg_special_token_id_map.get(tok_idx)
660
+ if expected_id is None:
661
+ raise RuntimeError(f"Unknown ECG special token index {tok_idx}.")
662
+ if text_cursor >= len(text_ids):
663
+ raise RuntimeError("Text cursor exceeded available text ids.")
664
+ if text_ids[text_cursor] != expected_id:
665
+ raise RuntimeError("Special token id does not match text_ids cursor.")
666
+ _take_text(1)
667
+ chunks.append(special_embeds[special_cursor:special_cursor + 1])
668
+ _record_text(1, cursor)
669
+
670
+ block_index = int(part.get("block_index", -1))
671
+ if block_index < 0 or block_index >= len(block_layouts):
672
+ raise RuntimeError("ECG part references unknown block_index.")
673
+ block_layout = block_layouts[block_index]
674
+ lead_name = part.get("lead")
675
+ if lead_name:
676
+ cnt = lead_special_counts[block_index].get(lead_name, 0)
677
+ if cnt == 0:
678
+ block_layout.lead_start_idx[lead_name] = cursor
679
+ else:
680
+ block_layout.lead_end_idx[lead_name] = cursor
681
+ lead_special_counts[block_index][lead_name] = cnt + 1
682
+ else:
683
+ if block_layout.global_start_idx is None:
684
+ block_layout.global_start_idx = cursor
685
+ block_layout.start_idx = cursor
686
+ else:
687
+ block_layout.global_end_idx = cursor
688
+ block_layout.end_idx_exclusive = cursor + 1
689
+
690
+ cursor += 1
691
+ special_cursor += 1
692
+ continue
693
+ if kind == "ecg":
694
+ block_index = int(part.get("block_index", -1))
695
+ if block_index < 0 or block_index >= len(block_layouts):
696
+ raise RuntimeError("ECG part references unknown block_index.")
697
+ ld = part["lead"]
698
+ sec = int(part["sec"])
699
+ lead_to_offset = lead_offsets[block_index]
700
+ block_layout = block_layouts[block_index]
701
+ prefix_all = prefixes[block_index]
702
+ assert_ecg_part_bounds(
703
+ lead=ld,
704
+ sec=sec,
705
+ lead_to_offset=lead_to_offset,
706
+ declared_segments=block_layout.declared_segments_per_lead,
707
+ total_prefix_rows=prefix_all.size(0),
708
+ sample_index=0,
709
+ block_index=block_index,
710
+ )
711
+ base = lead_to_offset[ld]
712
+ row_idx = base + (sec - 1)
713
+ chunks.append(prefix_all[row_idx:row_idx + 1])
714
+ sig_list = block_layout.signal_pos_by_lead[ld]
715
+ if sec - 1 >= len(sig_list):
716
+ raise RuntimeError("ECG segment index exceeds declared segments_per_lead")
717
+ sig_list[sec - 1] = cursor
718
+ block_layout.time_to_signal_idxs.setdefault(sec, []).append(cursor)
719
+ cursor += 1
720
+ continue
721
+ raise RuntimeError(f"Unknown turn part kind '{kind}'.")
722
+
723
+ remaining = len(text_ids) - text_cursor
724
+ if remaining > 0:
725
+ E_tail = _take_text(remaining)
726
+ chunks.append(E_tail)
727
+ _record_text(remaining, cursor)
728
+ cursor += remaining
729
+
730
+ if special_cursor != special_embeds.size(0):
731
+ raise RuntimeError("Did not consume all special-token embeddings for prompt")
732
+ if text_cursor != E_text_all.size(0):
733
+ raise RuntimeError("Text embedding cursor did not consume all embeddings")
734
+
735
+ inputs_embeds = torch.cat(chunks, dim=0)
736
+ layout.seq_len = inputs_embeds.size(0)
737
+
738
+ for blk_idx, blk_layout in enumerate(block_layouts):
739
+ for ld, expected in blk_layout.declared_segments_per_lead.items():
740
+ slots = blk_layout.signal_pos_by_lead[ld]
741
+ if any(pos is None for pos in slots):
742
+ raise RuntimeError(f"Lead {ld} missing ECG slots; expected {expected}.")
743
+ blk_layout.signal_pos_by_lead[ld] = [int(pos) for pos in slots]
744
+ if blk_layout.global_start_idx is None or blk_layout.global_end_idx is None:
745
+ raise RuntimeError("ECG block missing global start/end specials.")
746
+ if blk_layout.end_idx_exclusive is None:
747
+ blk_layout.end_idx_exclusive = int(blk_layout.global_end_idx) + 1
748
+ if blk_layout.start_idx is None:
749
+ blk_layout.start_idx = int(blk_layout.global_start_idx)
750
+ assert_layout_specials_complete(
751
+ block_layout=blk_layout,
752
+ lead_names=ecg_blocks[blk_idx]["lead_names"],
753
+ )
754
+ all_specials = []
755
+ if blk_layout.global_start_idx is not None:
756
+ all_specials.append(blk_layout.global_start_idx)
757
+ all_specials.extend(list(blk_layout.lead_start_idx.values()))
758
+ all_specials.extend(list(blk_layout.lead_end_idx.values()))
759
+ if blk_layout.global_end_idx is not None:
760
+ all_specials.append(blk_layout.global_end_idx)
761
+ blk_layout.special_idxs_sorted = sorted(all_specials)
762
+ blk_layout.signal_pos_list = sorted(
763
+ [p for lst in blk_layout.signal_pos_by_lead.values() for p in lst]
764
+ )
765
+
766
+ mask_result = self.mask_strategy.build(
767
+ layout,
768
+ device=device,
769
+ dtype=inputs_embeds.dtype,
770
+ )
771
+ use_return = self.packing_schema.conversation.format_id == "harmony_chat_v1"
772
+ _, stop_text = turn_wrappers(self.packing_schema, prompt_tokens.model_role, use_return=use_return)
773
+ stop_ids = tok.encode(stop_text, add_special_tokens=False)
774
+ return PromptContext(
775
+ inputs_embeds=inputs_embeds,
776
+ layout=layout,
777
+ prompt_preview=prompt_preview,
778
+ stop_ids=stop_ids,
779
+ input_embedder=input_embedder,
780
+ mask_strategy=self.mask_strategy,
781
+ mask_result=mask_result,
782
+ )
783
+
784
+ def _autoregressive_generate(
785
+ self,
786
+ *,
787
+ context: PromptContext,
788
+ max_new_tokens: int,
789
+ temperature: float,
790
+ top_k: Optional[int],
791
+ top_p: float,
792
+ min_p: float,
793
+ ) -> List[int]:
794
+ tok = self.tokenizer
795
+ wrapper = self.wrapper
796
+ device = context.inputs_embeds.device
797
+ embeds = context.inputs_embeds.clone()
798
+ layout = context.layout
799
+ input_embedder = context.input_embedder
800
+ mask_result = context.mask_result
801
+
802
+ generated: List[int] = []
803
+ stop_ids = context.stop_ids
804
+ stop_len = len(stop_ids)
805
+ eos_id = tok.eos_token_id
806
+
807
+ for _ in range(max_new_tokens):
808
+ additive = mask_result.additive.unsqueeze(0).unsqueeze(0)
809
+ outputs = wrapper.forward_language_model(
810
+ inputs_embeds=embeds.unsqueeze(0),
811
+ attention_mask=additive,
812
+ labels=None,
813
+ )
814
+ logits = outputs.logits[0, -1, :].float()
815
+ next_token = _sample_next_token(
816
+ logits,
817
+ temperature=temperature,
818
+ top_k=top_k,
819
+ top_p=top_p,
820
+ min_p=min_p,
821
+ )
822
+ token_id = int(next_token.item())
823
+ generated.append(token_id)
824
+
825
+ if stop_len and generated[-stop_len:] == stop_ids:
826
+ generated = generated[:-stop_len]
827
+ break
828
+ if eos_id is not None and token_id == eos_id:
829
+ break
830
+
831
+ with torch.no_grad():
832
+ new_embed = wrapper.tokens_to_embeds(input_embedder, [token_id], device=device)
833
+ embeds = torch.cat([embeds, new_embed], dim=0)
834
+
835
+ layout.seq_len = embeds.size(0)
836
+ new_idx = layout.seq_len - 1
837
+ layout.text_idxs.append(new_idx)
838
+ mask_result = context.mask_strategy.update_for_generated_token(
839
+ layout,
840
+ device=device,
841
+ dtype=embeds.dtype,
842
+ previous=mask_result,
843
+ )
844
+ context.mask_result = mask_result
845
+
846
+ return generated
camel_inference/src/camel/model_init.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, Optional, Tuple, Type
4
+ import torch
5
+ import torch.nn as nn
6
+ from peft import (
7
+ LoraConfig,
8
+ get_peft_model,
9
+ TaskType,
10
+ )
11
+
12
+ from camel.ecg_gemma_model import ECGGemmaPrefix as ECGModelPrefix
13
+
14
+ def attach_lora(
15
+ model: nn.Module,
16
+ lora_cfg_dict: Dict[str, Any],
17
+ device: torch.device,
18
+ ) -> Tuple[nn.Module, LoraConfig]:
19
+ """Attach LoRA adapters to the frozen model, leaving only LoRA trainable."""
20
+ cfg = LoraConfig(
21
+ r=int(lora_cfg_dict["r"]),
22
+ lora_alpha=int(lora_cfg_dict.get("lora_alpha", int(lora_cfg_dict["r"]) * 2)),
23
+ lora_dropout=float(lora_cfg_dict.get("lora_dropout", 0.0)),
24
+ target_modules=list(lora_cfg_dict.get("target_modules", [])),
25
+ task_type=TaskType(lora_cfg_dict.get("task_type", "CAUSAL_LM")),
26
+ bias=lora_cfg_dict.get("bias", "none"),
27
+ inference_mode=False,
28
+ use_dora=bool(lora_cfg_dict.get("use_dora", False)),
29
+ )
30
+ model = get_peft_model(model, cfg)
31
+ model.to(device)
32
+ return model, cfg
33
+
34
+ def build_conv_encoder(
35
+ *,
36
+ conv_ckpt_path: Optional[str],
37
+ device: torch.device,
38
+ unfreeze: bool = False,
39
+ ) -> nn.Module:
40
+ """
41
+ Build the 1D conv stack and load weights from the provided checkpoint, including
42
+ key normalization and optional unfreezing. Identical to train_ecg_text.py.
43
+ """
44
+ enc = nn.Sequential(
45
+ nn.Conv1d(1, 32, kernel_size=4, stride=2, padding=1), # L:256->128
46
+ nn.ReLU(inplace=True),
47
+ nn.Conv1d(32, 64, kernel_size=4, stride=2, padding=1), # L:128->64
48
+ nn.ReLU(inplace=True),
49
+ nn.Conv1d(64, 128, kernel_size=4, stride=2, padding=1), # L:64->32
50
+ nn.ReLU(inplace=True),
51
+ nn.Conv1d(128, 4, kernel_size=4, stride=2, padding=1), # L:32->16, C:4
52
+ nn.ReLU(inplace=True),
53
+ ).to(device=device, dtype=torch.float32)
54
+
55
+ if conv_ckpt_path:
56
+ ckpt = torch.load(conv_ckpt_path, map_location="cpu")
57
+ raw_sd = ckpt["model_state_dict"] if isinstance(ckpt, dict) and "model_state_dict" in ckpt else ckpt
58
+ norm_sd: Dict[str, torch.Tensor] = {}
59
+ for k, v in raw_sd.items():
60
+ kk = k
61
+ if kk.startswith("module."):
62
+ kk = kk[len("module."):]
63
+ if kk.startswith("_orig_mod."):
64
+ kk = kk[len("_orig_mod."):]
65
+ if kk.startswith("enc."):
66
+ kk = kk[len("enc."):]
67
+ norm_sd[kk] = v
68
+ wanted = {f"{i}.{w}" for i in (0, 2, 4, 6) for w in ("weight", "bias")}
69
+ conv_sd = {k: v for k, v in norm_sd.items() if k in wanted}
70
+ missing, unexpected = enc.load_state_dict(conv_sd, strict=True)
71
+ if missing or unexpected:
72
+ print(f"[conv load] missing={list(missing)} unexpected={list(unexpected)}")
73
+ enc.eval()
74
+ return enc
75
+
76
+ def build_wrapper(
77
+ *,
78
+ wrapper_cls: Type[nn.Module] = ECGModelPrefix,
79
+ language_model: nn.Module,
80
+ conv_encoder: nn.Module,
81
+ hidden_size: int,
82
+ num_ecg_special_tokens: int,
83
+ dtype: torch.dtype,
84
+ enc_out_dim: int = 64,
85
+ freeze_encoder: bool = True,
86
+ inference: bool = False,
87
+ projector_name: str = "linear",
88
+ ) -> ECGModelPrefix:
89
+ """Construct the ECG-language wrapper (keeps default wrapper class)."""
90
+ wrapper = wrapper_cls(
91
+ language_model,
92
+ enc=conv_encoder,
93
+ hidden_size=hidden_size,
94
+ num_ecg_special_tokens=num_ecg_special_tokens,
95
+ dtype=dtype,
96
+ enc_out_dim=enc_out_dim,
97
+ freeze_encoder=freeze_encoder,
98
+ inference=inference,
99
+ projector_name=projector_name,
100
+ )
101
+ return wrapper
102
+
103
+
104
+ __all__ = [
105
+ "attach_lora",
106
+ "build_conv_encoder",
107
+ "build_wrapper",
108
+ ]
camel_inference/src/camel/model_introspect.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model introspection utilities driven by registry hints.
2
+
3
+ Centralizes resolution of model-internal structures to avoid hardcoded
4
+ attribute names in call sites. Use the hint paths defined in the model
5
+ registry to locate transformer layers and config attributes.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ from typing import List, Optional, Sequence
10
+ import torch.nn as nn
11
+
12
+ def _walk_attr_path(root: object, dotted_path: str) -> Optional[object]:
13
+ cur: object = root
14
+ for part in dotted_path.split("."):
15
+ if not hasattr(cur, part):
16
+ return None
17
+ cur = getattr(cur, part)
18
+ return cur
19
+
20
+ def resolve_layers(model: nn.Module, path_hints: Sequence[str]) -> List[nn.Module]:
21
+ """
22
+ Resolve the text transformer layer sequence using the first successful
23
+ dotted path from `path_hints` relative to common roots.
24
+
25
+ We try against several candidate roots to be robust to wrappers (e.g., PEFT):
26
+ - the model itself
27
+ - model.base_model (if present)
28
+ - model.base_model.model (if present)
29
+
30
+ We also try a small set of generic fallback hints ("model.language_model.layers",
31
+ "language_model.layers", "model.layers", "layers") if the provided hints fail.
32
+ """
33
+ roots: List[object] = [model]
34
+ base = getattr(model, "base_model", None)
35
+ if base is not None:
36
+ roots.append(base)
37
+ base_model_attr = getattr(base, "model", None)
38
+ if base_model_attr is not None:
39
+ roots.append(base_model_attr)
40
+
41
+ tried: List[str] = []
42
+ def _try_hints(root: object, hints: Sequence[str]) -> Optional[List[nn.Module]]:
43
+ for hint in hints:
44
+ tried.append(hint)
45
+ obj = _walk_attr_path(root, hint)
46
+ if obj is None:
47
+ continue
48
+ if isinstance(obj, (list, tuple)) and all(isinstance(x, nn.Module) for x in obj):
49
+ return list(obj)
50
+ if hasattr(obj, "__iter__") and not isinstance(obj, (str, bytes)):
51
+ try:
52
+ seq = list(obj)
53
+ if seq and all(isinstance(x, nn.Module) for x in seq):
54
+ return seq
55
+ except Exception:
56
+ pass
57
+ return None
58
+
59
+ # Try provided hints against all roots
60
+ for root in roots:
61
+ found = _try_hints(root, path_hints)
62
+ if found is not None:
63
+ return found
64
+
65
+ # Fallback generic hints against all roots
66
+ generic_hints = (
67
+ "model.language_model.layers",
68
+ "language_model.layers",
69
+ "model.layers",
70
+ "layers",
71
+ )
72
+ for root in roots:
73
+ found = _try_hints(root, generic_hints)
74
+ if found is not None:
75
+ return found
76
+
77
+ raise RuntimeError(
78
+ f"Could not resolve transformer layers via provided hints: {list(path_hints)}"
79
+ )
80
+
81
+ def resolve_hidden_size(model: nn.Module, attr_paths: Sequence[str]) -> int:
82
+ """Resolve hidden size via the first successful dotted config attribute path."""
83
+ for path in attr_paths:
84
+ val = _walk_attr_path(model, path)
85
+ if isinstance(val, (int, float)):
86
+ return int(val)
87
+ raise AttributeError(
88
+ f"Could not resolve hidden size from any of: {list(attr_paths)}"
89
+ )
90
+
91
+
92
+ __all__ = [
93
+ "resolve_layers",
94
+ "resolve_hidden_size",
95
+ ]
camel_inference/src/camel/model_registry.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for loading per-model configuration metadata used across training and inference.
3
+
4
+ The registry is defined in YAML (see model_registry.yaml in this directory) and exposes
5
+ immutable ModelConfig objects for downstream consumers. The intent is to centralize
6
+ model-specific defaults (prompt format, tokenizer quirks, wrapper class path, LoRA constraints, etc.)
7
+ so that adding support for a new backbone primarily involves updating the registry.
8
+ """
9
+ from __future__ import annotations
10
+
11
+ import copy
12
+ import importlib
13
+ import dataclasses
14
+ import os
15
+ from collections.abc import Mapping as ABCMapping, Sequence as ABCSequence
16
+ from pathlib import Path
17
+ from types import MappingProxyType
18
+ from typing import Any, Dict, Iterable, Mapping, Optional, Tuple
19
+ import yaml
20
+
21
+ class ModelRegistryError(RuntimeError):
22
+ """Raised when the registry file is missing or malformed."""
23
+
24
+ @dataclasses.dataclass(frozen=True)
25
+ class PromptConfig:
26
+ start_of_turn: str
27
+ end_of_turn: str
28
+ roles: Mapping[str, str]
29
+ enforce_bos: bool
30
+ enforce_eos: bool
31
+ allow_multiple_eos: bool
32
+
33
+
34
+ @dataclasses.dataclass(frozen=True)
35
+ class TokenizerConfig:
36
+ pad_token_strategy: str
37
+ require_bos: bool
38
+ require_eos: bool
39
+ use_fast: bool = True
40
+ add_prefix_space: bool = False
41
+
42
+
43
+ @dataclasses.dataclass(frozen=True)
44
+ class ArchitectureConfig:
45
+ wrapper_class: str
46
+ hidden_size_attrs: Tuple[str, ...]
47
+ language_model_path_hints: Tuple[str, ...]
48
+ attn_implementation: str
49
+ conv_out_dim: Optional[int] = None
50
+
51
+
52
+ @dataclasses.dataclass(frozen=True)
53
+ class LoRAPolicyConfig:
54
+ expect_language_only: bool
55
+ allowed_markers: Tuple[str, ...]
56
+ blocked_markers: Tuple[str, ...]
57
+ freeze_vision: bool
58
+
59
+
60
+ @dataclasses.dataclass(frozen=True)
61
+ class PackingConversationConfig:
62
+ format_id: str
63
+ user_role_aliases: Tuple[str, ...]
64
+ model_role_aliases: Tuple[str, ...]
65
+ strip_image_from_roles: Tuple[str, ...]
66
+ merge_system_with_first_user: bool
67
+
68
+
69
+ @dataclasses.dataclass(frozen=True)
70
+ class PackingECGTokensConfig:
71
+ global_start: str
72
+ global_end: str
73
+ lead_start_template: str
74
+ lead_end_template: str
75
+ canonical_leads: Tuple[str, ...]
76
+
77
+
78
+ @dataclasses.dataclass(frozen=True)
79
+ class PackingConfig:
80
+ prompt_format: str
81
+ conversation: PackingConversationConfig
82
+ ecg_tokens: PackingECGTokensConfig
83
+
84
+
85
+ @dataclasses.dataclass(frozen=True)
86
+ class ModelConfig:
87
+ """Typed wrapper over a single model entry in the registry."""
88
+
89
+ name: str
90
+ data: Mapping[str, Any]
91
+
92
+ @property
93
+ def hf_id(self) -> str:
94
+ return _require_str(self.data, "hf_id", self.name)
95
+
96
+ @property
97
+ def prompt(self) -> Mapping[str, Any]:
98
+ return _require_mapping(self.data, "prompt", self.name)
99
+
100
+ @property
101
+ def tokenizer(self) -> Mapping[str, Any]:
102
+ return _require_mapping(self.data, "tokenizer", self.name)
103
+
104
+ @property
105
+ def architecture(self) -> Mapping[str, Any]:
106
+ return _require_mapping(self.data, "architecture", self.name)
107
+
108
+ @property
109
+ def lora_policy(self) -> Mapping[str, Any]:
110
+ return _require_mapping(self.data, "lora_policy", self.name)
111
+
112
+ @property
113
+ def packing(self) -> Mapping[str, Any]:
114
+ return _require_mapping(self.data, "packing", self.name)
115
+
116
+ def prompt_config(self) -> PromptConfig:
117
+ prompt = self.prompt
118
+ roles = _require_mapping(prompt, "roles", self.name, section="prompt")
119
+ return PromptConfig(
120
+ start_of_turn=_require_str(prompt, "start_of_turn", self.name, section="prompt"),
121
+ end_of_turn=_require_str(prompt, "end_of_turn", self.name, section="prompt"),
122
+ roles={k: str(v) for k, v in roles.items()},
123
+ enforce_bos=_require_bool(prompt, "enforce_bos", self.name, section="prompt"),
124
+ enforce_eos=_require_bool(prompt, "enforce_eos", self.name, section="prompt"),
125
+ allow_multiple_eos=_optional_bool(prompt, "allow_multiple_eos", False, self.name, section="prompt"),
126
+ )
127
+
128
+ def required_prompt_text(self, key: str) -> str:
129
+ return _require_str(self.prompt, key, self.name, section="prompt")
130
+
131
+ def tokenizer_config(self) -> TokenizerConfig:
132
+ tokenizer = self.tokenizer
133
+ pad_strategy = _require_str(tokenizer, "pad_token_strategy", self.name, section="tokenizer")
134
+ require_bos = _require_bool(tokenizer, "require_bos", self.name, section="tokenizer")
135
+ require_eos = _require_bool(tokenizer, "require_eos", self.name, section="tokenizer")
136
+ # Optional fields with defaults for backward compatibility
137
+ use_fast = _optional_bool(tokenizer, "use_fast", True, self.name, section="tokenizer")
138
+ add_prefix_space = _optional_bool(tokenizer, "add_prefix_space", False, self.name, section="tokenizer")
139
+ return TokenizerConfig(
140
+ pad_token_strategy=pad_strategy,
141
+ require_bos=require_bos,
142
+ require_eos=require_eos,
143
+ use_fast=use_fast,
144
+ add_prefix_space=add_prefix_space,
145
+ )
146
+
147
+ def architecture_config(self) -> ArchitectureConfig:
148
+ arch = self.architecture
149
+ conv_out = arch.get("conv_out_dim") if isinstance(arch, ABCMapping) else None
150
+ try:
151
+ conv_out_int = int(conv_out) if conv_out is not None else None
152
+ except Exception:
153
+ conv_out_int = None
154
+ return ArchitectureConfig(
155
+ wrapper_class=_require_str(arch, "wrapper_class", self.name, section="architecture"),
156
+ hidden_size_attrs=tuple(
157
+ _require_sequence_of_str(arch, "hidden_size_attrs", self.name, section="architecture")
158
+ ),
159
+ language_model_path_hints=tuple(
160
+ _require_sequence_of_str(arch, "language_model_path_hints", self.name, section="architecture")
161
+ ),
162
+ attn_implementation=_require_str(arch, "attn_implementation", self.name, section="architecture"),
163
+ conv_out_dim=conv_out_int,
164
+ )
165
+
166
+ def lora_policy_config(self) -> LoRAPolicyConfig:
167
+ lora = self.lora_policy
168
+ return LoRAPolicyConfig(
169
+ expect_language_only=_require_bool(lora, "expect_language_only", self.name, section="lora_policy"),
170
+ allowed_markers=tuple(
171
+ _require_sequence_of_str(lora, "allowed_markers", self.name, section="lora_policy")
172
+ ),
173
+ blocked_markers=tuple(
174
+ _require_sequence_of_str(lora, "blocked_markers", self.name, section="lora_policy")
175
+ ),
176
+ freeze_vision=_require_bool(lora, "freeze_vision", self.name, section="lora_policy"),
177
+ )
178
+
179
+ def packing_config(self) -> PackingConfig:
180
+ packing = self.packing
181
+ format_id = _require_str(packing, "prompt_format", self.name, section="packing")
182
+
183
+ conversation = _require_mapping(packing, "conversation", self.name, section="packing")
184
+ user_aliases = tuple(
185
+ _require_sequence_of_str(conversation, "user_role_aliases", self.name, section="packing.conversation")
186
+ )
187
+ model_aliases = tuple(
188
+ _require_sequence_of_str(conversation, "model_role_aliases", self.name, section="packing.conversation")
189
+ )
190
+ strip_roles = tuple(
191
+ _require_sequence_of_str(conversation, "strip_image_from_roles", self.name, section="packing.conversation")
192
+ )
193
+ merge_system = _require_bool(
194
+ conversation, "merge_system_with_first_user", self.name, section="packing.conversation"
195
+ )
196
+ conv_cfg = PackingConversationConfig(
197
+ format_id=format_id,
198
+ user_role_aliases=user_aliases,
199
+ model_role_aliases=model_aliases,
200
+ strip_image_from_roles=strip_roles,
201
+ merge_system_with_first_user=merge_system,
202
+ )
203
+
204
+ ecg_tokens = _require_mapping(packing, "ecg_tokens", self.name, section="packing")
205
+ global_start = _require_str(ecg_tokens, "global_start", self.name, section="packing.ecg_tokens")
206
+ global_end = _require_str(ecg_tokens, "global_end", self.name, section="packing.ecg_tokens")
207
+ lead_start_template = _require_str(ecg_tokens, "lead_start_template", self.name, section="packing.ecg_tokens")
208
+ lead_end_template = _require_str(ecg_tokens, "lead_end_template", self.name, section="packing.ecg_tokens")
209
+ canonical_leads = tuple(
210
+ _require_sequence_of_str(ecg_tokens, "canonical_leads", self.name, section="packing.ecg_tokens")
211
+ )
212
+ ecg_cfg = PackingECGTokensConfig(
213
+ global_start=global_start,
214
+ global_end=global_end,
215
+ lead_start_template=lead_start_template,
216
+ lead_end_template=lead_end_template,
217
+ canonical_leads=canonical_leads,
218
+ )
219
+
220
+ return PackingConfig(
221
+ prompt_format=format_id,
222
+ conversation=conv_cfg,
223
+ ecg_tokens=ecg_cfg,
224
+ )
225
+
226
+ def resolve_wrapper_class(self):
227
+ arch = self.architecture_config()
228
+ if "." not in arch.wrapper_class:
229
+ raise ModelRegistryError(
230
+ f"Wrapper class path '{arch.wrapper_class}' for model '{self.name}' must be in 'module.ClassName' form."
231
+ )
232
+ module_name, class_name = arch.wrapper_class.rsplit(".", 1)
233
+ try:
234
+ module = importlib.import_module("." + module_name, package=__package__)
235
+ except ImportError as exc:
236
+ raise ModelRegistryError(
237
+ f"Failed to import wrapper module '{module_name}' for model '{self.name}': {exc}"
238
+ ) from exc
239
+ try:
240
+ wrapper_cls = getattr(module, class_name)
241
+ except AttributeError as exc:
242
+ raise ModelRegistryError(
243
+ f"Wrapper class '{class_name}' not found in module '{module_name}' for model '{self.name}'."
244
+ ) from exc
245
+ return wrapper_cls
246
+
247
+
248
+ def _default_registry_path() -> Path:
249
+ return Path(__file__).resolve().with_name("model_registry.yaml")
250
+
251
+
252
+ def load_registry(
253
+ *,
254
+ registry_path: Optional[os.PathLike[str] | str] = None,
255
+ model_overrides: Optional[Mapping[str, Mapping[str, Any]]] = None,
256
+ ) -> "ModelRegistry":
257
+ """
258
+ Load the model registry from YAML.
259
+
260
+ Args:
261
+ registry_path: Optional path to the YAML file. Defaults to `model_registry.yaml` alongside this module.
262
+ model_overrides: Optional mapping of model name -> override dict that will be deep-merged
263
+ onto the YAML entry (useful for ad-hoc experimentation).
264
+ """
265
+ path = Path(registry_path) if registry_path is not None else _default_registry_path()
266
+ if not path.exists():
267
+ raise ModelRegistryError(f"Model registry file not found at {path}")
268
+ try:
269
+ with path.open("r", encoding="utf-8") as fh:
270
+ raw = yaml.safe_load(fh)
271
+ except yaml.YAMLError as exc:
272
+ raise ModelRegistryError(f"Failed to parse model registry YAML: {exc}") from exc
273
+
274
+ if not isinstance(raw, ABCMapping):
275
+ raise ModelRegistryError("Model registry root must be a mapping.")
276
+
277
+ models = raw.get("models")
278
+ if not isinstance(models, ABCMapping) or len(models) == 0:
279
+ raise ModelRegistryError("Model registry must define a non-empty 'models' mapping.")
280
+
281
+ entries: Dict[str, Mapping[str, Any]] = {}
282
+ for name, cfg in models.items():
283
+ if not isinstance(name, str):
284
+ raise ModelRegistryError("Model names must be strings.")
285
+ if not isinstance(cfg, ABCMapping):
286
+ raise ModelRegistryError(f"Model '{name}' entry must be a mapping.")
287
+ merged = copy.deepcopy(cfg)
288
+ if model_overrides and name in model_overrides:
289
+ _deep_update(merged, model_overrides[name])
290
+ _validate_model_entry(name, merged)
291
+ entries[name] = merged
292
+
293
+ return ModelRegistry(entries, source_path=path)
294
+
295
+
296
+ class ModelRegistry:
297
+ """In-memory view over the registry."""
298
+
299
+ def __init__(self, models: Mapping[str, Mapping[str, Any]], *, source_path: Path):
300
+ self._models = dict(models)
301
+ self._source_path = Path(source_path)
302
+
303
+ @property
304
+ def source_path(self) -> Path:
305
+ return self._source_path
306
+
307
+ def names(self) -> Iterable[str]:
308
+ return tuple(self._models.keys())
309
+
310
+ def get(self, name: str) -> ModelConfig:
311
+ if name not in self._models:
312
+ raise ModelRegistryError(f"Unknown model '{name}'. Known models: {sorted(self._models)}")
313
+ return ModelConfig(name=name, data=_deep_freeze(self._models[name]))
314
+
315
+
316
+ def _validate_model_entry(name: str, entry: Mapping[str, Any]) -> None:
317
+ # Required string
318
+ _require_str(entry, "hf_id", name)
319
+
320
+ # Prompt section
321
+ prompt = _require_mapping(entry, "prompt", name)
322
+ _require_str(prompt, "start_of_turn", name, section="prompt")
323
+ _require_str(prompt, "end_of_turn", name, section="prompt")
324
+ roles = _require_mapping(prompt, "roles", name, section="prompt")
325
+ for role_key in ("user", "model"):
326
+ _require_str(roles, role_key, name, section="prompt.roles")
327
+ for flag in ("enforce_bos", "enforce_eos"):
328
+ _require_bool(prompt, flag, name, section="prompt")
329
+
330
+ # Tokenizer section
331
+ tokenizer = _require_mapping(entry, "tokenizer", name)
332
+ _require_str(tokenizer, "pad_token_strategy", name, section="tokenizer")
333
+ for flag in ("require_bos", "require_eos"):
334
+ _require_bool(tokenizer, flag, name, section="tokenizer")
335
+
336
+ # Architecture
337
+ architecture = _require_mapping(entry, "architecture", name)
338
+ _require_str(architecture, "wrapper_class", name, section="architecture")
339
+ hidden_attrs = _require_sequence_of_str(architecture, "hidden_size_attrs", name, section="architecture")
340
+ if len(hidden_attrs) == 0:
341
+ raise ModelRegistryError(
342
+ f"Model '{name}' architecture.hidden_size_attrs must contain at least one attribute path."
343
+ )
344
+ _require_sequence_of_str(architecture, "language_model_path_hints", name, section="architecture")
345
+ _require_str(architecture, "attn_implementation", name, section="architecture")
346
+ # conv_out_dim is optional; if present, ensure it is an int-like
347
+ if "conv_out_dim" in architecture:
348
+ val = architecture.get("conv_out_dim")
349
+ try:
350
+ int(val) # type: ignore[arg-type]
351
+ except Exception:
352
+ raise ModelRegistryError(
353
+ f"Model '{name}' field 'architecture.conv_out_dim' must be an integer when provided."
354
+ )
355
+
356
+ # LoRA policy
357
+ lora_policy = _require_mapping(entry, "lora_policy", name)
358
+ _require_bool(lora_policy, "expect_language_only", name, section="lora_policy")
359
+ allowed = _require_sequence_of_str(lora_policy, "allowed_markers", name, section="lora_policy")
360
+ blocked = _require_sequence_of_str(lora_policy, "blocked_markers", name, section="lora_policy")
361
+ overlap = set(allowed).intersection(blocked)
362
+ if overlap:
363
+ raise ModelRegistryError(
364
+ f"Model '{name}' lora_policy.allowed_markers and lora_policy.blocked_markers overlap: {sorted(overlap)}"
365
+ )
366
+ _require_bool(lora_policy, "freeze_vision", name, section="lora_policy")
367
+
368
+ # Packing
369
+ packing = _require_mapping(entry, "packing", name)
370
+ _require_str(packing, "prompt_format", name, section="packing")
371
+ conversation = _require_mapping(packing, "conversation", name, section="packing")
372
+ _require_sequence_of_str(conversation, "user_role_aliases", name, section="packing.conversation")
373
+ _require_sequence_of_str(conversation, "model_role_aliases", name, section="packing.conversation")
374
+ _require_sequence_of_str(conversation, "strip_image_from_roles", name, section="packing.conversation")
375
+ _require_bool(conversation, "merge_system_with_first_user", name, section="packing.conversation")
376
+ ecg_tokens = _require_mapping(packing, "ecg_tokens", name, section="packing")
377
+ _require_str(ecg_tokens, "global_start", name, section="packing.ecg_tokens")
378
+ _require_str(ecg_tokens, "global_end", name, section="packing.ecg_tokens")
379
+ _require_str(ecg_tokens, "lead_start_template", name, section="packing.ecg_tokens")
380
+ _require_str(ecg_tokens, "lead_end_template", name, section="packing.ecg_tokens")
381
+ if len(_require_sequence_of_str(ecg_tokens, "canonical_leads", name, section="packing.ecg_tokens")) == 0:
382
+ raise ModelRegistryError(
383
+ f"Model '{name}' packing.ecg_tokens.canonical_leads must contain at least one lead."
384
+ )
385
+
386
+
387
+ def _require_mapping(
388
+ parent: Mapping[str, Any],
389
+ key: str,
390
+ model_name: str,
391
+ *,
392
+ section: Optional[str] = None,
393
+ ) -> Mapping[str, Any]:
394
+ value = parent.get(key)
395
+ if not isinstance(value, ABCMapping):
396
+ loc = f"{section}.{key}" if section else key
397
+ raise ModelRegistryError(f"Model '{model_name}' is missing mapping '{loc}'.")
398
+ return value
399
+
400
+
401
+ def _require_str(
402
+ parent: Mapping[str, Any],
403
+ key: str,
404
+ model_name: str,
405
+ *,
406
+ section: Optional[str] = None,
407
+ ) -> str:
408
+ value = parent.get(key)
409
+ if not isinstance(value, str):
410
+ loc = f"{section}.{key}" if section else key
411
+ raise ModelRegistryError(f"Model '{model_name}' field '{loc}' must be a string.")
412
+ return value
413
+
414
+
415
+ def _require_bool(
416
+ parent: Mapping[str, Any],
417
+ key: str,
418
+ model_name: str,
419
+ *,
420
+ section: Optional[str] = None,
421
+ ) -> bool:
422
+ value = parent.get(key)
423
+ if not isinstance(value, bool):
424
+ loc = f"{section}.{key}" if section else key
425
+ raise ModelRegistryError(f"Model '{model_name}' field '{loc}' must be a boolean.")
426
+ return value
427
+
428
+
429
+ def _require_sequence_of_str(
430
+ parent: Mapping[str, Any],
431
+ key: str,
432
+ model_name: str,
433
+ *,
434
+ section: Optional[str] = None,
435
+ ) -> Tuple[str, ...]:
436
+ value = parent.get(key)
437
+ if not isinstance(value, ABCSequence) or isinstance(value, (str, bytes)):
438
+ loc = f"{section}.{key}" if section else key
439
+ raise ModelRegistryError(f"Model '{model_name}' field '{loc}' must be a sequence of strings.")
440
+ items = []
441
+ for item in value:
442
+ if not isinstance(item, str):
443
+ loc = f"{section}.{key}" if section else key
444
+ raise ModelRegistryError(f"Model '{model_name}' field '{loc}' must contain only strings.")
445
+ items.append(item)
446
+ return tuple(items)
447
+
448
+
449
+ def _deep_update(target: Dict[str, Any], updates: Mapping[str, Any]) -> None:
450
+ """
451
+ Recursively merge `updates` into `target` in-place.
452
+ """
453
+ for key, value in updates.items():
454
+ if isinstance(value, ABCMapping) and isinstance(target.get(key), dict):
455
+ _deep_update(target[key], value) # type: ignore[arg-type]
456
+ else:
457
+ target[key] = copy.deepcopy(value)
458
+
459
+
460
+ def _optional_bool(
461
+ parent: Mapping[str, Any],
462
+ key: str,
463
+ default: bool,
464
+ model_name: str,
465
+ *,
466
+ section: Optional[str] = None,
467
+ ) -> bool:
468
+ if key not in parent:
469
+ return default
470
+ value = parent.get(key)
471
+ if isinstance(value, bool):
472
+ return value
473
+ if isinstance(value, int) and value in (0, 1):
474
+ return bool(value)
475
+ if isinstance(value, str):
476
+ normalized = value.strip().lower()
477
+ if normalized in {"true", "yes", "y", "1"}:
478
+ return True
479
+ if normalized in {"false", "no", "n", "0"}:
480
+ return False
481
+ loc = f"{section}.{key}" if section else key
482
+ raise ModelRegistryError(f"Model '{model_name}' field '{loc}' must be a boolean when provided.")
483
+
484
+
485
+ def _deep_freeze(obj: Any) -> Any:
486
+ """
487
+ Recursively convert mutable containers to immutable/read-only equivalents.
488
+ """
489
+ if isinstance(obj, dict):
490
+ return MappingProxyType({k: _deep_freeze(v) for k, v in obj.items()})
491
+ if isinstance(obj, list):
492
+ return tuple(_deep_freeze(v) for v in obj)
493
+ if isinstance(obj, tuple):
494
+ return tuple(_deep_freeze(v) for v in obj)
495
+ if isinstance(obj, set):
496
+ return frozenset(_deep_freeze(v) for v in obj)
497
+ return obj
camel_inference/src/camel/model_registry.yaml ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Registry
2
+ #
3
+ # This file defines per-backbone configuration consumed by training and inference.
4
+ # Fields overview:
5
+ # - hf_id: Hugging Face model identifier for the frozen language model.
6
+ # - prompt: Chat formatting tokens and roles used to build prompts
7
+ # - start_of_turn / end_of_turn: literal strings delimiting speaker turns
8
+ # - roles.user / roles.model: canonical role names for user/model
9
+ # - enforce_bos / enforce_eos: whether BOS/EOS must appear at the start/end
10
+ # - tokenizer:
11
+ # - pad_token_strategy: how to set pad token ("eos" uses eos as pad)
12
+ # - require_bos / require_eos: tokenizer must expose these tokens
13
+ # - architecture:
14
+ # - wrapper_class: import path to the ECG wrapper class (module.Class)
15
+ # - hidden_size_attrs: ordered config attribute paths to read hidden size
16
+ # - language_model_path_hints: ordered attribute paths to locate transformer layers
17
+ # - attn_implementation: preferred attention backend to use when loading
18
+ # - lora_policy:
19
+ # - expect_language_only: LoRA must live under language/text model stacks
20
+ # - allowed_markers / blocked_markers: substrings used to validate LoRA placement
21
+ # - freeze_vision: freeze LoRA params under blocked stacks during training
22
+ # - packing:
23
+ # - prompt_format: prompt/conversation template id
24
+ # - conversation: conversation role normalization and preprocessing
25
+ # - ecg_tokens: special tokens inserted to mark ECG structure
26
+ models:
27
+ medgemma-27b-it:
28
+ hf_id: "google/medgemma-27b-text-it"
29
+ prompt:
30
+ start_of_turn: "<start_of_turn>"
31
+ end_of_turn: "<end_of_turn>\n"
32
+ roles:
33
+ user: "user"
34
+ model: "model"
35
+ enforce_bos: true
36
+ enforce_eos: true
37
+ tokenizer:
38
+ pad_token_strategy: "eos"
39
+ require_bos: true
40
+ require_eos: true
41
+ use_fast: true
42
+ add_prefix_space: false
43
+ architecture:
44
+ wrapper_class: "ecg_gemma_model.ECGGemmaPrefix"
45
+ hidden_size_attrs:
46
+ - "config.hidden_size"
47
+ - "config.text_config.hidden_size"
48
+ language_model_path_hints:
49
+ - "base_model.model.language_model.layers"
50
+ - "model.language_model.layers"
51
+ - "model.layers"
52
+ attn_implementation: "eager"
53
+ conv_out_dim: 64
54
+ lora_policy:
55
+ expect_language_only: true
56
+ allowed_markers:
57
+ - "language_model"
58
+ - ".model.layers."
59
+ - ".layers."
60
+ blocked_markers:
61
+ - "vision"
62
+ - "multi_modal"
63
+ - "projector"
64
+ - ".enc."
65
+ - "encoder_proj"
66
+ freeze_vision: true
67
+ packing:
68
+ prompt_format: "gemma_chat_v1"
69
+ conversation:
70
+ user_role_aliases: ["human", "user"]
71
+ model_role_aliases: ["gpt", "assistant"]
72
+ strip_image_from_roles: ["human"]
73
+ merge_system_with_first_user: true
74
+ ecg_tokens:
75
+ global_start: "<ecg_global_start>"
76
+ global_end: "<ecg_global_end>"
77
+ lead_start_template: "<ecg_lead_{lead_lower}_start>"
78
+ lead_end_template: "<ecg_lead_{lead_lower}_end>"
79
+ canonical_leads: ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"]
80
+ gemma-12b-it:
81
+ hf_id: "google/gemma-3-12b-it"
82
+ prompt:
83
+ start_of_turn: "<start_of_turn>"
84
+ end_of_turn: "<end_of_turn>\n"
85
+ roles:
86
+ user: "user"
87
+ model: "model"
88
+ enforce_bos: true
89
+ enforce_eos: true
90
+ tokenizer:
91
+ pad_token_strategy: "eos"
92
+ require_bos: true
93
+ require_eos: true
94
+ use_fast: true
95
+ add_prefix_space: false
96
+ architecture:
97
+ wrapper_class: "ecg_gemma_model.ECGGemmaPrefix"
98
+ hidden_size_attrs:
99
+ - "config.hidden_size"
100
+ - "config.text_config.hidden_size"
101
+ language_model_path_hints:
102
+ - "base_model.model.language_model.layers"
103
+ - "model.language_model.layers"
104
+ - "model.layers"
105
+ attn_implementation: "eager"
106
+ conv_out_dim: 64
107
+ lora_policy:
108
+ expect_language_only: true
109
+ allowed_markers:
110
+ - "language_model"
111
+ - ".model.layers."
112
+ - ".layers."
113
+ blocked_markers:
114
+ - "vision"
115
+ - "multi_modal"
116
+ - "projector"
117
+ - ".enc."
118
+ - "encoder_proj"
119
+ freeze_vision: true
120
+ packing:
121
+ prompt_format: "gemma_chat_v1"
122
+ conversation:
123
+ user_role_aliases: ["human", "user"]
124
+ model_role_aliases: ["gpt", "assistant"]
125
+ strip_image_from_roles: ["human"]
126
+ merge_system_with_first_user: true
127
+ ecg_tokens:
128
+ global_start: "<ecg_global_start>"
129
+ global_end: "<ecg_global_end>"
130
+ lead_start_template: "<ecg_lead_{lead_lower}_start>"
131
+ lead_end_template: "<ecg_lead_{lead_lower}_end>"
132
+ canonical_leads: ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"]
133
+ medgemma-4b-it:
134
+ hf_id: "google/medgemma-4b-it"
135
+ prompt:
136
+ start_of_turn: "<start_of_turn>"
137
+ end_of_turn: "<end_of_turn>\n"
138
+ roles:
139
+ user: "user"
140
+ model: "model"
141
+ enforce_bos: true
142
+ enforce_eos: true
143
+ tokenizer:
144
+ pad_token_strategy: "eos"
145
+ require_bos: true
146
+ require_eos: true
147
+ use_fast: true
148
+ add_prefix_space: false
149
+ architecture:
150
+ wrapper_class: "ecg_gemma_model.ECGGemmaPrefix"
151
+ hidden_size_attrs:
152
+ - "config.hidden_size"
153
+ - "config.text_config.hidden_size"
154
+ language_model_path_hints:
155
+ - "base_model.model.language_model.layers"
156
+ - "model.language_model.layers"
157
+ - "model.layers"
158
+ attn_implementation: "eager"
159
+ conv_out_dim: 64
160
+ lora_policy:
161
+ expect_language_only: true
162
+ allowed_markers:
163
+ - "language_model"
164
+ - ".model.layers."
165
+ - ".layers."
166
+ blocked_markers:
167
+ - "vision"
168
+ - "multi_modal"
169
+ - "projector"
170
+ - ".enc."
171
+ - "encoder_proj"
172
+ freeze_vision: true
173
+ packing:
174
+ prompt_format: "gemma_chat_v1"
175
+ conversation:
176
+ user_role_aliases: ["human", "user"]
177
+ model_role_aliases: ["gpt", "assistant"]
178
+ strip_image_from_roles: ["human"]
179
+ merge_system_with_first_user: true
180
+ ecg_tokens:
181
+ global_start: "<ecg_global_start>"
182
+ global_end: "<ecg_global_end>"
183
+ lead_start_template: "<ecg_lead_{lead_lower}_start>"
184
+ lead_end_template: "<ecg_lead_{lead_lower}_end>"
185
+ canonical_leads: ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"]
186
+ qwen3-4b-instruct:
187
+ hf_id: "Qwen/Qwen3-4B-Instruct-2507"
188
+ prompt:
189
+ start_of_turn: "<|im_start|>"
190
+ end_of_turn: "<|im_end|>"
191
+ roles:
192
+ user: "user"
193
+ model: "assistant"
194
+ enforce_bos: false
195
+ enforce_eos: false
196
+ allow_multiple_eos: true
197
+ tokenizer:
198
+ pad_token_strategy: "existing"
199
+ require_bos: false
200
+ require_eos: false
201
+ use_fast: true
202
+ add_prefix_space: false
203
+ architecture:
204
+ wrapper_class: "ecg_qwen_model.ECGQwenPrefix"
205
+ hidden_size_attrs:
206
+ - "config.hidden_size"
207
+ language_model_path_hints:
208
+ - "model.layers"
209
+ - "model.model.layers"
210
+ attn_implementation: "eager"
211
+ conv_out_dim: 64
212
+ lora_policy:
213
+ expect_language_only: true
214
+ allowed_markers:
215
+ - "model.layers."
216
+ - ".layers."
217
+ blocked_markers:
218
+ - "vision"
219
+ - "multi_modal"
220
+ - "projector"
221
+ - ".enc."
222
+ - "encoder_proj"
223
+ freeze_vision: false
224
+ packing:
225
+ prompt_format: "qwen_chat_v1"
226
+ conversation:
227
+ user_role_aliases: ["human", "user"]
228
+ model_role_aliases: ["gpt", "assistant"]
229
+ strip_image_from_roles: ["human"]
230
+ merge_system_with_first_user: true
231
+ ecg_tokens:
232
+ global_start: "<ecg_global_start>"
233
+ global_end: "<ecg_global_end>"
234
+ lead_start_template: "<ecg_lead_{lead_lower}_start>"
235
+ lead_end_template: "<ecg_lead_{lead_lower}_end>"
236
+ canonical_leads: ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"]
237
+ "qwen3-4b-instruct":
238
+ hf_id: "Qwen/Qwen3-4B"
239
+ prompt:
240
+ start_of_turn: "<|im_start|>"
241
+ end_of_turn: "<|im_end|>"
242
+ roles:
243
+ user: "user"
244
+ model: "assistant"
245
+ enforce_bos: false
246
+ enforce_eos: false
247
+ allow_multiple_eos: true
248
+ tokenizer:
249
+ pad_token_strategy: "existing"
250
+ require_bos: false
251
+ require_eos: false
252
+ use_fast: true
253
+ add_prefix_space: false
254
+ architecture:
255
+ wrapper_class: "ecg_qwen_model.ECGQwenPrefix"
256
+ hidden_size_attrs:
257
+ - "config.hidden_size"
258
+ language_model_path_hints:
259
+ - "model.layers"
260
+ - "model.model.layers"
261
+ attn_implementation: "eager"
262
+ conv_out_dim: 64
263
+ lora_policy:
264
+ expect_language_only: true
265
+ allowed_markers:
266
+ - "model.layers."
267
+ - ".layers."
268
+ blocked_markers:
269
+ - "vision"
270
+ - "multi_modal"
271
+ - "projector"
272
+ - ".enc."
273
+ - "encoder_proj"
274
+ freeze_vision: false
275
+ packing:
276
+ prompt_format: "qwen_chat_v1"
277
+ conversation:
278
+ user_role_aliases: ["human", "user"]
279
+ model_role_aliases: ["gpt", "assistant"]
280
+ strip_image_from_roles: ["human"]
281
+ merge_system_with_first_user: true
282
+ ecg_tokens:
283
+ global_start: "<ecg_global_start>"
284
+ global_end: "<ecg_global_end>"
285
+ lead_start_template: "<ecg_lead_{lead_lower}_start>"
286
+ lead_end_template: "<ecg_lead_{lead_lower}_end>"
287
+ canonical_leads: ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"]
288
+ gpt-oss-120b:
289
+ hf_id: "openai/gpt-oss-120b"
290
+ prompt:
291
+ start_of_turn: "<|start|>"
292
+ end_of_turn: "<|end|>"
293
+ roles:
294
+ user: "user"
295
+ model: "assistant"
296
+ enforce_bos: false
297
+ enforce_eos: false
298
+ allow_multiple_eos: true
299
+ system_prompt: |-
300
+ You are ChatGPT, a large language model trained by OpenAI.
301
+ Knowledge cutoff: 2024-06
302
+ Current date: 2025-06-28
303
+
304
+ Reasoning: low
305
+
306
+ # Valid channels: final. Channel must be included for every message.
307
+ developer_prompt: |-
308
+ # Instructions
309
+
310
+ You are trained to interpret electrocardiograms (ECGs) and must answer questions about them clearly and accurately.
311
+ tokenizer:
312
+ pad_token_strategy: "existing"
313
+ require_bos: false
314
+ require_eos: false
315
+ use_fast: true
316
+ add_prefix_space: false
317
+ architecture:
318
+ wrapper_class: "ecg_gptoss_model.ECGGptOssPrefix"
319
+ hidden_size_attrs:
320
+ - "config.hidden_size"
321
+ language_model_path_hints:
322
+ - "model.layers"
323
+ - "model.model.layers"
324
+ attn_implementation: "eager"
325
+ conv_out_dim: 64
326
+ lora_policy:
327
+ expect_language_only: true
328
+ allowed_markers:
329
+ - "model.layers."
330
+ - ".layers."
331
+ blocked_markers:
332
+ - "vision"
333
+ - "multi_modal"
334
+ - "projector"
335
+ - ".enc."
336
+ - "encoder_proj"
337
+ freeze_vision: false
338
+ packing:
339
+ prompt_format: "harmony_chat_v1"
340
+ conversation:
341
+ user_role_aliases: ["human", "user"]
342
+ model_role_aliases: ["assistant", "gpt"]
343
+ strip_image_from_roles: ["human"]
344
+ merge_system_with_first_user: false
345
+ ecg_tokens:
346
+ global_start: "<ecg_global_start>"
347
+ global_end: "<ecg_global_end>"
348
+ lead_start_template: "<ecg_lead_{lead_lower}_start>"
349
+ lead_end_template: "<ecg_lead_{lead_lower}_end>"
350
+ canonical_leads: ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"]
351
+ gemma-1b-it:
352
+ hf_id: "google/gemma-3-1b-it"
353
+ prompt:
354
+ start_of_turn: "<start_of_turn>"
355
+ end_of_turn: "<end_of_turn>\n"
356
+ roles:
357
+ user: "user"
358
+ model: "model"
359
+ enforce_bos: true
360
+ enforce_eos: true
361
+ tokenizer:
362
+ pad_token_strategy: "eos"
363
+ require_bos: true
364
+ require_eos: true
365
+ use_fast: true
366
+ add_prefix_space: false
367
+ architecture:
368
+ wrapper_class: "ecg_gemma_model.ECGGemmaPrefix"
369
+ hidden_size_attrs:
370
+ - "config.hidden_size"
371
+ - "config.text_config.hidden_size"
372
+ language_model_path_hints:
373
+ - "base_model.model.language_model.layers"
374
+ - "model.language_model.layers"
375
+ - "model.layers"
376
+ attn_implementation: "eager"
377
+ conv_out_dim: 64
378
+ lora_policy:
379
+ expect_language_only: true
380
+ allowed_markers:
381
+ - "language_model"
382
+ - ".model.layers."
383
+ - ".layers."
384
+ blocked_markers:
385
+ - "vision"
386
+ - "multi_modal"
387
+ - "projector"
388
+ - ".enc."
389
+ - "encoder_proj"
390
+ freeze_vision: true
391
+ packing:
392
+ prompt_format: "gemma_chat_v1"
393
+ conversation:
394
+ user_role_aliases: ["human", "user"]
395
+ model_role_aliases: ["gpt", "assistant"]
396
+ strip_image_from_roles: ["human"]
397
+ merge_system_with_first_user: true
398
+ ecg_tokens:
399
+ global_start: "<ecg_global_start>"
400
+ global_end: "<ecg_global_end>"
401
+ lead_start_template: "<ecg_lead_{lead_lower}_start>"
402
+ lead_end_template: "<ecg_lead_{lead_lower}_end>"
403
+ canonical_leads: ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"]
404
+ gemma-270m-it:
405
+ hf_id: "google/gemma-3-270m-it"
406
+ prompt:
407
+ start_of_turn: "<start_of_turn>"
408
+ end_of_turn: "<end_of_turn>\n"
409
+ roles:
410
+ user: "user"
411
+ model: "model"
412
+ enforce_bos: true
413
+ enforce_eos: true
414
+ tokenizer:
415
+ pad_token_strategy: "eos"
416
+ require_bos: true
417
+ require_eos: true
418
+ use_fast: true
419
+ add_prefix_space: false
420
+ architecture:
421
+ wrapper_class: "ecg_gemma_model.ECGGemmaPrefix"
422
+ hidden_size_attrs:
423
+ - "config.hidden_size"
424
+ - "config.text_config.hidden_size"
425
+ language_model_path_hints:
426
+ - "base_model.model.language_model.layers"
427
+ - "model.language_model.layers"
428
+ - "model.layers"
429
+ attn_implementation: "eager"
430
+ conv_out_dim: 64
431
+ lora_policy:
432
+ expect_language_only: true
433
+ allowed_markers:
434
+ - "language_model"
435
+ - ".model.layers."
436
+ - ".layers."
437
+ blocked_markers:
438
+ - "vision"
439
+ - "multi_modal"
440
+ - "projector"
441
+ - ".enc."
442
+ - "encoder_proj"
443
+ freeze_vision: true
444
+ packing:
445
+ prompt_format: "gemma_chat_v1"
446
+ conversation:
447
+ user_role_aliases: ["human", "user"]
448
+ model_role_aliases: ["gpt", "assistant"]
449
+ strip_image_from_roles: ["human"]
450
+ merge_system_with_first_user: true
451
+ ecg_tokens:
452
+ global_start: "<ecg_global_start>"
453
+ global_end: "<ecg_global_end>"
454
+ lead_start_template: "<ecg_lead_{lead_lower}_start>"
455
+ lead_end_template: "<ecg_lead_{lead_lower}_end>"
456
+ canonical_leads: ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"]
457
+ medgemma-27b-it:
458
+ hf_id: "google/medgemma-27b-text-it"
459
+ prompt:
460
+ start_of_turn: "<start_of_turn>"
461
+ end_of_turn: "<end_of_turn>\n"
462
+ roles:
463
+ user: "user"
464
+ model: "model"
465
+ enforce_bos: true
466
+ enforce_eos: true
467
+ tokenizer:
468
+ pad_token_strategy: "eos"
469
+ require_bos: true
470
+ require_eos: true
471
+ use_fast: true
472
+ add_prefix_space: false
473
+ architecture:
474
+ wrapper_class: "ecg_gemma_model.ECGGemmaPrefix"
475
+ hidden_size_attrs:
476
+ - "config.hidden_size"
477
+ - "config.text_config.hidden_size"
478
+ language_model_path_hints:
479
+ - "base_model.model.language_model.layers"
480
+ - "model.language_model.layers"
481
+ - "model.layers"
482
+ attn_implementation: "eager"
483
+ conv_out_dim: 64
484
+ lora_policy:
485
+ expect_language_only: true
486
+ allowed_markers:
487
+ - "language_model"
488
+ - ".model.layers."
489
+ - ".layers."
490
+ blocked_markers:
491
+ - "vision"
492
+ - "multi_modal"
493
+ - "projector"
494
+ - ".enc."
495
+ - "encoder_proj"
496
+ freeze_vision: true
497
+ packing:
498
+ prompt_format: "gemma_chat_v1"
499
+ conversation:
500
+ user_role_aliases: ["human", "user"]
501
+ model_role_aliases: ["gpt", "assistant"]
502
+ strip_image_from_roles: ["human"]
503
+ merge_system_with_first_user: true
504
+ ecg_tokens:
505
+ global_start: "<ecg_global_start>"
506
+ global_end: "<ecg_global_end>"
507
+ lead_start_template: "<ecg_lead_{lead_lower}_start>"
508
+ lead_end_template: "<ecg_lead_{lead_lower}_end>"
509
+ canonical_leads: ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"]
camel_inference/src/camel/process_ecg.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List, Dict, Any
2
+ import numpy as np
3
+ import torch
4
+
5
+ from read_ecg import load_record
6
+
7
+ _LEAD_SYNONYMS: Dict[str, str] = {
8
+ # Limb
9
+ "I": "I", "II": "II", "III": "III",
10
+ "DI": "I", "DII": "II", "DIII": "III",
11
+ "MLII": "II",
12
+ # Augmented
13
+ "AVR": "aVR", "AVL": "aVL", "AVF": "aVF",
14
+ # Precordial
15
+ "V1": "V1", "V2": "V2", "V3": "V3", "V4": "V4", "V5": "V5", "V6": "V6",
16
+ # Dataset-specific
17
+ "ECG": "I", # Apnea-ECG
18
+ "ECG1": "I", "ECG2": "II", # AFDB
19
+ "CM5": "V5", "D3": "V3", "D4": "V4",
20
+ "CM2": "V2", "ML5": "V5",
21
+ "VF": "VF",
22
+ }
23
+
24
+ class NormOp:
25
+ def __init__(self, name: str, params: Dict[str, float] | None = None):
26
+ self.name = name
27
+ self.params = params or {}
28
+
29
+ def to_canonical_lead(name: str) -> Optional[str]:
30
+ if not isinstance(name, str):
31
+ return None
32
+ s = name.strip().upper()
33
+ if s in _LEAD_SYNONYMS:
34
+ return _LEAD_SYNONYMS[s]
35
+ if s in ("A VR", "A VL", "A VF"):
36
+ return "a" + s.replace(" ", "")[1:]
37
+ return None
38
+
39
+ def parse_pipeline(spec: Optional[str]) -> List[NormOp]:
40
+ if not spec:
41
+ return [NormOp("nonfinite_to_zero"), NormOp("clip", {})]
42
+ ops: List[NormOp] = []
43
+ for token in spec.split(','):
44
+ token = token.strip()
45
+ if not token:
46
+ continue
47
+ parts = token.split(':')
48
+ name = parts[0].lower()
49
+ if name == 'clip':
50
+ mn = None; mx = None
51
+ if len(parts) >= 2 and parts[1] != '':
52
+ mn = float(parts[1])
53
+ if len(parts) >= 3 and parts[2] != '':
54
+ mx = float(parts[2])
55
+ ops.append(NormOp('clip', {'min': mn, 'max': mx}))
56
+ elif name == 'nonfinite_to_zero':
57
+ ops.append(NormOp('nonfinite_to_zero'))
58
+ else:
59
+ raise ValueError(f"Unknown op '{name}'.")
60
+ return ops
61
+
62
+ def apply_pre_ops(signal_1d: np.ndarray, ops: List[NormOp]) -> np.ndarray:
63
+ x = signal_1d
64
+ for op in ops:
65
+ if op.name == 'nonfinite_to_zero':
66
+ np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0, copy=False)
67
+ return x
68
+
69
+
70
+ def apply_post_ops(segments: np.ndarray, ops: List[NormOp], lead_name: str,
71
+ clip_stats: Optional[Dict[str, Dict[str, float]]] = None) -> np.ndarray:
72
+ x = segments
73
+ for op in ops:
74
+ if op.name == 'clip':
75
+ mn = op.params.get('min', None)
76
+ mx = op.params.get('max', None)
77
+ if mn is None and mx is None and clip_stats is not None and lead_name in clip_stats:
78
+ mn = clip_stats[lead_name].get('clip_min', None)
79
+ mx = clip_stats[lead_name].get('clip_max', None)
80
+ if mn is None:
81
+ mn = -np.inf
82
+ if mx is None:
83
+ mx = np.inf
84
+ if mn > mx:
85
+ mn, mx = mx, mn
86
+ np.clip(x, mn, mx, out=x)
87
+ return x.astype(np.float32, copy=False)
88
+
89
+ def _segment_data(ecg_signal: np.ndarray, raw_fs: int) -> np.ndarray:
90
+ """
91
+ Segment 1D signal sampled at ``raw_fs`` into 1-second clips and resample to 256 samples.
92
+ Returns np.float32 array [N_segments, 256].
93
+ """
94
+ assert raw_fs > 0, f"raw sampling rate must be positive (got {raw_fs})"
95
+ samples_per_second = int(raw_fs)
96
+ n_samples = int(len(ecg_signal))
97
+ if samples_per_second <= 0 or n_samples <= 0:
98
+ return np.empty((0, 256), dtype=np.float32)
99
+
100
+ # Pre-allocate upper bound
101
+ n_full = n_samples // samples_per_second
102
+ has_partial = (n_samples % samples_per_second) > 0
103
+ max_segments = n_full + (1 if has_partial else 0)
104
+ if max_segments == 0:
105
+ return np.empty((0, 256), dtype=np.float32)
106
+
107
+ out = np.empty((max_segments, 256), dtype=np.float32)
108
+ new_idx = np.linspace(0, 1, num=256, dtype=np.float32)
109
+ actual = 0
110
+ for start in range(0, n_samples, samples_per_second):
111
+ end = min(start + samples_per_second, n_samples)
112
+ seg = ecg_signal[start:end]
113
+ if seg.shape[0] < samples_per_second * 0.5:
114
+ continue
115
+ old_idx = np.linspace(0, 1, num=seg.shape[0], dtype=np.float32)
116
+ out[actual] = np.interp(new_idx, old_idx, seg).astype(np.float32, copy=False)
117
+ actual += 1
118
+ return out[:actual]
119
+
120
+ def _apply_filters(signal_1d, fs: int):
121
+ """
122
+ Apply ECG signal filters: 50/60 Hz notch filters and 0.3 Hz high-pass filter.
123
+
124
+ Removes powerline interference and baseline wander from ECG signals.
125
+ Uses cascaded second-order sections (SOS) for numerical stability.
126
+ """
127
+ import numpy as np
128
+ from scipy.signal import iirnotch, butter, sosfiltfilt, tf2sos
129
+
130
+ x = np.asarray(signal_1d, dtype=np.float64)
131
+ if x.size == 0 or fs <= 0:
132
+ return x.astype('float32')
133
+
134
+ # Design notch filters for powerline interference
135
+ Q = 30.0
136
+ nyq = fs / 2.0
137
+ sos_filters = []
138
+ for freq in (50.0, 60.0):
139
+ if freq >= nyq:
140
+ continue
141
+ b, a = iirnotch(freq, Q, fs)
142
+ sos_filters.append(tf2sos(b, a))
143
+
144
+ # Add high-pass filter for baseline wander removal
145
+ sos_hp = butter(N=2, Wn=0.3 / nyq, btype='highpass', output='sos')
146
+ sos_filters.append(sos_hp)
147
+
148
+ # Apply all filters in cascade
149
+ sos = np.vstack(sos_filters)
150
+ x = sosfiltfilt(sos, x)
151
+ return x.astype('float32')
152
+
153
+ def get_waveform(device:torch.device, ecg_path:str, start_sec = None, end_sec = None, leads: Optional[List[str]] = None,
154
+ process:bool = False, norm: str = "nonfinite_to_zero,clip") -> Dict[str, Any]:
155
+ """
156
+ Run ECG preprocessing pipeline (Step II): raw → filter → segment → normalize → storage.
157
+
158
+ Processes ECG records from manifest, applies signal filtering, segments into 1s windows,
159
+ normalizes using computed clip stats, and writes to efficient storage format.
160
+
161
+ Args:
162
+ dataset: Dataset name
163
+ p: file path
164
+ clip_stats_path: Path to clip_stats json file
165
+ process: Y/N process
166
+ fs: Target sampling frequency
167
+ leads: Optional list of leads to process
168
+ norm: Normalization pipeline (e.g., "nonfinite_to_zero,clip")
169
+
170
+ Returns:
171
+ Dict with output paths: {output_dir, index, clip_stats}
172
+ """
173
+
174
+ # Parallel path: use per-worker shard/npy directories then aggregate
175
+ ops = parse_pipeline(norm)
176
+
177
+ # Convert to dict format
178
+ ecg_dict = {}
179
+ try:
180
+ df, sig_names, original_fs = load_record(ecg_path, start_sec, end_sec, leads)
181
+
182
+ # Process each lead in the record
183
+ for i, raw_name in enumerate(sig_names):
184
+ canon = to_canonical_lead(raw_name)
185
+ if not canon:
186
+ print('to_canonical_lead')
187
+ continue
188
+ x = df[:, i].astype('float32', copy=False)
189
+
190
+ if process:
191
+ x = apply_pre_ops(x, ops) # Handle non-finite values
192
+ x = _apply_filters(x, original_fs) # Remove noise and baseline wander
193
+
194
+ # Segment into 1s windows and normalize
195
+ segs = _segment_data(x, original_fs)
196
+ segs = apply_post_ops(segs, ops, canon)
197
+
198
+ lead_tensor = torch.from_numpy(segs).to(torch.float32).to(device)
199
+ if not (torch.any(lead_tensor.isnan())):
200
+ lead_tensor = lead_tensor.nan_to_num()
201
+
202
+ ecg_dict[canon] = lead_tensor
203
+
204
+ except Exception as e:
205
+ # Skip failed records
206
+ print(e)
207
+
208
+ return ecg_dict
camel_inference/src/camel/projectors.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Projector registry utilities.
2
+
3
+ Provides a lightweight mechanism to swap the adapter architecture that maps the
4
+ conv encoder output to the language-model hidden size. Mirrors the ergonomic
5
+ API used by loss.py: a registry, default implementations, and a simple factory.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ from typing import Callable, Dict, Iterable
10
+ import torch.nn as nn
11
+
12
+ ProjectorBuilder = Callable[[int, int], nn.Module]
13
+
14
+ _PROJECTOR_REGISTRY: Dict[str, ProjectorBuilder] = {}
15
+
16
+ def register_projector(name: str) -> Callable[[ProjectorBuilder], ProjectorBuilder]:
17
+ """Decorator to register a projector builder under a unique name."""
18
+ key = name.strip().lower()
19
+
20
+ def _decorator(fn: ProjectorBuilder) -> ProjectorBuilder:
21
+ if not callable(fn):
22
+ raise TypeError("Projector builder must be callable.")
23
+ if key in _PROJECTOR_REGISTRY:
24
+ raise ValueError(f"Projector '{name}' is already registered.")
25
+ _PROJECTOR_REGISTRY[key] = fn
26
+ return fn
27
+
28
+ return _decorator
29
+
30
+ @register_projector("linear")
31
+ def _linear_projector(in_dim: int, out_dim: int) -> nn.Module:
32
+ """Single linear adapter (current default)."""
33
+ return nn.Linear(in_dim, out_dim, bias=True)
34
+
35
+ def available_projectors() -> Iterable[str]:
36
+ """Return sorted projector names."""
37
+ return sorted(_PROJECTOR_REGISTRY.keys())
38
+
39
+ def build_projector(name: str, in_dim: int, out_dim: int) -> nn.Module:
40
+ """Instantiate a registered projector."""
41
+ if not _PROJECTOR_REGISTRY:
42
+ raise RuntimeError("No projectors registered.")
43
+ key = (name or "").strip().lower()
44
+ if not key:
45
+ raise ValueError("Projector name must be a non-empty string.")
46
+ builder = _PROJECTOR_REGISTRY.get(key)
47
+ if builder is None:
48
+ raise KeyError(
49
+ f"Unknown projector '{name}'. Available: {', '.join(available_projectors())}"
50
+ )
51
+ return builder(int(in_dim), int(out_dim))
52
+
53
+ __all__ = [
54
+ "ProjectorBuilder",
55
+ "available_projectors",
56
+ "build_projector",
57
+ ]
camel_inference/src/camel/prompt_renderers.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Prompt rendering and span construction helpers."""
2
+ from __future__ import annotations
3
+
4
+ from typing import Any, Callable, Dict, List, Tuple
5
+ from transformers import PreTrainedTokenizer
6
+
7
+ from camel.assertions import (
8
+ assert_tokenization_cursor_matches,
9
+ assert_model_spans_valid,
10
+ assert_eos_appended,
11
+ )
12
+
13
+ def _ensure_trailing_newline(s: str) -> str:
14
+ if s.endswith("\n"):
15
+ return s
16
+ return s + "\n"
17
+
18
+ def _chat_v1_wrappers(schema, role: str) -> Tuple[str, str]:
19
+ prefix = f"{schema.prompt.start_of_turn}{role}\n"
20
+ suffix = _ensure_trailing_newline(schema.prompt.end_of_turn)
21
+ return prefix, suffix
22
+
23
+ def _harmony_v1_wrappers(schema, role: str, *, use_return: bool = False) -> Tuple[str, str]:
24
+ if role == schema.prompt.model_role:
25
+ prefix = f"{schema.prompt.start_of_turn}{role}"
26
+ suffix = "<|return|>" if use_return else str(schema.prompt.end_of_turn)
27
+ return prefix, suffix
28
+ prefix = f"{schema.prompt.start_of_turn}{role}<|message|>"
29
+ suffix = str(schema.prompt.end_of_turn)
30
+ return prefix, suffix
31
+
32
+ def _render_with_wrappers(
33
+ tokenizer: PreTrainedTokenizer,
34
+ turns: List[Dict[str, str]],
35
+ *,
36
+ schema,
37
+ wrapper_fn,
38
+ ) -> Dict[str, Any]:
39
+ tok = tokenizer
40
+ prompt_tokens = schema.prompt
41
+
42
+ text_ids: List[int] = []
43
+ if prompt_tokens.require_bos and tok.bos_token_id is not None:
44
+ text_ids.append(tok.bos_token_id)
45
+
46
+ model_spans_in_text: List[Tuple[int, int]] = []
47
+ cursor = len(text_ids)
48
+ text_preview_parts: List[str] = []
49
+
50
+ for turn in turns:
51
+ role = turn["role"]
52
+ text_block = turn["text_block"]
53
+ prefix, suffix = wrapper_fn(schema, role)
54
+ content = text_block
55
+ if suffix and content.endswith(suffix):
56
+ content = content[: -len(suffix)]
57
+ ids_prefix = tok.encode(prefix, add_special_tokens=False)
58
+ ids_content = tok.encode(content, add_special_tokens=False)
59
+ ids_suffix = tok.encode(suffix, add_special_tokens=False)
60
+
61
+ text_ids.extend(ids_prefix)
62
+ text_ids.extend(ids_content)
63
+ text_ids.extend(ids_suffix)
64
+
65
+ if role == prompt_tokens.model_role:
66
+ s = cursor + len(ids_prefix)
67
+ e = s + len(ids_content) + len(ids_suffix)
68
+ if e > s:
69
+ model_spans_in_text.append((s, e))
70
+ cursor += len(ids_prefix) + len(ids_content) + len(ids_suffix)
71
+ text_preview_parts.append(prefix + content + suffix)
72
+
73
+ assert_tokenization_cursor_matches(cursor, len(text_ids))
74
+
75
+ if prompt_tokens.require_eos and tok.eos_token_id is not None:
76
+ text_ids.append(tok.eos_token_id)
77
+ if model_spans_in_text and turns[-1]["role"] == prompt_tokens.model_role:
78
+ model_spans_in_text[-1] = (model_spans_in_text[-1][0], len(text_ids))
79
+
80
+ assert_eos_appended(text_ids, tok, prompt_tokens.require_eos)
81
+ assert_model_spans_valid(model_spans_in_text, len(text_ids))
82
+
83
+ return {
84
+ "text_ids": text_ids,
85
+ "model_spans_in_text": model_spans_in_text,
86
+ "text_preview": "".join(text_preview_parts),
87
+ }
88
+
89
+ def _render_chat_v1(
90
+ tokenizer: PreTrainedTokenizer,
91
+ turns: List[Dict[str, str]],
92
+ *,
93
+ schema,
94
+ ) -> Dict[str, Any]:
95
+ return _render_with_wrappers(
96
+ tokenizer,
97
+ turns,
98
+ schema=schema,
99
+ wrapper_fn=_chat_v1_wrappers,
100
+ )
101
+
102
+ def _render_harmony_v1(
103
+ tokenizer: PreTrainedTokenizer,
104
+ turns: List[Dict[str, str]],
105
+ *,
106
+ schema,
107
+ use_return_for_last_assistant: bool = False,
108
+ ) -> Dict[str, Any]:
109
+ tok = tokenizer
110
+ prompt_tokens = schema.prompt
111
+
112
+ text_ids: List[int] = []
113
+ if prompt_tokens.require_bos and tok.bos_token_id is not None:
114
+ text_ids.append(tok.bos_token_id)
115
+
116
+ model_spans_in_text: List[Tuple[int, int]] = []
117
+ cursor = len(text_ids)
118
+ text_preview_parts: List[str] = []
119
+
120
+ last_assistant_idx = None
121
+ if use_return_for_last_assistant:
122
+ for idx in range(len(turns) - 1, -1, -1):
123
+ if turns[idx]["role"] == prompt_tokens.model_role:
124
+ last_assistant_idx = idx
125
+ break
126
+
127
+ for idx, turn in enumerate(turns):
128
+ role = turn["role"]
129
+ text_block = turn["text_block"]
130
+ use_return = use_return_for_last_assistant and last_assistant_idx is not None and idx == last_assistant_idx
131
+ prefix, suffix = _harmony_v1_wrappers(schema, role, use_return=use_return)
132
+ content = text_block
133
+ if suffix and content.endswith(suffix):
134
+ content = content[: -len(suffix)]
135
+ ids_prefix = tok.encode(prefix, add_special_tokens=False)
136
+ ids_content = tok.encode(content, add_special_tokens=False)
137
+ ids_suffix = tok.encode(suffix, add_special_tokens=False)
138
+
139
+ text_ids.extend(ids_prefix)
140
+ text_ids.extend(ids_content)
141
+ text_ids.extend(ids_suffix)
142
+
143
+ if role == prompt_tokens.model_role:
144
+ s = cursor + len(ids_prefix)
145
+ e = s + len(ids_content) + len(ids_suffix)
146
+ if e > s:
147
+ model_spans_in_text.append((s, e))
148
+ cursor += len(ids_prefix) + len(ids_content) + len(ids_suffix)
149
+ text_preview_parts.append(prefix + content + suffix)
150
+
151
+ assert_tokenization_cursor_matches(cursor, len(text_ids))
152
+
153
+ if prompt_tokens.require_eos and tok.eos_token_id is not None:
154
+ text_ids.append(tok.eos_token_id)
155
+ if model_spans_in_text and turns[-1]["role"] == prompt_tokens.model_role:
156
+ model_spans_in_text[-1] = (model_spans_in_text[-1][0], len(text_ids))
157
+
158
+ assert_eos_appended(text_ids, tok, prompt_tokens.require_eos)
159
+ assert_model_spans_valid(model_spans_in_text, len(text_ids))
160
+
161
+ return {
162
+ "text_ids": text_ids,
163
+ "model_spans_in_text": model_spans_in_text,
164
+ "text_preview": "".join(text_preview_parts),
165
+ }
166
+
167
+ _PROMPT_RENDERERS: Dict[str, Callable[[PreTrainedTokenizer, List[Dict[str, str]], Any], Dict[str, Any]]] = {
168
+ "gemma_chat_v1": _render_chat_v1,
169
+ "qwen_chat_v1": _render_chat_v1,
170
+ }
171
+
172
+ def render_prompt_and_spans(
173
+ tokenizer: PreTrainedTokenizer,
174
+ turns: List[Dict[str, str]],
175
+ *,
176
+ schema,
177
+ use_return_for_last_assistant: bool = False,
178
+ ) -> Dict[str, Any]:
179
+ format_id = str(schema.conversation.format_id)
180
+ if format_id == "harmony_chat_v1":
181
+ return _render_harmony_v1(
182
+ tokenizer,
183
+ turns,
184
+ schema=schema,
185
+ use_return_for_last_assistant=use_return_for_last_assistant,
186
+ )
187
+ renderer = _PROMPT_RENDERERS.get(format_id)
188
+ if renderer is None:
189
+ raise ValueError(f"Unknown prompt format '{format_id}'.")
190
+ return renderer(tokenizer, turns, schema=schema)
191
+
192
+ def turn_wrappers(schema, role: str, *, use_return: bool = False) -> Tuple[str, str]:
193
+ format_id = str(schema.conversation.format_id)
194
+ if format_id in ("gemma_chat_v1", "qwen_chat_v1"):
195
+ return _chat_v1_wrappers(schema, role)
196
+ if format_id == "harmony_chat_v1":
197
+ return _harmony_v1_wrappers(schema, role, use_return=use_return)
198
+ raise ValueError(f"Unknown prompt format '{format_id}'.")
199
+
200
+ def assistant_generation_prefix(schema) -> str:
201
+ format_id = str(schema.conversation.format_id)
202
+ if format_id in ("gemma_chat_v1", "qwen_chat_v1"):
203
+ return f"{schema.prompt.start_of_turn}{schema.prompt.model_role}\n"
204
+ if format_id == "harmony_chat_v1":
205
+ return f"{schema.prompt.start_of_turn}{schema.prompt.model_role}"
206
+ raise ValueError(f"Unknown prompt format '{format_id}'.")
207
+
208
+
209
+ __all__ = ["render_prompt_and_spans", "turn_wrappers", "assistant_generation_prefix"]
camel_inference/src/camel/training_setup.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Runtime and configuration helpers extracted from train_ecg_text.py.
3
+ These utilities keep the training entrypoint concise while preserving the
4
+ original behaviour when preparing distributed state, tokenizer metadata, and
5
+ packing configuration.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ from typing import Dict, Optional, List
10
+ import torch.distributed as dist
11
+ from transformers import AutoTokenizer
12
+
13
+ from camel.ecg_text_packing import (
14
+ ECGSpecialTokenCatalog,
15
+ PackingSchema,
16
+ PromptTokens,
17
+ )
18
+ from camel.model_registry import ModelConfig, ModelRegistryError, load_registry
19
+
20
+ def is_main_process() -> bool:
21
+ """Return True for rank 0 (or standalone execution)."""
22
+ return (not dist.is_initialized()) or dist.get_rank() == 0
23
+
24
+ def build_packing_schema(pretrained_model_id: str) -> PackingSchema:
25
+ """
26
+ Construct the packing schema (prompt + conversation rules + ECG tokens)
27
+ for the given backbone using the shared registry.
28
+ """
29
+ registry = load_registry()
30
+ cfg: Optional[ModelConfig]
31
+ try:
32
+ cfg = registry.get(pretrained_model_id)
33
+ except ModelRegistryError:
34
+ cfg = None
35
+ for name in registry.names():
36
+ candidate = registry.get(name)
37
+ if candidate.hf_id == pretrained_model_id:
38
+ cfg = candidate
39
+ break
40
+ if cfg is None:
41
+ raise ModelRegistryError(
42
+ f"Pretrained model '{pretrained_model_id}' not found in registry at {registry.source_path}"
43
+ )
44
+ prompt_cfg = cfg.prompt_config()
45
+ roles_dict = dict(prompt_cfg.roles or {})
46
+ try:
47
+ user_role = str(roles_dict["user"])
48
+ model_role = str(roles_dict["model"])
49
+ except KeyError as exc:
50
+ missing = exc.args[0]
51
+ raise ModelRegistryError(
52
+ f"Prompt configuration for registry entry '{cfg.name}' is missing the '{missing}' role."
53
+ ) from exc
54
+ prompt_tokens = PromptTokens(
55
+ start_of_turn=prompt_cfg.start_of_turn,
56
+ end_of_turn=prompt_cfg.end_of_turn,
57
+ user_role=user_role,
58
+ model_role=model_role,
59
+ require_bos=prompt_cfg.enforce_bos,
60
+ require_eos=prompt_cfg.enforce_eos,
61
+ allow_multiple_eos=prompt_cfg.allow_multiple_eos,
62
+ )
63
+ packing_cfg = cfg.packing_config()
64
+ conversation_rules = packing_cfg.conversation
65
+ ecg_tokens = packing_cfg.ecg_tokens
66
+ return PackingSchema(
67
+ prompt=prompt_tokens,
68
+ conversation=conversation_rules,
69
+ ecg=ecg_tokens,
70
+ )
71
+
72
+ def initialize_tokenizer(
73
+ model_id: str,
74
+ *,
75
+ trust_remote_code: bool = True,
76
+ use_fast: Optional[bool] = None,
77
+ add_prefix_space: Optional[bool] = None,
78
+ ) -> AutoTokenizer:
79
+ """
80
+ Instantiate the HF tokenizer, allowing policy to be driven by the registry
81
+ (use_fast/add_prefix_space). If not provided, defaults are use_fast=True,
82
+ add_prefix_space=False.
83
+ """
84
+ # Honor registry defaults when the caller doesn't override them.
85
+ default_use_fast = True
86
+ default_add_prefix_space = False
87
+ try:
88
+ registry = load_registry()
89
+ cfg: Optional[ModelConfig]
90
+ try:
91
+ cfg = registry.get(model_id)
92
+ except ModelRegistryError:
93
+ cfg = None
94
+ for name in registry.names():
95
+ candidate = registry.get(name)
96
+ if candidate.hf_id == model_id:
97
+ cfg = candidate
98
+ break
99
+ if cfg is not None:
100
+ tcfg = cfg.tokenizer_config()
101
+ default_use_fast = bool(tcfg.use_fast)
102
+ default_add_prefix_space = bool(tcfg.add_prefix_space)
103
+ except Exception:
104
+ # Fall back to built-in defaults if registry is unavailable.
105
+ pass
106
+
107
+ return AutoTokenizer.from_pretrained(
108
+ model_id,
109
+ use_fast=default_use_fast if use_fast is None else bool(use_fast),
110
+ add_prefix_space=default_add_prefix_space if add_prefix_space is None else bool(add_prefix_space),
111
+ trust_remote_code=trust_remote_code,
112
+ )
113
+
114
+ def register_ecg_special_tokens(
115
+ tokenizer: AutoTokenizer,
116
+ catalog: ECGSpecialTokenCatalog,
117
+ ) -> Dict[int, int]:
118
+ """
119
+ Ensure the tokenizer includes the ECG special tokens from the provided catalog.
120
+ Returns a mapping from catalog index to token ID.
121
+ """
122
+ # Add only tokens that are currently unknown to the tokenizer (not present
123
+ # as core specials or regular vocab entries).
124
+ tokens_to_add: List[str] = []
125
+ for token in catalog.tokens:
126
+ tok_id = tokenizer.convert_tokens_to_ids(token)
127
+ if tok_id is None or tok_id == tokenizer.unk_token_id:
128
+ tokens_to_add.append(token)
129
+ if tokens_to_add:
130
+ tokenizer.add_special_tokens({"additional_special_tokens": tokens_to_add})
131
+ ecg_special_token_id_map: Dict[int, int] = {}
132
+ for token, catalog_index in catalog.token_to_index.items():
133
+ token_id = tokenizer.convert_tokens_to_ids(token)
134
+ if token_id is None or token_id == tokenizer.unk_token_id:
135
+ raise RuntimeError(f"Tokenizer failed to register ECG special token: {token}")
136
+ encoded = tokenizer.encode(token, add_special_tokens=False)
137
+ if len(encoded) != 1 or encoded[0] != token_id:
138
+ raise RuntimeError(f"ECG special token does not map to a single id: {token}")
139
+ ecg_special_token_id_map[catalog_index] = int(token_id)
140
+ return ecg_special_token_id_map
camel_inference/src/read_ecg.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import numpy as np
3
+ import wfdb
4
+
5
+ def load_record(ecg_path, start_sec: Optional[int], end_sec: Optional[int], leads: Optional[list[str]]):
6
+ record = wfdb.rdrecord(ecg_path)
7
+ fs = record.fs
8
+ lead_names = record.sig_name
9
+
10
+ signal = record.p_signal # n_samples x n_leads
11
+ if leads:
12
+ kept_signals, kept_leads = [], []
13
+ lead_to_idx = {name: i for i, name in enumerate(lead_names)}
14
+ for l in leads:
15
+ if l in lead_to_idx:
16
+ kept_signals.append(signal[:, lead_to_idx[l]])
17
+ kept_leads.append(l)
18
+ else:
19
+ print(f'Lead {l} does not exist. Skipping.')
20
+ if not kept_signals:
21
+ raise ValueError(f"None of the requested leads were found. requested={leads}, available={lead_names}")
22
+
23
+ signal = np.stack(kept_signals, axis=1)
24
+ lead_names = kept_leads
25
+
26
+ # Optinally subsample the signal
27
+ start_ind = 0 if start_sec is None else start_sec * fs
28
+ end_ind = len(signal) if end_sec is None else end_sec * fs
29
+ if end_ind > len(signal):
30
+ print(f'ECG is {len(signal) / fs} seconds')
31
+ signal = signal[start_ind:end_ind, :]
32
+
33
+ return signal, lead_names, fs