Mayank Keoliya commited on
Commit ·
cb9d9d1
1
Parent(s): c2708da
Deploy with bundled camel_inference and LFS for demo data
Browse files- .gitattributes +1 -0
- .gitignore +0 -1
- camel_inference/.gitignore +12 -0
- camel_inference/README.md +82 -0
- camel_inference/demo/08704_hr.dat +3 -0
- camel_inference/demo/08704_hr.hea +13 -0
- camel_inference/demo/12585_hr.dat +3 -0
- camel_inference/demo/12585_hr.hea +13 -0
- camel_inference/demo/12646_hr.dat +3 -0
- camel_inference/demo/12646_hr.hea +13 -0
- camel_inference/demo/example_prompt.json +8 -0
- camel_inference/pyproject.toml +36 -0
- camel_inference/run_camel.py +49 -0
- camel_inference/scripts/download_checkpoints.sh +37 -0
- camel_inference/src/camel/__init__.py +0 -0
- camel_inference/src/camel/assertions.py +509 -0
- camel_inference/src/camel/camel_model.py +166 -0
- camel_inference/src/camel/checkpoint_utils.py +235 -0
- camel_inference/src/camel/ecg_attention_masks.py +343 -0
- camel_inference/src/camel/ecg_gemma_model.py +91 -0
- camel_inference/src/camel/ecg_model_wrapper.py +394 -0
- camel_inference/src/camel/ecg_text_packing.py +499 -0
- camel_inference/src/camel/inference.py +846 -0
- camel_inference/src/camel/model_init.py +108 -0
- camel_inference/src/camel/model_introspect.py +95 -0
- camel_inference/src/camel/model_registry.py +497 -0
- camel_inference/src/camel/model_registry.yaml +509 -0
- camel_inference/src/camel/process_ecg.py +208 -0
- camel_inference/src/camel/projectors.py +57 -0
- camel_inference/src/camel/prompt_renderers.py +209 -0
- camel_inference/src/camel/training_setup.py +140 -0
- camel_inference/src/read_ecg.py +33 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.dat filter=lfs diff=lfs merge=lfs -text
|
.gitignore
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
camel_inference/
|
| 2 |
checkpoints/
|
| 3 |
__pycache__/
|
| 4 |
*.pyc
|
|
|
|
|
|
|
| 1 |
checkpoints/
|
| 2 |
__pycache__/
|
| 3 |
*.pyc
|
camel_inference/.gitignore
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MacOS
|
| 2 |
+
.DS_Store
|
| 3 |
+
|
| 4 |
+
# Python
|
| 5 |
+
/.env
|
| 6 |
+
__pycache__
|
| 7 |
+
|
| 8 |
+
# Ignore model checkpoints
|
| 9 |
+
checkpoints/
|
| 10 |
+
*.pt
|
| 11 |
+
|
| 12 |
+
*.egg-info
|
camel_inference/README.md
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CAMEL Inference
|
| 2 |
+
|
| 3 |
+
Inference-only repository for running CAMEL ECG-language checkpoints.
|
| 4 |
+
|
| 5 |
+
Only `run_camel.py` is intended as a public entrypoint. Modules under `src/camel/` are internal implementation details and may change.
|
| 6 |
+
|
| 7 |
+
## Repository Layout
|
| 8 |
+
|
| 9 |
+
- `run_camel.py`: public inference CLI
|
| 10 |
+
- `src/camel/`: internal model, tokenizer, ECG packing, and loading utilities
|
| 11 |
+
- `checkpoints/`: local adapter/checkpoint files
|
| 12 |
+
|
| 13 |
+
## Requirements
|
| 14 |
+
|
| 15 |
+
- Python 3.10+
|
| 16 |
+
- CUDA-enabled PyTorch recommended for practical inference latency
|
| 17 |
+
|
| 18 |
+
## Install
|
| 19 |
+
|
| 20 |
+
```bash
|
| 21 |
+
conda create -n camel python=3.10 -y
|
| 22 |
+
conda activate camel
|
| 23 |
+
pip install -e .
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
## Checkpoints
|
| 27 |
+
|
| 28 |
+
Checkpoints must be downloaded from huggingface `CAMEL-ECG/CAMEL` or with the repository script:
|
| 29 |
+
|
| 30 |
+
```bash
|
| 31 |
+
bash scripts/download_checkpoints.sh
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
## Usage
|
| 35 |
+
|
| 36 |
+
* CAMEL is available in three modes:
|
| 37 |
+
- `base`
|
| 38 |
+
- `ecgbench`
|
| 39 |
+
- `forecast`
|
| 40 |
+
|
| 41 |
+
```bash
|
| 42 |
+
python run_camel.py \
|
| 43 |
+
--mode forecast \
|
| 44 |
+
--text "Forecast cardiac rhythm for the next 5 minutes." \
|
| 45 |
+
--ecgs demo/08704_hr \
|
| 46 |
+
--device cuda:0
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
```bash
|
| 50 |
+
python run_camel.py \
|
| 51 |
+
--mode base \
|
| 52 |
+
--text "Compare the two ECG waveforms." \
|
| 53 |
+
--ecgs demo/12585_hr demo/12646_hr \
|
| 54 |
+
--device cuda:0
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
* Optionally, you can set start, end, and leads with `--ecgs-config`.
|
| 58 |
+
|
| 59 |
+
```bash
|
| 60 |
+
python run_camel.py \
|
| 61 |
+
--mode forecast \
|
| 62 |
+
--text "Forecast cardiac rhythm for the next 5 minutes." \
|
| 63 |
+
--ecgs demo/08704_hr \
|
| 64 |
+
--ecg-configs "start:0;end:5;use_leads:I,II" \
|
| 65 |
+
--device cuda:0
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
* Using `--text` and `--ecgs` defaults to text followed by the ecg in order.
|
| 69 |
+
For arbitrary text/ECG interleaving use `--json`.
|
| 70 |
+
```bash
|
| 71 |
+
python run_camel.py --mode base --json demo/example_prompt.json --device cuda:0
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
* Sampling flags:
|
| 75 |
+
- `--temperature`
|
| 76 |
+
- `--top-k`
|
| 77 |
+
- `--top-p`
|
| 78 |
+
- `--min-p`
|
| 79 |
+
- `--max-new-tokens`
|
| 80 |
+
|
| 81 |
+
Implementation notes:
|
| 82 |
+
- ECG loading is currently implemented for WFDB-format inputs. To support additional formats, extend `src/read_ecg.py`.
|
camel_inference/demo/08704_hr.dat
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ade700b2d1a1e3a0889d4135ee179a33b7dc2eaf76d6702bc9bc64f6cef4a3f4
|
| 3 |
+
size 120000
|
camel_inference/demo/08704_hr.hea
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
08704_hr 12 500 5000
|
| 2 |
+
08704_hr.dat 16 1000.0(0)/mV 16 0 -50 41247 0 I
|
| 3 |
+
08704_hr.dat 16 1000.0(0)/mV 16 0 -20 64030 0 II
|
| 4 |
+
08704_hr.dat 16 1000.0(0)/mV 16 0 30 22786 0 III
|
| 5 |
+
08704_hr.dat 16 1000.0(0)/mV 16 0 35 13135 0 aVR
|
| 6 |
+
08704_hr.dat 16 1000.0(0)/mV 16 0 -40 42425 0 aVL
|
| 7 |
+
08704_hr.dat 16 1000.0(0)/mV 16 0 5 11132 0 aVF
|
| 8 |
+
08704_hr.dat 16 1000.0(0)/mV 16 0 210 35283 0 V1
|
| 9 |
+
08704_hr.dat 16 1000.0(0)/mV 16 0 -605 1875 0 V2
|
| 10 |
+
08704_hr.dat 16 1000.0(0)/mV 16 0 -360 20664 0 V3
|
| 11 |
+
08704_hr.dat 16 1000.0(0)/mV 16 0 -255 26244 0 V4
|
| 12 |
+
08704_hr.dat 16 1000.0(0)/mV 16 0 -230 25112 0 V5
|
| 13 |
+
08704_hr.dat 16 1000.0(0)/mV 16 0 60 30065 0 V6
|
camel_inference/demo/12585_hr.dat
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1dbc29ba7d9857a606c23d1a661dee034285f471908206e5428c9b6b7c8ff467
|
| 3 |
+
size 120000
|
camel_inference/demo/12585_hr.hea
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
12585_hr 12 500 5000
|
| 2 |
+
12585_hr.dat 16 1000.0(0)/mV 16 0 540 5223 0 I
|
| 3 |
+
12585_hr.dat 16 1000.0(0)/mV 16 0 40 11911 0 II
|
| 4 |
+
12585_hr.dat 16 1000.0(0)/mV 16 0 -500 6676 0 III
|
| 5 |
+
12585_hr.dat 16 1000.0(0)/mV 16 0 -290 57071 0 aVR
|
| 6 |
+
12585_hr.dat 16 1000.0(0)/mV 16 0 520 65066 0 aVL
|
| 7 |
+
12585_hr.dat 16 1000.0(0)/mV 16 0 -230 9479 0 aVF
|
| 8 |
+
12585_hr.dat 16 1000.0(0)/mV 16 0 -130 26077 0 V1
|
| 9 |
+
12585_hr.dat 16 1000.0(0)/mV 16 0 -145 27523 0 V2
|
| 10 |
+
12585_hr.dat 16 1000.0(0)/mV 16 0 -155 35476 0 V3
|
| 11 |
+
12585_hr.dat 16 1000.0(0)/mV 16 0 -350 16663 0 V4
|
| 12 |
+
12585_hr.dat 16 1000.0(0)/mV 16 0 -545 60445 0 V5
|
| 13 |
+
12585_hr.dat 16 1000.0(0)/mV 16 0 -105 58137 0 V6
|
camel_inference/demo/12646_hr.dat
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7cabc7399daa14ebf97cc705b9da09aeb3daee12939ca797d145237cdfcf4ed8
|
| 3 |
+
size 120000
|
camel_inference/demo/12646_hr.hea
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
12646_hr 12 500 5000
|
| 2 |
+
12646_hr.dat 16 1000.0(0)/mV 16 0 -260 38739 0 I
|
| 3 |
+
12646_hr.dat 16 1000.0(0)/mV 16 0 -465 47305 0 II
|
| 4 |
+
12646_hr.dat 16 1000.0(0)/mV 16 0 -205 8550 0 III
|
| 5 |
+
12646_hr.dat 16 1000.0(0)/mV 16 0 363 22259 0 aVR
|
| 6 |
+
12646_hr.dat 16 1000.0(0)/mV 16 0 -27 47829 0 aVL
|
| 7 |
+
12646_hr.dat 16 1000.0(0)/mV 16 0 -335 60490 0 aVF
|
| 8 |
+
12646_hr.dat 16 1000.0(0)/mV 16 0 220 55705 0 V1
|
| 9 |
+
12646_hr.dat 16 1000.0(0)/mV 16 0 120 51889 0 V2
|
| 10 |
+
12646_hr.dat 16 1000.0(0)/mV 16 0 -390 40903 0 V3
|
| 11 |
+
12646_hr.dat 16 1000.0(0)/mV 16 0 -660 39373 0 V4
|
| 12 |
+
12646_hr.dat 16 1000.0(0)/mV 16 0 -770 43235 0 V5
|
| 13 |
+
12646_hr.dat 16 1000.0(0)/mV 16 0 135 65290 0 V6
|
camel_inference/demo/example_prompt.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{ "type": "text", "text": "You are given two ECG waveforms of a same patient from two different time points."},
|
| 3 |
+
{ "type": "text", "text": "This is the first ECG."},
|
| 4 |
+
{ "type": "ecg", "ecg": "demo/12585_hr" },
|
| 5 |
+
{ "type": "text", "text": "This is the first ECG." },
|
| 6 |
+
{ "type": "ecg", "ecg": "demo/12646_hr" },
|
| 7 |
+
{ "type": "text", "text": "Has non-diagnostic t abnormalities been resolved in the recent tracing compared to the previous one?"}
|
| 8 |
+
]
|
camel_inference/pyproject.toml
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=61.0", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[tool.setuptools]
|
| 6 |
+
py-modules = ["run_camel"]
|
| 7 |
+
|
| 8 |
+
[tool.setuptools.packages.find]
|
| 9 |
+
where = ["src"]
|
| 10 |
+
|
| 11 |
+
[tool.setuptools.package-data]
|
| 12 |
+
camel = ["model_registry.yaml"]
|
| 13 |
+
|
| 14 |
+
[project]
|
| 15 |
+
name = "camel-inference"
|
| 16 |
+
version = "0.1.0"
|
| 17 |
+
description = "Inference-only CLI for CAMEL ECG-language checkpoints"
|
| 18 |
+
readme = "README.md"
|
| 19 |
+
requires-python = ">=3.9"
|
| 20 |
+
authors = [
|
| 21 |
+
{ name = "CAMEL contributors" }
|
| 22 |
+
]
|
| 23 |
+
dependencies = [
|
| 24 |
+
"numpy",
|
| 25 |
+
"scipy",
|
| 26 |
+
"pyyaml",
|
| 27 |
+
"torch",
|
| 28 |
+
"transformers",
|
| 29 |
+
"peft",
|
| 30 |
+
"accelerate",
|
| 31 |
+
"sentencepiece",
|
| 32 |
+
"protobuf"
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
[project.scripts]
|
| 36 |
+
camel-infer = "run_camel:main"
|
camel_inference/run_camel.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from camel.camel_model import CAMEL
|
| 3 |
+
|
| 4 |
+
def main():
|
| 5 |
+
parser = argparse.ArgumentParser(description="CAMEL")
|
| 6 |
+
parser.add_argument("--mode", type=str, choices=['forecast', 'base', 'ecgbench'], default='base')
|
| 7 |
+
parser.add_argument("--device", type=str, default='cuda:0')
|
| 8 |
+
parser.add_argument("--json", type=str, default=None)
|
| 9 |
+
parser.add_argument("--text", type=str, default=None)
|
| 10 |
+
parser.add_argument("--ecgs", type=str, default=None, nargs='+')
|
| 11 |
+
parser.add_argument("--ecg-configs", type=str, default=None, nargs='+')
|
| 12 |
+
parser.add_argument("--temperature", type=float, default=0.0)
|
| 13 |
+
parser.add_argument(
|
| 14 |
+
"--top-k",
|
| 15 |
+
dest="top_k",
|
| 16 |
+
type=int,
|
| 17 |
+
default=64,
|
| 18 |
+
help="Top-k sampling cutoff (set <=0 to disable).",
|
| 19 |
+
)
|
| 20 |
+
parser.add_argument(
|
| 21 |
+
"--top-p",
|
| 22 |
+
dest="top_p",
|
| 23 |
+
type=float,
|
| 24 |
+
default=0.95,
|
| 25 |
+
help="Nucleus sampling cumulative probability cutoff.",
|
| 26 |
+
)
|
| 27 |
+
parser.add_argument(
|
| 28 |
+
"--min-p",
|
| 29 |
+
dest="min_p",
|
| 30 |
+
type=float,
|
| 31 |
+
default=0.0,
|
| 32 |
+
help="Minimum per-token probability threshold applied after temperature scaling.",
|
| 33 |
+
)
|
| 34 |
+
parser.add_argument(
|
| 35 |
+
"--max-new-tokens",
|
| 36 |
+
type=int,
|
| 37 |
+
default=512,
|
| 38 |
+
help="Maximum number of tokens to generate per sample.",
|
| 39 |
+
)
|
| 40 |
+
args = parser.parse_args()
|
| 41 |
+
|
| 42 |
+
model = CAMEL(mode=args.mode, device=args.device)
|
| 43 |
+
output, prompt = model.run(args)
|
| 44 |
+
|
| 45 |
+
print(f'Prompt: {prompt}')
|
| 46 |
+
print(f'Prediction: {output}')
|
| 47 |
+
|
| 48 |
+
if __name__ == "__main__":
|
| 49 |
+
main()
|
camel_inference/scripts/download_checkpoints.sh
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
echo "Installing huggingface_hub if needed..."
|
| 5 |
+
python3 -m pip install -q --user huggingface_hub
|
| 6 |
+
|
| 7 |
+
echo "Downloading CAMEL checkpoints from Hugging Face..."
|
| 8 |
+
mkdir -p checkpoints
|
| 9 |
+
|
| 10 |
+
python3 - <<'PY'
|
| 11 |
+
import os, shutil
|
| 12 |
+
from huggingface_hub import hf_hub_download
|
| 13 |
+
|
| 14 |
+
repo = "CAMEL-ECG/CAMEL"
|
| 15 |
+
files = [
|
| 16 |
+
"camel_base.pt",
|
| 17 |
+
"camel_ecginstruct.pt",
|
| 18 |
+
"camel_forecast.pt"
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
os.makedirs("checkpoints", exist_ok=True)
|
| 22 |
+
|
| 23 |
+
for f in files:
|
| 24 |
+
print(f"Downloading {f}...")
|
| 25 |
+
src = hf_hub_download(
|
| 26 |
+
repo_id=repo,
|
| 27 |
+
filename=f,
|
| 28 |
+
repo_type="model"
|
| 29 |
+
)
|
| 30 |
+
dst = os.path.join("checkpoints", f)
|
| 31 |
+
shutil.copy2(src, dst)
|
| 32 |
+
print(f"Saved to {dst}")
|
| 33 |
+
|
| 34 |
+
print("All checkpoints downloaded.")
|
| 35 |
+
PY
|
| 36 |
+
|
| 37 |
+
echo "Done."
|
camel_inference/src/camel/__init__.py
ADDED
|
File without changes
|
camel_inference/src/camel/assertions.py
ADDED
|
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Assertions and summaries for ECG language-model wrappers and their LoRA adapters.
|
| 3 |
+
"""
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
from typing import Any, Dict, Iterable, List, Mapping, Set
|
| 7 |
+
import functools
|
| 8 |
+
import os
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from camel.ecg_attention_masks import ECGBlockLayout
|
| 12 |
+
|
| 13 |
+
_ASSERTIONS_ENABLED = os.getenv("ASSERTIONS") == "1"
|
| 14 |
+
|
| 15 |
+
def _skip_if_assertions_disabled(func):
|
| 16 |
+
"""Decorator that no-ops assertion helpers when ASSERTIONS env var is not set."""
|
| 17 |
+
@functools.wraps(func)
|
| 18 |
+
def wrapper(*args, **kwargs):
|
| 19 |
+
if not _ASSERTIONS_ENABLED:
|
| 20 |
+
return None
|
| 21 |
+
return func(*args, **kwargs)
|
| 22 |
+
|
| 23 |
+
return wrapper
|
| 24 |
+
|
| 25 |
+
@_skip_if_assertions_disabled
|
| 26 |
+
def assert_tensor_dtype(tensor: torch.Tensor, *, expected: torch.dtype, context: str) -> None:
|
| 27 |
+
"""Verify tensor dtype matches expectation."""
|
| 28 |
+
if tensor.dtype != expected:
|
| 29 |
+
raise AssertionError(f"{context}: expected dtype {expected}, got {tensor.dtype}")
|
| 30 |
+
|
| 31 |
+
@_skip_if_assertions_disabled
|
| 32 |
+
def assert_ecg_blocks_consistent(
|
| 33 |
+
*,
|
| 34 |
+
turn_parts: Iterable[Iterable[Dict[str, Any]]],
|
| 35 |
+
ecg_blocks: Iterable[Dict[str, Any]],
|
| 36 |
+
) -> None:
|
| 37 |
+
"""
|
| 38 |
+
Validate that structured turn parts contain the expected ECG markers per block.
|
| 39 |
+
"""
|
| 40 |
+
blocks_list = list(ecg_blocks)
|
| 41 |
+
expected_counts: Dict[int, Dict[str, int]] = {}
|
| 42 |
+
per_lead_special_counts: Dict[int, Dict[str, int]] = {}
|
| 43 |
+
per_lead_secs: Dict[int, Dict[str, Set[int]]] = {}
|
| 44 |
+
global_special_count: Dict[int, int] = {}
|
| 45 |
+
|
| 46 |
+
for idx, blk in enumerate(blocks_list):
|
| 47 |
+
leads = [str(ld) for ld in blk.get("lead_names", [])]
|
| 48 |
+
segs = [int(n) for n in blk.get("segments_per_lead", [])]
|
| 49 |
+
expected_counts[idx] = {ld: int(n) for ld, n in zip(leads, segs)}
|
| 50 |
+
per_lead_special_counts[idx] = {ld: 0 for ld in leads}
|
| 51 |
+
per_lead_secs[idx] = {ld: set() for ld in leads}
|
| 52 |
+
global_special_count[idx] = 0
|
| 53 |
+
|
| 54 |
+
for turn in turn_parts:
|
| 55 |
+
for part in turn:
|
| 56 |
+
kind = part.get("kind")
|
| 57 |
+
if kind == "text":
|
| 58 |
+
continue
|
| 59 |
+
block_idx = part.get("block_index")
|
| 60 |
+
if block_idx is None or int(block_idx) not in expected_counts:
|
| 61 |
+
raise AssertionError("ECG part references unknown block_index.")
|
| 62 |
+
block_idx = int(block_idx)
|
| 63 |
+
allowed_leads = set(expected_counts[block_idx].keys())
|
| 64 |
+
if kind == "special":
|
| 65 |
+
lead = part.get("lead")
|
| 66 |
+
if lead:
|
| 67 |
+
if lead not in allowed_leads:
|
| 68 |
+
raise AssertionError(f"Special token references unknown lead '{lead}'.")
|
| 69 |
+
per_lead_special_counts[block_idx][lead] = per_lead_special_counts[block_idx].get(lead, 0) + 1
|
| 70 |
+
else:
|
| 71 |
+
global_special_count[block_idx] = global_special_count.get(block_idx, 0) + 1
|
| 72 |
+
token_literal = part.get("token")
|
| 73 |
+
if not isinstance(token_literal, str) or len(token_literal) == 0:
|
| 74 |
+
raise AssertionError("Special turn part lacks a string token literal.")
|
| 75 |
+
continue
|
| 76 |
+
if kind == "ecg":
|
| 77 |
+
lead = part.get("lead")
|
| 78 |
+
if lead not in allowed_leads:
|
| 79 |
+
raise AssertionError(f"ECG segment references unknown lead '{lead}'.")
|
| 80 |
+
sec_val = part.get("sec")
|
| 81 |
+
expected_total = expected_counts[block_idx].get(lead, 0)
|
| 82 |
+
if expected_total <= 0:
|
| 83 |
+
raise AssertionError(
|
| 84 |
+
f"Lead '{lead}' has non-positive declared segments ({expected_total}) but ECG markers are present."
|
| 85 |
+
)
|
| 86 |
+
try:
|
| 87 |
+
sec = int(sec_val)
|
| 88 |
+
except Exception as exc: # noqa: BLE001
|
| 89 |
+
raise AssertionError(f"ECG segment for lead '{lead}' has non-integer sec {sec_val!r}.") from exc
|
| 90 |
+
if sec < 1 or sec > expected_total:
|
| 91 |
+
raise AssertionError(
|
| 92 |
+
f"ECG segment for lead '{lead}' has second={sec}, expected within [1,{expected_total}]."
|
| 93 |
+
)
|
| 94 |
+
if sec in per_lead_secs[block_idx][lead]:
|
| 95 |
+
raise AssertionError(f"Duplicate ECG segment marker for lead '{lead}' second {sec}.")
|
| 96 |
+
per_lead_secs[block_idx][lead].add(sec)
|
| 97 |
+
continue
|
| 98 |
+
raise AssertionError(f"Unknown turn_parts kind '{kind}'.")
|
| 99 |
+
|
| 100 |
+
for block_idx, expected in expected_counts.items():
|
| 101 |
+
if global_special_count.get(block_idx, 0) != 2:
|
| 102 |
+
raise AssertionError(
|
| 103 |
+
f"Expected exactly two global ECG markers for block {block_idx}; "
|
| 104 |
+
f"found {global_special_count.get(block_idx, 0)}."
|
| 105 |
+
)
|
| 106 |
+
for lead, expected_total in expected.items():
|
| 107 |
+
expected_specials = per_lead_special_counts[block_idx].get(lead, 0)
|
| 108 |
+
if expected_specials != 2:
|
| 109 |
+
raise AssertionError(
|
| 110 |
+
f"Lead '{lead}' has {expected_specials} special markers; expected start and end (2 total)."
|
| 111 |
+
)
|
| 112 |
+
seen_secs = per_lead_secs[block_idx].get(lead, set())
|
| 113 |
+
if expected_total != len(seen_secs):
|
| 114 |
+
missing = sorted(set(range(1, expected_total + 1)) - seen_secs)
|
| 115 |
+
raise AssertionError(
|
| 116 |
+
f"Lead '{lead}' missing ECG segment markers for seconds {missing} (expected {expected_total})."
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# ---------------- Trainer batch validation helpers -----------------------------------------------
|
| 121 |
+
|
| 122 |
+
@_skip_if_assertions_disabled
|
| 123 |
+
def assert_prefix_split_complete(*, offset: int, total_prefix_rows: int) -> None:
|
| 124 |
+
"""Validate that prefix splitting consumed all rows."""
|
| 125 |
+
if offset != total_prefix_rows:
|
| 126 |
+
raise RuntimeError(
|
| 127 |
+
f"Prefix split mismatch: consumed {offset} rows but have {total_prefix_rows}"
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@_skip_if_assertions_disabled
|
| 132 |
+
def assert_prefix_matches_segments(
|
| 133 |
+
*,
|
| 134 |
+
prefix_rows: int,
|
| 135 |
+
segments_per_lead: Iterable[int],
|
| 136 |
+
lead_names: Iterable[str],
|
| 137 |
+
sample_index: int,
|
| 138 |
+
block_index: int,
|
| 139 |
+
) -> None:
|
| 140 |
+
"""Validate that prefix row count matches sum of segments_per_lead."""
|
| 141 |
+
total_segments = sum(int(n) for n in segments_per_lead)
|
| 142 |
+
if prefix_rows != total_segments:
|
| 143 |
+
raise RuntimeError(
|
| 144 |
+
f"Sample {sample_index} block {block_index}: Prefix rows ({prefix_rows}) "
|
| 145 |
+
f"!= sum(segments_per_lead) ({total_segments}). "
|
| 146 |
+
f"lead_names={list(lead_names)} segments_per_lead={list(segments_per_lead)}"
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
@_skip_if_assertions_disabled
|
| 150 |
+
def assert_ecg_part_bounds(
|
| 151 |
+
*,
|
| 152 |
+
lead: str,
|
| 153 |
+
sec: int,
|
| 154 |
+
lead_to_offset: Mapping[str, int],
|
| 155 |
+
declared_segments: Mapping[str, int],
|
| 156 |
+
total_prefix_rows: int,
|
| 157 |
+
sample_index: int,
|
| 158 |
+
block_index: int,
|
| 159 |
+
) -> None:
|
| 160 |
+
"""Validate ECG part (lead, sec) falls within expected bounds."""
|
| 161 |
+
if lead not in declared_segments:
|
| 162 |
+
raise RuntimeError(f"Unknown lead {lead} in parts for sample {sample_index} block {block_index}")
|
| 163 |
+
|
| 164 |
+
nseg = int(declared_segments[lead])
|
| 165 |
+
if not (1 <= sec <= nseg):
|
| 166 |
+
raise RuntimeError(
|
| 167 |
+
f"sec out of range for lead {lead}: got {sec}, expected 1..{nseg}"
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
base = lead_to_offset[lead]
|
| 171 |
+
start = base
|
| 172 |
+
end = base + nseg # exclusive
|
| 173 |
+
row_idx = start + (sec - 1)
|
| 174 |
+
|
| 175 |
+
# Check both global and per-lead bounds
|
| 176 |
+
if not (0 <= row_idx < total_prefix_rows):
|
| 177 |
+
raise RuntimeError(
|
| 178 |
+
f"Bad (lead,sec)=({lead},{sec}) for sample {sample_index} block {block_index}: "
|
| 179 |
+
f"row_idx {row_idx} not in [0,{total_prefix_rows})"
|
| 180 |
+
)
|
| 181 |
+
if not (start <= row_idx < end):
|
| 182 |
+
raise RuntimeError(
|
| 183 |
+
f"(lead,sec)=({lead},{sec}) maps outside this lead block "
|
| 184 |
+
f"[{start},{end}) (row_idx={row_idx}) for sample {sample_index}"
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
@_skip_if_assertions_disabled
|
| 189 |
+
def assert_layout_specials_complete(
|
| 190 |
+
*,
|
| 191 |
+
block_layout: ECGBlockLayout,
|
| 192 |
+
lead_names: Iterable[str],
|
| 193 |
+
) -> None:
|
| 194 |
+
"""Validate that layout has complete and ordered special token markers.
|
| 195 |
+
|
| 196 |
+
For each declared lead:
|
| 197 |
+
- Both start and end must be present (or both absent)
|
| 198 |
+
- If present, start < end
|
| 199 |
+
|
| 200 |
+
For global markers:
|
| 201 |
+
- Both start and end must be present (or both absent)
|
| 202 |
+
- If present, start < end
|
| 203 |
+
"""
|
| 204 |
+
# Check per-lead specials
|
| 205 |
+
for ld in lead_names:
|
| 206 |
+
s = block_layout.lead_start_idx.get(ld)
|
| 207 |
+
e = block_layout.lead_end_idx.get(ld)
|
| 208 |
+
if (s is None) != (e is None):
|
| 209 |
+
raise RuntimeError(f"Lead {ld} missing start/end special (s={s}, e={e})")
|
| 210 |
+
if s is not None and not (s < e):
|
| 211 |
+
raise RuntimeError(f"Lead {ld} specials out of order: start={s}, end={e}")
|
| 212 |
+
|
| 213 |
+
# Check global specials
|
| 214 |
+
if (block_layout.global_start_idx is None) != (block_layout.global_end_idx is None):
|
| 215 |
+
raise RuntimeError("Global start/end special mismatch")
|
| 216 |
+
if block_layout.global_start_idx is not None and block_layout.global_end_idx is not None:
|
| 217 |
+
if not (block_layout.global_start_idx < block_layout.global_end_idx):
|
| 218 |
+
raise RuntimeError(
|
| 219 |
+
f"Global specials out of order: start={block_layout.global_start_idx} "
|
| 220 |
+
f"end={block_layout.global_end_idx}"
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
# ---------------- Wrapper embedding validations --------------------------------------------------
|
| 224 |
+
|
| 225 |
+
@_skip_if_assertions_disabled
|
| 226 |
+
def assert_wrapper_embed_length(
|
| 227 |
+
*,
|
| 228 |
+
embeddings: torch.Tensor,
|
| 229 |
+
ids: List[int],
|
| 230 |
+
context: str,
|
| 231 |
+
) -> None:
|
| 232 |
+
"""Ensure embedding sequence length matches token count exactly.
|
| 233 |
+
|
| 234 |
+
This is a critical invariant check that ensures the 1:1 mapping between
|
| 235 |
+
input token IDs and output embeddings is preserved.
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
embeddings: Output embedding tensor (must be at least 1-D)
|
| 239 |
+
ids: Input token ID list
|
| 240 |
+
context: Description of where this check is being performed
|
| 241 |
+
|
| 242 |
+
Raises:
|
| 243 |
+
RuntimeError: if ids is not a list, embeddings are not at least 1-D,
|
| 244 |
+
or if the embedding length doesn't match the token count
|
| 245 |
+
"""
|
| 246 |
+
if not isinstance(ids, list):
|
| 247 |
+
raise RuntimeError(f"{context}: ids must be a python list of ints")
|
| 248 |
+
if embeddings.dim() < 1:
|
| 249 |
+
raise RuntimeError(f"{context}: embeddings must be at least 1-D, got shape {tuple(embeddings.shape)}")
|
| 250 |
+
if embeddings.size(0) != len(ids):
|
| 251 |
+
raise RuntimeError(f"{context}: embed length {embeddings.size(0)} != token count {len(ids)}")
|
| 252 |
+
|
| 253 |
+
@_skip_if_assertions_disabled
|
| 254 |
+
def assert_rest_length_nonnegative(*, rest_length: int) -> None:
|
| 255 |
+
"""Validate that rest token length is non-negative.
|
| 256 |
+
|
| 257 |
+
This should never happen in correct code, but catching it early helps
|
| 258 |
+
identify bugs in label construction logic.
|
| 259 |
+
|
| 260 |
+
Args:
|
| 261 |
+
rest_length: Length of ids_rest list
|
| 262 |
+
|
| 263 |
+
Raises:
|
| 264 |
+
ValueError: if rest_length is negative
|
| 265 |
+
"""
|
| 266 |
+
if rest_length < 0:
|
| 267 |
+
raise ValueError("ids_rest length is negative (internal error).")
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
# ---------------- Utility assertion helpers ------------------------------------------------------
|
| 271 |
+
|
| 272 |
+
@_skip_if_assertions_disabled
|
| 273 |
+
def assert_sorted_non_overlapping_spans(spans: List[tuple[int, int]], length: int, ctx: str) -> None:
|
| 274 |
+
"""Validate that spans are sorted, non-overlapping, and within bounds."""
|
| 275 |
+
prev_end = -1
|
| 276 |
+
for i, (s, e) in enumerate(spans):
|
| 277 |
+
if not (0 <= s <= e <= length):
|
| 278 |
+
raise AssertionError(f"{ctx}: span {i}={(s,e)} out of bounds for length {length}")
|
| 279 |
+
if s < prev_end:
|
| 280 |
+
raise AssertionError(f"{ctx}: spans overlap or not sorted at {i-1},{i}: prev_end={prev_end}, curr_start={s}")
|
| 281 |
+
prev_end = e
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
@_skip_if_assertions_disabled
|
| 285 |
+
def assert_equal_int(a: int, b: int, msg: str) -> None:
|
| 286 |
+
"""Assert two integers are equal with a descriptive message."""
|
| 287 |
+
if int(a) != int(b):
|
| 288 |
+
raise AssertionError(f"{msg}: {a} != {b}")
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
@_skip_if_assertions_disabled
|
| 292 |
+
def assert_positive_int(n: int, msg: str) -> None:
|
| 293 |
+
"""Assert an integer is positive (> 0)."""
|
| 294 |
+
if int(n) <= 0:
|
| 295 |
+
raise AssertionError(f"{msg}: expected > 0, got {n}")
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
# ---------------- Schema and catalog validations -------------------------------------------------
|
| 299 |
+
|
| 300 |
+
@_skip_if_assertions_disabled
|
| 301 |
+
def assert_ecg_catalog_valid(catalog: Any, schema: Any) -> None:
|
| 302 |
+
"""Validate ECG special token catalog for uniqueness and mapping consistency.
|
| 303 |
+
|
| 304 |
+
Checks:
|
| 305 |
+
- All tokens are unique
|
| 306 |
+
- Every canonical lead has entries in lead_to_indices and lead_to_tokens
|
| 307 |
+
- Token-to-index mappings are consistent across all structures
|
| 308 |
+
- Global markers (start/end) are present in the catalog
|
| 309 |
+
"""
|
| 310 |
+
# Uniqueness
|
| 311 |
+
if len(set(catalog.tokens)) != len(catalog.tokens):
|
| 312 |
+
raise AssertionError("ECG special tokens contain duplicates")
|
| 313 |
+
|
| 314 |
+
# Per-lead mappings
|
| 315 |
+
for lead in schema.ecg.canonical_leads:
|
| 316 |
+
if lead not in catalog.lead_to_indices or lead not in catalog.lead_to_tokens:
|
| 317 |
+
raise AssertionError(f"Missing lead in catalog: {lead}")
|
| 318 |
+
for kind in ("start", "end"):
|
| 319 |
+
tok = catalog.lead_to_tokens[lead][kind]
|
| 320 |
+
idx = catalog.lead_to_indices[lead][kind]
|
| 321 |
+
if catalog.tokens[idx] != tok:
|
| 322 |
+
raise AssertionError(f"Catalog mismatch for {lead}:{kind}: tokens[idx] != tok")
|
| 323 |
+
if catalog.token_to_index.get(tok, None) != idx:
|
| 324 |
+
raise AssertionError(f"token_to_index mismatch for {lead}:{kind}")
|
| 325 |
+
|
| 326 |
+
# Global markers
|
| 327 |
+
for tok in (schema.ecg.global_start, schema.ecg.global_end):
|
| 328 |
+
if tok not in catalog.token_to_index:
|
| 329 |
+
raise AssertionError(f"Global ECG token missing from catalog: {tok}")
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
# ---------------- Conversation and role validations ----------------------------------------------
|
| 333 |
+
|
| 334 |
+
@_skip_if_assertions_disabled
|
| 335 |
+
def assert_normalized_role_canonical(role: str, schema: Any) -> None:
|
| 336 |
+
"""Ensure normalized role matches one of the canonical prompt roles."""
|
| 337 |
+
if role not in (schema.prompt.user_role, schema.prompt.model_role):
|
| 338 |
+
raise AssertionError(f"Normalized role '{role}' did not resolve to a canonical prompt role")
|
| 339 |
+
|
| 340 |
+
# ---------------- Tokenization and span validations ----------------------------------------------
|
| 341 |
+
|
| 342 |
+
@_skip_if_assertions_disabled
|
| 343 |
+
def assert_tokenization_cursor_matches(cursor: int, ids_length: int) -> None:
|
| 344 |
+
"""Ensure cursor tracking matches actual text_ids length."""
|
| 345 |
+
if cursor != ids_length:
|
| 346 |
+
raise AssertionError(f"cursor ({cursor}) != len(text_ids) ({ids_length})")
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
@_skip_if_assertions_disabled
|
| 350 |
+
def assert_model_spans_valid(model_spans: List[tuple[int, int]], ids_length: int) -> None:
|
| 351 |
+
"""Validate model spans are sorted, non-overlapping, and at least one exists."""
|
| 352 |
+
assert_sorted_non_overlapping_spans(model_spans, ids_length, ctx="model_spans_in_text")
|
| 353 |
+
if len(model_spans) == 0:
|
| 354 |
+
raise AssertionError("No model spans found in text ids")
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
@_skip_if_assertions_disabled
|
| 358 |
+
def assert_eos_appended(ids: List[int], tokenizer: Any, require_eos: bool) -> None:
|
| 359 |
+
"""Validate EOS token was appended if required."""
|
| 360 |
+
if require_eos and tokenizer.eos_token_id is not None:
|
| 361 |
+
if not ids or ids[-1] != tokenizer.eos_token_id:
|
| 362 |
+
raise AssertionError("Required EOS was not appended at the end of text_ids")
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
# ---------------- u0 parts structure validations -------------------------------------------------
|
| 366 |
+
|
| 367 |
+
@_skip_if_assertions_disabled
|
| 368 |
+
def assert_turn_parts_structure_valid(
|
| 369 |
+
parts: List[Dict[str, Any]],
|
| 370 |
+
ecg_blocks: List[Dict[str, Any]],
|
| 371 |
+
schema: Any,
|
| 372 |
+
catalog: Any,
|
| 373 |
+
) -> None:
|
| 374 |
+
"""Validate the complete structure of turn parts for all blocks present."""
|
| 375 |
+
block_indices = sorted({int(p.get("block_index")) for p in parts if p.get("block_index") is not None})
|
| 376 |
+
for block_idx in block_indices:
|
| 377 |
+
if block_idx < 0 or block_idx >= len(ecg_blocks):
|
| 378 |
+
raise AssertionError(f"Unknown block_index {block_idx} in turn parts")
|
| 379 |
+
blk = ecg_blocks[block_idx]
|
| 380 |
+
leads_present = [str(ld) for ld in blk.get("lead_names", [])]
|
| 381 |
+
segments_per_lead = [int(n) for n in blk.get("segments_per_lead", [])]
|
| 382 |
+
|
| 383 |
+
special_tokens = [p.get("token") for p in parts if p.get("kind") == "special" and p.get("block_index") == block_idx]
|
| 384 |
+
if schema.ecg.global_start not in special_tokens:
|
| 385 |
+
raise AssertionError(f"Missing global_start special in block {block_idx}")
|
| 386 |
+
if schema.ecg.global_end not in special_tokens:
|
| 387 |
+
raise AssertionError(f"Missing global_end special in block {block_idx}")
|
| 388 |
+
|
| 389 |
+
idx_global_start = next(
|
| 390 |
+
(i for i, p in enumerate(parts)
|
| 391 |
+
if p.get("block_index") == block_idx and p.get("kind") == "special"
|
| 392 |
+
and p.get("token") == schema.ecg.global_start),
|
| 393 |
+
None
|
| 394 |
+
)
|
| 395 |
+
idx_global_end = next(
|
| 396 |
+
(i for i, p in enumerate(parts)
|
| 397 |
+
if p.get("block_index") == block_idx and p.get("kind") == "special"
|
| 398 |
+
and p.get("token") == schema.ecg.global_end),
|
| 399 |
+
None
|
| 400 |
+
)
|
| 401 |
+
if idx_global_start is None or idx_global_end is None or not (idx_global_start < idx_global_end):
|
| 402 |
+
raise AssertionError(f"Block {block_idx}: missing or misordered global start/end specials")
|
| 403 |
+
|
| 404 |
+
for lead, nseg in zip(leads_present, segments_per_lead):
|
| 405 |
+
assert_positive_int(nseg, f"segments_per_lead for {lead}")
|
| 406 |
+
idx_start = next(
|
| 407 |
+
(i for i, p in enumerate(parts)
|
| 408 |
+
if p.get("block_index") == block_idx and p.get("kind") == "special" and p.get("lead") == lead
|
| 409 |
+
and p.get("token") == catalog.lead_to_tokens[lead]["start"]),
|
| 410 |
+
None
|
| 411 |
+
)
|
| 412 |
+
idx_end = next(
|
| 413 |
+
(i for i, p in enumerate(parts)
|
| 414 |
+
if p.get("block_index") == block_idx and p.get("kind") == "special" and p.get("lead") == lead
|
| 415 |
+
and p.get("token") == catalog.lead_to_tokens[lead]["end"]),
|
| 416 |
+
None
|
| 417 |
+
)
|
| 418 |
+
if idx_start is None or idx_end is None or not (idx_start < idx_end):
|
| 419 |
+
raise AssertionError(f"Lead {lead}: missing or misordered start/end specials in block {block_idx}")
|
| 420 |
+
secs = [p["sec"] for p in parts[idx_start+1:idx_end] if p.get("kind") == "ecg" and p.get("lead") == lead]
|
| 421 |
+
if secs != list(range(1, int(nseg) + 1)):
|
| 422 |
+
raise AssertionError(f"Lead {lead}: ECG seconds sequence invalid: {secs} vs 1..{nseg}")
|
| 423 |
+
if parts[idx_start]["token_index"] != catalog.lead_to_indices[lead]["start"]:
|
| 424 |
+
raise AssertionError(f"Lead {lead}: start token_index mismatch")
|
| 425 |
+
if parts[idx_end]["token_index"] != catalog.lead_to_indices[lead]["end"]:
|
| 426 |
+
raise AssertionError(f"Lead {lead}: end token_index mismatch")
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
@_skip_if_assertions_disabled
|
| 430 |
+
def assert_turn_content_ends_with_eot(text_block: str, end_of_turn: str) -> None:
|
| 431 |
+
"""Ensure turn content ends with the provided end-of-turn suffix."""
|
| 432 |
+
if not text_block.endswith(end_of_turn):
|
| 433 |
+
raise AssertionError("Turn content must end with end_of_turn suffix")
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
# ---------------- Per-sample packing validations -------------------------------------------------
|
| 437 |
+
|
| 438 |
+
@_skip_if_assertions_disabled
|
| 439 |
+
def assert_leads_canonical_and_ordered(leads_present: List[str], canonical_leads: tuple) -> None:
|
| 440 |
+
"""Validate all leads are canonical (order is explicit and not enforced here)."""
|
| 441 |
+
lead_list = list(leads_present)
|
| 442 |
+
if any(ld not in canonical_leads for ld in lead_list):
|
| 443 |
+
raise AssertionError(f"Non-canonical lead found in leads_present: {lead_list}")
|
| 444 |
+
if len(set(lead_list)) != len(lead_list):
|
| 445 |
+
raise AssertionError(f"Duplicate lead detected in leads_present: {lead_list}")
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
@_skip_if_assertions_disabled
|
| 449 |
+
def assert_waveform_shapes_valid(
|
| 450 |
+
leads_present: List[str],
|
| 451 |
+
segments_per_lead: List[int],
|
| 452 |
+
waveform_segments: Dict[str, Any],
|
| 453 |
+
) -> None:
|
| 454 |
+
"""Validate waveform tensor shapes and segment counts.
|
| 455 |
+
|
| 456 |
+
Each waveform must be [T, 256] where T matches segments_per_lead.
|
| 457 |
+
"""
|
| 458 |
+
for ld, nseg in zip(leads_present, segments_per_lead):
|
| 459 |
+
assert_positive_int(nseg, f"segments_per_lead[{ld}]")
|
| 460 |
+
wf = waveform_segments[ld]
|
| 461 |
+
if wf.ndim != 2 or wf.shape[1] != 256:
|
| 462 |
+
raise AssertionError(f"Waveform for {ld} must be [T,256], got {tuple(wf.shape)}")
|
| 463 |
+
assert_equal_int(wf.shape[0], nseg, f"Waveform seconds vs segments_per_lead for {ld}")
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
__all__ = [
|
| 467 |
+
"assertions_active",
|
| 468 |
+
"capture_adapter_snapshot",
|
| 469 |
+
"assert_wrapper_adapter_requires_grad",
|
| 470 |
+
"assert_wrapper_optimizer_coverage",
|
| 471 |
+
"assert_adapter_gradients",
|
| 472 |
+
"assert_adapter_updates",
|
| 473 |
+
"assert_trainable_param_sync",
|
| 474 |
+
"assert_tensor_dtype",
|
| 475 |
+
"assert_only_llava_proj_trainable",
|
| 476 |
+
"summarize_trainables_llava_lora",
|
| 477 |
+
"assert_language_lora_only",
|
| 478 |
+
"assert_single_bos_eos",
|
| 479 |
+
"assert_ecg_layout_valid",
|
| 480 |
+
"assert_ecg_mask_against_layout",
|
| 481 |
+
"assert_single_block_mask_matches_reference",
|
| 482 |
+
"assert_additive_mask_padding",
|
| 483 |
+
"assert_nonempty_waveform_segments",
|
| 484 |
+
"assert_prefix_split_complete",
|
| 485 |
+
"assert_prefix_matches_segments",
|
| 486 |
+
"assert_ids_are_lists",
|
| 487 |
+
"assert_embedding_length_matches_tokens",
|
| 488 |
+
"assert_ecg_part_bounds",
|
| 489 |
+
"assert_layout_specials_complete",
|
| 490 |
+
"assert_labels_match_spans",
|
| 491 |
+
"assert_wrapper_embed_length",
|
| 492 |
+
"assert_rest_length_nonnegative",
|
| 493 |
+
"assert_sorted_non_overlapping_spans",
|
| 494 |
+
"assert_equal_int",
|
| 495 |
+
"assert_positive_int",
|
| 496 |
+
"assert_ecg_catalog_valid",
|
| 497 |
+
"assert_normalized_role_canonical",
|
| 498 |
+
"assert_rest_blocks_valid",
|
| 499 |
+
"assert_tokenization_cursor_matches",
|
| 500 |
+
"assert_model_spans_valid",
|
| 501 |
+
"assert_turn_parts_consistent",
|
| 502 |
+
"assert_ecg_blocks_consistent",
|
| 503 |
+
"assert_eos_appended",
|
| 504 |
+
"assert_turn_parts_structure_valid",
|
| 505 |
+
"assert_turn_content_ends_with_eot",
|
| 506 |
+
"assert_leads_canonical_and_ordered",
|
| 507 |
+
"assert_waveform_shapes_valid",
|
| 508 |
+
"assert_collate_item_valid",
|
| 509 |
+
]
|
camel_inference/src/camel/camel_model.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Compatibility wrapper around inference.KardiaLM."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
from typing import Any, Optional
|
| 7 |
+
import json
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from camel.inference import KardiaLM
|
| 11 |
+
from camel.process_ecg import get_waveform
|
| 12 |
+
|
| 13 |
+
class CAMEL:
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
device: torch.device,
|
| 17 |
+
mode: str,
|
| 18 |
+
model_config_name: str = 'medgemma-4b-it',
|
| 19 |
+
conv_ckpt: Optional[str] = None,
|
| 20 |
+
no_lora: bool = False,
|
| 21 |
+
mask_strategy: str = 'semantic',
|
| 22 |
+
**model_args,
|
| 23 |
+
) -> None:
|
| 24 |
+
default_top_k = model_args.pop("default_top_k", 64)
|
| 25 |
+
default_top_p = float(model_args.pop("default_top_p", 0.95))
|
| 26 |
+
default_min_p = float(model_args.pop("default_min_p", 0.0))
|
| 27 |
+
|
| 28 |
+
# Initialize model
|
| 29 |
+
if mode == 'base':
|
| 30 |
+
ckpt = 'checkpoints/camel_base.pt'
|
| 31 |
+
elif mode == 'ecgbench':
|
| 32 |
+
ckpt = 'checkpoints/camel_ecginstruct.pt'
|
| 33 |
+
elif mode == 'forecast':
|
| 34 |
+
ckpt = 'checkpoints/camel_forecast.pt'
|
| 35 |
+
|
| 36 |
+
self.session = KardiaLM(
|
| 37 |
+
model_registry_path=None,
|
| 38 |
+
hf_model_id_override=None,
|
| 39 |
+
model_config_name=model_config_name,
|
| 40 |
+
adapter_ckpt=ckpt,
|
| 41 |
+
conv_ckpt=conv_ckpt,
|
| 42 |
+
no_lora=no_lora,
|
| 43 |
+
default_max_new_tokens=int(model_args.pop("default_max_new_tokens", 1000)),
|
| 44 |
+
default_temperature=float(model_args.pop("default_temperature", 1.0)),
|
| 45 |
+
default_top_k=None if default_top_k is None else int(default_top_k),
|
| 46 |
+
default_top_p=default_top_p,
|
| 47 |
+
default_min_p=default_min_p,
|
| 48 |
+
mask_strategy=mask_strategy,
|
| 49 |
+
device=device
|
| 50 |
+
)
|
| 51 |
+
self.prompt_tokens = self.session.packing_schema.prompt
|
| 52 |
+
self.device = device
|
| 53 |
+
|
| 54 |
+
def run(self, args):
|
| 55 |
+
if args.json is None and (args.ecgs is None):
|
| 56 |
+
raise ValueError("Either one of --json or --ecgs should be non-empty.")
|
| 57 |
+
|
| 58 |
+
if args.json is None:
|
| 59 |
+
text = args.text or ''
|
| 60 |
+
raw_context = [{'type': 'text', 'text': text}]
|
| 61 |
+
for ecg in args.ecgs:
|
| 62 |
+
raw_context.append({'type': 'ecg', 'ecg': ecg})
|
| 63 |
+
else:
|
| 64 |
+
try:
|
| 65 |
+
with open(args.json, "r") as f:
|
| 66 |
+
raw_context = json.load(f)
|
| 67 |
+
except:
|
| 68 |
+
raise ValueError(f'Failed during reading json: {args.json}')
|
| 69 |
+
|
| 70 |
+
content = self._build_content(raw_content=raw_context, ecg_configs=args.ecg_configs)
|
| 71 |
+
generate_kwargs = dict(
|
| 72 |
+
content=content,
|
| 73 |
+
max_new_tokens=args.max_new_tokens,
|
| 74 |
+
temperature=args.temperature,
|
| 75 |
+
top_k=args.top_k,
|
| 76 |
+
top_p=args.top_p,
|
| 77 |
+
min_p=args.min_p,
|
| 78 |
+
)
|
| 79 |
+
generated_text = self.generate(**generate_kwargs)
|
| 80 |
+
return generated_text
|
| 81 |
+
|
| 82 |
+
def generate(
|
| 83 |
+
self,
|
| 84 |
+
content: list[dict[str, Any]],
|
| 85 |
+
max_new_tokens: int = 1000,
|
| 86 |
+
temperature: float = 1.0,
|
| 87 |
+
top_k: Optional[int] = 64,
|
| 88 |
+
top_p: float = 0.95,
|
| 89 |
+
min_p: float = 0.0,
|
| 90 |
+
) -> str:
|
| 91 |
+
text, prompt_preview = self.session.chat(
|
| 92 |
+
conversation=content,
|
| 93 |
+
max_new_tokens=max_new_tokens,
|
| 94 |
+
temperature=temperature,
|
| 95 |
+
top_k=top_k,
|
| 96 |
+
top_p=top_p,
|
| 97 |
+
min_p=min_p,
|
| 98 |
+
)
|
| 99 |
+
return text, prompt_preview
|
| 100 |
+
|
| 101 |
+
def _parse_ecg_config(self, ecg_configs: Optional[list[str]], n_ecgs:int) -> tuple:
|
| 102 |
+
def _parse_single_config(config):
|
| 103 |
+
start_ind, end_ind, leads = None, None, None
|
| 104 |
+
for field in config.split(";"):
|
| 105 |
+
field = field.strip()
|
| 106 |
+
if not field:
|
| 107 |
+
continue
|
| 108 |
+
if ":" not in field:
|
| 109 |
+
raise ValueError(f"Invalid field: {field}. Expected key:value.")
|
| 110 |
+
|
| 111 |
+
key, value = field.split(":", 1)
|
| 112 |
+
key = key.strip().lower()
|
| 113 |
+
value = value.strip()
|
| 114 |
+
|
| 115 |
+
if key == "start":
|
| 116 |
+
start_ind = int(value)
|
| 117 |
+
elif key == "end":
|
| 118 |
+
end_ind = int(value)
|
| 119 |
+
elif key in ("use_leads", "leads"):
|
| 120 |
+
leads = [x.strip() for x in value.split(",") if x.strip()]
|
| 121 |
+
else:
|
| 122 |
+
print(f"Ignoring the unknown key: {key}")
|
| 123 |
+
return start_ind, end_ind, leads
|
| 124 |
+
|
| 125 |
+
if ecg_configs is None:
|
| 126 |
+
output = [None] * n_ecgs
|
| 127 |
+
return output, output, output
|
| 128 |
+
|
| 129 |
+
n_configs = len(ecg_configs)
|
| 130 |
+
if n_configs!= 1 and n_configs != n_ecgs:
|
| 131 |
+
raise ValueError(f'Found {n_configs} ECG configs for {n_ecgs} ECG inputs. The number of config should be 1 or match the number of ECGs.')
|
| 132 |
+
|
| 133 |
+
start_inds, end_inds, leads = [], [], []
|
| 134 |
+
for config in ecg_configs:
|
| 135 |
+
start_ind, end_ind, lead = _parse_single_config(config)
|
| 136 |
+
print(f'ECG Config: {start_ind}, {end_ind}, {lead}')
|
| 137 |
+
start_inds.append(start_ind)
|
| 138 |
+
end_inds.append(end_ind)
|
| 139 |
+
leads.append(lead)
|
| 140 |
+
|
| 141 |
+
if n_configs == 1 and n_ecgs > 1:
|
| 142 |
+
start_inds = start_inds * n_ecgs
|
| 143 |
+
end_inds = end_inds * n_ecgs
|
| 144 |
+
leads = leads * n_ecgs
|
| 145 |
+
|
| 146 |
+
return start_inds, end_inds, leads
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def _build_content(self, *, raw_content: list[dict[str, str]], ecg_configs: Optional[list[str]]) -> list[dict[str, Any]]:
|
| 150 |
+
n_ecgs = sum([True for c in raw_content if c['type'] == 'ecg'])
|
| 151 |
+
starts, ends, leads = self._parse_ecg_config(ecg_configs, n_ecgs)
|
| 152 |
+
ecg_ind = 0
|
| 153 |
+
|
| 154 |
+
content: list[dict[str, Any]] = []
|
| 155 |
+
for c in raw_content:
|
| 156 |
+
if c['type'] == 'text':
|
| 157 |
+
content.append({"type": "text", "text": c['text']})
|
| 158 |
+
elif c['type'] == 'ecg':
|
| 159 |
+
waveform = get_waveform(ecg_path=c['ecg'], start_sec=starts[ecg_ind], end_sec=ends[ecg_ind], leads=leads[ecg_ind], device=self.device)
|
| 160 |
+
ecg_ind += 1
|
| 161 |
+
content.append({"type": "ecg", "waveform_segments": waveform})
|
| 162 |
+
|
| 163 |
+
conversation = [{"from": self.prompt_tokens.user_role, "content": content}]
|
| 164 |
+
return conversation
|
| 165 |
+
|
| 166 |
+
__all__ = ["CAMEL"]
|
camel_inference/src/camel/checkpoint_utils.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Checkpoint- and state-management utilities shared by ECG training scripts.
|
| 3 |
+
These helpers were extracted from train_ecg_text.py to keep the training entrypoint
|
| 4 |
+
focused on orchestration while preserving original behaviour.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
from typing import Any, Dict, Optional, Tuple
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch.distributed.checkpoint.state_dict import (
|
| 12 |
+
set_model_state_dict,
|
| 13 |
+
StateDictOptions,
|
| 14 |
+
)
|
| 15 |
+
from torch.distributed.tensor import DTensor
|
| 16 |
+
from peft import (
|
| 17 |
+
LoraConfig,
|
| 18 |
+
TaskType,
|
| 19 |
+
set_peft_model_state_dict,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
from camel.training_setup import is_main_process
|
| 23 |
+
|
| 24 |
+
def _module_has_dtensor_params(mod: nn.Module) -> bool:
|
| 25 |
+
"""
|
| 26 |
+
Return True if any parameter tensor underlying the module is a DTensor.
|
| 27 |
+
Under FSDP2 it is typically parameter.data that carries the DTensor type.
|
| 28 |
+
"""
|
| 29 |
+
for param in mod.parameters(recurse=True):
|
| 30 |
+
if isinstance(getattr(param, "data", None), DTensor):
|
| 31 |
+
return True
|
| 32 |
+
return False
|
| 33 |
+
|
| 34 |
+
def _extract_projector_name(payload: Dict[str, Any]) -> Optional[str]:
|
| 35 |
+
"""Return the stored projector name if present."""
|
| 36 |
+
name = payload.get("projector_name")
|
| 37 |
+
if isinstance(name, str) and name:
|
| 38 |
+
return name
|
| 39 |
+
extra = payload.get("extra")
|
| 40 |
+
if isinstance(extra, dict):
|
| 41 |
+
extra_name = extra.get("projector_name")
|
| 42 |
+
if isinstance(extra_name, str) and extra_name:
|
| 43 |
+
return extra_name
|
| 44 |
+
return None
|
| 45 |
+
|
| 46 |
+
def peek_projector_name(path: str) -> Optional[str]:
|
| 47 |
+
"""Load a checkpoint just far enough to read the projector name metadata."""
|
| 48 |
+
if path is None:
|
| 49 |
+
return None
|
| 50 |
+
payload = torch.load(path, map_location="cpu")
|
| 51 |
+
if not isinstance(payload, dict):
|
| 52 |
+
raise RuntimeError(f"Checkpoint {path} must be a dict to inspect projector metadata.")
|
| 53 |
+
return _extract_projector_name(payload)
|
| 54 |
+
|
| 55 |
+
def extract_lora_config_from_checkpoints(
|
| 56 |
+
resume_ckpt_path: Optional[str],
|
| 57 |
+
load_llava_from: Optional[str],
|
| 58 |
+
) -> Optional[Dict[str, Any]]:
|
| 59 |
+
"""Extract the LoRA configuration embedded in checkpoints (resume path preferred)."""
|
| 60 |
+
def _load_config(path: Optional[str]) -> Optional[Dict[str, Any]]:
|
| 61 |
+
if not path:
|
| 62 |
+
return None
|
| 63 |
+
try:
|
| 64 |
+
payload = torch.load(path, map_location="cpu")
|
| 65 |
+
except Exception as exc:
|
| 66 |
+
raise RuntimeError(
|
| 67 |
+
f"Failed to load checkpoint '{path}' while extracting LoRA config"
|
| 68 |
+
) from exc
|
| 69 |
+
lora_payload = payload.get("lora")
|
| 70 |
+
if isinstance(lora_payload, dict) and isinstance(lora_payload.get("config"), dict):
|
| 71 |
+
cfg = dict(lora_payload["config"])
|
| 72 |
+
if "use_dora" in cfg:
|
| 73 |
+
cfg["use_dora"] = bool(cfg["use_dora"])
|
| 74 |
+
return cfg
|
| 75 |
+
return None
|
| 76 |
+
cfg = _load_config(resume_ckpt_path)
|
| 77 |
+
if cfg is not None:
|
| 78 |
+
return cfg
|
| 79 |
+
return _load_config(load_llava_from)
|
| 80 |
+
|
| 81 |
+
def load_llava_and_lora(
|
| 82 |
+
wrapper: nn.Module,
|
| 83 |
+
model: nn.Module,
|
| 84 |
+
ckpt_path: str,
|
| 85 |
+
*,
|
| 86 |
+
expect_lora: bool,
|
| 87 |
+
load_lora: bool = True,
|
| 88 |
+
missing_lora_ok: bool = False,
|
| 89 |
+
) -> Tuple[Dict[str, Any], nn.Module, Optional[LoraConfig]]:
|
| 90 |
+
"""
|
| 91 |
+
Load llava_proj (mandatory), optional conv encoder weights, ECG special-token
|
| 92 |
+
embeddings, and LoRA adapters from a checkpoint.
|
| 93 |
+
"""
|
| 94 |
+
payload = torch.load(ckpt_path, map_location="cpu")
|
| 95 |
+
if not isinstance(payload, dict):
|
| 96 |
+
raise RuntimeError(f"Checkpoint {ckpt_path} must be a dict, got {type(payload).__name__}")
|
| 97 |
+
extra_payload = payload.get("extra") or {}
|
| 98 |
+
if not isinstance(extra_payload, dict):
|
| 99 |
+
raise RuntimeError(
|
| 100 |
+
f"Checkpoint {ckpt_path} has non-dict extra payload of type {type(extra_payload).__name__}"
|
| 101 |
+
)
|
| 102 |
+
ckpt_projector = _extract_projector_name(payload)
|
| 103 |
+
wrapper_projector = getattr(wrapper, "projector_name", None)
|
| 104 |
+
if ckpt_projector is not None and wrapper_projector is not None and wrapper_projector != ckpt_projector:
|
| 105 |
+
raise RuntimeError(
|
| 106 |
+
f"Checkpoint {ckpt_path} projector '{ckpt_projector}' does not match wrapper projector '{wrapper_projector}'."
|
| 107 |
+
)
|
| 108 |
+
llava_sd = payload.get("llava_proj")
|
| 109 |
+
if not isinstance(llava_sd, dict):
|
| 110 |
+
raise RuntimeError(f"Checkpoint {ckpt_path} missing llava_proj state_dict.")
|
| 111 |
+
if is_main_process():
|
| 112 |
+
print(f"[load-llava] Loading llava_proj weights from {ckpt_path}", flush=True)
|
| 113 |
+
any_llava_dt = _module_has_dtensor_params(wrapper.llava_proj)
|
| 114 |
+
if any_llava_dt:
|
| 115 |
+
set_model_state_dict(
|
| 116 |
+
model=wrapper.llava_proj,
|
| 117 |
+
model_state_dict=llava_sd,
|
| 118 |
+
options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True),
|
| 119 |
+
)
|
| 120 |
+
else:
|
| 121 |
+
wrapper.llava_proj.load_state_dict(llava_sd, strict=True)
|
| 122 |
+
conv_sd = payload.get("conv")
|
| 123 |
+
conv_expected = bool(extra_payload.get("conv_trainable"))
|
| 124 |
+
if conv_sd is None:
|
| 125 |
+
if conv_expected:
|
| 126 |
+
raise RuntimeError(
|
| 127 |
+
f"Checkpoint {ckpt_path} indicates conv_trainable=True but conv weights are missing."
|
| 128 |
+
)
|
| 129 |
+
else:
|
| 130 |
+
if not isinstance(conv_sd, dict):
|
| 131 |
+
raise RuntimeError(
|
| 132 |
+
f"Checkpoint {ckpt_path} conv payload must be a state_dict, got {type(conv_sd).__name__}"
|
| 133 |
+
)
|
| 134 |
+
if is_main_process():
|
| 135 |
+
print(f"[load-llava] Loading conv encoder weights from {ckpt_path}", flush=True)
|
| 136 |
+
any_conv_dt = _module_has_dtensor_params(wrapper.enc)
|
| 137 |
+
if any_conv_dt:
|
| 138 |
+
set_model_state_dict(
|
| 139 |
+
model=wrapper.enc,
|
| 140 |
+
model_state_dict=conv_sd,
|
| 141 |
+
options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True),
|
| 142 |
+
)
|
| 143 |
+
else:
|
| 144 |
+
wrapper.enc.load_state_dict(conv_sd, strict=True)
|
| 145 |
+
ecg_special_sd = payload.get("ecg_special")
|
| 146 |
+
if not isinstance(ecg_special_sd, dict):
|
| 147 |
+
raise RuntimeError(f"Checkpoint {ckpt_path} missing ECG special-token embedding state.")
|
| 148 |
+
if is_main_process():
|
| 149 |
+
print(f"[load-llava] Loading ECG special-token embedding from {ckpt_path}", flush=True)
|
| 150 |
+
any_special_dt = _module_has_dtensor_params(wrapper.ecg_special_embed)
|
| 151 |
+
if any_special_dt:
|
| 152 |
+
set_model_state_dict(
|
| 153 |
+
model=wrapper.ecg_special_embed,
|
| 154 |
+
model_state_dict=ecg_special_sd,
|
| 155 |
+
options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True),
|
| 156 |
+
)
|
| 157 |
+
else:
|
| 158 |
+
wrapper.ecg_special_embed.load_state_dict(ecg_special_sd, strict=True)
|
| 159 |
+
lora_payload = payload.get("lora")
|
| 160 |
+
loaded_lora = False
|
| 161 |
+
created_cfg: Optional[LoraConfig] = None
|
| 162 |
+
if load_lora and lora_payload is not None:
|
| 163 |
+
if not isinstance(lora_payload, dict):
|
| 164 |
+
raise RuntimeError(
|
| 165 |
+
f"Checkpoint {ckpt_path} has non-dict LoRA payload of type {type(lora_payload).__name__}"
|
| 166 |
+
)
|
| 167 |
+
lora_state = lora_payload.get("state_dict")
|
| 168 |
+
if not isinstance(lora_state, dict):
|
| 169 |
+
raise RuntimeError(f"Checkpoint {ckpt_path} LoRA payload missing state_dict.")
|
| 170 |
+
if is_main_process():
|
| 171 |
+
print(f"[load-llava] Loading LoRA adapters from {ckpt_path}", flush=True)
|
| 172 |
+
set_peft_model_state_dict(model, lora_state)
|
| 173 |
+
loaded_lora = True
|
| 174 |
+
cfg_dict = lora_payload.get("config")
|
| 175 |
+
if isinstance(cfg_dict, dict):
|
| 176 |
+
cfg_args = dict(cfg_dict)
|
| 177 |
+
task_type_raw = cfg_args.get("task_type", TaskType.CAUSAL_LM)
|
| 178 |
+
if not isinstance(task_type_raw, TaskType):
|
| 179 |
+
try:
|
| 180 |
+
task_type_raw = TaskType(task_type_raw)
|
| 181 |
+
except Exception:
|
| 182 |
+
task_type_raw = TaskType.CAUSAL_LM
|
| 183 |
+
cfg_args["task_type"] = task_type_raw
|
| 184 |
+
if "lora_dropout" in cfg_args:
|
| 185 |
+
try:
|
| 186 |
+
cfg_args["lora_dropout"] = float(cfg_args["lora_dropout"])
|
| 187 |
+
except Exception:
|
| 188 |
+
cfg_args["lora_dropout"] = 0.0
|
| 189 |
+
if "r" in cfg_args:
|
| 190 |
+
try:
|
| 191 |
+
cfg_args["r"] = int(cfg_args["r"])
|
| 192 |
+
except Exception:
|
| 193 |
+
cfg_args["r"] = 0
|
| 194 |
+
if "lora_alpha" in cfg_args:
|
| 195 |
+
try:
|
| 196 |
+
cfg_args["lora_alpha"] = int(cfg_args["lora_alpha"])
|
| 197 |
+
except Exception:
|
| 198 |
+
cfg_args["lora_alpha"] = 0
|
| 199 |
+
if "target_modules" in cfg_args and cfg_args["target_modules"] is not None:
|
| 200 |
+
cfg_args["target_modules"] = list(cfg_args["target_modules"])
|
| 201 |
+
cfg_args.setdefault("bias", "none")
|
| 202 |
+
cfg_args.setdefault("inference_mode", False)
|
| 203 |
+
cfg_args["use_dora"] = bool(cfg_args.get("use_dora", False))
|
| 204 |
+
try:
|
| 205 |
+
created_cfg = LoraConfig(**cfg_args)
|
| 206 |
+
except Exception:
|
| 207 |
+
created_cfg = None
|
| 208 |
+
if expect_lora and load_lora and not loaded_lora:
|
| 209 |
+
if missing_lora_ok:
|
| 210 |
+
if is_main_process():
|
| 211 |
+
print(
|
| 212 |
+
f"[load-llava] Warning: checkpoint {ckpt_path} contains no LoRA adapters; "
|
| 213 |
+
"continuing with the currently-initialized adapters.",
|
| 214 |
+
flush=True,
|
| 215 |
+
)
|
| 216 |
+
else:
|
| 217 |
+
raise RuntimeError(
|
| 218 |
+
f"[load-llava] Warning: expected LoRA adapters in {ckpt_path} but none were loaded."
|
| 219 |
+
)
|
| 220 |
+
return extra_payload, model, created_cfg
|
| 221 |
+
|
| 222 |
+
def update_wrapper_language_model(wrapper: nn.Module, model: nn.Module) -> None:
|
| 223 |
+
"""Ensure the wrapper references the latest language-model instance."""
|
| 224 |
+
wrapper.language_model = model
|
| 225 |
+
|
| 226 |
+
__all__ = [
|
| 227 |
+
"find_latest_step_checkpoint",
|
| 228 |
+
"extract_lora_config_from_checkpoints",
|
| 229 |
+
"dump_lora_state_fsdp_safe",
|
| 230 |
+
"prepare_optimizer_state_payload",
|
| 231 |
+
"load_llava_and_lora",
|
| 232 |
+
"update_wrapper_language_model",
|
| 233 |
+
"ensure_no_dtensor",
|
| 234 |
+
"peek_projector_name",
|
| 235 |
+
]
|
camel_inference/src/camel/ecg_attention_masks.py
ADDED
|
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ECG-aware attention mask utilities with pluggable strategy support."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from typing import Dict, List, Optional, Protocol
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class ECGBlockLayout:
|
| 11 |
+
"""Layout describing a single ECG block inside the assembled sequence."""
|
| 12 |
+
|
| 13 |
+
start_idx: Optional[int]
|
| 14 |
+
end_idx_exclusive: Optional[int]
|
| 15 |
+
global_start_idx: Optional[int] = None
|
| 16 |
+
global_end_idx: Optional[int] = None
|
| 17 |
+
lead_start_idx: Dict[str, int] = field(default_factory=dict)
|
| 18 |
+
lead_end_idx: Dict[str, int] = field(default_factory=dict)
|
| 19 |
+
signal_pos_by_lead: Dict[str, List[int]] = field(default_factory=dict)
|
| 20 |
+
time_to_signal_idxs: Dict[int, List[int]] = field(default_factory=dict)
|
| 21 |
+
special_idxs_sorted: List[int] = field(default_factory=list)
|
| 22 |
+
signal_pos_list: List[int] = field(default_factory=list)
|
| 23 |
+
declared_segments_per_lead: Dict[str, int] = field(default_factory=dict)
|
| 24 |
+
conv_idxs: List[int] = field(default_factory=list)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class ECGSequenceLayout:
|
| 29 |
+
"""Compact description of the assembled token layout for one training sample."""
|
| 30 |
+
|
| 31 |
+
seq_len: int
|
| 32 |
+
text_idxs: List[int] = field(default_factory=list)
|
| 33 |
+
blocks: List[ECGBlockLayout] = field(default_factory=list)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _as_tensor(indices: List[int], device: torch.device) -> torch.Tensor:
|
| 37 |
+
if not indices:
|
| 38 |
+
return torch.empty(0, dtype=torch.long, device=device)
|
| 39 |
+
return torch.tensor(sorted(indices), dtype=torch.long, device=device)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def apply_single_block_semantic_mask_(
|
| 43 |
+
allowed: torch.Tensor,
|
| 44 |
+
block_layout: ECGBlockLayout,
|
| 45 |
+
*,
|
| 46 |
+
visible_prefix_len: int,
|
| 47 |
+
key_limit_exclusive: int,
|
| 48 |
+
apply_header_causal: bool = True,
|
| 49 |
+
) -> None:
|
| 50 |
+
"""
|
| 51 |
+
In-place semantic mask update for a single ECG block. Mirrors the historical
|
| 52 |
+
single-block logic, but operates on a provided boolean mask.
|
| 53 |
+
"""
|
| 54 |
+
L = int(allowed.size(0))
|
| 55 |
+
device = allowed.device
|
| 56 |
+
header = (
|
| 57 |
+
torch.arange(int(visible_prefix_len), dtype=torch.long, device=device)
|
| 58 |
+
if int(visible_prefix_len) > 0
|
| 59 |
+
else torch.empty(0, dtype=torch.long, device=device)
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
specials_list = block_layout.special_idxs_sorted or []
|
| 63 |
+
if not specials_list:
|
| 64 |
+
specials_list = sorted(
|
| 65 |
+
([block_layout.global_start_idx] if block_layout.global_start_idx is not None else [])
|
| 66 |
+
+ list(block_layout.lead_start_idx.values())
|
| 67 |
+
+ list(block_layout.lead_end_idx.values())
|
| 68 |
+
+ ([block_layout.global_end_idx] if block_layout.global_end_idx is not None else [])
|
| 69 |
+
)
|
| 70 |
+
block_layout.special_idxs_sorted = specials_list
|
| 71 |
+
if key_limit_exclusive is not None:
|
| 72 |
+
specials_list = [i for i in specials_list if int(i) < int(key_limit_exclusive)]
|
| 73 |
+
specials = _as_tensor([int(i) for i in specials_list], device)
|
| 74 |
+
|
| 75 |
+
signals_list = block_layout.signal_pos_list or []
|
| 76 |
+
if not signals_list:
|
| 77 |
+
signal_all: List[int] = []
|
| 78 |
+
for lst in block_layout.signal_pos_by_lead.values():
|
| 79 |
+
signal_all.extend(lst)
|
| 80 |
+
signals_list = sorted(signal_all)
|
| 81 |
+
block_layout.signal_pos_list = signals_list
|
| 82 |
+
if key_limit_exclusive is not None:
|
| 83 |
+
signals_list = [i for i in signals_list if int(i) < int(key_limit_exclusive)]
|
| 84 |
+
signals = _as_tensor([int(i) for i in signals_list], device)
|
| 85 |
+
|
| 86 |
+
lead_starts = _as_tensor(list(block_layout.lead_start_idx.values()), device)
|
| 87 |
+
lead_ends = _as_tensor(list(block_layout.lead_end_idx.values()), device)
|
| 88 |
+
|
| 89 |
+
if apply_header_causal and header.numel():
|
| 90 |
+
allowed[header[:, None], header[None, :]] = header[:, None] >= header[None, :]
|
| 91 |
+
|
| 92 |
+
gs = block_layout.global_start_idx
|
| 93 |
+
if gs is not None and header.numel():
|
| 94 |
+
allowed[int(gs), header] = True
|
| 95 |
+
|
| 96 |
+
rows_before = []
|
| 97 |
+
if lead_starts.numel():
|
| 98 |
+
rows_before.append(lead_starts)
|
| 99 |
+
if signals.numel():
|
| 100 |
+
rows_before.append(signals)
|
| 101 |
+
if lead_ends.numel():
|
| 102 |
+
rows_before.append(lead_ends)
|
| 103 |
+
rows_before_t = (
|
| 104 |
+
torch.cat(rows_before, dim=0) if rows_before else torch.empty(0, dtype=torch.long, device=device)
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
if rows_before_t.numel():
|
| 108 |
+
if header.numel():
|
| 109 |
+
allowed[rows_before_t[:, None], header[None, :]] = True
|
| 110 |
+
if specials.numel():
|
| 111 |
+
allowed[rows_before_t[:, None], specials[None, :]] = specials[None, :] < rows_before_t[:, None]
|
| 112 |
+
|
| 113 |
+
ttsi: Dict[int, List[int]] = block_layout.time_to_signal_idxs
|
| 114 |
+
if signals.numel() and ttsi:
|
| 115 |
+
pos_min_time: Dict[int, int] = {}
|
| 116 |
+
pos_to_time: Dict[int, int] = {}
|
| 117 |
+
for t, idxs in ttsi.items():
|
| 118 |
+
for p in idxs:
|
| 119 |
+
pos_to_time[p] = t
|
| 120 |
+
prev = pos_min_time.get(p)
|
| 121 |
+
if prev is None or t < prev:
|
| 122 |
+
pos_min_time[p] = t
|
| 123 |
+
|
| 124 |
+
u_pos_list = sorted(pos_min_time.keys())
|
| 125 |
+
if u_pos_list:
|
| 126 |
+
u_pos = torch.tensor(u_pos_list, dtype=torch.long, device=device)
|
| 127 |
+
u_time = torch.tensor([pos_min_time[p] for p in u_pos_list], dtype=torch.long, device=device)
|
| 128 |
+
q_time = torch.tensor([pos_to_time.get(p, 0) for p in signals_list], dtype=torch.long, device=device)
|
| 129 |
+
allowed[signals[:, None], u_pos[None, :]] = (u_time[None, :] <= q_time[:, None])
|
| 130 |
+
|
| 131 |
+
for lead, eidx in block_layout.lead_end_idx.items():
|
| 132 |
+
lead_sigs = block_layout.signal_pos_by_lead.get(lead, [])
|
| 133 |
+
if lead_sigs:
|
| 134 |
+
allowed[int(eidx), torch.tensor(lead_sigs, dtype=torch.long, device=device)] = True
|
| 135 |
+
|
| 136 |
+
ge = block_layout.global_end_idx
|
| 137 |
+
if ge is not None:
|
| 138 |
+
gei = int(ge)
|
| 139 |
+
if header.numel():
|
| 140 |
+
allowed[gei, header] = True
|
| 141 |
+
if specials.numel():
|
| 142 |
+
allowed[gei, specials] = True
|
| 143 |
+
if signals.numel():
|
| 144 |
+
allowed[gei, signals] = True
|
| 145 |
+
|
| 146 |
+
conv = _as_tensor(block_layout.conv_idxs, device)
|
| 147 |
+
if conv.numel():
|
| 148 |
+
allowed[conv[:, None], conv[None, :]] = conv[:, None] >= conv[None, :]
|
| 149 |
+
if header.numel():
|
| 150 |
+
allowed[conv[:, None], header[None, :]] = True
|
| 151 |
+
if specials.numel():
|
| 152 |
+
allowed[conv[:, None], specials[None, :]] = True
|
| 153 |
+
if signals.numel():
|
| 154 |
+
allowed[conv[:, None], signals[None, :]] = True
|
| 155 |
+
cols = torch.arange(L, device=device)
|
| 156 |
+
conv_rows = allowed[conv, :]
|
| 157 |
+
conv_rows &= (cols.unsqueeze(0) <= conv.unsqueeze(1))
|
| 158 |
+
allowed[conv, :] = conv_rows
|
| 159 |
+
|
| 160 |
+
if specials.numel():
|
| 161 |
+
allowed[specials, specials] = True
|
| 162 |
+
|
| 163 |
+
if key_limit_exclusive is not None and int(key_limit_exclusive) < L:
|
| 164 |
+
block_rows_list = list(specials_list) + list(signals_list)
|
| 165 |
+
if block_rows_list:
|
| 166 |
+
block_rows = _as_tensor(block_rows_list, device)
|
| 167 |
+
allowed[block_rows, int(key_limit_exclusive):] = False
|
| 168 |
+
|
| 169 |
+
@dataclass
|
| 170 |
+
class MaskBuildResult:
|
| 171 |
+
"""Container for per-sample mask artifacts produced by a strategy."""
|
| 172 |
+
|
| 173 |
+
additive: torch.Tensor
|
| 174 |
+
boolean: Optional[torch.Tensor] = None
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class ECGMaskStrategy(Protocol):
|
| 178 |
+
"""Protocol for strategies that build and update per-sample attention masks."""
|
| 179 |
+
|
| 180 |
+
name: str
|
| 181 |
+
|
| 182 |
+
def build(
|
| 183 |
+
self,
|
| 184 |
+
layout: ECGSequenceLayout,
|
| 185 |
+
*,
|
| 186 |
+
device: torch.device,
|
| 187 |
+
dtype: torch.dtype,
|
| 188 |
+
) -> MaskBuildResult:
|
| 189 |
+
...
|
| 190 |
+
|
| 191 |
+
def update_for_generated_token(
|
| 192 |
+
self,
|
| 193 |
+
layout: ECGSequenceLayout,
|
| 194 |
+
*,
|
| 195 |
+
device: torch.device,
|
| 196 |
+
dtype: torch.dtype,
|
| 197 |
+
previous: MaskBuildResult,
|
| 198 |
+
) -> MaskBuildResult:
|
| 199 |
+
...
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class SemanticMaskStrategy:
|
| 203 |
+
"""Default strategy reproducing the historical ECG-aware attention mask."""
|
| 204 |
+
|
| 205 |
+
name = "semantic"
|
| 206 |
+
|
| 207 |
+
def build(
|
| 208 |
+
self,
|
| 209 |
+
layout: ECGSequenceLayout,
|
| 210 |
+
*,
|
| 211 |
+
device: torch.device,
|
| 212 |
+
dtype: torch.dtype,
|
| 213 |
+
) -> MaskBuildResult:
|
| 214 |
+
L = int(layout.seq_len)
|
| 215 |
+
if L <= 0:
|
| 216 |
+
raise ValueError("Semantic mask requires a positive sequence length")
|
| 217 |
+
allowed = torch.tril(torch.ones((L, L), dtype=torch.bool, device=device))
|
| 218 |
+
multi_block = len(layout.blocks) > 1
|
| 219 |
+
for block in layout.blocks:
|
| 220 |
+
if block.start_idx is None or block.end_idx_exclusive is None:
|
| 221 |
+
continue
|
| 222 |
+
block_rows: List[int] = []
|
| 223 |
+
specials_list = block.special_idxs_sorted or []
|
| 224 |
+
if not specials_list:
|
| 225 |
+
specials_list = sorted(
|
| 226 |
+
([block.global_start_idx] if block.global_start_idx is not None else [])
|
| 227 |
+
+ list(block.lead_start_idx.values())
|
| 228 |
+
+ list(block.lead_end_idx.values())
|
| 229 |
+
+ ([block.global_end_idx] if block.global_end_idx is not None else [])
|
| 230 |
+
)
|
| 231 |
+
block.special_idxs_sorted = specials_list
|
| 232 |
+
signals_list = block.signal_pos_list or []
|
| 233 |
+
if not signals_list:
|
| 234 |
+
signal_all: List[int] = []
|
| 235 |
+
for lst in block.signal_pos_by_lead.values():
|
| 236 |
+
signal_all.extend(lst)
|
| 237 |
+
signals_list = sorted(signal_all)
|
| 238 |
+
block.signal_pos_list = signals_list
|
| 239 |
+
if specials_list:
|
| 240 |
+
block_rows.extend(int(i) for i in specials_list)
|
| 241 |
+
if signals_list:
|
| 242 |
+
block_rows.extend(int(i) for i in signals_list)
|
| 243 |
+
if block_rows:
|
| 244 |
+
rows = torch.tensor(sorted(set(block_rows)), dtype=torch.long, device=device)
|
| 245 |
+
allowed[rows, :] = False
|
| 246 |
+
apply_single_block_semantic_mask_(
|
| 247 |
+
allowed,
|
| 248 |
+
block,
|
| 249 |
+
visible_prefix_len=int(block.start_idx),
|
| 250 |
+
key_limit_exclusive=int(block.end_idx_exclusive),
|
| 251 |
+
apply_header_causal=not multi_block,
|
| 252 |
+
)
|
| 253 |
+
additive = self._boolean_to_additive(allowed, device=device, dtype=dtype)
|
| 254 |
+
return MaskBuildResult(additive=additive, boolean=allowed)
|
| 255 |
+
|
| 256 |
+
def update_for_generated_token(
|
| 257 |
+
self,
|
| 258 |
+
layout: ECGSequenceLayout,
|
| 259 |
+
*,
|
| 260 |
+
device: torch.device,
|
| 261 |
+
dtype: torch.dtype,
|
| 262 |
+
previous: MaskBuildResult,
|
| 263 |
+
) -> MaskBuildResult:
|
| 264 |
+
allowed = previous.boolean
|
| 265 |
+
if allowed is None:
|
| 266 |
+
return self.build(layout, device=device, dtype=dtype)
|
| 267 |
+
prev_len = int(allowed.size(0))
|
| 268 |
+
new_allowed = torch.zeros((prev_len + 1, prev_len + 1), dtype=torch.bool, device=device)
|
| 269 |
+
new_allowed[:prev_len, :prev_len] = allowed
|
| 270 |
+
new_allowed[prev_len, : prev_len + 1] = True
|
| 271 |
+
additive = self._boolean_to_additive(new_allowed, device=device, dtype=dtype)
|
| 272 |
+
return MaskBuildResult(additive=additive, boolean=new_allowed)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
@classmethod
|
| 276 |
+
def _build_boolean_mask(cls, layout: "ECGSequenceLayout", device: torch.device) -> torch.Tensor:
|
| 277 |
+
if len(layout.blocks) != 1:
|
| 278 |
+
raise ValueError("Single-block mask builder requires exactly one ECG block")
|
| 279 |
+
L = int(layout.seq_len)
|
| 280 |
+
allowed = torch.zeros((L, L), dtype=torch.bool, device=device)
|
| 281 |
+
block = layout.blocks[0]
|
| 282 |
+
prefix_len = int(block.start_idx or 0)
|
| 283 |
+
end_idx = int(block.end_idx_exclusive or L)
|
| 284 |
+
conv_idxs = [idx for idx in layout.text_idxs if int(idx) >= end_idx]
|
| 285 |
+
block_ref = ECGBlockLayout(
|
| 286 |
+
start_idx=block.start_idx,
|
| 287 |
+
end_idx_exclusive=block.end_idx_exclusive,
|
| 288 |
+
global_start_idx=block.global_start_idx,
|
| 289 |
+
global_end_idx=block.global_end_idx,
|
| 290 |
+
lead_start_idx=dict(block.lead_start_idx),
|
| 291 |
+
lead_end_idx=dict(block.lead_end_idx),
|
| 292 |
+
signal_pos_by_lead=dict(block.signal_pos_by_lead),
|
| 293 |
+
time_to_signal_idxs=dict(block.time_to_signal_idxs),
|
| 294 |
+
special_idxs_sorted=list(block.special_idxs_sorted),
|
| 295 |
+
signal_pos_list=list(block.signal_pos_list),
|
| 296 |
+
declared_segments_per_lead=dict(block.declared_segments_per_lead),
|
| 297 |
+
conv_idxs=conv_idxs,
|
| 298 |
+
)
|
| 299 |
+
apply_single_block_semantic_mask_(
|
| 300 |
+
allowed,
|
| 301 |
+
block_ref,
|
| 302 |
+
visible_prefix_len=prefix_len,
|
| 303 |
+
key_limit_exclusive=end_idx,
|
| 304 |
+
)
|
| 305 |
+
return allowed
|
| 306 |
+
|
| 307 |
+
@staticmethod
|
| 308 |
+
def _boolean_to_additive(allowed: torch.Tensor, *, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
| 309 |
+
additive = torch.zeros(allowed.shape, dtype=dtype, device=device)
|
| 310 |
+
additive.masked_fill_(~allowed, float("-inf"))
|
| 311 |
+
return additive
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
DEFAULT_MASK_STRATEGY = SemanticMaskStrategy()
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
MASK_STRATEGY_REGISTRY: Dict[str, ECGMaskStrategy] = {
|
| 318 |
+
DEFAULT_MASK_STRATEGY.name: DEFAULT_MASK_STRATEGY,
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def get_mask_strategy(name: Optional[str]) -> ECGMaskStrategy:
|
| 323 |
+
"""Resolve a registered mask strategy by name (case-insensitive)."""
|
| 324 |
+
if name is None:
|
| 325 |
+
return DEFAULT_MASK_STRATEGY
|
| 326 |
+
key = str(name).lower()
|
| 327 |
+
try:
|
| 328 |
+
return MASK_STRATEGY_REGISTRY[key]
|
| 329 |
+
except KeyError as exc:
|
| 330 |
+
known = ", ".join(sorted(MASK_STRATEGY_REGISTRY))
|
| 331 |
+
raise ValueError(f"Unknown ECG mask strategy '{name}'. Known strategies: {known}") from exc
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
__all__ = [
|
| 335 |
+
"ECGBlockLayout",
|
| 336 |
+
"ECGSequenceLayout",
|
| 337 |
+
"MaskBuildResult",
|
| 338 |
+
"ECGMaskStrategy",
|
| 339 |
+
"SemanticMaskStrategy",
|
| 340 |
+
"DEFAULT_MASK_STRATEGY",
|
| 341 |
+
"MASK_STRATEGY_REGISTRY",
|
| 342 |
+
"get_mask_strategy",
|
| 343 |
+
]
|
camel_inference/src/camel/ecg_gemma_model.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ecg_gemma_model.py
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from typing import Optional
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
from transformers import AutoModelForCausalLM
|
| 9 |
+
|
| 10 |
+
from camel.ecg_model_wrapper import ECGLanguageModelWrapper
|
| 11 |
+
|
| 12 |
+
class ECGGemmaPrefix(ECGLanguageModelWrapper):
|
| 13 |
+
"""
|
| 14 |
+
Frozen: Gemma-IT (language model), 1D conv signal encoder (loaded from disk)
|
| 15 |
+
Trainable: llava_proj (Linear 64 -> Gemma hidden size)
|
| 16 |
+
Optionally trainable: conv encoder when explicitly requested
|
| 17 |
+
|
| 18 |
+
This wrapper turns per-second ECG windows into single "pseudo-token" rows
|
| 19 |
+
that are interleaved into user turns at embedding time.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
gemma: AutoModelForCausalLM,
|
| 25 |
+
enc: nn.Module,
|
| 26 |
+
hidden_size: int,
|
| 27 |
+
num_ecg_special_tokens: int,
|
| 28 |
+
dtype: Optional[torch.dtype] = torch.bfloat16,
|
| 29 |
+
enc_out_dim: int = 64, # from the specified conv stack: 4 channels * 16 length = 64
|
| 30 |
+
freeze_encoder: bool = True,
|
| 31 |
+
inference: bool = False,
|
| 32 |
+
projector_name: str = "linear",
|
| 33 |
+
):
|
| 34 |
+
super().__init__(
|
| 35 |
+
language_model=gemma,
|
| 36 |
+
conv_encoder=enc,
|
| 37 |
+
hidden_size=hidden_size,
|
| 38 |
+
num_ecg_special_tokens=num_ecg_special_tokens,
|
| 39 |
+
dtype=dtype,
|
| 40 |
+
enc_out_dim=enc_out_dim,
|
| 41 |
+
freeze_encoder=freeze_encoder,
|
| 42 |
+
inference=inference,
|
| 43 |
+
projector_name=projector_name,
|
| 44 |
+
)
|
| 45 |
+
# Use language_model consistently; no Gemma alias is set.
|
| 46 |
+
|
| 47 |
+
def forward_language_model(
|
| 48 |
+
self,
|
| 49 |
+
inputs_embeds: Tensor,
|
| 50 |
+
attention_mask: Tensor,
|
| 51 |
+
labels: Optional[Tensor],
|
| 52 |
+
output_hidden_states = False,
|
| 53 |
+
):
|
| 54 |
+
# Ensure inputs live on the device of the LM input embeddings when sharded
|
| 55 |
+
embedder_fn = getattr(self.language_model, "get_input_embeddings", None)
|
| 56 |
+
if callable(embedder_fn):
|
| 57 |
+
try:
|
| 58 |
+
embed_module = embedder_fn()
|
| 59 |
+
except Exception as exc:
|
| 60 |
+
raise RuntimeError("Failed to obtain language-model input embeddings for device inference") from exc
|
| 61 |
+
if not hasattr(embed_module, "weight"):
|
| 62 |
+
raise RuntimeError("Input embedding module lacks a weight parameter; cannot infer device")
|
| 63 |
+
dev0 = embed_module.weight.device
|
| 64 |
+
else:
|
| 65 |
+
params_iter = self.language_model.parameters()
|
| 66 |
+
try:
|
| 67 |
+
first_param = next(params_iter)
|
| 68 |
+
except StopIteration as exc:
|
| 69 |
+
raise RuntimeError("Language model exposes no parameters to infer device placement") from exc
|
| 70 |
+
dev0 = first_param.device
|
| 71 |
+
|
| 72 |
+
if inputs_embeds.device != dev0:
|
| 73 |
+
inputs_embeds = inputs_embeds.to(dev0)
|
| 74 |
+
if attention_mask.device != dev0:
|
| 75 |
+
attention_mask = attention_mask.to(dev0)
|
| 76 |
+
if labels is not None and labels.device != dev0:
|
| 77 |
+
labels = labels.to(dev0)
|
| 78 |
+
|
| 79 |
+
return self.language_model(
|
| 80 |
+
inputs_embeds=inputs_embeds,
|
| 81 |
+
attention_mask=attention_mask,
|
| 82 |
+
labels=labels,
|
| 83 |
+
use_cache=False,
|
| 84 |
+
output_hidden_states=output_hidden_states,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# FSDP2 helpers removed (not used)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
__all__ = ["ECGGemmaPrefix"]
|
camel_inference/src/camel/ecg_model_wrapper.py
ADDED
|
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Abstract ECG-language model adapter with shared conv→projection logic.
|
| 3 |
+
|
| 4 |
+
Subclasses can override only the pieces that differ (e.g., prompt format or
|
| 5 |
+
language-model forwarding) while inheriting the common ECG prefix handling.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from typing import Dict, Iterable, List, Mapping, Optional, Tuple
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
from torch import Tensor
|
| 14 |
+
|
| 15 |
+
from camel.assertions import assert_wrapper_embed_length, assert_rest_length_nonnegative, assert_tensor_dtype
|
| 16 |
+
from camel.projectors import build_projector
|
| 17 |
+
|
| 18 |
+
class ECGNonFiniteInputError(RuntimeError):
|
| 19 |
+
"""Raised when the conv encoder input contains NaN or Inf values."""
|
| 20 |
+
|
| 21 |
+
def __init__(self, sample_idx: int, lead: Optional[str] = None) -> None:
|
| 22 |
+
self.sample_idx = int(sample_idx)
|
| 23 |
+
self.lead = lead
|
| 24 |
+
lead_part = f", lead={lead}" if lead is not None else ""
|
| 25 |
+
super().__init__(f"Non-finite waveform detected (sample_idx={self.sample_idx}{lead_part})")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class ECGLanguageModelWrapper(nn.Module):
|
| 29 |
+
"""
|
| 30 |
+
Turns per-lead ECG waveforms into prefix embeddings consumable by a language model.
|
| 31 |
+
Stores the ECG encoder, trainable adapter, and the target language model.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
*,
|
| 37 |
+
language_model: nn.Module,
|
| 38 |
+
conv_encoder: nn.Module,
|
| 39 |
+
hidden_size: int,
|
| 40 |
+
num_ecg_special_tokens: int,
|
| 41 |
+
dtype: Optional[torch.dtype] = torch.bfloat16,
|
| 42 |
+
enc_out_dim: int = 64,
|
| 43 |
+
freeze_encoder: bool = True,
|
| 44 |
+
inference: bool = False,
|
| 45 |
+
projector_name: str = "linear",
|
| 46 |
+
) -> None:
|
| 47 |
+
super().__init__()
|
| 48 |
+
|
| 49 |
+
if int(num_ecg_special_tokens) <= 0:
|
| 50 |
+
raise ValueError("num_ecg_special_tokens must be positive")
|
| 51 |
+
|
| 52 |
+
# Keep LM in train mode (actual freezing handled by caller).
|
| 53 |
+
self.language_model = language_model.train()
|
| 54 |
+
|
| 55 |
+
# Conv encoder may be frozen or trainable depending on configuration.
|
| 56 |
+
self.enc = conv_encoder
|
| 57 |
+
if freeze_encoder:
|
| 58 |
+
self.enc = self.enc.eval()
|
| 59 |
+
for p in self.enc.parameters():
|
| 60 |
+
p.requires_grad = False
|
| 61 |
+
else:
|
| 62 |
+
self.enc = self.enc.train()
|
| 63 |
+
for p in self.enc.parameters():
|
| 64 |
+
p.requires_grad = True
|
| 65 |
+
|
| 66 |
+
self.hidden_size = int(hidden_size)
|
| 67 |
+
self.dtype = dtype or torch.bfloat16
|
| 68 |
+
self.num_ecg_special_tokens = int(num_ecg_special_tokens)
|
| 69 |
+
self.inference = bool(inference)
|
| 70 |
+
self.projector_name = str(projector_name or "linear")
|
| 71 |
+
|
| 72 |
+
# Trainable adapter: conv output → LM hidden size (kept in fp32 for stability).
|
| 73 |
+
self.llava_proj = build_projector(self.projector_name, int(enc_out_dim), self.hidden_size)
|
| 74 |
+
if self.inference:
|
| 75 |
+
self.llava_proj.to(dtype=torch.float32)
|
| 76 |
+
|
| 77 |
+
self.ecg_special_embed = nn.Embedding(self.num_ecg_special_tokens, self.hidden_size)
|
| 78 |
+
self.ecg_special_embed.to(dtype=self.dtype)
|
| 79 |
+
if self.inference:
|
| 80 |
+
self.enc.to(dtype=torch.float32)
|
| 81 |
+
self._grad_ckpt_enabled = self._detect_grad_ckpt_state()
|
| 82 |
+
|
| 83 |
+
def _projector_param_dtype(self) -> torch.dtype:
|
| 84 |
+
"""Return dtype of the first projector parameter (defaults to fp32)."""
|
| 85 |
+
first_param = next(self.llava_proj.parameters(), None)
|
| 86 |
+
if first_param is None:
|
| 87 |
+
return torch.float32
|
| 88 |
+
return first_param.dtype
|
| 89 |
+
|
| 90 |
+
# ---- Gradient checkpointing --------------------------------------------------------------
|
| 91 |
+
|
| 92 |
+
def _detect_grad_ckpt_state(self) -> bool:
|
| 93 |
+
if hasattr(self.language_model, "is_gradient_checkpointing"):
|
| 94 |
+
try:
|
| 95 |
+
return bool(self.language_model.is_gradient_checkpointing)
|
| 96 |
+
except Exception:
|
| 97 |
+
return False
|
| 98 |
+
if hasattr(self.language_model, "gradient_checkpointing"):
|
| 99 |
+
try:
|
| 100 |
+
return bool(self.language_model.gradient_checkpointing)
|
| 101 |
+
except Exception:
|
| 102 |
+
return False
|
| 103 |
+
return False
|
| 104 |
+
|
| 105 |
+
def set_gradient_checkpointing(self, enabled: bool) -> bool:
|
| 106 |
+
enabled = bool(enabled)
|
| 107 |
+
current = self._detect_grad_ckpt_state()
|
| 108 |
+
if current == enabled:
|
| 109 |
+
self._grad_ckpt_enabled = enabled
|
| 110 |
+
return False
|
| 111 |
+
if enabled:
|
| 112 |
+
if hasattr(self.language_model, "gradient_checkpointing_enable"):
|
| 113 |
+
self.language_model.gradient_checkpointing_enable()
|
| 114 |
+
if hasattr(getattr(self.language_model, "config", None), "use_cache"):
|
| 115 |
+
self.language_model.config.use_cache = False
|
| 116 |
+
else:
|
| 117 |
+
if hasattr(self.language_model, "gradient_checkpointing_disable"):
|
| 118 |
+
self.language_model.gradient_checkpointing_disable()
|
| 119 |
+
elif hasattr(self.language_model, "gradient_checkpointing"):
|
| 120 |
+
try:
|
| 121 |
+
self.language_model.gradient_checkpointing = False
|
| 122 |
+
except Exception:
|
| 123 |
+
pass
|
| 124 |
+
if hasattr(getattr(self.language_model, "config", None), "use_cache"):
|
| 125 |
+
self.language_model.config.use_cache = True
|
| 126 |
+
self._grad_ckpt_enabled = enabled
|
| 127 |
+
return True
|
| 128 |
+
|
| 129 |
+
def enable_gradient_checkpointing(self) -> None:
|
| 130 |
+
self.set_gradient_checkpointing(True)
|
| 131 |
+
|
| 132 |
+
def disable_gradient_checkpointing(self) -> None:
|
| 133 |
+
self.set_gradient_checkpointing(False)
|
| 134 |
+
|
| 135 |
+
def is_gradient_checkpointing_enabled(self) -> bool:
|
| 136 |
+
return bool(self._grad_ckpt_enabled)
|
| 137 |
+
|
| 138 |
+
# ---- Token helpers -----------------------------------------------------------------------
|
| 139 |
+
|
| 140 |
+
def tokens_to_embeds(self, input_embedder: nn.Embedding, ids: List[int], device: torch.device) -> Tensor:
|
| 141 |
+
ids_t = torch.tensor(ids, dtype=torch.long, device=device)
|
| 142 |
+
embeddings = input_embedder(ids_t)
|
| 143 |
+
embeddings = embeddings.to(dtype=self.dtype)
|
| 144 |
+
# Defensive: enforce 1:1 mapping between ids and embeddings
|
| 145 |
+
assert_wrapper_embed_length(embeddings=embeddings, ids=ids, context="tokens_to_embeds")
|
| 146 |
+
return embeddings
|
| 147 |
+
|
| 148 |
+
def ecg_special_tokens_to_embeds(self, indices: torch.Tensor | List[int], device: torch.device) -> Tensor:
|
| 149 |
+
if torch.is_tensor(indices):
|
| 150 |
+
idx = indices.to(device=device, dtype=torch.long)
|
| 151 |
+
else:
|
| 152 |
+
idx = torch.tensor(indices, dtype=torch.long, device=device)
|
| 153 |
+
embeds = self.ecg_special_embed(idx)
|
| 154 |
+
return embeds.to(dtype=self.dtype)
|
| 155 |
+
|
| 156 |
+
# ---- ECG prefix encoding -----------------------------------------------------------------
|
| 157 |
+
|
| 158 |
+
def ecg_prefix(
|
| 159 |
+
self,
|
| 160 |
+
waveform_segments: Dict[str, Tensor],
|
| 161 |
+
device: torch.device,
|
| 162 |
+
lead_order: Optional[List[str]] = None,
|
| 163 |
+
) -> Tensor:
|
| 164 |
+
"""Encode a single sample's ECG prefix in a deterministic lead order.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
waveform_segments: Mapping lead -> [T,256] windows.
|
| 168 |
+
device: Target device for encoder input.
|
| 169 |
+
lead_order: Optional explicit order of leads to iterate. If provided,
|
| 170 |
+
segments are concatenated in this exact order; otherwise relies on
|
| 171 |
+
insertion order of the mapping.
|
| 172 |
+
"""
|
| 173 |
+
seqs: List[Tensor] = []
|
| 174 |
+
if lead_order is None:
|
| 175 |
+
items = list(waveform_segments.items())
|
| 176 |
+
else:
|
| 177 |
+
items = [(ld, waveform_segments[ld]) for ld in lead_order if ld in waveform_segments]
|
| 178 |
+
# Validate presence when an explicit order is provided
|
| 179 |
+
missing = [ld for ld in lead_order if ld not in waveform_segments]
|
| 180 |
+
if missing:
|
| 181 |
+
raise ValueError(f"Missing leads in waveform_segments for requested order: {missing}")
|
| 182 |
+
for lead, seg in items:
|
| 183 |
+
seg = torch.as_tensor(seg)
|
| 184 |
+
if seg.ndim != 2 or seg.shape[1] != 256:
|
| 185 |
+
raise ValueError(f"Waveform for lead {lead} must be [T,256], got {seg.shape}")
|
| 186 |
+
seqs.append(seg)
|
| 187 |
+
x = torch.cat(seqs, dim=0)
|
| 188 |
+
x = x.to(device=device, dtype=torch.float32)
|
| 189 |
+
x = x.unsqueeze(1) # [P, 1, 256]
|
| 190 |
+
|
| 191 |
+
enc_trainable = any(p.requires_grad for p in self.enc.parameters())
|
| 192 |
+
ctx = torch.enable_grad() if enc_trainable else torch.no_grad()
|
| 193 |
+
with ctx:
|
| 194 |
+
z = self.enc(x) # [P, C, L]
|
| 195 |
+
if self.inference:
|
| 196 |
+
z = z.to(dtype=torch.float32)
|
| 197 |
+
assert_tensor_dtype(z, expected=torch.float32, context="conv encoder output (single)")
|
| 198 |
+
self.ensure_finite(z, "conv encoder output")
|
| 199 |
+
z = z.flatten(1) # [P, 64] for conv stack
|
| 200 |
+
|
| 201 |
+
proj_dtype = self._projector_param_dtype()
|
| 202 |
+
if z.dtype != proj_dtype:
|
| 203 |
+
z = z.to(dtype=proj_dtype)
|
| 204 |
+
y = self.llava_proj(z)
|
| 205 |
+
if self.inference:
|
| 206 |
+
y = y.to(dtype=torch.float32)
|
| 207 |
+
assert_tensor_dtype(y, expected=torch.float32, context="llava_proj output (single)")
|
| 208 |
+
return y.to(dtype=self.dtype)
|
| 209 |
+
|
| 210 |
+
def ecg_prefix_batch(
|
| 211 |
+
self,
|
| 212 |
+
waveform_segments_batch: List[Dict[str, Tensor]],
|
| 213 |
+
device: torch.device,
|
| 214 |
+
lead_orders: Optional[List[List[str]]] = None,
|
| 215 |
+
) -> Tuple[Tensor, List[int]]:
|
| 216 |
+
"""Encode a batch of ECG prefixes with explicit per-sample lead orders.
|
| 217 |
+
|
| 218 |
+
Args:
|
| 219 |
+
waveform_segments_batch: List of dicts; each maps lead -> [T,256].
|
| 220 |
+
device: Target device for encoder input.
|
| 221 |
+
lead_orders: Optional list of lead-order lists, one per sample. If
|
| 222 |
+
provided, each sample's segments are concatenated in that exact
|
| 223 |
+
order; otherwise relies on insertion order of each mapping.
|
| 224 |
+
"""
|
| 225 |
+
all_seqs: List[Tensor] = []
|
| 226 |
+
prefix_lengths: List[int] = []
|
| 227 |
+
|
| 228 |
+
for i, wv_dict in enumerate(waveform_segments_batch):
|
| 229 |
+
seqs: List[Tensor] = []
|
| 230 |
+
leads_for_sample: List[str] = []
|
| 231 |
+
if lead_orders is not None and i < len(lead_orders) and lead_orders[i] is not None:
|
| 232 |
+
order = lead_orders[i]
|
| 233 |
+
missing = [ld for ld in order if ld not in wv_dict]
|
| 234 |
+
if missing:
|
| 235 |
+
raise ValueError(f"Missing leads for sample {i}: {missing}")
|
| 236 |
+
items = [(ld, wv_dict[ld]) for ld in order]
|
| 237 |
+
else:
|
| 238 |
+
items = list(wv_dict.items())
|
| 239 |
+
for lead, seg in items:
|
| 240 |
+
seg = torch.as_tensor(seg)
|
| 241 |
+
if seg.ndim != 2 or seg.shape[1] != 256:
|
| 242 |
+
raise ValueError(f"Waveform for lead {lead} must be [T,256], got {seg.shape}")
|
| 243 |
+
seqs.append(seg)
|
| 244 |
+
leads_for_sample.append(str(lead))
|
| 245 |
+
sample_segments = torch.cat(seqs, dim=0)
|
| 246 |
+
if not torch.isfinite(sample_segments).all().item():
|
| 247 |
+
bad_lead = None
|
| 248 |
+
for lead_name, seg_tensor in zip(leads_for_sample, seqs):
|
| 249 |
+
if not torch.isfinite(seg_tensor).all().item():
|
| 250 |
+
bad_lead = lead_name
|
| 251 |
+
break
|
| 252 |
+
raise ECGNonFiniteInputError(sample_idx=i, lead=bad_lead)
|
| 253 |
+
all_seqs.append(sample_segments)
|
| 254 |
+
prefix_lengths.append(sample_segments.size(0))
|
| 255 |
+
|
| 256 |
+
x = torch.cat(all_seqs, dim=0)
|
| 257 |
+
x = x.to(device=device, dtype=torch.float32).unsqueeze(1)
|
| 258 |
+
|
| 259 |
+
enc_trainable = any(p.requires_grad for p in self.enc.parameters())
|
| 260 |
+
ctx = torch.enable_grad() if enc_trainable else torch.no_grad()
|
| 261 |
+
with ctx:
|
| 262 |
+
z = self.enc(x)
|
| 263 |
+
if self.inference:
|
| 264 |
+
z = z.to(dtype=torch.float32)
|
| 265 |
+
assert_tensor_dtype(z, expected=torch.float32, context="conv encoder output (batch)")
|
| 266 |
+
self.ensure_finite(z, "conv encoder output (batch)")
|
| 267 |
+
z = z.flatten(1)
|
| 268 |
+
|
| 269 |
+
proj_dtype = self._projector_param_dtype()
|
| 270 |
+
if z.dtype != proj_dtype:
|
| 271 |
+
z = z.to(dtype=proj_dtype)
|
| 272 |
+
y = self.llava_proj(z)
|
| 273 |
+
if self.inference:
|
| 274 |
+
y = y.to(dtype=torch.float32)
|
| 275 |
+
assert_tensor_dtype(y, expected=torch.float32, context="llava_proj output (batch)")
|
| 276 |
+
return y.to(dtype=self.dtype), prefix_lengths
|
| 277 |
+
|
| 278 |
+
# ---- Language-model forward --------------------------------------------------------------
|
| 279 |
+
|
| 280 |
+
def forward_language_model(
|
| 281 |
+
self,
|
| 282 |
+
inputs_embeds: Tensor,
|
| 283 |
+
attention_mask: Tensor,
|
| 284 |
+
labels: Optional[Tensor],
|
| 285 |
+
):
|
| 286 |
+
"""
|
| 287 |
+
Default HF-style forward call; subclasses can override for custom behavior.
|
| 288 |
+
"""
|
| 289 |
+
embedder_fn = getattr(self.language_model, "get_input_embeddings", None)
|
| 290 |
+
if callable(embedder_fn):
|
| 291 |
+
embed_module = embedder_fn()
|
| 292 |
+
if not hasattr(embed_module, "weight"):
|
| 293 |
+
raise RuntimeError("Language model embeddings missing weight parameter; cannot infer device.")
|
| 294 |
+
target_device = embed_module.weight.device
|
| 295 |
+
else:
|
| 296 |
+
try:
|
| 297 |
+
first_param = next(self.language_model.parameters())
|
| 298 |
+
except StopIteration as exc:
|
| 299 |
+
raise RuntimeError("Language model exposes no parameters to infer device placement.") from exc
|
| 300 |
+
target_device = first_param.device
|
| 301 |
+
|
| 302 |
+
if inputs_embeds.device != target_device:
|
| 303 |
+
inputs_embeds = inputs_embeds.to(target_device)
|
| 304 |
+
if attention_mask.device != target_device:
|
| 305 |
+
attention_mask = attention_mask.to(target_device)
|
| 306 |
+
if labels is not None and labels.device != target_device:
|
| 307 |
+
labels = labels.to(target_device)
|
| 308 |
+
|
| 309 |
+
return self.language_model(
|
| 310 |
+
inputs_embeds=inputs_embeds,
|
| 311 |
+
attention_mask=attention_mask,
|
| 312 |
+
labels=labels,
|
| 313 |
+
use_cache=False,
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
# ---- Label helpers ----------------------------------------------------------------------
|
| 317 |
+
|
| 318 |
+
@staticmethod
|
| 319 |
+
def build_labels_from_lengths(
|
| 320 |
+
*,
|
| 321 |
+
ids_rest: List[int],
|
| 322 |
+
model_spans_in_rest: List[Tuple[int, int]],
|
| 323 |
+
total_len: int,
|
| 324 |
+
offset_rest: int,
|
| 325 |
+
) -> Tensor:
|
| 326 |
+
"""Build label tensor using explicit offsets to satisfy any packing schema.
|
| 327 |
+
|
| 328 |
+
Args:
|
| 329 |
+
ids_rest: Token ids for the supervised span of the prompt.
|
| 330 |
+
model_spans_in_rest: List of (start, end) spans (relative to ids_rest) to supervise.
|
| 331 |
+
total_len: Total sequence length of the assembled prompt.
|
| 332 |
+
offset_rest: Absolute position in the sequence where `ids_rest` begins.
|
| 333 |
+
|
| 334 |
+
Returns:
|
| 335 |
+
Tensor of shape (total_len,) with supervised ids placed according to spans; all other
|
| 336 |
+
positions are filled with -100.
|
| 337 |
+
|
| 338 |
+
Raises:
|
| 339 |
+
ValueError: if `offset_rest` is invalid or spans fall outside the provided bounds.
|
| 340 |
+
"""
|
| 341 |
+
labels = torch.full((total_len,), fill_value=-100, dtype=torch.long)
|
| 342 |
+
|
| 343 |
+
offset_rest = int(offset_rest)
|
| 344 |
+
if offset_rest < 0 or offset_rest > total_len:
|
| 345 |
+
raise ValueError(f"Invalid rest offset {offset_rest} for sequence length {total_len}.")
|
| 346 |
+
rest_len = len(ids_rest)
|
| 347 |
+
if offset_rest + rest_len > total_len:
|
| 348 |
+
raise ValueError(
|
| 349 |
+
f"Rest tokens (len={rest_len}) exceed total_len {total_len} with offset {offset_rest}."
|
| 350 |
+
)
|
| 351 |
+
# Defensive: no negative rest length (should not happen, but keep invariant tight)
|
| 352 |
+
assert_rest_length_nonnegative(rest_length=rest_len)
|
| 353 |
+
|
| 354 |
+
for (s, e) in model_spans_in_rest:
|
| 355 |
+
if not (0 <= s <= e <= rest_len):
|
| 356 |
+
raise ValueError(
|
| 357 |
+
f"Model span {(s, e)} is out of bounds for ids_rest length {rest_len}."
|
| 358 |
+
)
|
| 359 |
+
s_abs = offset_rest + s
|
| 360 |
+
e_abs = offset_rest + e
|
| 361 |
+
if not (0 <= s_abs <= e_abs <= total_len):
|
| 362 |
+
raise ValueError(
|
| 363 |
+
f"Model span {(s, e)} with rest offset {offset_rest} is out of bounds for length {total_len}."
|
| 364 |
+
)
|
| 365 |
+
labels[s_abs:e_abs] = torch.tensor(ids_rest[s:e], dtype=torch.long)
|
| 366 |
+
return labels
|
| 367 |
+
|
| 368 |
+
# ---- Trainable summaries -----------------------------------------------------------------
|
| 369 |
+
|
| 370 |
+
def summarize_trainables(self) -> Mapping[str, int]:
|
| 371 |
+
def _count(params: Iterable[nn.Parameter]) -> int:
|
| 372 |
+
return sum(int(p.numel()) for p in params if p.requires_grad)
|
| 373 |
+
|
| 374 |
+
llava_train = _count(self.llava_proj.parameters())
|
| 375 |
+
ecg_special_train = _count(self.ecg_special_embed.parameters())
|
| 376 |
+
enc_train = _count(self.enc.parameters())
|
| 377 |
+
lm_train = _count(self.language_model.parameters())
|
| 378 |
+
|
| 379 |
+
total = llava_train + ecg_special_train + enc_train + lm_train
|
| 380 |
+
return {
|
| 381 |
+
"llava_proj_trainable": llava_train,
|
| 382 |
+
"ecg_special_trainable": ecg_special_train,
|
| 383 |
+
"enc_trainable": enc_train,
|
| 384 |
+
"lm_trainable": lm_train,
|
| 385 |
+
"total_trainable": total,
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
# ---- Utility ---------------------------------------------------------------------------
|
| 389 |
+
|
| 390 |
+
@staticmethod
|
| 391 |
+
def ensure_finite(tensor: Tensor, context: str) -> None:
|
| 392 |
+
if not torch.isfinite(tensor).all():
|
| 393 |
+
xin_nan = torch.isnan(tensor).any().item()
|
| 394 |
+
raise RuntimeError(f"Encountered non-finite values in {context} (input_has_nan={bool(xin_nan)}).")
|
camel_inference/src/camel/ecg_text_packing.py
ADDED
|
@@ -0,0 +1,499 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import List, Dict, Tuple, Optional, Any
|
| 5 |
+
from transformers import PreTrainedTokenizer
|
| 6 |
+
|
| 7 |
+
from camel.assertions import (
|
| 8 |
+
assert_ecg_catalog_valid,
|
| 9 |
+
assert_normalized_role_canonical,
|
| 10 |
+
assert_turn_parts_structure_valid,
|
| 11 |
+
assert_turn_content_ends_with_eot,
|
| 12 |
+
assert_leads_canonical_and_ordered,
|
| 13 |
+
assert_waveform_shapes_valid,
|
| 14 |
+
)
|
| 15 |
+
from camel.prompt_renderers import turn_wrappers
|
| 16 |
+
|
| 17 |
+
# NOTE: Local BOS/EOS assertions are implemented at the bottom of this file.
|
| 18 |
+
@dataclass(frozen=True)
|
| 19 |
+
class PromptTokens:
|
| 20 |
+
start_of_turn: str
|
| 21 |
+
end_of_turn: str
|
| 22 |
+
user_role: str
|
| 23 |
+
model_role: str
|
| 24 |
+
require_bos: bool = True
|
| 25 |
+
require_eos: bool = True
|
| 26 |
+
allow_multiple_eos: bool = False
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass(frozen=True)
|
| 30 |
+
class ConversationRules:
|
| 31 |
+
format_id: str
|
| 32 |
+
user_role_aliases: Tuple[str, ...] = ("human", "user")
|
| 33 |
+
model_role_aliases: Tuple[str, ...] = ("gpt", "assistant")
|
| 34 |
+
strip_image_from_roles: Tuple[str, ...] = ("human",)
|
| 35 |
+
merge_system_with_first_user: bool = True
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass(frozen=True)
|
| 39 |
+
class ECGTokenSchema:
|
| 40 |
+
global_start: str
|
| 41 |
+
global_end: str
|
| 42 |
+
lead_start_template: str
|
| 43 |
+
lead_end_template: str
|
| 44 |
+
canonical_leads: Tuple[str, ...]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@dataclass(frozen=True)
|
| 48 |
+
class PackingSchema:
|
| 49 |
+
prompt: PromptTokens
|
| 50 |
+
conversation: ConversationRules
|
| 51 |
+
ecg: ECGTokenSchema
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@dataclass(frozen=True)
|
| 55 |
+
class ECGSpecialTokenCatalog:
|
| 56 |
+
tokens: Tuple[str, ...]
|
| 57 |
+
lead_to_indices: Dict[str, Dict[str, int]]
|
| 58 |
+
lead_to_tokens: Dict[str, Dict[str, str]]
|
| 59 |
+
token_to_index: Dict[str, int]
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _render_lead_template(template: str, lead: str) -> str:
|
| 63 |
+
return template.format(
|
| 64 |
+
lead=lead,
|
| 65 |
+
lead_lower=lead.lower(),
|
| 66 |
+
lead_upper=lead.upper(),
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
_ECG_TOKEN_CACHE: Dict[PackingSchema, ECGSpecialTokenCatalog] = {}
|
| 70 |
+
|
| 71 |
+
def get_ecg_special_token_catalog(schema: PackingSchema) -> ECGSpecialTokenCatalog:
|
| 72 |
+
cached = _ECG_TOKEN_CACHE.get(schema)
|
| 73 |
+
if cached is not None:
|
| 74 |
+
return cached
|
| 75 |
+
|
| 76 |
+
tokens: List[str] = []
|
| 77 |
+
lead_to_indices: Dict[str, Dict[str, int]] = {}
|
| 78 |
+
lead_to_tokens: Dict[str, Dict[str, str]] = {}
|
| 79 |
+
|
| 80 |
+
tokens.append(schema.ecg.global_start)
|
| 81 |
+
tokens.append(schema.ecg.global_end)
|
| 82 |
+
|
| 83 |
+
for lead in schema.ecg.canonical_leads:
|
| 84 |
+
start_token = _render_lead_template(schema.ecg.lead_start_template, lead)
|
| 85 |
+
end_token = _render_lead_template(schema.ecg.lead_end_template, lead)
|
| 86 |
+
start_idx = len(tokens)
|
| 87 |
+
tokens.append(start_token)
|
| 88 |
+
end_idx = len(tokens)
|
| 89 |
+
tokens.append(end_token)
|
| 90 |
+
lead_to_indices[lead] = {"start": start_idx, "end": end_idx}
|
| 91 |
+
lead_to_tokens[lead] = {"start": start_token, "end": end_token}
|
| 92 |
+
|
| 93 |
+
catalog = ECGSpecialTokenCatalog(
|
| 94 |
+
tokens=tuple(tokens),
|
| 95 |
+
lead_to_indices=lead_to_indices,
|
| 96 |
+
lead_to_tokens=lead_to_tokens,
|
| 97 |
+
token_to_index={tok: idx for idx, tok in enumerate(tokens)},
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# Validate catalog consistency
|
| 101 |
+
assert_ecg_catalog_valid(catalog, schema)
|
| 102 |
+
|
| 103 |
+
_ECG_TOKEN_CACHE[schema] = catalog
|
| 104 |
+
return catalog
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# ---- Conversation normalization + validation ----------------------------------------------------
|
| 108 |
+
|
| 109 |
+
def canonical_leads(schema: PackingSchema) -> List[str]:
|
| 110 |
+
return list(schema.ecg.canonical_leads)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def _strip_image_tag(text: str) -> str:
|
| 114 |
+
"""
|
| 115 |
+
Remove any <image> placeholder without gluing surrounding words.
|
| 116 |
+
Replace the token and surrounding whitespace with a single space and
|
| 117 |
+
normalize repeated spaces.
|
| 118 |
+
"""
|
| 119 |
+
cleaned = re.sub(r"\s*<image>\s*", " ", text)
|
| 120 |
+
cleaned = re.sub(r"[ \t]{2,}", " ", cleaned)
|
| 121 |
+
return cleaned.strip()
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def _normalize_role(role_value: Any, schema: PackingSchema) -> str:
|
| 125 |
+
role = (role_value or "").strip()
|
| 126 |
+
if not role:
|
| 127 |
+
raise ValueError("Conversation turn is missing a role identifier.")
|
| 128 |
+
if role.lower() == "system":
|
| 129 |
+
return "system"
|
| 130 |
+
if role.lower() == "developer":
|
| 131 |
+
return "developer"
|
| 132 |
+
needles = [schema.prompt.user_role] + list(schema.conversation.user_role_aliases)
|
| 133 |
+
if role.lower() in (needle.lower() for needle in needles):
|
| 134 |
+
out = schema.prompt.user_role
|
| 135 |
+
assert_normalized_role_canonical(out, schema)
|
| 136 |
+
return out
|
| 137 |
+
needles = [schema.prompt.model_role] + list(schema.conversation.model_role_aliases)
|
| 138 |
+
if role.lower() in (needle.lower() for needle in needles):
|
| 139 |
+
out = schema.prompt.model_role
|
| 140 |
+
assert_normalized_role_canonical(out, schema)
|
| 141 |
+
return out
|
| 142 |
+
raise ValueError(f"Unknown conversation role '{role_value}' for schema '{schema.conversation.format_id}'.")
|
| 143 |
+
|
| 144 |
+
def _normalize_conversation(
|
| 145 |
+
convo: List[Dict[str, Any]],
|
| 146 |
+
schema: PackingSchema,
|
| 147 |
+
system_text: Optional[str],
|
| 148 |
+
developer_text: Optional[str],
|
| 149 |
+
) -> List[Dict[str, Any]]:
|
| 150 |
+
if not convo:
|
| 151 |
+
raise ValueError("Conversation must contain at least one turn.")
|
| 152 |
+
has_system_turn = False
|
| 153 |
+
has_developer_turn = False
|
| 154 |
+
for turn in convo:
|
| 155 |
+
if not isinstance(turn, dict):
|
| 156 |
+
continue
|
| 157 |
+
role_val = turn.get("from")
|
| 158 |
+
if role_val is None and "role" in turn:
|
| 159 |
+
role_val = turn.get("role")
|
| 160 |
+
role_lower = str(role_val or "").strip().lower()
|
| 161 |
+
if role_lower == "system":
|
| 162 |
+
has_system_turn = True
|
| 163 |
+
elif role_lower == "developer":
|
| 164 |
+
has_developer_turn = True
|
| 165 |
+
normalized: List[Dict[str, Any]] = []
|
| 166 |
+
if system_text and system_text.strip() and not has_system_turn:
|
| 167 |
+
sys = system_text.strip()
|
| 168 |
+
normalized.append({
|
| 169 |
+
"role": "system",
|
| 170 |
+
"content": [{"type": "text", "text": sys}],
|
| 171 |
+
})
|
| 172 |
+
if developer_text and developer_text.strip() and not has_developer_turn:
|
| 173 |
+
dev = developer_text.strip()
|
| 174 |
+
normalized.append({
|
| 175 |
+
"role": "developer",
|
| 176 |
+
"content": [{"type": "text", "text": dev}],
|
| 177 |
+
})
|
| 178 |
+
for idx, turn in enumerate(convo):
|
| 179 |
+
role_val = turn.get("from")
|
| 180 |
+
if role_val is None:
|
| 181 |
+
role_val = turn.get("role")
|
| 182 |
+
role = _normalize_role(role_val, schema)
|
| 183 |
+
content = turn.get("content")
|
| 184 |
+
if not isinstance(content, list):
|
| 185 |
+
raise ValueError(f"Turn {idx} content must be a list of items.")
|
| 186 |
+
normalized.append({"role": role, "content": content})
|
| 187 |
+
|
| 188 |
+
if schema.conversation.merge_system_with_first_user:
|
| 189 |
+
system_items: List[Dict[str, Any]] = []
|
| 190 |
+
out: List[Dict[str, Any]] = []
|
| 191 |
+
first_user_idx: Optional[int] = None
|
| 192 |
+
for turn in normalized:
|
| 193 |
+
if turn["role"] == "system":
|
| 194 |
+
system_items.extend(list(turn["content"]))
|
| 195 |
+
continue
|
| 196 |
+
if turn["role"] == "developer":
|
| 197 |
+
raise ValueError("Developer turn present but merge_system_with_first_user is true.")
|
| 198 |
+
if first_user_idx is None and turn["role"] == schema.prompt.user_role:
|
| 199 |
+
first_user_idx = len(out)
|
| 200 |
+
out.append(turn)
|
| 201 |
+
|
| 202 |
+
if system_items:
|
| 203 |
+
if first_user_idx is None:
|
| 204 |
+
raise ValueError("System turn present but no user turn to merge into.")
|
| 205 |
+
user_turn = out[first_user_idx]
|
| 206 |
+
user_turn["content"] = list(system_items) + list(user_turn["content"])
|
| 207 |
+
|
| 208 |
+
if not out:
|
| 209 |
+
raise ValueError("Conversation must contain at least one non-system turn.")
|
| 210 |
+
if out[0]["role"] != schema.prompt.user_role:
|
| 211 |
+
raise ValueError("Conversation must start with a user/human turn.")
|
| 212 |
+
return out
|
| 213 |
+
|
| 214 |
+
out = list(normalized)
|
| 215 |
+
if not out:
|
| 216 |
+
raise ValueError("Conversation must contain at least one turn.")
|
| 217 |
+
system_turns = [t for t in out if t["role"] == "system"]
|
| 218 |
+
developer_turns = [t for t in out if t["role"] == "developer"]
|
| 219 |
+
non_preamble_turns = [t for t in out if t["role"] not in ("system", "developer")]
|
| 220 |
+
if not non_preamble_turns:
|
| 221 |
+
raise ValueError("Conversation must contain at least one non-system/developer turn.")
|
| 222 |
+
if non_preamble_turns[0]["role"] != schema.prompt.user_role:
|
| 223 |
+
raise ValueError("Conversation must start with a user/human turn.")
|
| 224 |
+
return system_turns + developer_turns + non_preamble_turns
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def _maybe_strip_content(text: str, canonical_role: str, schema: PackingSchema) -> str:
|
| 228 |
+
target_roles: set = set()
|
| 229 |
+
for r in schema.conversation.strip_image_from_roles:
|
| 230 |
+
try:
|
| 231 |
+
target_roles.add(_normalize_role(r, schema).lower())
|
| 232 |
+
except ValueError:
|
| 233 |
+
target_roles.add(str(r).lower())
|
| 234 |
+
if canonical_role.lower() in target_roles:
|
| 235 |
+
return _strip_image_tag(text)
|
| 236 |
+
return text
|
| 237 |
+
|
| 238 |
+
# ---- Tokenize & mark assistant spans (exclude control tokens from loss) --------------------------
|
| 239 |
+
|
| 240 |
+
# ---- Build structured turn parts ---------------------------------------------------------------
|
| 241 |
+
|
| 242 |
+
def build_structured_turn_parts(
|
| 243 |
+
*,
|
| 244 |
+
content: List[Dict[str, Any]],
|
| 245 |
+
canonical_role: str,
|
| 246 |
+
schema: PackingSchema,
|
| 247 |
+
ecg_blocks: List[Dict[str, Any]],
|
| 248 |
+
sampling_rate: Optional[float],
|
| 249 |
+
turn_suffix: Optional[str] = None,
|
| 250 |
+
) -> Tuple[str, List[Dict[str, Any]]]:
|
| 251 |
+
prompt_tokens = schema.prompt
|
| 252 |
+
catalog = get_ecg_special_token_catalog(schema)
|
| 253 |
+
parts: List[Dict[str, Any]] = []
|
| 254 |
+
text_segments: List[str] = []
|
| 255 |
+
|
| 256 |
+
def _append_text(txt: str) -> None:
|
| 257 |
+
if not txt:
|
| 258 |
+
return
|
| 259 |
+
if parts and parts[-1].get("kind") == "text":
|
| 260 |
+
parts[-1]["text"] += txt
|
| 261 |
+
else:
|
| 262 |
+
parts.append({"kind": "text", "text": txt})
|
| 263 |
+
text_segments.append(txt)
|
| 264 |
+
|
| 265 |
+
for item in content:
|
| 266 |
+
if not isinstance(item, dict):
|
| 267 |
+
raise ValueError("Conversation content items must be dicts.")
|
| 268 |
+
item_type = item.get("type")
|
| 269 |
+
if item_type == "text":
|
| 270 |
+
text_val = item.get("text", "")
|
| 271 |
+
if not isinstance(text_val, str):
|
| 272 |
+
raise ValueError("Text content item must have a string 'text' field.")
|
| 273 |
+
cleaned = _maybe_strip_content(text_val, canonical_role, schema)
|
| 274 |
+
_append_text(cleaned)
|
| 275 |
+
continue
|
| 276 |
+
if item_type == "ecg":
|
| 277 |
+
waveform_segments = item.get("waveform_segments")
|
| 278 |
+
if not isinstance(waveform_segments, dict):
|
| 279 |
+
raise ValueError("ECG content item missing waveform_segments mapping.")
|
| 280 |
+
item_rate = item.get("sampling_rate")
|
| 281 |
+
if item_rate is not None and sampling_rate is not None and float(item_rate) != float(sampling_rate):
|
| 282 |
+
raise ValueError(
|
| 283 |
+
f"ECG item sampling_rate {item_rate} does not match sample sampling_rate {sampling_rate}."
|
| 284 |
+
)
|
| 285 |
+
lead_names = [str(ld) for ld in waveform_segments.keys()]
|
| 286 |
+
if not lead_names:
|
| 287 |
+
raise ValueError("ECG content item has no leads.")
|
| 288 |
+
segments_per_lead = [int(waveform_segments[ld].shape[0]) for ld in lead_names]
|
| 289 |
+
assert_leads_canonical_and_ordered(lead_names, schema.ecg.canonical_leads)
|
| 290 |
+
assert_waveform_shapes_valid(lead_names, segments_per_lead, waveform_segments)
|
| 291 |
+
block_index = len(ecg_blocks)
|
| 292 |
+
ecg_blocks.append({
|
| 293 |
+
"lead_names": lead_names,
|
| 294 |
+
"segments_per_lead": segments_per_lead,
|
| 295 |
+
"waveform_segments": OrderedDict((ld, waveform_segments[ld]) for ld in lead_names),
|
| 296 |
+
})
|
| 297 |
+
|
| 298 |
+
parts.append({
|
| 299 |
+
"kind": "special",
|
| 300 |
+
"token": schema.ecg.global_start,
|
| 301 |
+
"token_index": catalog.token_to_index[schema.ecg.global_start],
|
| 302 |
+
"block_index": block_index,
|
| 303 |
+
})
|
| 304 |
+
text_segments.append(schema.ecg.global_start)
|
| 305 |
+
|
| 306 |
+
for lead, nseg in zip(lead_names, segments_per_lead):
|
| 307 |
+
lead_tokens = catalog.lead_to_indices[lead]
|
| 308 |
+
parts.append({
|
| 309 |
+
"kind": "special",
|
| 310 |
+
"token": catalog.lead_to_tokens[lead]["start"],
|
| 311 |
+
"token_index": lead_tokens["start"],
|
| 312 |
+
"lead": lead,
|
| 313 |
+
"block_index": block_index,
|
| 314 |
+
})
|
| 315 |
+
text_segments.append(catalog.lead_to_tokens[lead]["start"])
|
| 316 |
+
for sec in range(1, int(nseg) + 1):
|
| 317 |
+
parts.append({
|
| 318 |
+
"kind": "ecg",
|
| 319 |
+
"lead": lead,
|
| 320 |
+
"sec": sec,
|
| 321 |
+
"block_index": block_index,
|
| 322 |
+
})
|
| 323 |
+
parts.append({
|
| 324 |
+
"kind": "special",
|
| 325 |
+
"token": catalog.lead_to_tokens[lead]["end"],
|
| 326 |
+
"token_index": lead_tokens["end"],
|
| 327 |
+
"lead": lead,
|
| 328 |
+
"block_index": block_index,
|
| 329 |
+
})
|
| 330 |
+
text_segments.append(catalog.lead_to_tokens[lead]["end"])
|
| 331 |
+
|
| 332 |
+
parts.append({
|
| 333 |
+
"kind": "special",
|
| 334 |
+
"token": schema.ecg.global_end,
|
| 335 |
+
"token_index": catalog.token_to_index[schema.ecg.global_end],
|
| 336 |
+
"block_index": block_index,
|
| 337 |
+
})
|
| 338 |
+
text_segments.append(schema.ecg.global_end)
|
| 339 |
+
continue
|
| 340 |
+
raise ValueError(f"Unknown content item type '{item_type}'.")
|
| 341 |
+
|
| 342 |
+
turn_content = "".join(text_segments)
|
| 343 |
+
if turn_suffix is None:
|
| 344 |
+
_, suffix = turn_wrappers(schema, canonical_role)
|
| 345 |
+
else:
|
| 346 |
+
suffix = turn_suffix
|
| 347 |
+
turn_text_block = turn_content + suffix
|
| 348 |
+
|
| 349 |
+
assert_turn_parts_structure_valid(parts, ecg_blocks, schema, catalog)
|
| 350 |
+
assert_turn_content_ends_with_eot(turn_text_block, suffix)
|
| 351 |
+
|
| 352 |
+
return turn_text_block, parts
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def build_text_only_turn_parts(
|
| 356 |
+
*,
|
| 357 |
+
content: List[Dict[str, Any]],
|
| 358 |
+
canonical_role: str,
|
| 359 |
+
schema: PackingSchema,
|
| 360 |
+
turn_suffix: Optional[str] = None,
|
| 361 |
+
) -> Tuple[str, List[Dict[str, Any]]]:
|
| 362 |
+
prompt_tokens = schema.prompt
|
| 363 |
+
parts: List[Dict[str, Any]] = []
|
| 364 |
+
text_segments: List[str] = []
|
| 365 |
+
needs_channel_header = (
|
| 366 |
+
schema.conversation.format_id == "harmony_chat_v1"
|
| 367 |
+
and canonical_role == schema.prompt.model_role
|
| 368 |
+
)
|
| 369 |
+
# Track raw content for diagnostics
|
| 370 |
+
raw_content_debug: List[Dict[str, Any]] = []
|
| 371 |
+
|
| 372 |
+
def _append_text(txt: str) -> None:
|
| 373 |
+
if not txt:
|
| 374 |
+
return
|
| 375 |
+
if parts and parts[-1].get("kind") == "text":
|
| 376 |
+
parts[-1]["text"] += txt
|
| 377 |
+
else:
|
| 378 |
+
parts.append({"kind": "text", "text": txt})
|
| 379 |
+
text_segments.append(txt)
|
| 380 |
+
|
| 381 |
+
for item in content:
|
| 382 |
+
if not isinstance(item, dict):
|
| 383 |
+
raise ValueError("Conversation content items must be dicts.")
|
| 384 |
+
item_type = item.get("type")
|
| 385 |
+
if item_type != "text":
|
| 386 |
+
raise ValueError("Model turns cannot contain ECG content items.")
|
| 387 |
+
text_val = item.get("text", "")
|
| 388 |
+
if not isinstance(text_val, str):
|
| 389 |
+
raise ValueError("Text content item must have a string 'text' field.")
|
| 390 |
+
if needs_channel_header and text_val.lstrip().startswith("<|channel|>"):
|
| 391 |
+
raise ValueError("Assistant content must not include harmony channel headers.")
|
| 392 |
+
cleaned = _maybe_strip_content(text_val, canonical_role, schema)
|
| 393 |
+
raw_content_debug.append({
|
| 394 |
+
"raw_text": repr(text_val[:200]) + ("..." if len(text_val) > 200 else ""),
|
| 395 |
+
"raw_len": len(text_val),
|
| 396 |
+
"cleaned_text": repr(cleaned[:200]) + ("..." if len(cleaned) > 200 else ""),
|
| 397 |
+
"cleaned_len": len(cleaned),
|
| 398 |
+
})
|
| 399 |
+
_append_text(cleaned)
|
| 400 |
+
|
| 401 |
+
turn_content = "".join(text_segments)
|
| 402 |
+
if not turn_content:
|
| 403 |
+
# Build detailed diagnostic message
|
| 404 |
+
diag_lines = [
|
| 405 |
+
"Model turn is empty after preprocessing.",
|
| 406 |
+
f" role: {canonical_role}",
|
| 407 |
+
f" num_content_items: {len(content)}",
|
| 408 |
+
]
|
| 409 |
+
for i, dbg in enumerate(raw_content_debug):
|
| 410 |
+
diag_lines.append(f" item[{i}]: raw_len={dbg['raw_len']}, cleaned_len={dbg['cleaned_len']}")
|
| 411 |
+
diag_lines.append(f" raw: {dbg['raw_text']}")
|
| 412 |
+
diag_lines.append(f" cleaned: {dbg['cleaned_text']}")
|
| 413 |
+
raise ValueError("\n".join(diag_lines))
|
| 414 |
+
if needs_channel_header:
|
| 415 |
+
channel_header = "<|channel|>final<|message|>"
|
| 416 |
+
parts.insert(0, {"kind": "text", "text": channel_header})
|
| 417 |
+
text_segments.insert(0, channel_header)
|
| 418 |
+
turn_content = "".join(text_segments)
|
| 419 |
+
if turn_suffix is None:
|
| 420 |
+
_, suffix = turn_wrappers(schema, canonical_role)
|
| 421 |
+
else:
|
| 422 |
+
suffix = turn_suffix
|
| 423 |
+
turn_text_block = turn_content + suffix
|
| 424 |
+
assert_turn_content_ends_with_eot(turn_text_block, suffix)
|
| 425 |
+
return turn_text_block, parts
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def annotate_turn_parts_with_ids(
|
| 429 |
+
turn_parts: List[List[Dict[str, Any]]],
|
| 430 |
+
tokenizer: PreTrainedTokenizer,
|
| 431 |
+
) -> List[List[Dict[str, Any]]]:
|
| 432 |
+
"""Attach token ids to text parts so the trainer can skip per-step tokenization."""
|
| 433 |
+
for parts in turn_parts:
|
| 434 |
+
for part in parts:
|
| 435 |
+
if part.get("kind") == "text":
|
| 436 |
+
txt = part.get("text", "")
|
| 437 |
+
part["ids"] = tokenizer.encode(txt, add_special_tokens=False) if txt else []
|
| 438 |
+
return turn_parts
|
| 439 |
+
|
| 440 |
+
def assert_single_bos_eos(
|
| 441 |
+
text_ids: List[int],
|
| 442 |
+
tok: PreTrainedTokenizer,
|
| 443 |
+
*,
|
| 444 |
+
require_bos_at_start: bool,
|
| 445 |
+
require_single_terminal_eos: bool,
|
| 446 |
+
allow_multiple_eos: bool = False,
|
| 447 |
+
) -> None:
|
| 448 |
+
"""Validate BOS/EOS placement under various schema-specific policies."""
|
| 449 |
+
if require_bos_at_start:
|
| 450 |
+
if tok.bos_token_id is None:
|
| 451 |
+
raise AssertionError("BOS token required but tokenizer has none")
|
| 452 |
+
bos_pos = [i for i, t in enumerate(text_ids) if t == tok.bos_token_id]
|
| 453 |
+
if len(bos_pos) != 1 or bos_pos[0] != 0:
|
| 454 |
+
raise AssertionError(f"BOS placement invalid: positions={bos_pos}")
|
| 455 |
+
else:
|
| 456 |
+
if tok.bos_token_id is not None:
|
| 457 |
+
bos_pos = [i for i, t in enumerate(text_ids) if t == tok.bos_token_id]
|
| 458 |
+
if len(bos_pos) > 1 or (bos_pos and bos_pos[0] != 0):
|
| 459 |
+
raise AssertionError(f"Unexpected BOS placement: positions={bos_pos}")
|
| 460 |
+
|
| 461 |
+
# EOS checks
|
| 462 |
+
if tok.eos_token_id is not None:
|
| 463 |
+
eos_pos = [i for i, t in enumerate(text_ids) if t == tok.eos_token_id]
|
| 464 |
+
if allow_multiple_eos:
|
| 465 |
+
if require_single_terminal_eos:
|
| 466 |
+
# Allow multiple EOS (e.g., ChatML-style per turn), but require the last EOS at sequence end.
|
| 467 |
+
if len(eos_pos) == 0 or eos_pos[-1] != (len(text_ids) - 1):
|
| 468 |
+
raise AssertionError(f"EOS bad: positions={eos_pos}")
|
| 469 |
+
else:
|
| 470 |
+
# Allow any count anywhere; no terminal EOS required (matches Qwen3 training practice).
|
| 471 |
+
pass
|
| 472 |
+
else:
|
| 473 |
+
if require_single_terminal_eos:
|
| 474 |
+
# Require exactly one EOS, and it must be terminal.
|
| 475 |
+
if len(eos_pos) != 1 or eos_pos[0] != (len(text_ids) - 1):
|
| 476 |
+
raise AssertionError(f"EOS bad: positions={eos_pos}")
|
| 477 |
+
else:
|
| 478 |
+
# Allow 0 or 1, but if present it must be terminal.
|
| 479 |
+
if len(eos_pos) > 1 or (eos_pos and eos_pos[0] != (len(text_ids) - 1)):
|
| 480 |
+
raise AssertionError(f"EOS bad: positions={eos_pos}")
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
def assert_struct_bos_eos(
|
| 484 |
+
token_struct: Dict[str, List[int]],
|
| 485 |
+
tok: PreTrainedTokenizer,
|
| 486 |
+
*,
|
| 487 |
+
require_bos_at_start: bool,
|
| 488 |
+
require_single_terminal_eos: bool,
|
| 489 |
+
allow_multiple_eos: bool = False,
|
| 490 |
+
) -> None:
|
| 491 |
+
"""Wrapper that validates concatenated ids per schema policy."""
|
| 492 |
+
text_ids = token_struct["text_ids"]
|
| 493 |
+
assert_single_bos_eos(
|
| 494 |
+
text_ids,
|
| 495 |
+
tok,
|
| 496 |
+
require_bos_at_start=require_bos_at_start,
|
| 497 |
+
require_single_terminal_eos=require_single_terminal_eos,
|
| 498 |
+
allow_multiple_eos=allow_multiple_eos,
|
| 499 |
+
)
|
camel_inference/src/camel/inference.py
ADDED
|
@@ -0,0 +1,846 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# inference.py — ECGText inference (model loading + generation helpers)
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Dict, List, Tuple, Optional, Any
|
| 7 |
+
from collections import OrderedDict
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from transformers import AutoModelForCausalLM
|
| 12 |
+
from peft import LoraConfig
|
| 13 |
+
|
| 14 |
+
# Local imports
|
| 15 |
+
from camel.model_introspect import resolve_hidden_size as _resolve_hidden_size
|
| 16 |
+
from camel.model_registry import load_registry
|
| 17 |
+
from camel.training_setup import initialize_tokenizer, build_packing_schema, register_ecg_special_tokens
|
| 18 |
+
from camel.model_init import build_wrapper, attach_lora, build_conv_encoder
|
| 19 |
+
from camel.ecg_text_packing import (
|
| 20 |
+
_normalize_conversation,
|
| 21 |
+
annotate_turn_parts_with_ids,
|
| 22 |
+
build_structured_turn_parts,
|
| 23 |
+
build_text_only_turn_parts,
|
| 24 |
+
get_ecg_special_token_catalog,
|
| 25 |
+
)
|
| 26 |
+
from camel.prompt_renderers import render_prompt_and_spans, turn_wrappers, assistant_generation_prefix
|
| 27 |
+
from camel.ecg_attention_masks import (
|
| 28 |
+
ECGBlockLayout,
|
| 29 |
+
ECGSequenceLayout,
|
| 30 |
+
MaskBuildResult,
|
| 31 |
+
ECGMaskStrategy,
|
| 32 |
+
get_mask_strategy,
|
| 33 |
+
)
|
| 34 |
+
from camel.assertions import (
|
| 35 |
+
assert_ecg_blocks_consistent,
|
| 36 |
+
assert_ecg_part_bounds,
|
| 37 |
+
assert_layout_specials_complete,
|
| 38 |
+
assert_prefix_matches_segments,
|
| 39 |
+
assert_prefix_split_complete,
|
| 40 |
+
)
|
| 41 |
+
from camel.checkpoint_utils import (
|
| 42 |
+
load_llava_and_lora,
|
| 43 |
+
update_wrapper_language_model,
|
| 44 |
+
extract_lora_config_from_checkpoints,
|
| 45 |
+
peek_projector_name,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# ------------------------------
|
| 49 |
+
# Device & conv builder
|
| 50 |
+
# ------------------------------
|
| 51 |
+
|
| 52 |
+
def _device(device=None) -> torch.device:
|
| 53 |
+
if device:
|
| 54 |
+
return device
|
| 55 |
+
if torch.cuda.is_available():
|
| 56 |
+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
| 57 |
+
return torch.device(f"cuda:{local_rank}")
|
| 58 |
+
return torch.device("cpu")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
_HARMONY_CHANNEL_RE = re.compile(r"<\|channel\|>(.*?)<\|message\|>", re.DOTALL)
|
| 62 |
+
_HARMONY_DELIM_RE = re.compile(r"<\|end\|>|<\|return\|>|<\|call\|>|<\|start\|>")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _extract_harmony_messages(text: str) -> List[Tuple[str, str]]:
|
| 66 |
+
matches = list(_HARMONY_CHANNEL_RE.finditer(text))
|
| 67 |
+
if not matches:
|
| 68 |
+
raise ValueError("No harmony channel headers found in model output.")
|
| 69 |
+
out: List[Tuple[str, str]] = []
|
| 70 |
+
for match in matches:
|
| 71 |
+
channel_raw = match.group(1).strip()
|
| 72 |
+
channel = channel_raw.split()[0] if channel_raw else ""
|
| 73 |
+
if not channel:
|
| 74 |
+
raise ValueError("Harmony channel header is empty.")
|
| 75 |
+
start = match.end()
|
| 76 |
+
end_match = _HARMONY_DELIM_RE.search(text, start)
|
| 77 |
+
end = end_match.start() if end_match else len(text)
|
| 78 |
+
out.append((channel, text[start:end]))
|
| 79 |
+
return out
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _checkpoint_has_conv(ckpt_path: Optional[str]) -> bool:
|
| 83 |
+
if not ckpt_path:
|
| 84 |
+
return False
|
| 85 |
+
payload = torch.load(ckpt_path, map_location="cpu")
|
| 86 |
+
if not isinstance(payload, dict):
|
| 87 |
+
raise RuntimeError(f"Checkpoint {ckpt_path} must be a dict to inspect conv metadata.")
|
| 88 |
+
return isinstance(payload.get("conv"), dict)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# ------------------------------
|
| 92 |
+
# Prompt building & stopping
|
| 93 |
+
# ------------------------------
|
| 94 |
+
|
| 95 |
+
@dataclass
|
| 96 |
+
class PromptContext:
|
| 97 |
+
"""Container describing the prepared prompt state for autoregressive generation."""
|
| 98 |
+
|
| 99 |
+
inputs_embeds: torch.Tensor
|
| 100 |
+
layout: ECGSequenceLayout
|
| 101 |
+
prompt_preview: str
|
| 102 |
+
stop_ids: List[int]
|
| 103 |
+
input_embedder: nn.Embedding
|
| 104 |
+
mask_strategy: ECGMaskStrategy
|
| 105 |
+
mask_result: MaskBuildResult
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def _sanitize_segments(tensor: torch.Tensor) -> torch.Tensor:
|
| 109 |
+
"""Detach → float32 → replace NaN/Inf so downstream encoders stay numerically stable."""
|
| 110 |
+
out = tensor.detach().cpu().to(dtype=torch.float32)
|
| 111 |
+
return torch.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def _sample_next_token(
|
| 115 |
+
logits: torch.Tensor,
|
| 116 |
+
*,
|
| 117 |
+
temperature: float,
|
| 118 |
+
top_k: Optional[int],
|
| 119 |
+
top_p: float,
|
| 120 |
+
min_p: float,
|
| 121 |
+
) -> torch.Tensor:
|
| 122 |
+
"""
|
| 123 |
+
Draw the next token id given final-step logits and sampling parameters.
|
| 124 |
+
|
| 125 |
+
Uses greedy decoding when temperature <= 0; otherwise applies temperature
|
| 126 |
+
scaling, optional nucleus sampling, and multinomial sampling.
|
| 127 |
+
"""
|
| 128 |
+
if logits.ndim != 1:
|
| 129 |
+
raise ValueError(f"Expected 1D logits, got shape {tuple(logits.shape)}")
|
| 130 |
+
|
| 131 |
+
if temperature <= 0.0:
|
| 132 |
+
return torch.argmax(logits, dim=-1)
|
| 133 |
+
|
| 134 |
+
scaled = logits / max(temperature, 1e-5)
|
| 135 |
+
probs = torch.softmax(scaled, dim=-1)
|
| 136 |
+
|
| 137 |
+
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
|
| 138 |
+
keep_mask = torch.ones_like(sorted_probs, dtype=torch.bool)
|
| 139 |
+
|
| 140 |
+
if top_k is not None and top_k > 0:
|
| 141 |
+
top_k = min(int(top_k), sorted_probs.numel())
|
| 142 |
+
top_k_mask = torch.zeros_like(sorted_probs, dtype=torch.bool)
|
| 143 |
+
top_k_mask[:top_k] = True
|
| 144 |
+
keep_mask &= top_k_mask
|
| 145 |
+
|
| 146 |
+
if 0.0 < top_p < 1.0:
|
| 147 |
+
cumulative = torch.cumsum(sorted_probs, dim=-1)
|
| 148 |
+
cutoff_mask = (cumulative - sorted_probs) < top_p
|
| 149 |
+
cutoff_mask[0] = True # always keep the highest-prob token
|
| 150 |
+
keep_mask &= cutoff_mask
|
| 151 |
+
|
| 152 |
+
if min_p is not None and min_p > 0.0:
|
| 153 |
+
keep_mask &= sorted_probs >= float(min_p)
|
| 154 |
+
|
| 155 |
+
filtered_probs = sorted_probs[keep_mask]
|
| 156 |
+
filtered_indices = sorted_indices[keep_mask]
|
| 157 |
+
if filtered_probs.numel() == 0:
|
| 158 |
+
filtered_probs = sorted_probs[:1]
|
| 159 |
+
filtered_indices = sorted_indices[:1]
|
| 160 |
+
|
| 161 |
+
prob_sum = filtered_probs.sum()
|
| 162 |
+
if not torch.isfinite(prob_sum) or prob_sum <= 0:
|
| 163 |
+
return sorted_indices[0]
|
| 164 |
+
normalized = filtered_probs / prob_sum
|
| 165 |
+
next_idx = torch.multinomial(normalized, num_samples=1, replacement=False)
|
| 166 |
+
return filtered_indices[next_idx].squeeze(0)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class KardiaLM:
|
| 170 |
+
"""High-level chat interface around an ECG language model."""
|
| 171 |
+
|
| 172 |
+
def __init__(
|
| 173 |
+
self,
|
| 174 |
+
*,
|
| 175 |
+
model_registry_path: Optional[str],
|
| 176 |
+
model_config_name: str,
|
| 177 |
+
hf_model_id_override: Optional[str],
|
| 178 |
+
adapter_ckpt: str,
|
| 179 |
+
conv_ckpt: Optional[str] = None,
|
| 180 |
+
no_lora: bool = False,
|
| 181 |
+
use_dora: bool = False,
|
| 182 |
+
default_max_new_tokens: int = 1000,
|
| 183 |
+
default_temperature: float = 1.0,
|
| 184 |
+
default_top_k: Optional[int] = 64,
|
| 185 |
+
default_top_p: float = 0.95,
|
| 186 |
+
default_min_p: float = 0.0,
|
| 187 |
+
mask_strategy: Optional[str] = None,
|
| 188 |
+
device: Optional[torch.device] = None,
|
| 189 |
+
) -> None:
|
| 190 |
+
registry = load_registry(registry_path=model_registry_path)
|
| 191 |
+
model_cfg = registry.get(model_config_name)
|
| 192 |
+
|
| 193 |
+
self.model_cfg = model_cfg
|
| 194 |
+
self.hf_model_id = hf_model_id_override or model_cfg.hf_id
|
| 195 |
+
self.packing_schema = build_packing_schema(self.hf_model_id)
|
| 196 |
+
self.tokenizer_cfg = model_cfg.tokenizer_config()
|
| 197 |
+
self.arch_cfg = model_cfg.architecture_config()
|
| 198 |
+
self.system_text = None
|
| 199 |
+
self.developer_text = None
|
| 200 |
+
if self.packing_schema.conversation.format_id == "harmony_chat_v1":
|
| 201 |
+
self.system_text = model_cfg.required_prompt_text("system_prompt")
|
| 202 |
+
self.developer_text = model_cfg.required_prompt_text("developer_prompt")
|
| 203 |
+
if not self.system_text.strip():
|
| 204 |
+
raise RuntimeError("System prompt text for harmony format must be non-empty.")
|
| 205 |
+
if not self.developer_text.strip():
|
| 206 |
+
raise RuntimeError("Developer prompt text for harmony format must be non-empty.")
|
| 207 |
+
|
| 208 |
+
self.device = _device(device)
|
| 209 |
+
self.dtype = torch.bfloat16
|
| 210 |
+
self.mask_strategy: ECGMaskStrategy = get_mask_strategy(mask_strategy)
|
| 211 |
+
self.expect_dora = bool(use_dora)
|
| 212 |
+
|
| 213 |
+
tok = initialize_tokenizer(
|
| 214 |
+
self.hf_model_id,
|
| 215 |
+
trust_remote_code=True,
|
| 216 |
+
use_fast=self.tokenizer_cfg.use_fast,
|
| 217 |
+
add_prefix_space=self.tokenizer_cfg.add_prefix_space,
|
| 218 |
+
)
|
| 219 |
+
self.tokenizer = tok
|
| 220 |
+
|
| 221 |
+
catalog = get_ecg_special_token_catalog(self.packing_schema)
|
| 222 |
+
self.ecg_special_token_id_map = register_ecg_special_tokens(tok, catalog)
|
| 223 |
+
|
| 224 |
+
pad_strategy = self.tokenizer_cfg.pad_token_strategy.lower()
|
| 225 |
+
if pad_strategy == "eos":
|
| 226 |
+
if tok.eos_token is None:
|
| 227 |
+
raise RuntimeError(
|
| 228 |
+
f"Tokenizer for model '{model_cfg.name}' lacks an EOS token required for pad_token_strategy='eos'."
|
| 229 |
+
)
|
| 230 |
+
tok.pad_token = tok.eos_token
|
| 231 |
+
elif pad_strategy not in ("existing", "keep"):
|
| 232 |
+
raise RuntimeError(f"Unsupported pad_token_strategy '{self.tokenizer_cfg.pad_token_strategy}'.")
|
| 233 |
+
|
| 234 |
+
if self.tokenizer_cfg.require_bos and tok.bos_token is None:
|
| 235 |
+
raise RuntimeError(f"Tokenizer for model '{model_cfg.name}' is missing a BOS token.")
|
| 236 |
+
if self.tokenizer_cfg.require_eos and tok.eos_token is None:
|
| 237 |
+
raise RuntimeError(f"Tokenizer for model '{model_cfg.name}' is missing an EOS token.")
|
| 238 |
+
|
| 239 |
+
attn_impl = self.arch_cfg.attn_implementation or "flash_attention_2"
|
| 240 |
+
try:
|
| 241 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 242 |
+
self.hf_model_id,
|
| 243 |
+
torch_dtype=self.dtype,
|
| 244 |
+
trust_remote_code=True,
|
| 245 |
+
attn_implementation=attn_impl,
|
| 246 |
+
device_map=None,
|
| 247 |
+
).to(self.device)
|
| 248 |
+
except Exception:
|
| 249 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 250 |
+
self.hf_model_id,
|
| 251 |
+
torch_dtype=self.dtype,
|
| 252 |
+
trust_remote_code=True,
|
| 253 |
+
attn_implementation="eager",
|
| 254 |
+
device_map=None,
|
| 255 |
+
).to(self.device)
|
| 256 |
+
if model.get_input_embeddings().weight.shape[0] != len(tok):
|
| 257 |
+
model.resize_token_embeddings(len(tok))
|
| 258 |
+
for p in model.parameters():
|
| 259 |
+
p.requires_grad = False
|
| 260 |
+
model.eval()
|
| 261 |
+
if hasattr(model, "gradient_checkpointing_disable"):
|
| 262 |
+
model.gradient_checkpointing_disable()
|
| 263 |
+
if hasattr(model.config, "use_cache"):
|
| 264 |
+
model.config.use_cache = True
|
| 265 |
+
self.model = model
|
| 266 |
+
|
| 267 |
+
adapter_ckpt_path = os.path.expanduser(adapter_ckpt)
|
| 268 |
+
adapter_has_conv = _checkpoint_has_conv(adapter_ckpt_path)
|
| 269 |
+
if not adapter_has_conv and not conv_ckpt:
|
| 270 |
+
raise RuntimeError(
|
| 271 |
+
"Adapter checkpoint lacks conv weights; supply --conv_ckpt to match training."
|
| 272 |
+
)
|
| 273 |
+
lora_cfg_dict = extract_lora_config_from_checkpoints(adapter_ckpt_path, None)
|
| 274 |
+
active_lora_cfg: Optional[LoraConfig] = None
|
| 275 |
+
if self.expect_dora and no_lora:
|
| 276 |
+
raise RuntimeError("--use-dora cannot be combined with --no-lora since no adapters would be loaded.")
|
| 277 |
+
if lora_cfg_dict and not no_lora:
|
| 278 |
+
cfg_use_dora = bool(lora_cfg_dict.get("use_dora", False))
|
| 279 |
+
if cfg_use_dora and not self.expect_dora:
|
| 280 |
+
raise RuntimeError(
|
| 281 |
+
"Checkpoint adapters were trained with DoRA; re-run inference with --use-dora to load them."
|
| 282 |
+
)
|
| 283 |
+
if self.expect_dora and not cfg_use_dora:
|
| 284 |
+
raise RuntimeError(
|
| 285 |
+
"Checkpoint adapters were trained without DoRA; omit --use-dora or use a checkpoint with DoRA."
|
| 286 |
+
)
|
| 287 |
+
model, active_lora_cfg = attach_lora(model, lora_cfg_dict, self.device)
|
| 288 |
+
model.eval()
|
| 289 |
+
elif no_lora and lora_cfg_dict:
|
| 290 |
+
print("[LoRA] --no-lora set; skipping LoRA adapters from checkpoint.", flush=True)
|
| 291 |
+
elif self.expect_dora:
|
| 292 |
+
raise RuntimeError("--use-dora was provided, but no LoRA/DoRA adapters were found in the checkpoint.")
|
| 293 |
+
|
| 294 |
+
conv = build_conv_encoder(
|
| 295 |
+
conv_ckpt_path=None if adapter_has_conv else conv_ckpt,
|
| 296 |
+
device=self.device,
|
| 297 |
+
unfreeze=False,
|
| 298 |
+
)
|
| 299 |
+
conv.eval()
|
| 300 |
+
for p in conv.parameters():
|
| 301 |
+
p.requires_grad = False
|
| 302 |
+
self.conv_encoder = conv
|
| 303 |
+
|
| 304 |
+
hidden_size = _resolve_hidden_size(model, self.arch_cfg.hidden_size_attrs)
|
| 305 |
+
wrapper_cls = model_cfg.resolve_wrapper_class()
|
| 306 |
+
enc_out_dim = self.arch_cfg.conv_out_dim if getattr(self.arch_cfg, "conv_out_dim", None) is not None else 64
|
| 307 |
+
projector_name = peek_projector_name(adapter_ckpt_path) or "linear"
|
| 308 |
+
wrapper = build_wrapper(
|
| 309 |
+
wrapper_cls=wrapper_cls,
|
| 310 |
+
language_model=model,
|
| 311 |
+
conv_encoder=conv,
|
| 312 |
+
hidden_size=hidden_size,
|
| 313 |
+
num_ecg_special_tokens=len(catalog.tokens),
|
| 314 |
+
dtype=self.dtype,
|
| 315 |
+
enc_out_dim=int(enc_out_dim),
|
| 316 |
+
freeze_encoder=True,
|
| 317 |
+
inference=True,
|
| 318 |
+
projector_name=projector_name,
|
| 319 |
+
)
|
| 320 |
+
self.projector_name = projector_name
|
| 321 |
+
self.wrapper = wrapper
|
| 322 |
+
|
| 323 |
+
_extra_payload, model, inferred_lora_cfg = load_llava_and_lora(
|
| 324 |
+
wrapper,
|
| 325 |
+
model,
|
| 326 |
+
adapter_ckpt_path,
|
| 327 |
+
expect_lora=(active_lora_cfg is not None),
|
| 328 |
+
load_lora=not no_lora,
|
| 329 |
+
)
|
| 330 |
+
update_wrapper_language_model(wrapper, model)
|
| 331 |
+
if active_lora_cfg is None and inferred_lora_cfg is not None and not no_lora:
|
| 332 |
+
active_lora_cfg = inferred_lora_cfg
|
| 333 |
+
model.eval()
|
| 334 |
+
for p in model.parameters():
|
| 335 |
+
p.requires_grad = False
|
| 336 |
+
|
| 337 |
+
inp_emb = model.get_input_embeddings().weight
|
| 338 |
+
inp_dev = inp_emb.device
|
| 339 |
+
target_dtype = inp_emb.dtype
|
| 340 |
+
wrapper.llava_proj.to(device=inp_dev, dtype=torch.float32)
|
| 341 |
+
wrapper.enc.to(device=inp_dev, dtype=torch.float32)
|
| 342 |
+
wrapper.ecg_special_embed.to(device=inp_dev, dtype=target_dtype)
|
| 343 |
+
llava_param = next(wrapper.llava_proj.parameters(), None)
|
| 344 |
+
if llava_param is None:
|
| 345 |
+
raise AssertionError("llava_proj unexpectedly has no parameters.")
|
| 346 |
+
if llava_param.device != inp_dev:
|
| 347 |
+
raise AssertionError(f"llava_proj on {llava_param.device}, expected {inp_dev}")
|
| 348 |
+
if llava_param.dtype != torch.float32:
|
| 349 |
+
raise AssertionError(f"llava_proj dtype {llava_param.dtype}, expected torch.float32")
|
| 350 |
+
conv_param = next(wrapper.enc.parameters(), None)
|
| 351 |
+
if conv_param is None:
|
| 352 |
+
raise AssertionError("Convolutional encoder unexpectedly has no parameters.")
|
| 353 |
+
if conv_param.device != inp_dev:
|
| 354 |
+
raise AssertionError(f"Conv encoder on {conv_param.device}, expected {inp_dev}")
|
| 355 |
+
if conv_param.dtype != torch.float32:
|
| 356 |
+
raise AssertionError(f"Conv encoder dtype {conv_param.dtype}, expected torch.float32")
|
| 357 |
+
assert next(wrapper.ecg_special_embed.parameters()).device == inp_dev, (
|
| 358 |
+
f"ecg_special_embed on {next(wrapper.ecg_special_embed.parameters()).device}, expected {inp_dev}"
|
| 359 |
+
)
|
| 360 |
+
try:
|
| 361 |
+
wrapper.language_model.eval()
|
| 362 |
+
except Exception:
|
| 363 |
+
pass
|
| 364 |
+
|
| 365 |
+
self.default_max_new_tokens = int(default_max_new_tokens)
|
| 366 |
+
self.default_temperature = float(default_temperature)
|
| 367 |
+
self.default_top_k = int(default_top_k) if default_top_k is not None else None
|
| 368 |
+
self.default_top_p = float(default_top_p)
|
| 369 |
+
self.default_min_p = float(default_min_p)
|
| 370 |
+
|
| 371 |
+
def chat(
|
| 372 |
+
self,
|
| 373 |
+
*,
|
| 374 |
+
conversation: List[Dict[str, Any]],
|
| 375 |
+
max_new_tokens: Optional[int] = None,
|
| 376 |
+
temperature: Optional[float] = None,
|
| 377 |
+
top_k: Optional[int] = None,
|
| 378 |
+
top_p: Optional[float] = None,
|
| 379 |
+
min_p: Optional[float] = None,
|
| 380 |
+
harmony_output: Optional[str] = None,
|
| 381 |
+
) -> Tuple[str, str]:
|
| 382 |
+
"""Generate a response for a structured multi-turn conversation."""
|
| 383 |
+
context = self._prepare_prompt_context(conversation=conversation)
|
| 384 |
+
|
| 385 |
+
max_new_tokens = int(max_new_tokens if max_new_tokens is not None else self.default_max_new_tokens)
|
| 386 |
+
temperature = float(temperature if temperature is not None else self.default_temperature)
|
| 387 |
+
resolved_top_k = int(top_k) if top_k is not None else (self.default_top_k if self.default_top_k is not None else None)
|
| 388 |
+
top_p = float(top_p if top_p is not None else self.default_top_p)
|
| 389 |
+
min_p = float(min_p if min_p is not None else self.default_min_p)
|
| 390 |
+
|
| 391 |
+
token_ids = self._autoregressive_generate(
|
| 392 |
+
context=context,
|
| 393 |
+
max_new_tokens=max_new_tokens,
|
| 394 |
+
temperature=temperature,
|
| 395 |
+
top_k=resolved_top_k,
|
| 396 |
+
top_p=top_p,
|
| 397 |
+
min_p=min_p,
|
| 398 |
+
)
|
| 399 |
+
text = self.tokenizer.decode(token_ids, skip_special_tokens=False)
|
| 400 |
+
if self.packing_schema.conversation.format_id == "harmony_chat_v1":
|
| 401 |
+
mode = harmony_output if harmony_output is not None else "all"
|
| 402 |
+
if mode != "raw":
|
| 403 |
+
messages = _extract_harmony_messages(text)
|
| 404 |
+
if mode == "all":
|
| 405 |
+
text = "\n".join(msg for _, msg in messages)
|
| 406 |
+
elif mode == "final":
|
| 407 |
+
finals = [msg for channel, msg in messages if channel == "final"]
|
| 408 |
+
if not finals:
|
| 409 |
+
raise ValueError("No final channel output found in harmony response.")
|
| 410 |
+
text = finals[-1]
|
| 411 |
+
else:
|
| 412 |
+
raise ValueError(f"Unknown harmony_output '{mode}'.")
|
| 413 |
+
return text, context.prompt_preview
|
| 414 |
+
|
| 415 |
+
def _to_waveform_tensor(self, value: Any) -> torch.Tensor:
|
| 416 |
+
if isinstance(value, torch.Tensor):
|
| 417 |
+
tensor = value.detach().cpu()
|
| 418 |
+
elif isinstance(value, np.ndarray):
|
| 419 |
+
tensor = torch.from_numpy(np.asarray(value))
|
| 420 |
+
else:
|
| 421 |
+
tensor = torch.tensor(value, dtype=torch.float32)
|
| 422 |
+
tensor = tensor.to(dtype=torch.float32)
|
| 423 |
+
if tensor.ndim == 1:
|
| 424 |
+
if tensor.numel() != 256:
|
| 425 |
+
raise ValueError("Expected a 256-sample vector for a single lead second.")
|
| 426 |
+
tensor = tensor.view(1, 256)
|
| 427 |
+
elif tensor.ndim == 2:
|
| 428 |
+
if tensor.size(-1) != 256:
|
| 429 |
+
raise ValueError("Waveform segments must have length 256 along the last dimension.")
|
| 430 |
+
else:
|
| 431 |
+
raise ValueError("Waveform tensor must be rank 1 or 2 with 256-sample segments.")
|
| 432 |
+
return tensor.contiguous()
|
| 433 |
+
|
| 434 |
+
def _prepare_prompt_context(
|
| 435 |
+
self,
|
| 436 |
+
*,
|
| 437 |
+
conversation: List[Dict[str, Any]],
|
| 438 |
+
) -> PromptContext:
|
| 439 |
+
tok = self.tokenizer
|
| 440 |
+
wrapper = self.wrapper
|
| 441 |
+
packing_schema = self.packing_schema
|
| 442 |
+
device = self.device
|
| 443 |
+
|
| 444 |
+
prompt_tokens = packing_schema.prompt
|
| 445 |
+
if not isinstance(conversation, list) or not conversation:
|
| 446 |
+
raise ValueError("conversation must be a non-empty list of turns.")
|
| 447 |
+
|
| 448 |
+
conv_input: List[Dict[str, Any]] = []
|
| 449 |
+
for turn in conversation:
|
| 450 |
+
if not isinstance(turn, dict):
|
| 451 |
+
raise ValueError("Conversation turns must be dicts.")
|
| 452 |
+
if "from" not in turn and "role" in turn:
|
| 453 |
+
turn = dict(turn)
|
| 454 |
+
turn["from"] = turn.get("role")
|
| 455 |
+
conv_input.append(turn)
|
| 456 |
+
|
| 457 |
+
turns = _normalize_conversation(conv_input, packing_schema, self.system_text, self.developer_text)
|
| 458 |
+
if turns[-1]["role"] != prompt_tokens.user_role:
|
| 459 |
+
raise ValueError("Conversation must end with a user turn to generate.")
|
| 460 |
+
|
| 461 |
+
def _sanitize_content(content: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 462 |
+
sanitized: List[Dict[str, Any]] = []
|
| 463 |
+
for item in content:
|
| 464 |
+
if not isinstance(item, dict):
|
| 465 |
+
raise ValueError("Conversation content items must be dicts.")
|
| 466 |
+
item_type = item.get("type")
|
| 467 |
+
if item_type == "ecg":
|
| 468 |
+
waveform = item.get("waveform_segments")
|
| 469 |
+
if not isinstance(waveform, dict):
|
| 470 |
+
raise ValueError("ECG content item missing waveform_segments mapping.")
|
| 471 |
+
wf_out: "OrderedDict[str, torch.Tensor]" = OrderedDict()
|
| 472 |
+
for ld, value in waveform.items():
|
| 473 |
+
wf_out[str(ld)] = _sanitize_segments(self._to_waveform_tensor(value))
|
| 474 |
+
new_item = dict(item)
|
| 475 |
+
new_item["waveform_segments"] = wf_out
|
| 476 |
+
sanitized.append(new_item)
|
| 477 |
+
continue
|
| 478 |
+
if item_type == "text":
|
| 479 |
+
text_val = item.get("text")
|
| 480 |
+
if not isinstance(text_val, str):
|
| 481 |
+
raise ValueError("Text content item must have a string 'text' field.")
|
| 482 |
+
if "<image>" in text_val:
|
| 483 |
+
raise ValueError("Conversation text must not contain <image> in inference mode.")
|
| 484 |
+
sanitized.append(item)
|
| 485 |
+
return sanitized
|
| 486 |
+
|
| 487 |
+
ecg_blocks: List[Dict[str, Any]] = []
|
| 488 |
+
turn_parts: List[List[Dict[str, Any]]] = []
|
| 489 |
+
token_turns: List[Dict[str, str]] = []
|
| 490 |
+
|
| 491 |
+
for turn in turns:
|
| 492 |
+
role = turn["role"]
|
| 493 |
+
content = _sanitize_content(turn["content"])
|
| 494 |
+
if role == prompt_tokens.model_role:
|
| 495 |
+
turn_text_block, content_parts = build_text_only_turn_parts(
|
| 496 |
+
content=content,
|
| 497 |
+
canonical_role=role,
|
| 498 |
+
schema=packing_schema,
|
| 499 |
+
)
|
| 500 |
+
else:
|
| 501 |
+
turn_text_block, content_parts = build_structured_turn_parts(
|
| 502 |
+
content=content,
|
| 503 |
+
canonical_role=role,
|
| 504 |
+
schema=packing_schema,
|
| 505 |
+
ecg_blocks=ecg_blocks,
|
| 506 |
+
sampling_rate=None,
|
| 507 |
+
)
|
| 508 |
+
prefix, suffix = turn_wrappers(packing_schema, role)
|
| 509 |
+
parts = [{"kind": "text", "text": prefix}]
|
| 510 |
+
parts.extend(content_parts)
|
| 511 |
+
parts.append({"kind": "text", "text": suffix})
|
| 512 |
+
turn_parts.append(parts)
|
| 513 |
+
token_turns.append({"role": role, "text_block": turn_text_block})
|
| 514 |
+
|
| 515 |
+
if not ecg_blocks:
|
| 516 |
+
raise ValueError("No ECG blocks found in conversation.")
|
| 517 |
+
|
| 518 |
+
turn_parts = annotate_turn_parts_with_ids(turn_parts, tok)
|
| 519 |
+
assert_ecg_blocks_consistent(turn_parts=turn_parts, ecg_blocks=ecg_blocks)
|
| 520 |
+
|
| 521 |
+
token_struct = render_prompt_and_spans(tok, token_turns, schema=packing_schema)
|
| 522 |
+
text_ids = list(token_struct["text_ids"])
|
| 523 |
+
if (
|
| 524 |
+
prompt_tokens.require_eos
|
| 525 |
+
and tok.eos_token_id is not None
|
| 526 |
+
and text_ids
|
| 527 |
+
and text_ids[-1] == tok.eos_token_id
|
| 528 |
+
):
|
| 529 |
+
text_ids = text_ids[:-1]
|
| 530 |
+
text_preview = token_struct.get("text_preview", "")
|
| 531 |
+
|
| 532 |
+
model_prefix = assistant_generation_prefix(packing_schema)
|
| 533 |
+
ids_model_prefix = tok.encode(model_prefix, add_special_tokens=False)
|
| 534 |
+
text_ids.extend(ids_model_prefix)
|
| 535 |
+
prompt_preview = text_preview + model_prefix
|
| 536 |
+
|
| 537 |
+
model_prefix_parts = [{"kind": "text", "text": model_prefix, "ids": ids_model_prefix}]
|
| 538 |
+
all_parts = list(turn_parts) + [model_prefix_parts]
|
| 539 |
+
|
| 540 |
+
flat_blocks = [blk["waveform_segments"] for blk in ecg_blocks]
|
| 541 |
+
lead_orders = [blk["lead_names"] for blk in ecg_blocks]
|
| 542 |
+
prefix_all, prefix_lens = wrapper.ecg_prefix_batch(
|
| 543 |
+
flat_blocks,
|
| 544 |
+
device=device,
|
| 545 |
+
lead_orders=lead_orders,
|
| 546 |
+
)
|
| 547 |
+
prefixes: List[torch.Tensor] = []
|
| 548 |
+
offset = 0
|
| 549 |
+
for n in prefix_lens:
|
| 550 |
+
prefixes.append(prefix_all[offset:offset + int(n)])
|
| 551 |
+
offset += int(n)
|
| 552 |
+
assert_prefix_split_complete(offset=offset, total_prefix_rows=int(prefix_all.size(0)))
|
| 553 |
+
|
| 554 |
+
block_layouts: List[ECGBlockLayout] = []
|
| 555 |
+
lead_offsets: List[Dict[str, int]] = []
|
| 556 |
+
lead_special_counts: List[Dict[str, int]] = []
|
| 557 |
+
for blk_idx, blk in enumerate(ecg_blocks):
|
| 558 |
+
lead_names = [str(ld) for ld in blk.get("lead_names", [])]
|
| 559 |
+
segs_per_lead = [int(n) for n in blk.get("segments_per_lead", [])]
|
| 560 |
+
prefix_rows = prefixes[blk_idx].size(0) if blk_idx < len(prefixes) else 0
|
| 561 |
+
assert_prefix_matches_segments(
|
| 562 |
+
prefix_rows=prefix_rows,
|
| 563 |
+
segments_per_lead=segs_per_lead,
|
| 564 |
+
lead_names=lead_names,
|
| 565 |
+
sample_index=0,
|
| 566 |
+
block_index=blk_idx,
|
| 567 |
+
)
|
| 568 |
+
lead_to_offset: Dict[str, int] = {}
|
| 569 |
+
c = 0
|
| 570 |
+
for ld, nseg in zip(lead_names, segs_per_lead):
|
| 571 |
+
lead_to_offset[ld] = c
|
| 572 |
+
c += int(nseg)
|
| 573 |
+
lead_offsets.append(lead_to_offset)
|
| 574 |
+
block_layouts.append(ECGBlockLayout(
|
| 575 |
+
start_idx=None,
|
| 576 |
+
end_idx_exclusive=None,
|
| 577 |
+
global_start_idx=None,
|
| 578 |
+
global_end_idx=None,
|
| 579 |
+
lead_start_idx={},
|
| 580 |
+
lead_end_idx={},
|
| 581 |
+
signal_pos_by_lead={ld: [None] * int(nseg) for ld, nseg in zip(lead_names, segs_per_lead)},
|
| 582 |
+
time_to_signal_idxs={},
|
| 583 |
+
declared_segments_per_lead={ld: int(nseg) for ld, nseg in zip(lead_names, segs_per_lead)},
|
| 584 |
+
))
|
| 585 |
+
lead_special_counts.append({})
|
| 586 |
+
|
| 587 |
+
special_indices: List[int] = [
|
| 588 |
+
int(part["token_index"])
|
| 589 |
+
for turn in all_parts
|
| 590 |
+
for part in turn
|
| 591 |
+
if part.get("kind") == "special"
|
| 592 |
+
]
|
| 593 |
+
if special_indices:
|
| 594 |
+
special_idx_tensor = torch.tensor(special_indices, dtype=torch.long, device=device)
|
| 595 |
+
special_embeds = wrapper.ecg_special_tokens_to_embeds(special_idx_tensor, device=device)
|
| 596 |
+
else:
|
| 597 |
+
special_embeds = torch.empty((0, wrapper.hidden_size), dtype=wrapper.dtype, device=device)
|
| 598 |
+
|
| 599 |
+
input_embedder = wrapper.language_model.get_input_embeddings()
|
| 600 |
+
if text_ids:
|
| 601 |
+
E_text_all = wrapper.tokens_to_embeds(input_embedder, text_ids, device=device)
|
| 602 |
+
else:
|
| 603 |
+
E_text_all = torch.empty((0, wrapper.hidden_size), dtype=wrapper.dtype, device=device)
|
| 604 |
+
|
| 605 |
+
text_cursor = 0
|
| 606 |
+
empty_text = E_text_all[:0]
|
| 607 |
+
chunks: List[torch.Tensor] = []
|
| 608 |
+
layout = ECGSequenceLayout(seq_len=0, text_idxs=[], blocks=block_layouts)
|
| 609 |
+
|
| 610 |
+
def _take_text(count: int) -> torch.Tensor:
|
| 611 |
+
nonlocal text_cursor
|
| 612 |
+
if count <= 0:
|
| 613 |
+
return empty_text
|
| 614 |
+
end = text_cursor + count
|
| 615 |
+
if end > E_text_all.size(0):
|
| 616 |
+
raise RuntimeError("Text embedding cursor exceeded available embeddings")
|
| 617 |
+
out = E_text_all[text_cursor:end]
|
| 618 |
+
text_cursor = end
|
| 619 |
+
return out
|
| 620 |
+
|
| 621 |
+
def _record_text(count: int, cursor: int) -> None:
|
| 622 |
+
for i in range(count):
|
| 623 |
+
layout.text_idxs.append(cursor + i)
|
| 624 |
+
|
| 625 |
+
cursor = 0
|
| 626 |
+
special_cursor = 0
|
| 627 |
+
|
| 628 |
+
if (
|
| 629 |
+
text_ids
|
| 630 |
+
and prompt_tokens.require_bos
|
| 631 |
+
and tok.bos_token_id is not None
|
| 632 |
+
and text_ids[0] == tok.bos_token_id
|
| 633 |
+
):
|
| 634 |
+
E_bos = _take_text(1)
|
| 635 |
+
chunks.append(E_bos)
|
| 636 |
+
_record_text(1, cursor)
|
| 637 |
+
cursor += 1
|
| 638 |
+
|
| 639 |
+
for turn in all_parts:
|
| 640 |
+
for part in turn:
|
| 641 |
+
kind = part.get("kind")
|
| 642 |
+
if kind == "text":
|
| 643 |
+
ids_chunk = part.get("ids")
|
| 644 |
+
if ids_chunk is None:
|
| 645 |
+
txt = part.get("text", "")
|
| 646 |
+
ids_chunk = tok.encode(txt, add_special_tokens=False) if txt else []
|
| 647 |
+
if ids_chunk:
|
| 648 |
+
if ids_chunk != text_ids[text_cursor:text_cursor + len(ids_chunk)]:
|
| 649 |
+
raise RuntimeError("Special token id does not match text_ids cursor.")
|
| 650 |
+
E_chunk = _take_text(len(ids_chunk))
|
| 651 |
+
chunks.append(E_chunk)
|
| 652 |
+
_record_text(len(ids_chunk), cursor)
|
| 653 |
+
cursor += len(ids_chunk)
|
| 654 |
+
continue
|
| 655 |
+
if kind == "special":
|
| 656 |
+
if special_cursor >= special_embeds.size(0):
|
| 657 |
+
raise RuntimeError("Special-token cursor exceeded embeddings.")
|
| 658 |
+
tok_idx = int(part.get("token_index", -1))
|
| 659 |
+
expected_id = self.ecg_special_token_id_map.get(tok_idx)
|
| 660 |
+
if expected_id is None:
|
| 661 |
+
raise RuntimeError(f"Unknown ECG special token index {tok_idx}.")
|
| 662 |
+
if text_cursor >= len(text_ids):
|
| 663 |
+
raise RuntimeError("Text cursor exceeded available text ids.")
|
| 664 |
+
if text_ids[text_cursor] != expected_id:
|
| 665 |
+
raise RuntimeError("Special token id does not match text_ids cursor.")
|
| 666 |
+
_take_text(1)
|
| 667 |
+
chunks.append(special_embeds[special_cursor:special_cursor + 1])
|
| 668 |
+
_record_text(1, cursor)
|
| 669 |
+
|
| 670 |
+
block_index = int(part.get("block_index", -1))
|
| 671 |
+
if block_index < 0 or block_index >= len(block_layouts):
|
| 672 |
+
raise RuntimeError("ECG part references unknown block_index.")
|
| 673 |
+
block_layout = block_layouts[block_index]
|
| 674 |
+
lead_name = part.get("lead")
|
| 675 |
+
if lead_name:
|
| 676 |
+
cnt = lead_special_counts[block_index].get(lead_name, 0)
|
| 677 |
+
if cnt == 0:
|
| 678 |
+
block_layout.lead_start_idx[lead_name] = cursor
|
| 679 |
+
else:
|
| 680 |
+
block_layout.lead_end_idx[lead_name] = cursor
|
| 681 |
+
lead_special_counts[block_index][lead_name] = cnt + 1
|
| 682 |
+
else:
|
| 683 |
+
if block_layout.global_start_idx is None:
|
| 684 |
+
block_layout.global_start_idx = cursor
|
| 685 |
+
block_layout.start_idx = cursor
|
| 686 |
+
else:
|
| 687 |
+
block_layout.global_end_idx = cursor
|
| 688 |
+
block_layout.end_idx_exclusive = cursor + 1
|
| 689 |
+
|
| 690 |
+
cursor += 1
|
| 691 |
+
special_cursor += 1
|
| 692 |
+
continue
|
| 693 |
+
if kind == "ecg":
|
| 694 |
+
block_index = int(part.get("block_index", -1))
|
| 695 |
+
if block_index < 0 or block_index >= len(block_layouts):
|
| 696 |
+
raise RuntimeError("ECG part references unknown block_index.")
|
| 697 |
+
ld = part["lead"]
|
| 698 |
+
sec = int(part["sec"])
|
| 699 |
+
lead_to_offset = lead_offsets[block_index]
|
| 700 |
+
block_layout = block_layouts[block_index]
|
| 701 |
+
prefix_all = prefixes[block_index]
|
| 702 |
+
assert_ecg_part_bounds(
|
| 703 |
+
lead=ld,
|
| 704 |
+
sec=sec,
|
| 705 |
+
lead_to_offset=lead_to_offset,
|
| 706 |
+
declared_segments=block_layout.declared_segments_per_lead,
|
| 707 |
+
total_prefix_rows=prefix_all.size(0),
|
| 708 |
+
sample_index=0,
|
| 709 |
+
block_index=block_index,
|
| 710 |
+
)
|
| 711 |
+
base = lead_to_offset[ld]
|
| 712 |
+
row_idx = base + (sec - 1)
|
| 713 |
+
chunks.append(prefix_all[row_idx:row_idx + 1])
|
| 714 |
+
sig_list = block_layout.signal_pos_by_lead[ld]
|
| 715 |
+
if sec - 1 >= len(sig_list):
|
| 716 |
+
raise RuntimeError("ECG segment index exceeds declared segments_per_lead")
|
| 717 |
+
sig_list[sec - 1] = cursor
|
| 718 |
+
block_layout.time_to_signal_idxs.setdefault(sec, []).append(cursor)
|
| 719 |
+
cursor += 1
|
| 720 |
+
continue
|
| 721 |
+
raise RuntimeError(f"Unknown turn part kind '{kind}'.")
|
| 722 |
+
|
| 723 |
+
remaining = len(text_ids) - text_cursor
|
| 724 |
+
if remaining > 0:
|
| 725 |
+
E_tail = _take_text(remaining)
|
| 726 |
+
chunks.append(E_tail)
|
| 727 |
+
_record_text(remaining, cursor)
|
| 728 |
+
cursor += remaining
|
| 729 |
+
|
| 730 |
+
if special_cursor != special_embeds.size(0):
|
| 731 |
+
raise RuntimeError("Did not consume all special-token embeddings for prompt")
|
| 732 |
+
if text_cursor != E_text_all.size(0):
|
| 733 |
+
raise RuntimeError("Text embedding cursor did not consume all embeddings")
|
| 734 |
+
|
| 735 |
+
inputs_embeds = torch.cat(chunks, dim=0)
|
| 736 |
+
layout.seq_len = inputs_embeds.size(0)
|
| 737 |
+
|
| 738 |
+
for blk_idx, blk_layout in enumerate(block_layouts):
|
| 739 |
+
for ld, expected in blk_layout.declared_segments_per_lead.items():
|
| 740 |
+
slots = blk_layout.signal_pos_by_lead[ld]
|
| 741 |
+
if any(pos is None for pos in slots):
|
| 742 |
+
raise RuntimeError(f"Lead {ld} missing ECG slots; expected {expected}.")
|
| 743 |
+
blk_layout.signal_pos_by_lead[ld] = [int(pos) for pos in slots]
|
| 744 |
+
if blk_layout.global_start_idx is None or blk_layout.global_end_idx is None:
|
| 745 |
+
raise RuntimeError("ECG block missing global start/end specials.")
|
| 746 |
+
if blk_layout.end_idx_exclusive is None:
|
| 747 |
+
blk_layout.end_idx_exclusive = int(blk_layout.global_end_idx) + 1
|
| 748 |
+
if blk_layout.start_idx is None:
|
| 749 |
+
blk_layout.start_idx = int(blk_layout.global_start_idx)
|
| 750 |
+
assert_layout_specials_complete(
|
| 751 |
+
block_layout=blk_layout,
|
| 752 |
+
lead_names=ecg_blocks[blk_idx]["lead_names"],
|
| 753 |
+
)
|
| 754 |
+
all_specials = []
|
| 755 |
+
if blk_layout.global_start_idx is not None:
|
| 756 |
+
all_specials.append(blk_layout.global_start_idx)
|
| 757 |
+
all_specials.extend(list(blk_layout.lead_start_idx.values()))
|
| 758 |
+
all_specials.extend(list(blk_layout.lead_end_idx.values()))
|
| 759 |
+
if blk_layout.global_end_idx is not None:
|
| 760 |
+
all_specials.append(blk_layout.global_end_idx)
|
| 761 |
+
blk_layout.special_idxs_sorted = sorted(all_specials)
|
| 762 |
+
blk_layout.signal_pos_list = sorted(
|
| 763 |
+
[p for lst in blk_layout.signal_pos_by_lead.values() for p in lst]
|
| 764 |
+
)
|
| 765 |
+
|
| 766 |
+
mask_result = self.mask_strategy.build(
|
| 767 |
+
layout,
|
| 768 |
+
device=device,
|
| 769 |
+
dtype=inputs_embeds.dtype,
|
| 770 |
+
)
|
| 771 |
+
use_return = self.packing_schema.conversation.format_id == "harmony_chat_v1"
|
| 772 |
+
_, stop_text = turn_wrappers(self.packing_schema, prompt_tokens.model_role, use_return=use_return)
|
| 773 |
+
stop_ids = tok.encode(stop_text, add_special_tokens=False)
|
| 774 |
+
return PromptContext(
|
| 775 |
+
inputs_embeds=inputs_embeds,
|
| 776 |
+
layout=layout,
|
| 777 |
+
prompt_preview=prompt_preview,
|
| 778 |
+
stop_ids=stop_ids,
|
| 779 |
+
input_embedder=input_embedder,
|
| 780 |
+
mask_strategy=self.mask_strategy,
|
| 781 |
+
mask_result=mask_result,
|
| 782 |
+
)
|
| 783 |
+
|
| 784 |
+
def _autoregressive_generate(
|
| 785 |
+
self,
|
| 786 |
+
*,
|
| 787 |
+
context: PromptContext,
|
| 788 |
+
max_new_tokens: int,
|
| 789 |
+
temperature: float,
|
| 790 |
+
top_k: Optional[int],
|
| 791 |
+
top_p: float,
|
| 792 |
+
min_p: float,
|
| 793 |
+
) -> List[int]:
|
| 794 |
+
tok = self.tokenizer
|
| 795 |
+
wrapper = self.wrapper
|
| 796 |
+
device = context.inputs_embeds.device
|
| 797 |
+
embeds = context.inputs_embeds.clone()
|
| 798 |
+
layout = context.layout
|
| 799 |
+
input_embedder = context.input_embedder
|
| 800 |
+
mask_result = context.mask_result
|
| 801 |
+
|
| 802 |
+
generated: List[int] = []
|
| 803 |
+
stop_ids = context.stop_ids
|
| 804 |
+
stop_len = len(stop_ids)
|
| 805 |
+
eos_id = tok.eos_token_id
|
| 806 |
+
|
| 807 |
+
for _ in range(max_new_tokens):
|
| 808 |
+
additive = mask_result.additive.unsqueeze(0).unsqueeze(0)
|
| 809 |
+
outputs = wrapper.forward_language_model(
|
| 810 |
+
inputs_embeds=embeds.unsqueeze(0),
|
| 811 |
+
attention_mask=additive,
|
| 812 |
+
labels=None,
|
| 813 |
+
)
|
| 814 |
+
logits = outputs.logits[0, -1, :].float()
|
| 815 |
+
next_token = _sample_next_token(
|
| 816 |
+
logits,
|
| 817 |
+
temperature=temperature,
|
| 818 |
+
top_k=top_k,
|
| 819 |
+
top_p=top_p,
|
| 820 |
+
min_p=min_p,
|
| 821 |
+
)
|
| 822 |
+
token_id = int(next_token.item())
|
| 823 |
+
generated.append(token_id)
|
| 824 |
+
|
| 825 |
+
if stop_len and generated[-stop_len:] == stop_ids:
|
| 826 |
+
generated = generated[:-stop_len]
|
| 827 |
+
break
|
| 828 |
+
if eos_id is not None and token_id == eos_id:
|
| 829 |
+
break
|
| 830 |
+
|
| 831 |
+
with torch.no_grad():
|
| 832 |
+
new_embed = wrapper.tokens_to_embeds(input_embedder, [token_id], device=device)
|
| 833 |
+
embeds = torch.cat([embeds, new_embed], dim=0)
|
| 834 |
+
|
| 835 |
+
layout.seq_len = embeds.size(0)
|
| 836 |
+
new_idx = layout.seq_len - 1
|
| 837 |
+
layout.text_idxs.append(new_idx)
|
| 838 |
+
mask_result = context.mask_strategy.update_for_generated_token(
|
| 839 |
+
layout,
|
| 840 |
+
device=device,
|
| 841 |
+
dtype=embeds.dtype,
|
| 842 |
+
previous=mask_result,
|
| 843 |
+
)
|
| 844 |
+
context.mask_result = mask_result
|
| 845 |
+
|
| 846 |
+
return generated
|
camel_inference/src/camel/model_init.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any, Dict, Optional, Tuple, Type
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from peft import (
|
| 7 |
+
LoraConfig,
|
| 8 |
+
get_peft_model,
|
| 9 |
+
TaskType,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
from camel.ecg_gemma_model import ECGGemmaPrefix as ECGModelPrefix
|
| 13 |
+
|
| 14 |
+
def attach_lora(
|
| 15 |
+
model: nn.Module,
|
| 16 |
+
lora_cfg_dict: Dict[str, Any],
|
| 17 |
+
device: torch.device,
|
| 18 |
+
) -> Tuple[nn.Module, LoraConfig]:
|
| 19 |
+
"""Attach LoRA adapters to the frozen model, leaving only LoRA trainable."""
|
| 20 |
+
cfg = LoraConfig(
|
| 21 |
+
r=int(lora_cfg_dict["r"]),
|
| 22 |
+
lora_alpha=int(lora_cfg_dict.get("lora_alpha", int(lora_cfg_dict["r"]) * 2)),
|
| 23 |
+
lora_dropout=float(lora_cfg_dict.get("lora_dropout", 0.0)),
|
| 24 |
+
target_modules=list(lora_cfg_dict.get("target_modules", [])),
|
| 25 |
+
task_type=TaskType(lora_cfg_dict.get("task_type", "CAUSAL_LM")),
|
| 26 |
+
bias=lora_cfg_dict.get("bias", "none"),
|
| 27 |
+
inference_mode=False,
|
| 28 |
+
use_dora=bool(lora_cfg_dict.get("use_dora", False)),
|
| 29 |
+
)
|
| 30 |
+
model = get_peft_model(model, cfg)
|
| 31 |
+
model.to(device)
|
| 32 |
+
return model, cfg
|
| 33 |
+
|
| 34 |
+
def build_conv_encoder(
|
| 35 |
+
*,
|
| 36 |
+
conv_ckpt_path: Optional[str],
|
| 37 |
+
device: torch.device,
|
| 38 |
+
unfreeze: bool = False,
|
| 39 |
+
) -> nn.Module:
|
| 40 |
+
"""
|
| 41 |
+
Build the 1D conv stack and load weights from the provided checkpoint, including
|
| 42 |
+
key normalization and optional unfreezing. Identical to train_ecg_text.py.
|
| 43 |
+
"""
|
| 44 |
+
enc = nn.Sequential(
|
| 45 |
+
nn.Conv1d(1, 32, kernel_size=4, stride=2, padding=1), # L:256->128
|
| 46 |
+
nn.ReLU(inplace=True),
|
| 47 |
+
nn.Conv1d(32, 64, kernel_size=4, stride=2, padding=1), # L:128->64
|
| 48 |
+
nn.ReLU(inplace=True),
|
| 49 |
+
nn.Conv1d(64, 128, kernel_size=4, stride=2, padding=1), # L:64->32
|
| 50 |
+
nn.ReLU(inplace=True),
|
| 51 |
+
nn.Conv1d(128, 4, kernel_size=4, stride=2, padding=1), # L:32->16, C:4
|
| 52 |
+
nn.ReLU(inplace=True),
|
| 53 |
+
).to(device=device, dtype=torch.float32)
|
| 54 |
+
|
| 55 |
+
if conv_ckpt_path:
|
| 56 |
+
ckpt = torch.load(conv_ckpt_path, map_location="cpu")
|
| 57 |
+
raw_sd = ckpt["model_state_dict"] if isinstance(ckpt, dict) and "model_state_dict" in ckpt else ckpt
|
| 58 |
+
norm_sd: Dict[str, torch.Tensor] = {}
|
| 59 |
+
for k, v in raw_sd.items():
|
| 60 |
+
kk = k
|
| 61 |
+
if kk.startswith("module."):
|
| 62 |
+
kk = kk[len("module."):]
|
| 63 |
+
if kk.startswith("_orig_mod."):
|
| 64 |
+
kk = kk[len("_orig_mod."):]
|
| 65 |
+
if kk.startswith("enc."):
|
| 66 |
+
kk = kk[len("enc."):]
|
| 67 |
+
norm_sd[kk] = v
|
| 68 |
+
wanted = {f"{i}.{w}" for i in (0, 2, 4, 6) for w in ("weight", "bias")}
|
| 69 |
+
conv_sd = {k: v for k, v in norm_sd.items() if k in wanted}
|
| 70 |
+
missing, unexpected = enc.load_state_dict(conv_sd, strict=True)
|
| 71 |
+
if missing or unexpected:
|
| 72 |
+
print(f"[conv load] missing={list(missing)} unexpected={list(unexpected)}")
|
| 73 |
+
enc.eval()
|
| 74 |
+
return enc
|
| 75 |
+
|
| 76 |
+
def build_wrapper(
|
| 77 |
+
*,
|
| 78 |
+
wrapper_cls: Type[nn.Module] = ECGModelPrefix,
|
| 79 |
+
language_model: nn.Module,
|
| 80 |
+
conv_encoder: nn.Module,
|
| 81 |
+
hidden_size: int,
|
| 82 |
+
num_ecg_special_tokens: int,
|
| 83 |
+
dtype: torch.dtype,
|
| 84 |
+
enc_out_dim: int = 64,
|
| 85 |
+
freeze_encoder: bool = True,
|
| 86 |
+
inference: bool = False,
|
| 87 |
+
projector_name: str = "linear",
|
| 88 |
+
) -> ECGModelPrefix:
|
| 89 |
+
"""Construct the ECG-language wrapper (keeps default wrapper class)."""
|
| 90 |
+
wrapper = wrapper_cls(
|
| 91 |
+
language_model,
|
| 92 |
+
enc=conv_encoder,
|
| 93 |
+
hidden_size=hidden_size,
|
| 94 |
+
num_ecg_special_tokens=num_ecg_special_tokens,
|
| 95 |
+
dtype=dtype,
|
| 96 |
+
enc_out_dim=enc_out_dim,
|
| 97 |
+
freeze_encoder=freeze_encoder,
|
| 98 |
+
inference=inference,
|
| 99 |
+
projector_name=projector_name,
|
| 100 |
+
)
|
| 101 |
+
return wrapper
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
__all__ = [
|
| 105 |
+
"attach_lora",
|
| 106 |
+
"build_conv_encoder",
|
| 107 |
+
"build_wrapper",
|
| 108 |
+
]
|
camel_inference/src/camel/model_introspect.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model introspection utilities driven by registry hints.
|
| 2 |
+
|
| 3 |
+
Centralizes resolution of model-internal structures to avoid hardcoded
|
| 4 |
+
attribute names in call sites. Use the hint paths defined in the model
|
| 5 |
+
registry to locate transformer layers and config attributes.
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from typing import List, Optional, Sequence
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
|
| 12 |
+
def _walk_attr_path(root: object, dotted_path: str) -> Optional[object]:
|
| 13 |
+
cur: object = root
|
| 14 |
+
for part in dotted_path.split("."):
|
| 15 |
+
if not hasattr(cur, part):
|
| 16 |
+
return None
|
| 17 |
+
cur = getattr(cur, part)
|
| 18 |
+
return cur
|
| 19 |
+
|
| 20 |
+
def resolve_layers(model: nn.Module, path_hints: Sequence[str]) -> List[nn.Module]:
|
| 21 |
+
"""
|
| 22 |
+
Resolve the text transformer layer sequence using the first successful
|
| 23 |
+
dotted path from `path_hints` relative to common roots.
|
| 24 |
+
|
| 25 |
+
We try against several candidate roots to be robust to wrappers (e.g., PEFT):
|
| 26 |
+
- the model itself
|
| 27 |
+
- model.base_model (if present)
|
| 28 |
+
- model.base_model.model (if present)
|
| 29 |
+
|
| 30 |
+
We also try a small set of generic fallback hints ("model.language_model.layers",
|
| 31 |
+
"language_model.layers", "model.layers", "layers") if the provided hints fail.
|
| 32 |
+
"""
|
| 33 |
+
roots: List[object] = [model]
|
| 34 |
+
base = getattr(model, "base_model", None)
|
| 35 |
+
if base is not None:
|
| 36 |
+
roots.append(base)
|
| 37 |
+
base_model_attr = getattr(base, "model", None)
|
| 38 |
+
if base_model_attr is not None:
|
| 39 |
+
roots.append(base_model_attr)
|
| 40 |
+
|
| 41 |
+
tried: List[str] = []
|
| 42 |
+
def _try_hints(root: object, hints: Sequence[str]) -> Optional[List[nn.Module]]:
|
| 43 |
+
for hint in hints:
|
| 44 |
+
tried.append(hint)
|
| 45 |
+
obj = _walk_attr_path(root, hint)
|
| 46 |
+
if obj is None:
|
| 47 |
+
continue
|
| 48 |
+
if isinstance(obj, (list, tuple)) and all(isinstance(x, nn.Module) for x in obj):
|
| 49 |
+
return list(obj)
|
| 50 |
+
if hasattr(obj, "__iter__") and not isinstance(obj, (str, bytes)):
|
| 51 |
+
try:
|
| 52 |
+
seq = list(obj)
|
| 53 |
+
if seq and all(isinstance(x, nn.Module) for x in seq):
|
| 54 |
+
return seq
|
| 55 |
+
except Exception:
|
| 56 |
+
pass
|
| 57 |
+
return None
|
| 58 |
+
|
| 59 |
+
# Try provided hints against all roots
|
| 60 |
+
for root in roots:
|
| 61 |
+
found = _try_hints(root, path_hints)
|
| 62 |
+
if found is not None:
|
| 63 |
+
return found
|
| 64 |
+
|
| 65 |
+
# Fallback generic hints against all roots
|
| 66 |
+
generic_hints = (
|
| 67 |
+
"model.language_model.layers",
|
| 68 |
+
"language_model.layers",
|
| 69 |
+
"model.layers",
|
| 70 |
+
"layers",
|
| 71 |
+
)
|
| 72 |
+
for root in roots:
|
| 73 |
+
found = _try_hints(root, generic_hints)
|
| 74 |
+
if found is not None:
|
| 75 |
+
return found
|
| 76 |
+
|
| 77 |
+
raise RuntimeError(
|
| 78 |
+
f"Could not resolve transformer layers via provided hints: {list(path_hints)}"
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
def resolve_hidden_size(model: nn.Module, attr_paths: Sequence[str]) -> int:
|
| 82 |
+
"""Resolve hidden size via the first successful dotted config attribute path."""
|
| 83 |
+
for path in attr_paths:
|
| 84 |
+
val = _walk_attr_path(model, path)
|
| 85 |
+
if isinstance(val, (int, float)):
|
| 86 |
+
return int(val)
|
| 87 |
+
raise AttributeError(
|
| 88 |
+
f"Could not resolve hidden size from any of: {list(attr_paths)}"
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
__all__ = [
|
| 93 |
+
"resolve_layers",
|
| 94 |
+
"resolve_hidden_size",
|
| 95 |
+
]
|
camel_inference/src/camel/model_registry.py
ADDED
|
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utilities for loading per-model configuration metadata used across training and inference.
|
| 3 |
+
|
| 4 |
+
The registry is defined in YAML (see model_registry.yaml in this directory) and exposes
|
| 5 |
+
immutable ModelConfig objects for downstream consumers. The intent is to centralize
|
| 6 |
+
model-specific defaults (prompt format, tokenizer quirks, wrapper class path, LoRA constraints, etc.)
|
| 7 |
+
so that adding support for a new backbone primarily involves updating the registry.
|
| 8 |
+
"""
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import copy
|
| 12 |
+
import importlib
|
| 13 |
+
import dataclasses
|
| 14 |
+
import os
|
| 15 |
+
from collections.abc import Mapping as ABCMapping, Sequence as ABCSequence
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from types import MappingProxyType
|
| 18 |
+
from typing import Any, Dict, Iterable, Mapping, Optional, Tuple
|
| 19 |
+
import yaml
|
| 20 |
+
|
| 21 |
+
class ModelRegistryError(RuntimeError):
|
| 22 |
+
"""Raised when the registry file is missing or malformed."""
|
| 23 |
+
|
| 24 |
+
@dataclasses.dataclass(frozen=True)
|
| 25 |
+
class PromptConfig:
|
| 26 |
+
start_of_turn: str
|
| 27 |
+
end_of_turn: str
|
| 28 |
+
roles: Mapping[str, str]
|
| 29 |
+
enforce_bos: bool
|
| 30 |
+
enforce_eos: bool
|
| 31 |
+
allow_multiple_eos: bool
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclasses.dataclass(frozen=True)
|
| 35 |
+
class TokenizerConfig:
|
| 36 |
+
pad_token_strategy: str
|
| 37 |
+
require_bos: bool
|
| 38 |
+
require_eos: bool
|
| 39 |
+
use_fast: bool = True
|
| 40 |
+
add_prefix_space: bool = False
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@dataclasses.dataclass(frozen=True)
|
| 44 |
+
class ArchitectureConfig:
|
| 45 |
+
wrapper_class: str
|
| 46 |
+
hidden_size_attrs: Tuple[str, ...]
|
| 47 |
+
language_model_path_hints: Tuple[str, ...]
|
| 48 |
+
attn_implementation: str
|
| 49 |
+
conv_out_dim: Optional[int] = None
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@dataclasses.dataclass(frozen=True)
|
| 53 |
+
class LoRAPolicyConfig:
|
| 54 |
+
expect_language_only: bool
|
| 55 |
+
allowed_markers: Tuple[str, ...]
|
| 56 |
+
blocked_markers: Tuple[str, ...]
|
| 57 |
+
freeze_vision: bool
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@dataclasses.dataclass(frozen=True)
|
| 61 |
+
class PackingConversationConfig:
|
| 62 |
+
format_id: str
|
| 63 |
+
user_role_aliases: Tuple[str, ...]
|
| 64 |
+
model_role_aliases: Tuple[str, ...]
|
| 65 |
+
strip_image_from_roles: Tuple[str, ...]
|
| 66 |
+
merge_system_with_first_user: bool
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@dataclasses.dataclass(frozen=True)
|
| 70 |
+
class PackingECGTokensConfig:
|
| 71 |
+
global_start: str
|
| 72 |
+
global_end: str
|
| 73 |
+
lead_start_template: str
|
| 74 |
+
lead_end_template: str
|
| 75 |
+
canonical_leads: Tuple[str, ...]
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@dataclasses.dataclass(frozen=True)
|
| 79 |
+
class PackingConfig:
|
| 80 |
+
prompt_format: str
|
| 81 |
+
conversation: PackingConversationConfig
|
| 82 |
+
ecg_tokens: PackingECGTokensConfig
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@dataclasses.dataclass(frozen=True)
|
| 86 |
+
class ModelConfig:
|
| 87 |
+
"""Typed wrapper over a single model entry in the registry."""
|
| 88 |
+
|
| 89 |
+
name: str
|
| 90 |
+
data: Mapping[str, Any]
|
| 91 |
+
|
| 92 |
+
@property
|
| 93 |
+
def hf_id(self) -> str:
|
| 94 |
+
return _require_str(self.data, "hf_id", self.name)
|
| 95 |
+
|
| 96 |
+
@property
|
| 97 |
+
def prompt(self) -> Mapping[str, Any]:
|
| 98 |
+
return _require_mapping(self.data, "prompt", self.name)
|
| 99 |
+
|
| 100 |
+
@property
|
| 101 |
+
def tokenizer(self) -> Mapping[str, Any]:
|
| 102 |
+
return _require_mapping(self.data, "tokenizer", self.name)
|
| 103 |
+
|
| 104 |
+
@property
|
| 105 |
+
def architecture(self) -> Mapping[str, Any]:
|
| 106 |
+
return _require_mapping(self.data, "architecture", self.name)
|
| 107 |
+
|
| 108 |
+
@property
|
| 109 |
+
def lora_policy(self) -> Mapping[str, Any]:
|
| 110 |
+
return _require_mapping(self.data, "lora_policy", self.name)
|
| 111 |
+
|
| 112 |
+
@property
|
| 113 |
+
def packing(self) -> Mapping[str, Any]:
|
| 114 |
+
return _require_mapping(self.data, "packing", self.name)
|
| 115 |
+
|
| 116 |
+
def prompt_config(self) -> PromptConfig:
|
| 117 |
+
prompt = self.prompt
|
| 118 |
+
roles = _require_mapping(prompt, "roles", self.name, section="prompt")
|
| 119 |
+
return PromptConfig(
|
| 120 |
+
start_of_turn=_require_str(prompt, "start_of_turn", self.name, section="prompt"),
|
| 121 |
+
end_of_turn=_require_str(prompt, "end_of_turn", self.name, section="prompt"),
|
| 122 |
+
roles={k: str(v) for k, v in roles.items()},
|
| 123 |
+
enforce_bos=_require_bool(prompt, "enforce_bos", self.name, section="prompt"),
|
| 124 |
+
enforce_eos=_require_bool(prompt, "enforce_eos", self.name, section="prompt"),
|
| 125 |
+
allow_multiple_eos=_optional_bool(prompt, "allow_multiple_eos", False, self.name, section="prompt"),
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
def required_prompt_text(self, key: str) -> str:
|
| 129 |
+
return _require_str(self.prompt, key, self.name, section="prompt")
|
| 130 |
+
|
| 131 |
+
def tokenizer_config(self) -> TokenizerConfig:
|
| 132 |
+
tokenizer = self.tokenizer
|
| 133 |
+
pad_strategy = _require_str(tokenizer, "pad_token_strategy", self.name, section="tokenizer")
|
| 134 |
+
require_bos = _require_bool(tokenizer, "require_bos", self.name, section="tokenizer")
|
| 135 |
+
require_eos = _require_bool(tokenizer, "require_eos", self.name, section="tokenizer")
|
| 136 |
+
# Optional fields with defaults for backward compatibility
|
| 137 |
+
use_fast = _optional_bool(tokenizer, "use_fast", True, self.name, section="tokenizer")
|
| 138 |
+
add_prefix_space = _optional_bool(tokenizer, "add_prefix_space", False, self.name, section="tokenizer")
|
| 139 |
+
return TokenizerConfig(
|
| 140 |
+
pad_token_strategy=pad_strategy,
|
| 141 |
+
require_bos=require_bos,
|
| 142 |
+
require_eos=require_eos,
|
| 143 |
+
use_fast=use_fast,
|
| 144 |
+
add_prefix_space=add_prefix_space,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
def architecture_config(self) -> ArchitectureConfig:
|
| 148 |
+
arch = self.architecture
|
| 149 |
+
conv_out = arch.get("conv_out_dim") if isinstance(arch, ABCMapping) else None
|
| 150 |
+
try:
|
| 151 |
+
conv_out_int = int(conv_out) if conv_out is not None else None
|
| 152 |
+
except Exception:
|
| 153 |
+
conv_out_int = None
|
| 154 |
+
return ArchitectureConfig(
|
| 155 |
+
wrapper_class=_require_str(arch, "wrapper_class", self.name, section="architecture"),
|
| 156 |
+
hidden_size_attrs=tuple(
|
| 157 |
+
_require_sequence_of_str(arch, "hidden_size_attrs", self.name, section="architecture")
|
| 158 |
+
),
|
| 159 |
+
language_model_path_hints=tuple(
|
| 160 |
+
_require_sequence_of_str(arch, "language_model_path_hints", self.name, section="architecture")
|
| 161 |
+
),
|
| 162 |
+
attn_implementation=_require_str(arch, "attn_implementation", self.name, section="architecture"),
|
| 163 |
+
conv_out_dim=conv_out_int,
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
def lora_policy_config(self) -> LoRAPolicyConfig:
|
| 167 |
+
lora = self.lora_policy
|
| 168 |
+
return LoRAPolicyConfig(
|
| 169 |
+
expect_language_only=_require_bool(lora, "expect_language_only", self.name, section="lora_policy"),
|
| 170 |
+
allowed_markers=tuple(
|
| 171 |
+
_require_sequence_of_str(lora, "allowed_markers", self.name, section="lora_policy")
|
| 172 |
+
),
|
| 173 |
+
blocked_markers=tuple(
|
| 174 |
+
_require_sequence_of_str(lora, "blocked_markers", self.name, section="lora_policy")
|
| 175 |
+
),
|
| 176 |
+
freeze_vision=_require_bool(lora, "freeze_vision", self.name, section="lora_policy"),
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
def packing_config(self) -> PackingConfig:
|
| 180 |
+
packing = self.packing
|
| 181 |
+
format_id = _require_str(packing, "prompt_format", self.name, section="packing")
|
| 182 |
+
|
| 183 |
+
conversation = _require_mapping(packing, "conversation", self.name, section="packing")
|
| 184 |
+
user_aliases = tuple(
|
| 185 |
+
_require_sequence_of_str(conversation, "user_role_aliases", self.name, section="packing.conversation")
|
| 186 |
+
)
|
| 187 |
+
model_aliases = tuple(
|
| 188 |
+
_require_sequence_of_str(conversation, "model_role_aliases", self.name, section="packing.conversation")
|
| 189 |
+
)
|
| 190 |
+
strip_roles = tuple(
|
| 191 |
+
_require_sequence_of_str(conversation, "strip_image_from_roles", self.name, section="packing.conversation")
|
| 192 |
+
)
|
| 193 |
+
merge_system = _require_bool(
|
| 194 |
+
conversation, "merge_system_with_first_user", self.name, section="packing.conversation"
|
| 195 |
+
)
|
| 196 |
+
conv_cfg = PackingConversationConfig(
|
| 197 |
+
format_id=format_id,
|
| 198 |
+
user_role_aliases=user_aliases,
|
| 199 |
+
model_role_aliases=model_aliases,
|
| 200 |
+
strip_image_from_roles=strip_roles,
|
| 201 |
+
merge_system_with_first_user=merge_system,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
ecg_tokens = _require_mapping(packing, "ecg_tokens", self.name, section="packing")
|
| 205 |
+
global_start = _require_str(ecg_tokens, "global_start", self.name, section="packing.ecg_tokens")
|
| 206 |
+
global_end = _require_str(ecg_tokens, "global_end", self.name, section="packing.ecg_tokens")
|
| 207 |
+
lead_start_template = _require_str(ecg_tokens, "lead_start_template", self.name, section="packing.ecg_tokens")
|
| 208 |
+
lead_end_template = _require_str(ecg_tokens, "lead_end_template", self.name, section="packing.ecg_tokens")
|
| 209 |
+
canonical_leads = tuple(
|
| 210 |
+
_require_sequence_of_str(ecg_tokens, "canonical_leads", self.name, section="packing.ecg_tokens")
|
| 211 |
+
)
|
| 212 |
+
ecg_cfg = PackingECGTokensConfig(
|
| 213 |
+
global_start=global_start,
|
| 214 |
+
global_end=global_end,
|
| 215 |
+
lead_start_template=lead_start_template,
|
| 216 |
+
lead_end_template=lead_end_template,
|
| 217 |
+
canonical_leads=canonical_leads,
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
return PackingConfig(
|
| 221 |
+
prompt_format=format_id,
|
| 222 |
+
conversation=conv_cfg,
|
| 223 |
+
ecg_tokens=ecg_cfg,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
def resolve_wrapper_class(self):
|
| 227 |
+
arch = self.architecture_config()
|
| 228 |
+
if "." not in arch.wrapper_class:
|
| 229 |
+
raise ModelRegistryError(
|
| 230 |
+
f"Wrapper class path '{arch.wrapper_class}' for model '{self.name}' must be in 'module.ClassName' form."
|
| 231 |
+
)
|
| 232 |
+
module_name, class_name = arch.wrapper_class.rsplit(".", 1)
|
| 233 |
+
try:
|
| 234 |
+
module = importlib.import_module("." + module_name, package=__package__)
|
| 235 |
+
except ImportError as exc:
|
| 236 |
+
raise ModelRegistryError(
|
| 237 |
+
f"Failed to import wrapper module '{module_name}' for model '{self.name}': {exc}"
|
| 238 |
+
) from exc
|
| 239 |
+
try:
|
| 240 |
+
wrapper_cls = getattr(module, class_name)
|
| 241 |
+
except AttributeError as exc:
|
| 242 |
+
raise ModelRegistryError(
|
| 243 |
+
f"Wrapper class '{class_name}' not found in module '{module_name}' for model '{self.name}'."
|
| 244 |
+
) from exc
|
| 245 |
+
return wrapper_cls
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def _default_registry_path() -> Path:
|
| 249 |
+
return Path(__file__).resolve().with_name("model_registry.yaml")
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def load_registry(
|
| 253 |
+
*,
|
| 254 |
+
registry_path: Optional[os.PathLike[str] | str] = None,
|
| 255 |
+
model_overrides: Optional[Mapping[str, Mapping[str, Any]]] = None,
|
| 256 |
+
) -> "ModelRegistry":
|
| 257 |
+
"""
|
| 258 |
+
Load the model registry from YAML.
|
| 259 |
+
|
| 260 |
+
Args:
|
| 261 |
+
registry_path: Optional path to the YAML file. Defaults to `model_registry.yaml` alongside this module.
|
| 262 |
+
model_overrides: Optional mapping of model name -> override dict that will be deep-merged
|
| 263 |
+
onto the YAML entry (useful for ad-hoc experimentation).
|
| 264 |
+
"""
|
| 265 |
+
path = Path(registry_path) if registry_path is not None else _default_registry_path()
|
| 266 |
+
if not path.exists():
|
| 267 |
+
raise ModelRegistryError(f"Model registry file not found at {path}")
|
| 268 |
+
try:
|
| 269 |
+
with path.open("r", encoding="utf-8") as fh:
|
| 270 |
+
raw = yaml.safe_load(fh)
|
| 271 |
+
except yaml.YAMLError as exc:
|
| 272 |
+
raise ModelRegistryError(f"Failed to parse model registry YAML: {exc}") from exc
|
| 273 |
+
|
| 274 |
+
if not isinstance(raw, ABCMapping):
|
| 275 |
+
raise ModelRegistryError("Model registry root must be a mapping.")
|
| 276 |
+
|
| 277 |
+
models = raw.get("models")
|
| 278 |
+
if not isinstance(models, ABCMapping) or len(models) == 0:
|
| 279 |
+
raise ModelRegistryError("Model registry must define a non-empty 'models' mapping.")
|
| 280 |
+
|
| 281 |
+
entries: Dict[str, Mapping[str, Any]] = {}
|
| 282 |
+
for name, cfg in models.items():
|
| 283 |
+
if not isinstance(name, str):
|
| 284 |
+
raise ModelRegistryError("Model names must be strings.")
|
| 285 |
+
if not isinstance(cfg, ABCMapping):
|
| 286 |
+
raise ModelRegistryError(f"Model '{name}' entry must be a mapping.")
|
| 287 |
+
merged = copy.deepcopy(cfg)
|
| 288 |
+
if model_overrides and name in model_overrides:
|
| 289 |
+
_deep_update(merged, model_overrides[name])
|
| 290 |
+
_validate_model_entry(name, merged)
|
| 291 |
+
entries[name] = merged
|
| 292 |
+
|
| 293 |
+
return ModelRegistry(entries, source_path=path)
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
class ModelRegistry:
|
| 297 |
+
"""In-memory view over the registry."""
|
| 298 |
+
|
| 299 |
+
def __init__(self, models: Mapping[str, Mapping[str, Any]], *, source_path: Path):
|
| 300 |
+
self._models = dict(models)
|
| 301 |
+
self._source_path = Path(source_path)
|
| 302 |
+
|
| 303 |
+
@property
|
| 304 |
+
def source_path(self) -> Path:
|
| 305 |
+
return self._source_path
|
| 306 |
+
|
| 307 |
+
def names(self) -> Iterable[str]:
|
| 308 |
+
return tuple(self._models.keys())
|
| 309 |
+
|
| 310 |
+
def get(self, name: str) -> ModelConfig:
|
| 311 |
+
if name not in self._models:
|
| 312 |
+
raise ModelRegistryError(f"Unknown model '{name}'. Known models: {sorted(self._models)}")
|
| 313 |
+
return ModelConfig(name=name, data=_deep_freeze(self._models[name]))
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def _validate_model_entry(name: str, entry: Mapping[str, Any]) -> None:
|
| 317 |
+
# Required string
|
| 318 |
+
_require_str(entry, "hf_id", name)
|
| 319 |
+
|
| 320 |
+
# Prompt section
|
| 321 |
+
prompt = _require_mapping(entry, "prompt", name)
|
| 322 |
+
_require_str(prompt, "start_of_turn", name, section="prompt")
|
| 323 |
+
_require_str(prompt, "end_of_turn", name, section="prompt")
|
| 324 |
+
roles = _require_mapping(prompt, "roles", name, section="prompt")
|
| 325 |
+
for role_key in ("user", "model"):
|
| 326 |
+
_require_str(roles, role_key, name, section="prompt.roles")
|
| 327 |
+
for flag in ("enforce_bos", "enforce_eos"):
|
| 328 |
+
_require_bool(prompt, flag, name, section="prompt")
|
| 329 |
+
|
| 330 |
+
# Tokenizer section
|
| 331 |
+
tokenizer = _require_mapping(entry, "tokenizer", name)
|
| 332 |
+
_require_str(tokenizer, "pad_token_strategy", name, section="tokenizer")
|
| 333 |
+
for flag in ("require_bos", "require_eos"):
|
| 334 |
+
_require_bool(tokenizer, flag, name, section="tokenizer")
|
| 335 |
+
|
| 336 |
+
# Architecture
|
| 337 |
+
architecture = _require_mapping(entry, "architecture", name)
|
| 338 |
+
_require_str(architecture, "wrapper_class", name, section="architecture")
|
| 339 |
+
hidden_attrs = _require_sequence_of_str(architecture, "hidden_size_attrs", name, section="architecture")
|
| 340 |
+
if len(hidden_attrs) == 0:
|
| 341 |
+
raise ModelRegistryError(
|
| 342 |
+
f"Model '{name}' architecture.hidden_size_attrs must contain at least one attribute path."
|
| 343 |
+
)
|
| 344 |
+
_require_sequence_of_str(architecture, "language_model_path_hints", name, section="architecture")
|
| 345 |
+
_require_str(architecture, "attn_implementation", name, section="architecture")
|
| 346 |
+
# conv_out_dim is optional; if present, ensure it is an int-like
|
| 347 |
+
if "conv_out_dim" in architecture:
|
| 348 |
+
val = architecture.get("conv_out_dim")
|
| 349 |
+
try:
|
| 350 |
+
int(val) # type: ignore[arg-type]
|
| 351 |
+
except Exception:
|
| 352 |
+
raise ModelRegistryError(
|
| 353 |
+
f"Model '{name}' field 'architecture.conv_out_dim' must be an integer when provided."
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
# LoRA policy
|
| 357 |
+
lora_policy = _require_mapping(entry, "lora_policy", name)
|
| 358 |
+
_require_bool(lora_policy, "expect_language_only", name, section="lora_policy")
|
| 359 |
+
allowed = _require_sequence_of_str(lora_policy, "allowed_markers", name, section="lora_policy")
|
| 360 |
+
blocked = _require_sequence_of_str(lora_policy, "blocked_markers", name, section="lora_policy")
|
| 361 |
+
overlap = set(allowed).intersection(blocked)
|
| 362 |
+
if overlap:
|
| 363 |
+
raise ModelRegistryError(
|
| 364 |
+
f"Model '{name}' lora_policy.allowed_markers and lora_policy.blocked_markers overlap: {sorted(overlap)}"
|
| 365 |
+
)
|
| 366 |
+
_require_bool(lora_policy, "freeze_vision", name, section="lora_policy")
|
| 367 |
+
|
| 368 |
+
# Packing
|
| 369 |
+
packing = _require_mapping(entry, "packing", name)
|
| 370 |
+
_require_str(packing, "prompt_format", name, section="packing")
|
| 371 |
+
conversation = _require_mapping(packing, "conversation", name, section="packing")
|
| 372 |
+
_require_sequence_of_str(conversation, "user_role_aliases", name, section="packing.conversation")
|
| 373 |
+
_require_sequence_of_str(conversation, "model_role_aliases", name, section="packing.conversation")
|
| 374 |
+
_require_sequence_of_str(conversation, "strip_image_from_roles", name, section="packing.conversation")
|
| 375 |
+
_require_bool(conversation, "merge_system_with_first_user", name, section="packing.conversation")
|
| 376 |
+
ecg_tokens = _require_mapping(packing, "ecg_tokens", name, section="packing")
|
| 377 |
+
_require_str(ecg_tokens, "global_start", name, section="packing.ecg_tokens")
|
| 378 |
+
_require_str(ecg_tokens, "global_end", name, section="packing.ecg_tokens")
|
| 379 |
+
_require_str(ecg_tokens, "lead_start_template", name, section="packing.ecg_tokens")
|
| 380 |
+
_require_str(ecg_tokens, "lead_end_template", name, section="packing.ecg_tokens")
|
| 381 |
+
if len(_require_sequence_of_str(ecg_tokens, "canonical_leads", name, section="packing.ecg_tokens")) == 0:
|
| 382 |
+
raise ModelRegistryError(
|
| 383 |
+
f"Model '{name}' packing.ecg_tokens.canonical_leads must contain at least one lead."
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def _require_mapping(
|
| 388 |
+
parent: Mapping[str, Any],
|
| 389 |
+
key: str,
|
| 390 |
+
model_name: str,
|
| 391 |
+
*,
|
| 392 |
+
section: Optional[str] = None,
|
| 393 |
+
) -> Mapping[str, Any]:
|
| 394 |
+
value = parent.get(key)
|
| 395 |
+
if not isinstance(value, ABCMapping):
|
| 396 |
+
loc = f"{section}.{key}" if section else key
|
| 397 |
+
raise ModelRegistryError(f"Model '{model_name}' is missing mapping '{loc}'.")
|
| 398 |
+
return value
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def _require_str(
|
| 402 |
+
parent: Mapping[str, Any],
|
| 403 |
+
key: str,
|
| 404 |
+
model_name: str,
|
| 405 |
+
*,
|
| 406 |
+
section: Optional[str] = None,
|
| 407 |
+
) -> str:
|
| 408 |
+
value = parent.get(key)
|
| 409 |
+
if not isinstance(value, str):
|
| 410 |
+
loc = f"{section}.{key}" if section else key
|
| 411 |
+
raise ModelRegistryError(f"Model '{model_name}' field '{loc}' must be a string.")
|
| 412 |
+
return value
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def _require_bool(
|
| 416 |
+
parent: Mapping[str, Any],
|
| 417 |
+
key: str,
|
| 418 |
+
model_name: str,
|
| 419 |
+
*,
|
| 420 |
+
section: Optional[str] = None,
|
| 421 |
+
) -> bool:
|
| 422 |
+
value = parent.get(key)
|
| 423 |
+
if not isinstance(value, bool):
|
| 424 |
+
loc = f"{section}.{key}" if section else key
|
| 425 |
+
raise ModelRegistryError(f"Model '{model_name}' field '{loc}' must be a boolean.")
|
| 426 |
+
return value
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
def _require_sequence_of_str(
|
| 430 |
+
parent: Mapping[str, Any],
|
| 431 |
+
key: str,
|
| 432 |
+
model_name: str,
|
| 433 |
+
*,
|
| 434 |
+
section: Optional[str] = None,
|
| 435 |
+
) -> Tuple[str, ...]:
|
| 436 |
+
value = parent.get(key)
|
| 437 |
+
if not isinstance(value, ABCSequence) or isinstance(value, (str, bytes)):
|
| 438 |
+
loc = f"{section}.{key}" if section else key
|
| 439 |
+
raise ModelRegistryError(f"Model '{model_name}' field '{loc}' must be a sequence of strings.")
|
| 440 |
+
items = []
|
| 441 |
+
for item in value:
|
| 442 |
+
if not isinstance(item, str):
|
| 443 |
+
loc = f"{section}.{key}" if section else key
|
| 444 |
+
raise ModelRegistryError(f"Model '{model_name}' field '{loc}' must contain only strings.")
|
| 445 |
+
items.append(item)
|
| 446 |
+
return tuple(items)
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
def _deep_update(target: Dict[str, Any], updates: Mapping[str, Any]) -> None:
|
| 450 |
+
"""
|
| 451 |
+
Recursively merge `updates` into `target` in-place.
|
| 452 |
+
"""
|
| 453 |
+
for key, value in updates.items():
|
| 454 |
+
if isinstance(value, ABCMapping) and isinstance(target.get(key), dict):
|
| 455 |
+
_deep_update(target[key], value) # type: ignore[arg-type]
|
| 456 |
+
else:
|
| 457 |
+
target[key] = copy.deepcopy(value)
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
def _optional_bool(
|
| 461 |
+
parent: Mapping[str, Any],
|
| 462 |
+
key: str,
|
| 463 |
+
default: bool,
|
| 464 |
+
model_name: str,
|
| 465 |
+
*,
|
| 466 |
+
section: Optional[str] = None,
|
| 467 |
+
) -> bool:
|
| 468 |
+
if key not in parent:
|
| 469 |
+
return default
|
| 470 |
+
value = parent.get(key)
|
| 471 |
+
if isinstance(value, bool):
|
| 472 |
+
return value
|
| 473 |
+
if isinstance(value, int) and value in (0, 1):
|
| 474 |
+
return bool(value)
|
| 475 |
+
if isinstance(value, str):
|
| 476 |
+
normalized = value.strip().lower()
|
| 477 |
+
if normalized in {"true", "yes", "y", "1"}:
|
| 478 |
+
return True
|
| 479 |
+
if normalized in {"false", "no", "n", "0"}:
|
| 480 |
+
return False
|
| 481 |
+
loc = f"{section}.{key}" if section else key
|
| 482 |
+
raise ModelRegistryError(f"Model '{model_name}' field '{loc}' must be a boolean when provided.")
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
def _deep_freeze(obj: Any) -> Any:
|
| 486 |
+
"""
|
| 487 |
+
Recursively convert mutable containers to immutable/read-only equivalents.
|
| 488 |
+
"""
|
| 489 |
+
if isinstance(obj, dict):
|
| 490 |
+
return MappingProxyType({k: _deep_freeze(v) for k, v in obj.items()})
|
| 491 |
+
if isinstance(obj, list):
|
| 492 |
+
return tuple(_deep_freeze(v) for v in obj)
|
| 493 |
+
if isinstance(obj, tuple):
|
| 494 |
+
return tuple(_deep_freeze(v) for v in obj)
|
| 495 |
+
if isinstance(obj, set):
|
| 496 |
+
return frozenset(_deep_freeze(v) for v in obj)
|
| 497 |
+
return obj
|
camel_inference/src/camel/model_registry.yaml
ADDED
|
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Model Registry
|
| 2 |
+
#
|
| 3 |
+
# This file defines per-backbone configuration consumed by training and inference.
|
| 4 |
+
# Fields overview:
|
| 5 |
+
# - hf_id: Hugging Face model identifier for the frozen language model.
|
| 6 |
+
# - prompt: Chat formatting tokens and roles used to build prompts
|
| 7 |
+
# - start_of_turn / end_of_turn: literal strings delimiting speaker turns
|
| 8 |
+
# - roles.user / roles.model: canonical role names for user/model
|
| 9 |
+
# - enforce_bos / enforce_eos: whether BOS/EOS must appear at the start/end
|
| 10 |
+
# - tokenizer:
|
| 11 |
+
# - pad_token_strategy: how to set pad token ("eos" uses eos as pad)
|
| 12 |
+
# - require_bos / require_eos: tokenizer must expose these tokens
|
| 13 |
+
# - architecture:
|
| 14 |
+
# - wrapper_class: import path to the ECG wrapper class (module.Class)
|
| 15 |
+
# - hidden_size_attrs: ordered config attribute paths to read hidden size
|
| 16 |
+
# - language_model_path_hints: ordered attribute paths to locate transformer layers
|
| 17 |
+
# - attn_implementation: preferred attention backend to use when loading
|
| 18 |
+
# - lora_policy:
|
| 19 |
+
# - expect_language_only: LoRA must live under language/text model stacks
|
| 20 |
+
# - allowed_markers / blocked_markers: substrings used to validate LoRA placement
|
| 21 |
+
# - freeze_vision: freeze LoRA params under blocked stacks during training
|
| 22 |
+
# - packing:
|
| 23 |
+
# - prompt_format: prompt/conversation template id
|
| 24 |
+
# - conversation: conversation role normalization and preprocessing
|
| 25 |
+
# - ecg_tokens: special tokens inserted to mark ECG structure
|
| 26 |
+
models:
|
| 27 |
+
medgemma-27b-it:
|
| 28 |
+
hf_id: "google/medgemma-27b-text-it"
|
| 29 |
+
prompt:
|
| 30 |
+
start_of_turn: "<start_of_turn>"
|
| 31 |
+
end_of_turn: "<end_of_turn>\n"
|
| 32 |
+
roles:
|
| 33 |
+
user: "user"
|
| 34 |
+
model: "model"
|
| 35 |
+
enforce_bos: true
|
| 36 |
+
enforce_eos: true
|
| 37 |
+
tokenizer:
|
| 38 |
+
pad_token_strategy: "eos"
|
| 39 |
+
require_bos: true
|
| 40 |
+
require_eos: true
|
| 41 |
+
use_fast: true
|
| 42 |
+
add_prefix_space: false
|
| 43 |
+
architecture:
|
| 44 |
+
wrapper_class: "ecg_gemma_model.ECGGemmaPrefix"
|
| 45 |
+
hidden_size_attrs:
|
| 46 |
+
- "config.hidden_size"
|
| 47 |
+
- "config.text_config.hidden_size"
|
| 48 |
+
language_model_path_hints:
|
| 49 |
+
- "base_model.model.language_model.layers"
|
| 50 |
+
- "model.language_model.layers"
|
| 51 |
+
- "model.layers"
|
| 52 |
+
attn_implementation: "eager"
|
| 53 |
+
conv_out_dim: 64
|
| 54 |
+
lora_policy:
|
| 55 |
+
expect_language_only: true
|
| 56 |
+
allowed_markers:
|
| 57 |
+
- "language_model"
|
| 58 |
+
- ".model.layers."
|
| 59 |
+
- ".layers."
|
| 60 |
+
blocked_markers:
|
| 61 |
+
- "vision"
|
| 62 |
+
- "multi_modal"
|
| 63 |
+
- "projector"
|
| 64 |
+
- ".enc."
|
| 65 |
+
- "encoder_proj"
|
| 66 |
+
freeze_vision: true
|
| 67 |
+
packing:
|
| 68 |
+
prompt_format: "gemma_chat_v1"
|
| 69 |
+
conversation:
|
| 70 |
+
user_role_aliases: ["human", "user"]
|
| 71 |
+
model_role_aliases: ["gpt", "assistant"]
|
| 72 |
+
strip_image_from_roles: ["human"]
|
| 73 |
+
merge_system_with_first_user: true
|
| 74 |
+
ecg_tokens:
|
| 75 |
+
global_start: "<ecg_global_start>"
|
| 76 |
+
global_end: "<ecg_global_end>"
|
| 77 |
+
lead_start_template: "<ecg_lead_{lead_lower}_start>"
|
| 78 |
+
lead_end_template: "<ecg_lead_{lead_lower}_end>"
|
| 79 |
+
canonical_leads: ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"]
|
| 80 |
+
gemma-12b-it:
|
| 81 |
+
hf_id: "google/gemma-3-12b-it"
|
| 82 |
+
prompt:
|
| 83 |
+
start_of_turn: "<start_of_turn>"
|
| 84 |
+
end_of_turn: "<end_of_turn>\n"
|
| 85 |
+
roles:
|
| 86 |
+
user: "user"
|
| 87 |
+
model: "model"
|
| 88 |
+
enforce_bos: true
|
| 89 |
+
enforce_eos: true
|
| 90 |
+
tokenizer:
|
| 91 |
+
pad_token_strategy: "eos"
|
| 92 |
+
require_bos: true
|
| 93 |
+
require_eos: true
|
| 94 |
+
use_fast: true
|
| 95 |
+
add_prefix_space: false
|
| 96 |
+
architecture:
|
| 97 |
+
wrapper_class: "ecg_gemma_model.ECGGemmaPrefix"
|
| 98 |
+
hidden_size_attrs:
|
| 99 |
+
- "config.hidden_size"
|
| 100 |
+
- "config.text_config.hidden_size"
|
| 101 |
+
language_model_path_hints:
|
| 102 |
+
- "base_model.model.language_model.layers"
|
| 103 |
+
- "model.language_model.layers"
|
| 104 |
+
- "model.layers"
|
| 105 |
+
attn_implementation: "eager"
|
| 106 |
+
conv_out_dim: 64
|
| 107 |
+
lora_policy:
|
| 108 |
+
expect_language_only: true
|
| 109 |
+
allowed_markers:
|
| 110 |
+
- "language_model"
|
| 111 |
+
- ".model.layers."
|
| 112 |
+
- ".layers."
|
| 113 |
+
blocked_markers:
|
| 114 |
+
- "vision"
|
| 115 |
+
- "multi_modal"
|
| 116 |
+
- "projector"
|
| 117 |
+
- ".enc."
|
| 118 |
+
- "encoder_proj"
|
| 119 |
+
freeze_vision: true
|
| 120 |
+
packing:
|
| 121 |
+
prompt_format: "gemma_chat_v1"
|
| 122 |
+
conversation:
|
| 123 |
+
user_role_aliases: ["human", "user"]
|
| 124 |
+
model_role_aliases: ["gpt", "assistant"]
|
| 125 |
+
strip_image_from_roles: ["human"]
|
| 126 |
+
merge_system_with_first_user: true
|
| 127 |
+
ecg_tokens:
|
| 128 |
+
global_start: "<ecg_global_start>"
|
| 129 |
+
global_end: "<ecg_global_end>"
|
| 130 |
+
lead_start_template: "<ecg_lead_{lead_lower}_start>"
|
| 131 |
+
lead_end_template: "<ecg_lead_{lead_lower}_end>"
|
| 132 |
+
canonical_leads: ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"]
|
| 133 |
+
medgemma-4b-it:
|
| 134 |
+
hf_id: "google/medgemma-4b-it"
|
| 135 |
+
prompt:
|
| 136 |
+
start_of_turn: "<start_of_turn>"
|
| 137 |
+
end_of_turn: "<end_of_turn>\n"
|
| 138 |
+
roles:
|
| 139 |
+
user: "user"
|
| 140 |
+
model: "model"
|
| 141 |
+
enforce_bos: true
|
| 142 |
+
enforce_eos: true
|
| 143 |
+
tokenizer:
|
| 144 |
+
pad_token_strategy: "eos"
|
| 145 |
+
require_bos: true
|
| 146 |
+
require_eos: true
|
| 147 |
+
use_fast: true
|
| 148 |
+
add_prefix_space: false
|
| 149 |
+
architecture:
|
| 150 |
+
wrapper_class: "ecg_gemma_model.ECGGemmaPrefix"
|
| 151 |
+
hidden_size_attrs:
|
| 152 |
+
- "config.hidden_size"
|
| 153 |
+
- "config.text_config.hidden_size"
|
| 154 |
+
language_model_path_hints:
|
| 155 |
+
- "base_model.model.language_model.layers"
|
| 156 |
+
- "model.language_model.layers"
|
| 157 |
+
- "model.layers"
|
| 158 |
+
attn_implementation: "eager"
|
| 159 |
+
conv_out_dim: 64
|
| 160 |
+
lora_policy:
|
| 161 |
+
expect_language_only: true
|
| 162 |
+
allowed_markers:
|
| 163 |
+
- "language_model"
|
| 164 |
+
- ".model.layers."
|
| 165 |
+
- ".layers."
|
| 166 |
+
blocked_markers:
|
| 167 |
+
- "vision"
|
| 168 |
+
- "multi_modal"
|
| 169 |
+
- "projector"
|
| 170 |
+
- ".enc."
|
| 171 |
+
- "encoder_proj"
|
| 172 |
+
freeze_vision: true
|
| 173 |
+
packing:
|
| 174 |
+
prompt_format: "gemma_chat_v1"
|
| 175 |
+
conversation:
|
| 176 |
+
user_role_aliases: ["human", "user"]
|
| 177 |
+
model_role_aliases: ["gpt", "assistant"]
|
| 178 |
+
strip_image_from_roles: ["human"]
|
| 179 |
+
merge_system_with_first_user: true
|
| 180 |
+
ecg_tokens:
|
| 181 |
+
global_start: "<ecg_global_start>"
|
| 182 |
+
global_end: "<ecg_global_end>"
|
| 183 |
+
lead_start_template: "<ecg_lead_{lead_lower}_start>"
|
| 184 |
+
lead_end_template: "<ecg_lead_{lead_lower}_end>"
|
| 185 |
+
canonical_leads: ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"]
|
| 186 |
+
qwen3-4b-instruct:
|
| 187 |
+
hf_id: "Qwen/Qwen3-4B-Instruct-2507"
|
| 188 |
+
prompt:
|
| 189 |
+
start_of_turn: "<|im_start|>"
|
| 190 |
+
end_of_turn: "<|im_end|>"
|
| 191 |
+
roles:
|
| 192 |
+
user: "user"
|
| 193 |
+
model: "assistant"
|
| 194 |
+
enforce_bos: false
|
| 195 |
+
enforce_eos: false
|
| 196 |
+
allow_multiple_eos: true
|
| 197 |
+
tokenizer:
|
| 198 |
+
pad_token_strategy: "existing"
|
| 199 |
+
require_bos: false
|
| 200 |
+
require_eos: false
|
| 201 |
+
use_fast: true
|
| 202 |
+
add_prefix_space: false
|
| 203 |
+
architecture:
|
| 204 |
+
wrapper_class: "ecg_qwen_model.ECGQwenPrefix"
|
| 205 |
+
hidden_size_attrs:
|
| 206 |
+
- "config.hidden_size"
|
| 207 |
+
language_model_path_hints:
|
| 208 |
+
- "model.layers"
|
| 209 |
+
- "model.model.layers"
|
| 210 |
+
attn_implementation: "eager"
|
| 211 |
+
conv_out_dim: 64
|
| 212 |
+
lora_policy:
|
| 213 |
+
expect_language_only: true
|
| 214 |
+
allowed_markers:
|
| 215 |
+
- "model.layers."
|
| 216 |
+
- ".layers."
|
| 217 |
+
blocked_markers:
|
| 218 |
+
- "vision"
|
| 219 |
+
- "multi_modal"
|
| 220 |
+
- "projector"
|
| 221 |
+
- ".enc."
|
| 222 |
+
- "encoder_proj"
|
| 223 |
+
freeze_vision: false
|
| 224 |
+
packing:
|
| 225 |
+
prompt_format: "qwen_chat_v1"
|
| 226 |
+
conversation:
|
| 227 |
+
user_role_aliases: ["human", "user"]
|
| 228 |
+
model_role_aliases: ["gpt", "assistant"]
|
| 229 |
+
strip_image_from_roles: ["human"]
|
| 230 |
+
merge_system_with_first_user: true
|
| 231 |
+
ecg_tokens:
|
| 232 |
+
global_start: "<ecg_global_start>"
|
| 233 |
+
global_end: "<ecg_global_end>"
|
| 234 |
+
lead_start_template: "<ecg_lead_{lead_lower}_start>"
|
| 235 |
+
lead_end_template: "<ecg_lead_{lead_lower}_end>"
|
| 236 |
+
canonical_leads: ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"]
|
| 237 |
+
"qwen3-4b-instruct":
|
| 238 |
+
hf_id: "Qwen/Qwen3-4B"
|
| 239 |
+
prompt:
|
| 240 |
+
start_of_turn: "<|im_start|>"
|
| 241 |
+
end_of_turn: "<|im_end|>"
|
| 242 |
+
roles:
|
| 243 |
+
user: "user"
|
| 244 |
+
model: "assistant"
|
| 245 |
+
enforce_bos: false
|
| 246 |
+
enforce_eos: false
|
| 247 |
+
allow_multiple_eos: true
|
| 248 |
+
tokenizer:
|
| 249 |
+
pad_token_strategy: "existing"
|
| 250 |
+
require_bos: false
|
| 251 |
+
require_eos: false
|
| 252 |
+
use_fast: true
|
| 253 |
+
add_prefix_space: false
|
| 254 |
+
architecture:
|
| 255 |
+
wrapper_class: "ecg_qwen_model.ECGQwenPrefix"
|
| 256 |
+
hidden_size_attrs:
|
| 257 |
+
- "config.hidden_size"
|
| 258 |
+
language_model_path_hints:
|
| 259 |
+
- "model.layers"
|
| 260 |
+
- "model.model.layers"
|
| 261 |
+
attn_implementation: "eager"
|
| 262 |
+
conv_out_dim: 64
|
| 263 |
+
lora_policy:
|
| 264 |
+
expect_language_only: true
|
| 265 |
+
allowed_markers:
|
| 266 |
+
- "model.layers."
|
| 267 |
+
- ".layers."
|
| 268 |
+
blocked_markers:
|
| 269 |
+
- "vision"
|
| 270 |
+
- "multi_modal"
|
| 271 |
+
- "projector"
|
| 272 |
+
- ".enc."
|
| 273 |
+
- "encoder_proj"
|
| 274 |
+
freeze_vision: false
|
| 275 |
+
packing:
|
| 276 |
+
prompt_format: "qwen_chat_v1"
|
| 277 |
+
conversation:
|
| 278 |
+
user_role_aliases: ["human", "user"]
|
| 279 |
+
model_role_aliases: ["gpt", "assistant"]
|
| 280 |
+
strip_image_from_roles: ["human"]
|
| 281 |
+
merge_system_with_first_user: true
|
| 282 |
+
ecg_tokens:
|
| 283 |
+
global_start: "<ecg_global_start>"
|
| 284 |
+
global_end: "<ecg_global_end>"
|
| 285 |
+
lead_start_template: "<ecg_lead_{lead_lower}_start>"
|
| 286 |
+
lead_end_template: "<ecg_lead_{lead_lower}_end>"
|
| 287 |
+
canonical_leads: ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"]
|
| 288 |
+
gpt-oss-120b:
|
| 289 |
+
hf_id: "openai/gpt-oss-120b"
|
| 290 |
+
prompt:
|
| 291 |
+
start_of_turn: "<|start|>"
|
| 292 |
+
end_of_turn: "<|end|>"
|
| 293 |
+
roles:
|
| 294 |
+
user: "user"
|
| 295 |
+
model: "assistant"
|
| 296 |
+
enforce_bos: false
|
| 297 |
+
enforce_eos: false
|
| 298 |
+
allow_multiple_eos: true
|
| 299 |
+
system_prompt: |-
|
| 300 |
+
You are ChatGPT, a large language model trained by OpenAI.
|
| 301 |
+
Knowledge cutoff: 2024-06
|
| 302 |
+
Current date: 2025-06-28
|
| 303 |
+
|
| 304 |
+
Reasoning: low
|
| 305 |
+
|
| 306 |
+
# Valid channels: final. Channel must be included for every message.
|
| 307 |
+
developer_prompt: |-
|
| 308 |
+
# Instructions
|
| 309 |
+
|
| 310 |
+
You are trained to interpret electrocardiograms (ECGs) and must answer questions about them clearly and accurately.
|
| 311 |
+
tokenizer:
|
| 312 |
+
pad_token_strategy: "existing"
|
| 313 |
+
require_bos: false
|
| 314 |
+
require_eos: false
|
| 315 |
+
use_fast: true
|
| 316 |
+
add_prefix_space: false
|
| 317 |
+
architecture:
|
| 318 |
+
wrapper_class: "ecg_gptoss_model.ECGGptOssPrefix"
|
| 319 |
+
hidden_size_attrs:
|
| 320 |
+
- "config.hidden_size"
|
| 321 |
+
language_model_path_hints:
|
| 322 |
+
- "model.layers"
|
| 323 |
+
- "model.model.layers"
|
| 324 |
+
attn_implementation: "eager"
|
| 325 |
+
conv_out_dim: 64
|
| 326 |
+
lora_policy:
|
| 327 |
+
expect_language_only: true
|
| 328 |
+
allowed_markers:
|
| 329 |
+
- "model.layers."
|
| 330 |
+
- ".layers."
|
| 331 |
+
blocked_markers:
|
| 332 |
+
- "vision"
|
| 333 |
+
- "multi_modal"
|
| 334 |
+
- "projector"
|
| 335 |
+
- ".enc."
|
| 336 |
+
- "encoder_proj"
|
| 337 |
+
freeze_vision: false
|
| 338 |
+
packing:
|
| 339 |
+
prompt_format: "harmony_chat_v1"
|
| 340 |
+
conversation:
|
| 341 |
+
user_role_aliases: ["human", "user"]
|
| 342 |
+
model_role_aliases: ["assistant", "gpt"]
|
| 343 |
+
strip_image_from_roles: ["human"]
|
| 344 |
+
merge_system_with_first_user: false
|
| 345 |
+
ecg_tokens:
|
| 346 |
+
global_start: "<ecg_global_start>"
|
| 347 |
+
global_end: "<ecg_global_end>"
|
| 348 |
+
lead_start_template: "<ecg_lead_{lead_lower}_start>"
|
| 349 |
+
lead_end_template: "<ecg_lead_{lead_lower}_end>"
|
| 350 |
+
canonical_leads: ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"]
|
| 351 |
+
gemma-1b-it:
|
| 352 |
+
hf_id: "google/gemma-3-1b-it"
|
| 353 |
+
prompt:
|
| 354 |
+
start_of_turn: "<start_of_turn>"
|
| 355 |
+
end_of_turn: "<end_of_turn>\n"
|
| 356 |
+
roles:
|
| 357 |
+
user: "user"
|
| 358 |
+
model: "model"
|
| 359 |
+
enforce_bos: true
|
| 360 |
+
enforce_eos: true
|
| 361 |
+
tokenizer:
|
| 362 |
+
pad_token_strategy: "eos"
|
| 363 |
+
require_bos: true
|
| 364 |
+
require_eos: true
|
| 365 |
+
use_fast: true
|
| 366 |
+
add_prefix_space: false
|
| 367 |
+
architecture:
|
| 368 |
+
wrapper_class: "ecg_gemma_model.ECGGemmaPrefix"
|
| 369 |
+
hidden_size_attrs:
|
| 370 |
+
- "config.hidden_size"
|
| 371 |
+
- "config.text_config.hidden_size"
|
| 372 |
+
language_model_path_hints:
|
| 373 |
+
- "base_model.model.language_model.layers"
|
| 374 |
+
- "model.language_model.layers"
|
| 375 |
+
- "model.layers"
|
| 376 |
+
attn_implementation: "eager"
|
| 377 |
+
conv_out_dim: 64
|
| 378 |
+
lora_policy:
|
| 379 |
+
expect_language_only: true
|
| 380 |
+
allowed_markers:
|
| 381 |
+
- "language_model"
|
| 382 |
+
- ".model.layers."
|
| 383 |
+
- ".layers."
|
| 384 |
+
blocked_markers:
|
| 385 |
+
- "vision"
|
| 386 |
+
- "multi_modal"
|
| 387 |
+
- "projector"
|
| 388 |
+
- ".enc."
|
| 389 |
+
- "encoder_proj"
|
| 390 |
+
freeze_vision: true
|
| 391 |
+
packing:
|
| 392 |
+
prompt_format: "gemma_chat_v1"
|
| 393 |
+
conversation:
|
| 394 |
+
user_role_aliases: ["human", "user"]
|
| 395 |
+
model_role_aliases: ["gpt", "assistant"]
|
| 396 |
+
strip_image_from_roles: ["human"]
|
| 397 |
+
merge_system_with_first_user: true
|
| 398 |
+
ecg_tokens:
|
| 399 |
+
global_start: "<ecg_global_start>"
|
| 400 |
+
global_end: "<ecg_global_end>"
|
| 401 |
+
lead_start_template: "<ecg_lead_{lead_lower}_start>"
|
| 402 |
+
lead_end_template: "<ecg_lead_{lead_lower}_end>"
|
| 403 |
+
canonical_leads: ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"]
|
| 404 |
+
gemma-270m-it:
|
| 405 |
+
hf_id: "google/gemma-3-270m-it"
|
| 406 |
+
prompt:
|
| 407 |
+
start_of_turn: "<start_of_turn>"
|
| 408 |
+
end_of_turn: "<end_of_turn>\n"
|
| 409 |
+
roles:
|
| 410 |
+
user: "user"
|
| 411 |
+
model: "model"
|
| 412 |
+
enforce_bos: true
|
| 413 |
+
enforce_eos: true
|
| 414 |
+
tokenizer:
|
| 415 |
+
pad_token_strategy: "eos"
|
| 416 |
+
require_bos: true
|
| 417 |
+
require_eos: true
|
| 418 |
+
use_fast: true
|
| 419 |
+
add_prefix_space: false
|
| 420 |
+
architecture:
|
| 421 |
+
wrapper_class: "ecg_gemma_model.ECGGemmaPrefix"
|
| 422 |
+
hidden_size_attrs:
|
| 423 |
+
- "config.hidden_size"
|
| 424 |
+
- "config.text_config.hidden_size"
|
| 425 |
+
language_model_path_hints:
|
| 426 |
+
- "base_model.model.language_model.layers"
|
| 427 |
+
- "model.language_model.layers"
|
| 428 |
+
- "model.layers"
|
| 429 |
+
attn_implementation: "eager"
|
| 430 |
+
conv_out_dim: 64
|
| 431 |
+
lora_policy:
|
| 432 |
+
expect_language_only: true
|
| 433 |
+
allowed_markers:
|
| 434 |
+
- "language_model"
|
| 435 |
+
- ".model.layers."
|
| 436 |
+
- ".layers."
|
| 437 |
+
blocked_markers:
|
| 438 |
+
- "vision"
|
| 439 |
+
- "multi_modal"
|
| 440 |
+
- "projector"
|
| 441 |
+
- ".enc."
|
| 442 |
+
- "encoder_proj"
|
| 443 |
+
freeze_vision: true
|
| 444 |
+
packing:
|
| 445 |
+
prompt_format: "gemma_chat_v1"
|
| 446 |
+
conversation:
|
| 447 |
+
user_role_aliases: ["human", "user"]
|
| 448 |
+
model_role_aliases: ["gpt", "assistant"]
|
| 449 |
+
strip_image_from_roles: ["human"]
|
| 450 |
+
merge_system_with_first_user: true
|
| 451 |
+
ecg_tokens:
|
| 452 |
+
global_start: "<ecg_global_start>"
|
| 453 |
+
global_end: "<ecg_global_end>"
|
| 454 |
+
lead_start_template: "<ecg_lead_{lead_lower}_start>"
|
| 455 |
+
lead_end_template: "<ecg_lead_{lead_lower}_end>"
|
| 456 |
+
canonical_leads: ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"]
|
| 457 |
+
medgemma-27b-it:
|
| 458 |
+
hf_id: "google/medgemma-27b-text-it"
|
| 459 |
+
prompt:
|
| 460 |
+
start_of_turn: "<start_of_turn>"
|
| 461 |
+
end_of_turn: "<end_of_turn>\n"
|
| 462 |
+
roles:
|
| 463 |
+
user: "user"
|
| 464 |
+
model: "model"
|
| 465 |
+
enforce_bos: true
|
| 466 |
+
enforce_eos: true
|
| 467 |
+
tokenizer:
|
| 468 |
+
pad_token_strategy: "eos"
|
| 469 |
+
require_bos: true
|
| 470 |
+
require_eos: true
|
| 471 |
+
use_fast: true
|
| 472 |
+
add_prefix_space: false
|
| 473 |
+
architecture:
|
| 474 |
+
wrapper_class: "ecg_gemma_model.ECGGemmaPrefix"
|
| 475 |
+
hidden_size_attrs:
|
| 476 |
+
- "config.hidden_size"
|
| 477 |
+
- "config.text_config.hidden_size"
|
| 478 |
+
language_model_path_hints:
|
| 479 |
+
- "base_model.model.language_model.layers"
|
| 480 |
+
- "model.language_model.layers"
|
| 481 |
+
- "model.layers"
|
| 482 |
+
attn_implementation: "eager"
|
| 483 |
+
conv_out_dim: 64
|
| 484 |
+
lora_policy:
|
| 485 |
+
expect_language_only: true
|
| 486 |
+
allowed_markers:
|
| 487 |
+
- "language_model"
|
| 488 |
+
- ".model.layers."
|
| 489 |
+
- ".layers."
|
| 490 |
+
blocked_markers:
|
| 491 |
+
- "vision"
|
| 492 |
+
- "multi_modal"
|
| 493 |
+
- "projector"
|
| 494 |
+
- ".enc."
|
| 495 |
+
- "encoder_proj"
|
| 496 |
+
freeze_vision: true
|
| 497 |
+
packing:
|
| 498 |
+
prompt_format: "gemma_chat_v1"
|
| 499 |
+
conversation:
|
| 500 |
+
user_role_aliases: ["human", "user"]
|
| 501 |
+
model_role_aliases: ["gpt", "assistant"]
|
| 502 |
+
strip_image_from_roles: ["human"]
|
| 503 |
+
merge_system_with_first_user: true
|
| 504 |
+
ecg_tokens:
|
| 505 |
+
global_start: "<ecg_global_start>"
|
| 506 |
+
global_end: "<ecg_global_end>"
|
| 507 |
+
lead_start_template: "<ecg_lead_{lead_lower}_start>"
|
| 508 |
+
lead_end_template: "<ecg_lead_{lead_lower}_end>"
|
| 509 |
+
canonical_leads: ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"]
|
camel_inference/src/camel/process_ecg.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, List, Dict, Any
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from read_ecg import load_record
|
| 6 |
+
|
| 7 |
+
_LEAD_SYNONYMS: Dict[str, str] = {
|
| 8 |
+
# Limb
|
| 9 |
+
"I": "I", "II": "II", "III": "III",
|
| 10 |
+
"DI": "I", "DII": "II", "DIII": "III",
|
| 11 |
+
"MLII": "II",
|
| 12 |
+
# Augmented
|
| 13 |
+
"AVR": "aVR", "AVL": "aVL", "AVF": "aVF",
|
| 14 |
+
# Precordial
|
| 15 |
+
"V1": "V1", "V2": "V2", "V3": "V3", "V4": "V4", "V5": "V5", "V6": "V6",
|
| 16 |
+
# Dataset-specific
|
| 17 |
+
"ECG": "I", # Apnea-ECG
|
| 18 |
+
"ECG1": "I", "ECG2": "II", # AFDB
|
| 19 |
+
"CM5": "V5", "D3": "V3", "D4": "V4",
|
| 20 |
+
"CM2": "V2", "ML5": "V5",
|
| 21 |
+
"VF": "VF",
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
class NormOp:
|
| 25 |
+
def __init__(self, name: str, params: Dict[str, float] | None = None):
|
| 26 |
+
self.name = name
|
| 27 |
+
self.params = params or {}
|
| 28 |
+
|
| 29 |
+
def to_canonical_lead(name: str) -> Optional[str]:
|
| 30 |
+
if not isinstance(name, str):
|
| 31 |
+
return None
|
| 32 |
+
s = name.strip().upper()
|
| 33 |
+
if s in _LEAD_SYNONYMS:
|
| 34 |
+
return _LEAD_SYNONYMS[s]
|
| 35 |
+
if s in ("A VR", "A VL", "A VF"):
|
| 36 |
+
return "a" + s.replace(" ", "")[1:]
|
| 37 |
+
return None
|
| 38 |
+
|
| 39 |
+
def parse_pipeline(spec: Optional[str]) -> List[NormOp]:
|
| 40 |
+
if not spec:
|
| 41 |
+
return [NormOp("nonfinite_to_zero"), NormOp("clip", {})]
|
| 42 |
+
ops: List[NormOp] = []
|
| 43 |
+
for token in spec.split(','):
|
| 44 |
+
token = token.strip()
|
| 45 |
+
if not token:
|
| 46 |
+
continue
|
| 47 |
+
parts = token.split(':')
|
| 48 |
+
name = parts[0].lower()
|
| 49 |
+
if name == 'clip':
|
| 50 |
+
mn = None; mx = None
|
| 51 |
+
if len(parts) >= 2 and parts[1] != '':
|
| 52 |
+
mn = float(parts[1])
|
| 53 |
+
if len(parts) >= 3 and parts[2] != '':
|
| 54 |
+
mx = float(parts[2])
|
| 55 |
+
ops.append(NormOp('clip', {'min': mn, 'max': mx}))
|
| 56 |
+
elif name == 'nonfinite_to_zero':
|
| 57 |
+
ops.append(NormOp('nonfinite_to_zero'))
|
| 58 |
+
else:
|
| 59 |
+
raise ValueError(f"Unknown op '{name}'.")
|
| 60 |
+
return ops
|
| 61 |
+
|
| 62 |
+
def apply_pre_ops(signal_1d: np.ndarray, ops: List[NormOp]) -> np.ndarray:
|
| 63 |
+
x = signal_1d
|
| 64 |
+
for op in ops:
|
| 65 |
+
if op.name == 'nonfinite_to_zero':
|
| 66 |
+
np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0, copy=False)
|
| 67 |
+
return x
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def apply_post_ops(segments: np.ndarray, ops: List[NormOp], lead_name: str,
|
| 71 |
+
clip_stats: Optional[Dict[str, Dict[str, float]]] = None) -> np.ndarray:
|
| 72 |
+
x = segments
|
| 73 |
+
for op in ops:
|
| 74 |
+
if op.name == 'clip':
|
| 75 |
+
mn = op.params.get('min', None)
|
| 76 |
+
mx = op.params.get('max', None)
|
| 77 |
+
if mn is None and mx is None and clip_stats is not None and lead_name in clip_stats:
|
| 78 |
+
mn = clip_stats[lead_name].get('clip_min', None)
|
| 79 |
+
mx = clip_stats[lead_name].get('clip_max', None)
|
| 80 |
+
if mn is None:
|
| 81 |
+
mn = -np.inf
|
| 82 |
+
if mx is None:
|
| 83 |
+
mx = np.inf
|
| 84 |
+
if mn > mx:
|
| 85 |
+
mn, mx = mx, mn
|
| 86 |
+
np.clip(x, mn, mx, out=x)
|
| 87 |
+
return x.astype(np.float32, copy=False)
|
| 88 |
+
|
| 89 |
+
def _segment_data(ecg_signal: np.ndarray, raw_fs: int) -> np.ndarray:
|
| 90 |
+
"""
|
| 91 |
+
Segment 1D signal sampled at ``raw_fs`` into 1-second clips and resample to 256 samples.
|
| 92 |
+
Returns np.float32 array [N_segments, 256].
|
| 93 |
+
"""
|
| 94 |
+
assert raw_fs > 0, f"raw sampling rate must be positive (got {raw_fs})"
|
| 95 |
+
samples_per_second = int(raw_fs)
|
| 96 |
+
n_samples = int(len(ecg_signal))
|
| 97 |
+
if samples_per_second <= 0 or n_samples <= 0:
|
| 98 |
+
return np.empty((0, 256), dtype=np.float32)
|
| 99 |
+
|
| 100 |
+
# Pre-allocate upper bound
|
| 101 |
+
n_full = n_samples // samples_per_second
|
| 102 |
+
has_partial = (n_samples % samples_per_second) > 0
|
| 103 |
+
max_segments = n_full + (1 if has_partial else 0)
|
| 104 |
+
if max_segments == 0:
|
| 105 |
+
return np.empty((0, 256), dtype=np.float32)
|
| 106 |
+
|
| 107 |
+
out = np.empty((max_segments, 256), dtype=np.float32)
|
| 108 |
+
new_idx = np.linspace(0, 1, num=256, dtype=np.float32)
|
| 109 |
+
actual = 0
|
| 110 |
+
for start in range(0, n_samples, samples_per_second):
|
| 111 |
+
end = min(start + samples_per_second, n_samples)
|
| 112 |
+
seg = ecg_signal[start:end]
|
| 113 |
+
if seg.shape[0] < samples_per_second * 0.5:
|
| 114 |
+
continue
|
| 115 |
+
old_idx = np.linspace(0, 1, num=seg.shape[0], dtype=np.float32)
|
| 116 |
+
out[actual] = np.interp(new_idx, old_idx, seg).astype(np.float32, copy=False)
|
| 117 |
+
actual += 1
|
| 118 |
+
return out[:actual]
|
| 119 |
+
|
| 120 |
+
def _apply_filters(signal_1d, fs: int):
|
| 121 |
+
"""
|
| 122 |
+
Apply ECG signal filters: 50/60 Hz notch filters and 0.3 Hz high-pass filter.
|
| 123 |
+
|
| 124 |
+
Removes powerline interference and baseline wander from ECG signals.
|
| 125 |
+
Uses cascaded second-order sections (SOS) for numerical stability.
|
| 126 |
+
"""
|
| 127 |
+
import numpy as np
|
| 128 |
+
from scipy.signal import iirnotch, butter, sosfiltfilt, tf2sos
|
| 129 |
+
|
| 130 |
+
x = np.asarray(signal_1d, dtype=np.float64)
|
| 131 |
+
if x.size == 0 or fs <= 0:
|
| 132 |
+
return x.astype('float32')
|
| 133 |
+
|
| 134 |
+
# Design notch filters for powerline interference
|
| 135 |
+
Q = 30.0
|
| 136 |
+
nyq = fs / 2.0
|
| 137 |
+
sos_filters = []
|
| 138 |
+
for freq in (50.0, 60.0):
|
| 139 |
+
if freq >= nyq:
|
| 140 |
+
continue
|
| 141 |
+
b, a = iirnotch(freq, Q, fs)
|
| 142 |
+
sos_filters.append(tf2sos(b, a))
|
| 143 |
+
|
| 144 |
+
# Add high-pass filter for baseline wander removal
|
| 145 |
+
sos_hp = butter(N=2, Wn=0.3 / nyq, btype='highpass', output='sos')
|
| 146 |
+
sos_filters.append(sos_hp)
|
| 147 |
+
|
| 148 |
+
# Apply all filters in cascade
|
| 149 |
+
sos = np.vstack(sos_filters)
|
| 150 |
+
x = sosfiltfilt(sos, x)
|
| 151 |
+
return x.astype('float32')
|
| 152 |
+
|
| 153 |
+
def get_waveform(device:torch.device, ecg_path:str, start_sec = None, end_sec = None, leads: Optional[List[str]] = None,
|
| 154 |
+
process:bool = False, norm: str = "nonfinite_to_zero,clip") -> Dict[str, Any]:
|
| 155 |
+
"""
|
| 156 |
+
Run ECG preprocessing pipeline (Step II): raw → filter → segment → normalize → storage.
|
| 157 |
+
|
| 158 |
+
Processes ECG records from manifest, applies signal filtering, segments into 1s windows,
|
| 159 |
+
normalizes using computed clip stats, and writes to efficient storage format.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
dataset: Dataset name
|
| 163 |
+
p: file path
|
| 164 |
+
clip_stats_path: Path to clip_stats json file
|
| 165 |
+
process: Y/N process
|
| 166 |
+
fs: Target sampling frequency
|
| 167 |
+
leads: Optional list of leads to process
|
| 168 |
+
norm: Normalization pipeline (e.g., "nonfinite_to_zero,clip")
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
Dict with output paths: {output_dir, index, clip_stats}
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
# Parallel path: use per-worker shard/npy directories then aggregate
|
| 175 |
+
ops = parse_pipeline(norm)
|
| 176 |
+
|
| 177 |
+
# Convert to dict format
|
| 178 |
+
ecg_dict = {}
|
| 179 |
+
try:
|
| 180 |
+
df, sig_names, original_fs = load_record(ecg_path, start_sec, end_sec, leads)
|
| 181 |
+
|
| 182 |
+
# Process each lead in the record
|
| 183 |
+
for i, raw_name in enumerate(sig_names):
|
| 184 |
+
canon = to_canonical_lead(raw_name)
|
| 185 |
+
if not canon:
|
| 186 |
+
print('to_canonical_lead')
|
| 187 |
+
continue
|
| 188 |
+
x = df[:, i].astype('float32', copy=False)
|
| 189 |
+
|
| 190 |
+
if process:
|
| 191 |
+
x = apply_pre_ops(x, ops) # Handle non-finite values
|
| 192 |
+
x = _apply_filters(x, original_fs) # Remove noise and baseline wander
|
| 193 |
+
|
| 194 |
+
# Segment into 1s windows and normalize
|
| 195 |
+
segs = _segment_data(x, original_fs)
|
| 196 |
+
segs = apply_post_ops(segs, ops, canon)
|
| 197 |
+
|
| 198 |
+
lead_tensor = torch.from_numpy(segs).to(torch.float32).to(device)
|
| 199 |
+
if not (torch.any(lead_tensor.isnan())):
|
| 200 |
+
lead_tensor = lead_tensor.nan_to_num()
|
| 201 |
+
|
| 202 |
+
ecg_dict[canon] = lead_tensor
|
| 203 |
+
|
| 204 |
+
except Exception as e:
|
| 205 |
+
# Skip failed records
|
| 206 |
+
print(e)
|
| 207 |
+
|
| 208 |
+
return ecg_dict
|
camel_inference/src/camel/projectors.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Projector registry utilities.
|
| 2 |
+
|
| 3 |
+
Provides a lightweight mechanism to swap the adapter architecture that maps the
|
| 4 |
+
conv encoder output to the language-model hidden size. Mirrors the ergonomic
|
| 5 |
+
API used by loss.py: a registry, default implementations, and a simple factory.
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from typing import Callable, Dict, Iterable
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
|
| 12 |
+
ProjectorBuilder = Callable[[int, int], nn.Module]
|
| 13 |
+
|
| 14 |
+
_PROJECTOR_REGISTRY: Dict[str, ProjectorBuilder] = {}
|
| 15 |
+
|
| 16 |
+
def register_projector(name: str) -> Callable[[ProjectorBuilder], ProjectorBuilder]:
|
| 17 |
+
"""Decorator to register a projector builder under a unique name."""
|
| 18 |
+
key = name.strip().lower()
|
| 19 |
+
|
| 20 |
+
def _decorator(fn: ProjectorBuilder) -> ProjectorBuilder:
|
| 21 |
+
if not callable(fn):
|
| 22 |
+
raise TypeError("Projector builder must be callable.")
|
| 23 |
+
if key in _PROJECTOR_REGISTRY:
|
| 24 |
+
raise ValueError(f"Projector '{name}' is already registered.")
|
| 25 |
+
_PROJECTOR_REGISTRY[key] = fn
|
| 26 |
+
return fn
|
| 27 |
+
|
| 28 |
+
return _decorator
|
| 29 |
+
|
| 30 |
+
@register_projector("linear")
|
| 31 |
+
def _linear_projector(in_dim: int, out_dim: int) -> nn.Module:
|
| 32 |
+
"""Single linear adapter (current default)."""
|
| 33 |
+
return nn.Linear(in_dim, out_dim, bias=True)
|
| 34 |
+
|
| 35 |
+
def available_projectors() -> Iterable[str]:
|
| 36 |
+
"""Return sorted projector names."""
|
| 37 |
+
return sorted(_PROJECTOR_REGISTRY.keys())
|
| 38 |
+
|
| 39 |
+
def build_projector(name: str, in_dim: int, out_dim: int) -> nn.Module:
|
| 40 |
+
"""Instantiate a registered projector."""
|
| 41 |
+
if not _PROJECTOR_REGISTRY:
|
| 42 |
+
raise RuntimeError("No projectors registered.")
|
| 43 |
+
key = (name or "").strip().lower()
|
| 44 |
+
if not key:
|
| 45 |
+
raise ValueError("Projector name must be a non-empty string.")
|
| 46 |
+
builder = _PROJECTOR_REGISTRY.get(key)
|
| 47 |
+
if builder is None:
|
| 48 |
+
raise KeyError(
|
| 49 |
+
f"Unknown projector '{name}'. Available: {', '.join(available_projectors())}"
|
| 50 |
+
)
|
| 51 |
+
return builder(int(in_dim), int(out_dim))
|
| 52 |
+
|
| 53 |
+
__all__ = [
|
| 54 |
+
"ProjectorBuilder",
|
| 55 |
+
"available_projectors",
|
| 56 |
+
"build_projector",
|
| 57 |
+
]
|
camel_inference/src/camel/prompt_renderers.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Prompt rendering and span construction helpers."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from typing import Any, Callable, Dict, List, Tuple
|
| 5 |
+
from transformers import PreTrainedTokenizer
|
| 6 |
+
|
| 7 |
+
from camel.assertions import (
|
| 8 |
+
assert_tokenization_cursor_matches,
|
| 9 |
+
assert_model_spans_valid,
|
| 10 |
+
assert_eos_appended,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
def _ensure_trailing_newline(s: str) -> str:
|
| 14 |
+
if s.endswith("\n"):
|
| 15 |
+
return s
|
| 16 |
+
return s + "\n"
|
| 17 |
+
|
| 18 |
+
def _chat_v1_wrappers(schema, role: str) -> Tuple[str, str]:
|
| 19 |
+
prefix = f"{schema.prompt.start_of_turn}{role}\n"
|
| 20 |
+
suffix = _ensure_trailing_newline(schema.prompt.end_of_turn)
|
| 21 |
+
return prefix, suffix
|
| 22 |
+
|
| 23 |
+
def _harmony_v1_wrappers(schema, role: str, *, use_return: bool = False) -> Tuple[str, str]:
|
| 24 |
+
if role == schema.prompt.model_role:
|
| 25 |
+
prefix = f"{schema.prompt.start_of_turn}{role}"
|
| 26 |
+
suffix = "<|return|>" if use_return else str(schema.prompt.end_of_turn)
|
| 27 |
+
return prefix, suffix
|
| 28 |
+
prefix = f"{schema.prompt.start_of_turn}{role}<|message|>"
|
| 29 |
+
suffix = str(schema.prompt.end_of_turn)
|
| 30 |
+
return prefix, suffix
|
| 31 |
+
|
| 32 |
+
def _render_with_wrappers(
|
| 33 |
+
tokenizer: PreTrainedTokenizer,
|
| 34 |
+
turns: List[Dict[str, str]],
|
| 35 |
+
*,
|
| 36 |
+
schema,
|
| 37 |
+
wrapper_fn,
|
| 38 |
+
) -> Dict[str, Any]:
|
| 39 |
+
tok = tokenizer
|
| 40 |
+
prompt_tokens = schema.prompt
|
| 41 |
+
|
| 42 |
+
text_ids: List[int] = []
|
| 43 |
+
if prompt_tokens.require_bos and tok.bos_token_id is not None:
|
| 44 |
+
text_ids.append(tok.bos_token_id)
|
| 45 |
+
|
| 46 |
+
model_spans_in_text: List[Tuple[int, int]] = []
|
| 47 |
+
cursor = len(text_ids)
|
| 48 |
+
text_preview_parts: List[str] = []
|
| 49 |
+
|
| 50 |
+
for turn in turns:
|
| 51 |
+
role = turn["role"]
|
| 52 |
+
text_block = turn["text_block"]
|
| 53 |
+
prefix, suffix = wrapper_fn(schema, role)
|
| 54 |
+
content = text_block
|
| 55 |
+
if suffix and content.endswith(suffix):
|
| 56 |
+
content = content[: -len(suffix)]
|
| 57 |
+
ids_prefix = tok.encode(prefix, add_special_tokens=False)
|
| 58 |
+
ids_content = tok.encode(content, add_special_tokens=False)
|
| 59 |
+
ids_suffix = tok.encode(suffix, add_special_tokens=False)
|
| 60 |
+
|
| 61 |
+
text_ids.extend(ids_prefix)
|
| 62 |
+
text_ids.extend(ids_content)
|
| 63 |
+
text_ids.extend(ids_suffix)
|
| 64 |
+
|
| 65 |
+
if role == prompt_tokens.model_role:
|
| 66 |
+
s = cursor + len(ids_prefix)
|
| 67 |
+
e = s + len(ids_content) + len(ids_suffix)
|
| 68 |
+
if e > s:
|
| 69 |
+
model_spans_in_text.append((s, e))
|
| 70 |
+
cursor += len(ids_prefix) + len(ids_content) + len(ids_suffix)
|
| 71 |
+
text_preview_parts.append(prefix + content + suffix)
|
| 72 |
+
|
| 73 |
+
assert_tokenization_cursor_matches(cursor, len(text_ids))
|
| 74 |
+
|
| 75 |
+
if prompt_tokens.require_eos and tok.eos_token_id is not None:
|
| 76 |
+
text_ids.append(tok.eos_token_id)
|
| 77 |
+
if model_spans_in_text and turns[-1]["role"] == prompt_tokens.model_role:
|
| 78 |
+
model_spans_in_text[-1] = (model_spans_in_text[-1][0], len(text_ids))
|
| 79 |
+
|
| 80 |
+
assert_eos_appended(text_ids, tok, prompt_tokens.require_eos)
|
| 81 |
+
assert_model_spans_valid(model_spans_in_text, len(text_ids))
|
| 82 |
+
|
| 83 |
+
return {
|
| 84 |
+
"text_ids": text_ids,
|
| 85 |
+
"model_spans_in_text": model_spans_in_text,
|
| 86 |
+
"text_preview": "".join(text_preview_parts),
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
def _render_chat_v1(
|
| 90 |
+
tokenizer: PreTrainedTokenizer,
|
| 91 |
+
turns: List[Dict[str, str]],
|
| 92 |
+
*,
|
| 93 |
+
schema,
|
| 94 |
+
) -> Dict[str, Any]:
|
| 95 |
+
return _render_with_wrappers(
|
| 96 |
+
tokenizer,
|
| 97 |
+
turns,
|
| 98 |
+
schema=schema,
|
| 99 |
+
wrapper_fn=_chat_v1_wrappers,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
def _render_harmony_v1(
|
| 103 |
+
tokenizer: PreTrainedTokenizer,
|
| 104 |
+
turns: List[Dict[str, str]],
|
| 105 |
+
*,
|
| 106 |
+
schema,
|
| 107 |
+
use_return_for_last_assistant: bool = False,
|
| 108 |
+
) -> Dict[str, Any]:
|
| 109 |
+
tok = tokenizer
|
| 110 |
+
prompt_tokens = schema.prompt
|
| 111 |
+
|
| 112 |
+
text_ids: List[int] = []
|
| 113 |
+
if prompt_tokens.require_bos and tok.bos_token_id is not None:
|
| 114 |
+
text_ids.append(tok.bos_token_id)
|
| 115 |
+
|
| 116 |
+
model_spans_in_text: List[Tuple[int, int]] = []
|
| 117 |
+
cursor = len(text_ids)
|
| 118 |
+
text_preview_parts: List[str] = []
|
| 119 |
+
|
| 120 |
+
last_assistant_idx = None
|
| 121 |
+
if use_return_for_last_assistant:
|
| 122 |
+
for idx in range(len(turns) - 1, -1, -1):
|
| 123 |
+
if turns[idx]["role"] == prompt_tokens.model_role:
|
| 124 |
+
last_assistant_idx = idx
|
| 125 |
+
break
|
| 126 |
+
|
| 127 |
+
for idx, turn in enumerate(turns):
|
| 128 |
+
role = turn["role"]
|
| 129 |
+
text_block = turn["text_block"]
|
| 130 |
+
use_return = use_return_for_last_assistant and last_assistant_idx is not None and idx == last_assistant_idx
|
| 131 |
+
prefix, suffix = _harmony_v1_wrappers(schema, role, use_return=use_return)
|
| 132 |
+
content = text_block
|
| 133 |
+
if suffix and content.endswith(suffix):
|
| 134 |
+
content = content[: -len(suffix)]
|
| 135 |
+
ids_prefix = tok.encode(prefix, add_special_tokens=False)
|
| 136 |
+
ids_content = tok.encode(content, add_special_tokens=False)
|
| 137 |
+
ids_suffix = tok.encode(suffix, add_special_tokens=False)
|
| 138 |
+
|
| 139 |
+
text_ids.extend(ids_prefix)
|
| 140 |
+
text_ids.extend(ids_content)
|
| 141 |
+
text_ids.extend(ids_suffix)
|
| 142 |
+
|
| 143 |
+
if role == prompt_tokens.model_role:
|
| 144 |
+
s = cursor + len(ids_prefix)
|
| 145 |
+
e = s + len(ids_content) + len(ids_suffix)
|
| 146 |
+
if e > s:
|
| 147 |
+
model_spans_in_text.append((s, e))
|
| 148 |
+
cursor += len(ids_prefix) + len(ids_content) + len(ids_suffix)
|
| 149 |
+
text_preview_parts.append(prefix + content + suffix)
|
| 150 |
+
|
| 151 |
+
assert_tokenization_cursor_matches(cursor, len(text_ids))
|
| 152 |
+
|
| 153 |
+
if prompt_tokens.require_eos and tok.eos_token_id is not None:
|
| 154 |
+
text_ids.append(tok.eos_token_id)
|
| 155 |
+
if model_spans_in_text and turns[-1]["role"] == prompt_tokens.model_role:
|
| 156 |
+
model_spans_in_text[-1] = (model_spans_in_text[-1][0], len(text_ids))
|
| 157 |
+
|
| 158 |
+
assert_eos_appended(text_ids, tok, prompt_tokens.require_eos)
|
| 159 |
+
assert_model_spans_valid(model_spans_in_text, len(text_ids))
|
| 160 |
+
|
| 161 |
+
return {
|
| 162 |
+
"text_ids": text_ids,
|
| 163 |
+
"model_spans_in_text": model_spans_in_text,
|
| 164 |
+
"text_preview": "".join(text_preview_parts),
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
_PROMPT_RENDERERS: Dict[str, Callable[[PreTrainedTokenizer, List[Dict[str, str]], Any], Dict[str, Any]]] = {
|
| 168 |
+
"gemma_chat_v1": _render_chat_v1,
|
| 169 |
+
"qwen_chat_v1": _render_chat_v1,
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
def render_prompt_and_spans(
|
| 173 |
+
tokenizer: PreTrainedTokenizer,
|
| 174 |
+
turns: List[Dict[str, str]],
|
| 175 |
+
*,
|
| 176 |
+
schema,
|
| 177 |
+
use_return_for_last_assistant: bool = False,
|
| 178 |
+
) -> Dict[str, Any]:
|
| 179 |
+
format_id = str(schema.conversation.format_id)
|
| 180 |
+
if format_id == "harmony_chat_v1":
|
| 181 |
+
return _render_harmony_v1(
|
| 182 |
+
tokenizer,
|
| 183 |
+
turns,
|
| 184 |
+
schema=schema,
|
| 185 |
+
use_return_for_last_assistant=use_return_for_last_assistant,
|
| 186 |
+
)
|
| 187 |
+
renderer = _PROMPT_RENDERERS.get(format_id)
|
| 188 |
+
if renderer is None:
|
| 189 |
+
raise ValueError(f"Unknown prompt format '{format_id}'.")
|
| 190 |
+
return renderer(tokenizer, turns, schema=schema)
|
| 191 |
+
|
| 192 |
+
def turn_wrappers(schema, role: str, *, use_return: bool = False) -> Tuple[str, str]:
|
| 193 |
+
format_id = str(schema.conversation.format_id)
|
| 194 |
+
if format_id in ("gemma_chat_v1", "qwen_chat_v1"):
|
| 195 |
+
return _chat_v1_wrappers(schema, role)
|
| 196 |
+
if format_id == "harmony_chat_v1":
|
| 197 |
+
return _harmony_v1_wrappers(schema, role, use_return=use_return)
|
| 198 |
+
raise ValueError(f"Unknown prompt format '{format_id}'.")
|
| 199 |
+
|
| 200 |
+
def assistant_generation_prefix(schema) -> str:
|
| 201 |
+
format_id = str(schema.conversation.format_id)
|
| 202 |
+
if format_id in ("gemma_chat_v1", "qwen_chat_v1"):
|
| 203 |
+
return f"{schema.prompt.start_of_turn}{schema.prompt.model_role}\n"
|
| 204 |
+
if format_id == "harmony_chat_v1":
|
| 205 |
+
return f"{schema.prompt.start_of_turn}{schema.prompt.model_role}"
|
| 206 |
+
raise ValueError(f"Unknown prompt format '{format_id}'.")
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
__all__ = ["render_prompt_and_spans", "turn_wrappers", "assistant_generation_prefix"]
|
camel_inference/src/camel/training_setup.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Runtime and configuration helpers extracted from train_ecg_text.py.
|
| 3 |
+
These utilities keep the training entrypoint concise while preserving the
|
| 4 |
+
original behaviour when preparing distributed state, tokenizer metadata, and
|
| 5 |
+
packing configuration.
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from typing import Dict, Optional, List
|
| 10 |
+
import torch.distributed as dist
|
| 11 |
+
from transformers import AutoTokenizer
|
| 12 |
+
|
| 13 |
+
from camel.ecg_text_packing import (
|
| 14 |
+
ECGSpecialTokenCatalog,
|
| 15 |
+
PackingSchema,
|
| 16 |
+
PromptTokens,
|
| 17 |
+
)
|
| 18 |
+
from camel.model_registry import ModelConfig, ModelRegistryError, load_registry
|
| 19 |
+
|
| 20 |
+
def is_main_process() -> bool:
|
| 21 |
+
"""Return True for rank 0 (or standalone execution)."""
|
| 22 |
+
return (not dist.is_initialized()) or dist.get_rank() == 0
|
| 23 |
+
|
| 24 |
+
def build_packing_schema(pretrained_model_id: str) -> PackingSchema:
|
| 25 |
+
"""
|
| 26 |
+
Construct the packing schema (prompt + conversation rules + ECG tokens)
|
| 27 |
+
for the given backbone using the shared registry.
|
| 28 |
+
"""
|
| 29 |
+
registry = load_registry()
|
| 30 |
+
cfg: Optional[ModelConfig]
|
| 31 |
+
try:
|
| 32 |
+
cfg = registry.get(pretrained_model_id)
|
| 33 |
+
except ModelRegistryError:
|
| 34 |
+
cfg = None
|
| 35 |
+
for name in registry.names():
|
| 36 |
+
candidate = registry.get(name)
|
| 37 |
+
if candidate.hf_id == pretrained_model_id:
|
| 38 |
+
cfg = candidate
|
| 39 |
+
break
|
| 40 |
+
if cfg is None:
|
| 41 |
+
raise ModelRegistryError(
|
| 42 |
+
f"Pretrained model '{pretrained_model_id}' not found in registry at {registry.source_path}"
|
| 43 |
+
)
|
| 44 |
+
prompt_cfg = cfg.prompt_config()
|
| 45 |
+
roles_dict = dict(prompt_cfg.roles or {})
|
| 46 |
+
try:
|
| 47 |
+
user_role = str(roles_dict["user"])
|
| 48 |
+
model_role = str(roles_dict["model"])
|
| 49 |
+
except KeyError as exc:
|
| 50 |
+
missing = exc.args[0]
|
| 51 |
+
raise ModelRegistryError(
|
| 52 |
+
f"Prompt configuration for registry entry '{cfg.name}' is missing the '{missing}' role."
|
| 53 |
+
) from exc
|
| 54 |
+
prompt_tokens = PromptTokens(
|
| 55 |
+
start_of_turn=prompt_cfg.start_of_turn,
|
| 56 |
+
end_of_turn=prompt_cfg.end_of_turn,
|
| 57 |
+
user_role=user_role,
|
| 58 |
+
model_role=model_role,
|
| 59 |
+
require_bos=prompt_cfg.enforce_bos,
|
| 60 |
+
require_eos=prompt_cfg.enforce_eos,
|
| 61 |
+
allow_multiple_eos=prompt_cfg.allow_multiple_eos,
|
| 62 |
+
)
|
| 63 |
+
packing_cfg = cfg.packing_config()
|
| 64 |
+
conversation_rules = packing_cfg.conversation
|
| 65 |
+
ecg_tokens = packing_cfg.ecg_tokens
|
| 66 |
+
return PackingSchema(
|
| 67 |
+
prompt=prompt_tokens,
|
| 68 |
+
conversation=conversation_rules,
|
| 69 |
+
ecg=ecg_tokens,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
def initialize_tokenizer(
|
| 73 |
+
model_id: str,
|
| 74 |
+
*,
|
| 75 |
+
trust_remote_code: bool = True,
|
| 76 |
+
use_fast: Optional[bool] = None,
|
| 77 |
+
add_prefix_space: Optional[bool] = None,
|
| 78 |
+
) -> AutoTokenizer:
|
| 79 |
+
"""
|
| 80 |
+
Instantiate the HF tokenizer, allowing policy to be driven by the registry
|
| 81 |
+
(use_fast/add_prefix_space). If not provided, defaults are use_fast=True,
|
| 82 |
+
add_prefix_space=False.
|
| 83 |
+
"""
|
| 84 |
+
# Honor registry defaults when the caller doesn't override them.
|
| 85 |
+
default_use_fast = True
|
| 86 |
+
default_add_prefix_space = False
|
| 87 |
+
try:
|
| 88 |
+
registry = load_registry()
|
| 89 |
+
cfg: Optional[ModelConfig]
|
| 90 |
+
try:
|
| 91 |
+
cfg = registry.get(model_id)
|
| 92 |
+
except ModelRegistryError:
|
| 93 |
+
cfg = None
|
| 94 |
+
for name in registry.names():
|
| 95 |
+
candidate = registry.get(name)
|
| 96 |
+
if candidate.hf_id == model_id:
|
| 97 |
+
cfg = candidate
|
| 98 |
+
break
|
| 99 |
+
if cfg is not None:
|
| 100 |
+
tcfg = cfg.tokenizer_config()
|
| 101 |
+
default_use_fast = bool(tcfg.use_fast)
|
| 102 |
+
default_add_prefix_space = bool(tcfg.add_prefix_space)
|
| 103 |
+
except Exception:
|
| 104 |
+
# Fall back to built-in defaults if registry is unavailable.
|
| 105 |
+
pass
|
| 106 |
+
|
| 107 |
+
return AutoTokenizer.from_pretrained(
|
| 108 |
+
model_id,
|
| 109 |
+
use_fast=default_use_fast if use_fast is None else bool(use_fast),
|
| 110 |
+
add_prefix_space=default_add_prefix_space if add_prefix_space is None else bool(add_prefix_space),
|
| 111 |
+
trust_remote_code=trust_remote_code,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
def register_ecg_special_tokens(
|
| 115 |
+
tokenizer: AutoTokenizer,
|
| 116 |
+
catalog: ECGSpecialTokenCatalog,
|
| 117 |
+
) -> Dict[int, int]:
|
| 118 |
+
"""
|
| 119 |
+
Ensure the tokenizer includes the ECG special tokens from the provided catalog.
|
| 120 |
+
Returns a mapping from catalog index to token ID.
|
| 121 |
+
"""
|
| 122 |
+
# Add only tokens that are currently unknown to the tokenizer (not present
|
| 123 |
+
# as core specials or regular vocab entries).
|
| 124 |
+
tokens_to_add: List[str] = []
|
| 125 |
+
for token in catalog.tokens:
|
| 126 |
+
tok_id = tokenizer.convert_tokens_to_ids(token)
|
| 127 |
+
if tok_id is None or tok_id == tokenizer.unk_token_id:
|
| 128 |
+
tokens_to_add.append(token)
|
| 129 |
+
if tokens_to_add:
|
| 130 |
+
tokenizer.add_special_tokens({"additional_special_tokens": tokens_to_add})
|
| 131 |
+
ecg_special_token_id_map: Dict[int, int] = {}
|
| 132 |
+
for token, catalog_index in catalog.token_to_index.items():
|
| 133 |
+
token_id = tokenizer.convert_tokens_to_ids(token)
|
| 134 |
+
if token_id is None or token_id == tokenizer.unk_token_id:
|
| 135 |
+
raise RuntimeError(f"Tokenizer failed to register ECG special token: {token}")
|
| 136 |
+
encoded = tokenizer.encode(token, add_special_tokens=False)
|
| 137 |
+
if len(encoded) != 1 or encoded[0] != token_id:
|
| 138 |
+
raise RuntimeError(f"ECG special token does not map to a single id: {token}")
|
| 139 |
+
ecg_special_token_id_map[catalog_index] = int(token_id)
|
| 140 |
+
return ecg_special_token_id_map
|
camel_inference/src/read_ecg.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
import numpy as np
|
| 3 |
+
import wfdb
|
| 4 |
+
|
| 5 |
+
def load_record(ecg_path, start_sec: Optional[int], end_sec: Optional[int], leads: Optional[list[str]]):
|
| 6 |
+
record = wfdb.rdrecord(ecg_path)
|
| 7 |
+
fs = record.fs
|
| 8 |
+
lead_names = record.sig_name
|
| 9 |
+
|
| 10 |
+
signal = record.p_signal # n_samples x n_leads
|
| 11 |
+
if leads:
|
| 12 |
+
kept_signals, kept_leads = [], []
|
| 13 |
+
lead_to_idx = {name: i for i, name in enumerate(lead_names)}
|
| 14 |
+
for l in leads:
|
| 15 |
+
if l in lead_to_idx:
|
| 16 |
+
kept_signals.append(signal[:, lead_to_idx[l]])
|
| 17 |
+
kept_leads.append(l)
|
| 18 |
+
else:
|
| 19 |
+
print(f'Lead {l} does not exist. Skipping.')
|
| 20 |
+
if not kept_signals:
|
| 21 |
+
raise ValueError(f"None of the requested leads were found. requested={leads}, available={lead_names}")
|
| 22 |
+
|
| 23 |
+
signal = np.stack(kept_signals, axis=1)
|
| 24 |
+
lead_names = kept_leads
|
| 25 |
+
|
| 26 |
+
# Optinally subsample the signal
|
| 27 |
+
start_ind = 0 if start_sec is None else start_sec * fs
|
| 28 |
+
end_ind = len(signal) if end_sec is None else end_sec * fs
|
| 29 |
+
if end_ind > len(signal):
|
| 30 |
+
print(f'ECG is {len(signal) / fs} seconds')
|
| 31 |
+
signal = signal[start_ind:end_ind, :]
|
| 32 |
+
|
| 33 |
+
return signal, lead_names, fs
|