| """ |
| Assertions and summaries for ECG language-model wrappers and their LoRA adapters. |
| """ |
| from __future__ import annotations |
|
|
| from typing import Any, Dict, Iterable, List, Mapping, Set |
| import functools |
| import os |
| import torch |
|
|
| from camel.ecg_attention_masks import ECGBlockLayout |
|
|
| _ASSERTIONS_ENABLED = os.getenv("ASSERTIONS") == "1" |
|
|
| def _skip_if_assertions_disabled(func): |
| """Decorator that no-ops assertion helpers when ASSERTIONS env var is not set.""" |
| @functools.wraps(func) |
| def wrapper(*args, **kwargs): |
| if not _ASSERTIONS_ENABLED: |
| return None |
| return func(*args, **kwargs) |
|
|
| return wrapper |
|
|
| @_skip_if_assertions_disabled |
| def assert_tensor_dtype(tensor: torch.Tensor, *, expected: torch.dtype, context: str) -> None: |
| """Verify tensor dtype matches expectation.""" |
| if tensor.dtype != expected: |
| raise AssertionError(f"{context}: expected dtype {expected}, got {tensor.dtype}") |
|
|
| @_skip_if_assertions_disabled |
| def assert_ecg_blocks_consistent( |
| *, |
| turn_parts: Iterable[Iterable[Dict[str, Any]]], |
| ecg_blocks: Iterable[Dict[str, Any]], |
| ) -> None: |
| """ |
| Validate that structured turn parts contain the expected ECG markers per block. |
| """ |
| blocks_list = list(ecg_blocks) |
| expected_counts: Dict[int, Dict[str, int]] = {} |
| per_lead_special_counts: Dict[int, Dict[str, int]] = {} |
| per_lead_secs: Dict[int, Dict[str, Set[int]]] = {} |
| global_special_count: Dict[int, int] = {} |
|
|
| for idx, blk in enumerate(blocks_list): |
| leads = [str(ld) for ld in blk.get("lead_names", [])] |
| segs = [int(n) for n in blk.get("segments_per_lead", [])] |
| expected_counts[idx] = {ld: int(n) for ld, n in zip(leads, segs)} |
| per_lead_special_counts[idx] = {ld: 0 for ld in leads} |
| per_lead_secs[idx] = {ld: set() for ld in leads} |
| global_special_count[idx] = 0 |
|
|
| for turn in turn_parts: |
| for part in turn: |
| kind = part.get("kind") |
| if kind == "text": |
| continue |
| block_idx = part.get("block_index") |
| if block_idx is None or int(block_idx) not in expected_counts: |
| raise AssertionError("ECG part references unknown block_index.") |
| block_idx = int(block_idx) |
| allowed_leads = set(expected_counts[block_idx].keys()) |
| if kind == "special": |
| lead = part.get("lead") |
| if lead: |
| if lead not in allowed_leads: |
| raise AssertionError(f"Special token references unknown lead '{lead}'.") |
| per_lead_special_counts[block_idx][lead] = per_lead_special_counts[block_idx].get(lead, 0) + 1 |
| else: |
| global_special_count[block_idx] = global_special_count.get(block_idx, 0) + 1 |
| token_literal = part.get("token") |
| if not isinstance(token_literal, str) or len(token_literal) == 0: |
| raise AssertionError("Special turn part lacks a string token literal.") |
| continue |
| if kind == "ecg": |
| lead = part.get("lead") |
| if lead not in allowed_leads: |
| raise AssertionError(f"ECG segment references unknown lead '{lead}'.") |
| sec_val = part.get("sec") |
| expected_total = expected_counts[block_idx].get(lead, 0) |
| if expected_total <= 0: |
| raise AssertionError( |
| f"Lead '{lead}' has non-positive declared segments ({expected_total}) but ECG markers are present." |
| ) |
| try: |
| sec = int(sec_val) |
| except Exception as exc: |
| raise AssertionError(f"ECG segment for lead '{lead}' has non-integer sec {sec_val!r}.") from exc |
| if sec < 1 or sec > expected_total: |
| raise AssertionError( |
| f"ECG segment for lead '{lead}' has second={sec}, expected within [1,{expected_total}]." |
| ) |
| if sec in per_lead_secs[block_idx][lead]: |
| raise AssertionError(f"Duplicate ECG segment marker for lead '{lead}' second {sec}.") |
| per_lead_secs[block_idx][lead].add(sec) |
| continue |
| raise AssertionError(f"Unknown turn_parts kind '{kind}'.") |
|
|
| for block_idx, expected in expected_counts.items(): |
| if global_special_count.get(block_idx, 0) != 2: |
| raise AssertionError( |
| f"Expected exactly two global ECG markers for block {block_idx}; " |
| f"found {global_special_count.get(block_idx, 0)}." |
| ) |
| for lead, expected_total in expected.items(): |
| expected_specials = per_lead_special_counts[block_idx].get(lead, 0) |
| if expected_specials != 2: |
| raise AssertionError( |
| f"Lead '{lead}' has {expected_specials} special markers; expected start and end (2 total)." |
| ) |
| seen_secs = per_lead_secs[block_idx].get(lead, set()) |
| if expected_total != len(seen_secs): |
| missing = sorted(set(range(1, expected_total + 1)) - seen_secs) |
| raise AssertionError( |
| f"Lead '{lead}' missing ECG segment markers for seconds {missing} (expected {expected_total})." |
| ) |
|
|
|
|
| |
|
|
| @_skip_if_assertions_disabled |
| def assert_prefix_split_complete(*, offset: int, total_prefix_rows: int) -> None: |
| """Validate that prefix splitting consumed all rows.""" |
| if offset != total_prefix_rows: |
| raise RuntimeError( |
| f"Prefix split mismatch: consumed {offset} rows but have {total_prefix_rows}" |
| ) |
|
|
|
|
| @_skip_if_assertions_disabled |
| def assert_prefix_matches_segments( |
| *, |
| prefix_rows: int, |
| segments_per_lead: Iterable[int], |
| lead_names: Iterable[str], |
| sample_index: int, |
| block_index: int, |
| ) -> None: |
| """Validate that prefix row count matches sum of segments_per_lead.""" |
| total_segments = sum(int(n) for n in segments_per_lead) |
| if prefix_rows != total_segments: |
| raise RuntimeError( |
| f"Sample {sample_index} block {block_index}: Prefix rows ({prefix_rows}) " |
| f"!= sum(segments_per_lead) ({total_segments}). " |
| f"lead_names={list(lead_names)} segments_per_lead={list(segments_per_lead)}" |
| ) |
|
|
| @_skip_if_assertions_disabled |
| def assert_ecg_part_bounds( |
| *, |
| lead: str, |
| sec: int, |
| lead_to_offset: Mapping[str, int], |
| declared_segments: Mapping[str, int], |
| total_prefix_rows: int, |
| sample_index: int, |
| block_index: int, |
| ) -> None: |
| """Validate ECG part (lead, sec) falls within expected bounds.""" |
| if lead not in declared_segments: |
| raise RuntimeError(f"Unknown lead {lead} in parts for sample {sample_index} block {block_index}") |
|
|
| nseg = int(declared_segments[lead]) |
| if not (1 <= sec <= nseg): |
| raise RuntimeError( |
| f"sec out of range for lead {lead}: got {sec}, expected 1..{nseg}" |
| ) |
|
|
| base = lead_to_offset[lead] |
| start = base |
| end = base + nseg |
| row_idx = start + (sec - 1) |
|
|
| |
| if not (0 <= row_idx < total_prefix_rows): |
| raise RuntimeError( |
| f"Bad (lead,sec)=({lead},{sec}) for sample {sample_index} block {block_index}: " |
| f"row_idx {row_idx} not in [0,{total_prefix_rows})" |
| ) |
| if not (start <= row_idx < end): |
| raise RuntimeError( |
| f"(lead,sec)=({lead},{sec}) maps outside this lead block " |
| f"[{start},{end}) (row_idx={row_idx}) for sample {sample_index}" |
| ) |
|
|
|
|
| @_skip_if_assertions_disabled |
| def assert_layout_specials_complete( |
| *, |
| block_layout: ECGBlockLayout, |
| lead_names: Iterable[str], |
| ) -> None: |
| """Validate that layout has complete and ordered special token markers. |
| |
| For each declared lead: |
| - Both start and end must be present (or both absent) |
| - If present, start < end |
| |
| For global markers: |
| - Both start and end must be present (or both absent) |
| - If present, start < end |
| """ |
| |
| for ld in lead_names: |
| s = block_layout.lead_start_idx.get(ld) |
| e = block_layout.lead_end_idx.get(ld) |
| if (s is None) != (e is None): |
| raise RuntimeError(f"Lead {ld} missing start/end special (s={s}, e={e})") |
| if s is not None and not (s < e): |
| raise RuntimeError(f"Lead {ld} specials out of order: start={s}, end={e}") |
|
|
| |
| if (block_layout.global_start_idx is None) != (block_layout.global_end_idx is None): |
| raise RuntimeError("Global start/end special mismatch") |
| if block_layout.global_start_idx is not None and block_layout.global_end_idx is not None: |
| if not (block_layout.global_start_idx < block_layout.global_end_idx): |
| raise RuntimeError( |
| f"Global specials out of order: start={block_layout.global_start_idx} " |
| f"end={block_layout.global_end_idx}" |
| ) |
|
|
| |
|
|
| @_skip_if_assertions_disabled |
| def assert_wrapper_embed_length( |
| *, |
| embeddings: torch.Tensor, |
| ids: List[int], |
| context: str, |
| ) -> None: |
| """Ensure embedding sequence length matches token count exactly. |
| |
| This is a critical invariant check that ensures the 1:1 mapping between |
| input token IDs and output embeddings is preserved. |
| |
| Args: |
| embeddings: Output embedding tensor (must be at least 1-D) |
| ids: Input token ID list |
| context: Description of where this check is being performed |
| |
| Raises: |
| RuntimeError: if ids is not a list, embeddings are not at least 1-D, |
| or if the embedding length doesn't match the token count |
| """ |
| if not isinstance(ids, list): |
| raise RuntimeError(f"{context}: ids must be a python list of ints") |
| if embeddings.dim() < 1: |
| raise RuntimeError(f"{context}: embeddings must be at least 1-D, got shape {tuple(embeddings.shape)}") |
| if embeddings.size(0) != len(ids): |
| raise RuntimeError(f"{context}: embed length {embeddings.size(0)} != token count {len(ids)}") |
|
|
| @_skip_if_assertions_disabled |
| def assert_rest_length_nonnegative(*, rest_length: int) -> None: |
| """Validate that rest token length is non-negative. |
| |
| This should never happen in correct code, but catching it early helps |
| identify bugs in label construction logic. |
| |
| Args: |
| rest_length: Length of ids_rest list |
| |
| Raises: |
| ValueError: if rest_length is negative |
| """ |
| if rest_length < 0: |
| raise ValueError("ids_rest length is negative (internal error).") |
|
|
|
|
| |
|
|
| @_skip_if_assertions_disabled |
| def assert_sorted_non_overlapping_spans(spans: List[tuple[int, int]], length: int, ctx: str) -> None: |
| """Validate that spans are sorted, non-overlapping, and within bounds.""" |
| prev_end = -1 |
| for i, (s, e) in enumerate(spans): |
| if not (0 <= s <= e <= length): |
| raise AssertionError(f"{ctx}: span {i}={(s,e)} out of bounds for length {length}") |
| if s < prev_end: |
| raise AssertionError(f"{ctx}: spans overlap or not sorted at {i-1},{i}: prev_end={prev_end}, curr_start={s}") |
| prev_end = e |
|
|
|
|
| @_skip_if_assertions_disabled |
| def assert_equal_int(a: int, b: int, msg: str) -> None: |
| """Assert two integers are equal with a descriptive message.""" |
| if int(a) != int(b): |
| raise AssertionError(f"{msg}: {a} != {b}") |
|
|
|
|
| @_skip_if_assertions_disabled |
| def assert_positive_int(n: int, msg: str) -> None: |
| """Assert an integer is positive (> 0).""" |
| if int(n) <= 0: |
| raise AssertionError(f"{msg}: expected > 0, got {n}") |
|
|
|
|
| |
|
|
| @_skip_if_assertions_disabled |
| def assert_ecg_catalog_valid(catalog: Any, schema: Any) -> None: |
| """Validate ECG special token catalog for uniqueness and mapping consistency. |
| |
| Checks: |
| - All tokens are unique |
| - Every canonical lead has entries in lead_to_indices and lead_to_tokens |
| - Token-to-index mappings are consistent across all structures |
| - Global markers (start/end) are present in the catalog |
| """ |
| |
| if len(set(catalog.tokens)) != len(catalog.tokens): |
| raise AssertionError("ECG special tokens contain duplicates") |
|
|
| |
| for lead in schema.ecg.canonical_leads: |
| if lead not in catalog.lead_to_indices or lead not in catalog.lead_to_tokens: |
| raise AssertionError(f"Missing lead in catalog: {lead}") |
| for kind in ("start", "end"): |
| tok = catalog.lead_to_tokens[lead][kind] |
| idx = catalog.lead_to_indices[lead][kind] |
| if catalog.tokens[idx] != tok: |
| raise AssertionError(f"Catalog mismatch for {lead}:{kind}: tokens[idx] != tok") |
| if catalog.token_to_index.get(tok, None) != idx: |
| raise AssertionError(f"token_to_index mismatch for {lead}:{kind}") |
|
|
| |
| for tok in (schema.ecg.global_start, schema.ecg.global_end): |
| if tok not in catalog.token_to_index: |
| raise AssertionError(f"Global ECG token missing from catalog: {tok}") |
|
|
|
|
| |
|
|
| @_skip_if_assertions_disabled |
| def assert_normalized_role_canonical(role: str, schema: Any) -> None: |
| """Ensure normalized role matches one of the canonical prompt roles.""" |
| if role not in (schema.prompt.user_role, schema.prompt.model_role): |
| raise AssertionError(f"Normalized role '{role}' did not resolve to a canonical prompt role") |
|
|
| |
|
|
| @_skip_if_assertions_disabled |
| def assert_tokenization_cursor_matches(cursor: int, ids_length: int) -> None: |
| """Ensure cursor tracking matches actual text_ids length.""" |
| if cursor != ids_length: |
| raise AssertionError(f"cursor ({cursor}) != len(text_ids) ({ids_length})") |
|
|
|
|
| @_skip_if_assertions_disabled |
| def assert_model_spans_valid(model_spans: List[tuple[int, int]], ids_length: int) -> None: |
| """Validate model spans are sorted, non-overlapping, and at least one exists.""" |
| assert_sorted_non_overlapping_spans(model_spans, ids_length, ctx="model_spans_in_text") |
| if len(model_spans) == 0: |
| raise AssertionError("No model spans found in text ids") |
|
|
|
|
| @_skip_if_assertions_disabled |
| def assert_eos_appended(ids: List[int], tokenizer: Any, require_eos: bool) -> None: |
| """Validate EOS token was appended if required.""" |
| if require_eos and tokenizer.eos_token_id is not None: |
| if not ids or ids[-1] != tokenizer.eos_token_id: |
| raise AssertionError("Required EOS was not appended at the end of text_ids") |
|
|
|
|
| |
|
|
| @_skip_if_assertions_disabled |
| def assert_turn_parts_structure_valid( |
| parts: List[Dict[str, Any]], |
| ecg_blocks: List[Dict[str, Any]], |
| schema: Any, |
| catalog: Any, |
| ) -> None: |
| """Validate the complete structure of turn parts for all blocks present.""" |
| block_indices = sorted({int(p.get("block_index")) for p in parts if p.get("block_index") is not None}) |
| for block_idx in block_indices: |
| if block_idx < 0 or block_idx >= len(ecg_blocks): |
| raise AssertionError(f"Unknown block_index {block_idx} in turn parts") |
| blk = ecg_blocks[block_idx] |
| leads_present = [str(ld) for ld in blk.get("lead_names", [])] |
| segments_per_lead = [int(n) for n in blk.get("segments_per_lead", [])] |
|
|
| special_tokens = [p.get("token") for p in parts if p.get("kind") == "special" and p.get("block_index") == block_idx] |
| if schema.ecg.global_start not in special_tokens: |
| raise AssertionError(f"Missing global_start special in block {block_idx}") |
| if schema.ecg.global_end not in special_tokens: |
| raise AssertionError(f"Missing global_end special in block {block_idx}") |
|
|
| idx_global_start = next( |
| (i for i, p in enumerate(parts) |
| if p.get("block_index") == block_idx and p.get("kind") == "special" |
| and p.get("token") == schema.ecg.global_start), |
| None |
| ) |
| idx_global_end = next( |
| (i for i, p in enumerate(parts) |
| if p.get("block_index") == block_idx and p.get("kind") == "special" |
| and p.get("token") == schema.ecg.global_end), |
| None |
| ) |
| if idx_global_start is None or idx_global_end is None or not (idx_global_start < idx_global_end): |
| raise AssertionError(f"Block {block_idx}: missing or misordered global start/end specials") |
|
|
| for lead, nseg in zip(leads_present, segments_per_lead): |
| assert_positive_int(nseg, f"segments_per_lead for {lead}") |
| idx_start = next( |
| (i for i, p in enumerate(parts) |
| if p.get("block_index") == block_idx and p.get("kind") == "special" and p.get("lead") == lead |
| and p.get("token") == catalog.lead_to_tokens[lead]["start"]), |
| None |
| ) |
| idx_end = next( |
| (i for i, p in enumerate(parts) |
| if p.get("block_index") == block_idx and p.get("kind") == "special" and p.get("lead") == lead |
| and p.get("token") == catalog.lead_to_tokens[lead]["end"]), |
| None |
| ) |
| if idx_start is None or idx_end is None or not (idx_start < idx_end): |
| raise AssertionError(f"Lead {lead}: missing or misordered start/end specials in block {block_idx}") |
| secs = [p["sec"] for p in parts[idx_start+1:idx_end] if p.get("kind") == "ecg" and p.get("lead") == lead] |
| if secs != list(range(1, int(nseg) + 1)): |
| raise AssertionError(f"Lead {lead}: ECG seconds sequence invalid: {secs} vs 1..{nseg}") |
| if parts[idx_start]["token_index"] != catalog.lead_to_indices[lead]["start"]: |
| raise AssertionError(f"Lead {lead}: start token_index mismatch") |
| if parts[idx_end]["token_index"] != catalog.lead_to_indices[lead]["end"]: |
| raise AssertionError(f"Lead {lead}: end token_index mismatch") |
|
|
|
|
| @_skip_if_assertions_disabled |
| def assert_turn_content_ends_with_eot(text_block: str, end_of_turn: str) -> None: |
| """Ensure turn content ends with the provided end-of-turn suffix.""" |
| if not text_block.endswith(end_of_turn): |
| raise AssertionError("Turn content must end with end_of_turn suffix") |
|
|
|
|
| |
|
|
| @_skip_if_assertions_disabled |
| def assert_leads_canonical_and_ordered(leads_present: List[str], canonical_leads: tuple) -> None: |
| """Validate all leads are canonical (order is explicit and not enforced here).""" |
| lead_list = list(leads_present) |
| if any(ld not in canonical_leads for ld in lead_list): |
| raise AssertionError(f"Non-canonical lead found in leads_present: {lead_list}") |
| if len(set(lead_list)) != len(lead_list): |
| raise AssertionError(f"Duplicate lead detected in leads_present: {lead_list}") |
|
|
|
|
| @_skip_if_assertions_disabled |
| def assert_waveform_shapes_valid( |
| leads_present: List[str], |
| segments_per_lead: List[int], |
| waveform_segments: Dict[str, Any], |
| ) -> None: |
| """Validate waveform tensor shapes and segment counts. |
| |
| Each waveform must be [T, 256] where T matches segments_per_lead. |
| """ |
| for ld, nseg in zip(leads_present, segments_per_lead): |
| assert_positive_int(nseg, f"segments_per_lead[{ld}]") |
| wf = waveform_segments[ld] |
| if wf.ndim != 2 or wf.shape[1] != 256: |
| raise AssertionError(f"Waveform for {ld} must be [T,256], got {tuple(wf.shape)}") |
| assert_equal_int(wf.shape[0], nseg, f"Waveform seconds vs segments_per_lead for {ld}") |
|
|
|
|
| __all__ = [ |
| "assertions_active", |
| "capture_adapter_snapshot", |
| "assert_wrapper_adapter_requires_grad", |
| "assert_wrapper_optimizer_coverage", |
| "assert_adapter_gradients", |
| "assert_adapter_updates", |
| "assert_trainable_param_sync", |
| "assert_tensor_dtype", |
| "assert_only_llava_proj_trainable", |
| "summarize_trainables_llava_lora", |
| "assert_language_lora_only", |
| "assert_single_bos_eos", |
| "assert_ecg_layout_valid", |
| "assert_ecg_mask_against_layout", |
| "assert_single_block_mask_matches_reference", |
| "assert_additive_mask_padding", |
| "assert_nonempty_waveform_segments", |
| "assert_prefix_split_complete", |
| "assert_prefix_matches_segments", |
| "assert_ids_are_lists", |
| "assert_embedding_length_matches_tokens", |
| "assert_ecg_part_bounds", |
| "assert_layout_specials_complete", |
| "assert_labels_match_spans", |
| "assert_wrapper_embed_length", |
| "assert_rest_length_nonnegative", |
| "assert_sorted_non_overlapping_spans", |
| "assert_equal_int", |
| "assert_positive_int", |
| "assert_ecg_catalog_valid", |
| "assert_normalized_role_canonical", |
| "assert_rest_blocks_valid", |
| "assert_tokenization_cursor_matches", |
| "assert_model_spans_valid", |
| "assert_turn_parts_consistent", |
| "assert_ecg_blocks_consistent", |
| "assert_eos_appended", |
| "assert_turn_parts_structure_valid", |
| "assert_turn_content_ends_with_eot", |
| "assert_leads_canonical_and_ordered", |
| "assert_waveform_shapes_valid", |
| "assert_collate_item_valid", |
| ] |