CAMEL / camel /assertions.py
Mayank Keoliya
Bundle camel library, demo data, and update app.py for HF Space
4a4614d
"""
Assertions and summaries for ECG language-model wrappers and their LoRA adapters.
"""
from __future__ import annotations
from typing import Any, Dict, Iterable, List, Mapping, Set
import functools
import os
import torch
from camel.ecg_attention_masks import ECGBlockLayout
_ASSERTIONS_ENABLED = os.getenv("ASSERTIONS") == "1"
def _skip_if_assertions_disabled(func):
"""Decorator that no-ops assertion helpers when ASSERTIONS env var is not set."""
@functools.wraps(func)
def wrapper(*args, **kwargs):
if not _ASSERTIONS_ENABLED:
return None
return func(*args, **kwargs)
return wrapper
@_skip_if_assertions_disabled
def assert_tensor_dtype(tensor: torch.Tensor, *, expected: torch.dtype, context: str) -> None:
"""Verify tensor dtype matches expectation."""
if tensor.dtype != expected:
raise AssertionError(f"{context}: expected dtype {expected}, got {tensor.dtype}")
@_skip_if_assertions_disabled
def assert_ecg_blocks_consistent(
*,
turn_parts: Iterable[Iterable[Dict[str, Any]]],
ecg_blocks: Iterable[Dict[str, Any]],
) -> None:
"""
Validate that structured turn parts contain the expected ECG markers per block.
"""
blocks_list = list(ecg_blocks)
expected_counts: Dict[int, Dict[str, int]] = {}
per_lead_special_counts: Dict[int, Dict[str, int]] = {}
per_lead_secs: Dict[int, Dict[str, Set[int]]] = {}
global_special_count: Dict[int, int] = {}
for idx, blk in enumerate(blocks_list):
leads = [str(ld) for ld in blk.get("lead_names", [])]
segs = [int(n) for n in blk.get("segments_per_lead", [])]
expected_counts[idx] = {ld: int(n) for ld, n in zip(leads, segs)}
per_lead_special_counts[idx] = {ld: 0 for ld in leads}
per_lead_secs[idx] = {ld: set() for ld in leads}
global_special_count[idx] = 0
for turn in turn_parts:
for part in turn:
kind = part.get("kind")
if kind == "text":
continue
block_idx = part.get("block_index")
if block_idx is None or int(block_idx) not in expected_counts:
raise AssertionError("ECG part references unknown block_index.")
block_idx = int(block_idx)
allowed_leads = set(expected_counts[block_idx].keys())
if kind == "special":
lead = part.get("lead")
if lead:
if lead not in allowed_leads:
raise AssertionError(f"Special token references unknown lead '{lead}'.")
per_lead_special_counts[block_idx][lead] = per_lead_special_counts[block_idx].get(lead, 0) + 1
else:
global_special_count[block_idx] = global_special_count.get(block_idx, 0) + 1
token_literal = part.get("token")
if not isinstance(token_literal, str) or len(token_literal) == 0:
raise AssertionError("Special turn part lacks a string token literal.")
continue
if kind == "ecg":
lead = part.get("lead")
if lead not in allowed_leads:
raise AssertionError(f"ECG segment references unknown lead '{lead}'.")
sec_val = part.get("sec")
expected_total = expected_counts[block_idx].get(lead, 0)
if expected_total <= 0:
raise AssertionError(
f"Lead '{lead}' has non-positive declared segments ({expected_total}) but ECG markers are present."
)
try:
sec = int(sec_val)
except Exception as exc: # noqa: BLE001
raise AssertionError(f"ECG segment for lead '{lead}' has non-integer sec {sec_val!r}.") from exc
if sec < 1 or sec > expected_total:
raise AssertionError(
f"ECG segment for lead '{lead}' has second={sec}, expected within [1,{expected_total}]."
)
if sec in per_lead_secs[block_idx][lead]:
raise AssertionError(f"Duplicate ECG segment marker for lead '{lead}' second {sec}.")
per_lead_secs[block_idx][lead].add(sec)
continue
raise AssertionError(f"Unknown turn_parts kind '{kind}'.")
for block_idx, expected in expected_counts.items():
if global_special_count.get(block_idx, 0) != 2:
raise AssertionError(
f"Expected exactly two global ECG markers for block {block_idx}; "
f"found {global_special_count.get(block_idx, 0)}."
)
for lead, expected_total in expected.items():
expected_specials = per_lead_special_counts[block_idx].get(lead, 0)
if expected_specials != 2:
raise AssertionError(
f"Lead '{lead}' has {expected_specials} special markers; expected start and end (2 total)."
)
seen_secs = per_lead_secs[block_idx].get(lead, set())
if expected_total != len(seen_secs):
missing = sorted(set(range(1, expected_total + 1)) - seen_secs)
raise AssertionError(
f"Lead '{lead}' missing ECG segment markers for seconds {missing} (expected {expected_total})."
)
# ---------------- Trainer batch validation helpers -----------------------------------------------
@_skip_if_assertions_disabled
def assert_prefix_split_complete(*, offset: int, total_prefix_rows: int) -> None:
"""Validate that prefix splitting consumed all rows."""
if offset != total_prefix_rows:
raise RuntimeError(
f"Prefix split mismatch: consumed {offset} rows but have {total_prefix_rows}"
)
@_skip_if_assertions_disabled
def assert_prefix_matches_segments(
*,
prefix_rows: int,
segments_per_lead: Iterable[int],
lead_names: Iterable[str],
sample_index: int,
block_index: int,
) -> None:
"""Validate that prefix row count matches sum of segments_per_lead."""
total_segments = sum(int(n) for n in segments_per_lead)
if prefix_rows != total_segments:
raise RuntimeError(
f"Sample {sample_index} block {block_index}: Prefix rows ({prefix_rows}) "
f"!= sum(segments_per_lead) ({total_segments}). "
f"lead_names={list(lead_names)} segments_per_lead={list(segments_per_lead)}"
)
@_skip_if_assertions_disabled
def assert_ecg_part_bounds(
*,
lead: str,
sec: int,
lead_to_offset: Mapping[str, int],
declared_segments: Mapping[str, int],
total_prefix_rows: int,
sample_index: int,
block_index: int,
) -> None:
"""Validate ECG part (lead, sec) falls within expected bounds."""
if lead not in declared_segments:
raise RuntimeError(f"Unknown lead {lead} in parts for sample {sample_index} block {block_index}")
nseg = int(declared_segments[lead])
if not (1 <= sec <= nseg):
raise RuntimeError(
f"sec out of range for lead {lead}: got {sec}, expected 1..{nseg}"
)
base = lead_to_offset[lead]
start = base
end = base + nseg # exclusive
row_idx = start + (sec - 1)
# Check both global and per-lead bounds
if not (0 <= row_idx < total_prefix_rows):
raise RuntimeError(
f"Bad (lead,sec)=({lead},{sec}) for sample {sample_index} block {block_index}: "
f"row_idx {row_idx} not in [0,{total_prefix_rows})"
)
if not (start <= row_idx < end):
raise RuntimeError(
f"(lead,sec)=({lead},{sec}) maps outside this lead block "
f"[{start},{end}) (row_idx={row_idx}) for sample {sample_index}"
)
@_skip_if_assertions_disabled
def assert_layout_specials_complete(
*,
block_layout: ECGBlockLayout,
lead_names: Iterable[str],
) -> None:
"""Validate that layout has complete and ordered special token markers.
For each declared lead:
- Both start and end must be present (or both absent)
- If present, start < end
For global markers:
- Both start and end must be present (or both absent)
- If present, start < end
"""
# Check per-lead specials
for ld in lead_names:
s = block_layout.lead_start_idx.get(ld)
e = block_layout.lead_end_idx.get(ld)
if (s is None) != (e is None):
raise RuntimeError(f"Lead {ld} missing start/end special (s={s}, e={e})")
if s is not None and not (s < e):
raise RuntimeError(f"Lead {ld} specials out of order: start={s}, end={e}")
# Check global specials
if (block_layout.global_start_idx is None) != (block_layout.global_end_idx is None):
raise RuntimeError("Global start/end special mismatch")
if block_layout.global_start_idx is not None and block_layout.global_end_idx is not None:
if not (block_layout.global_start_idx < block_layout.global_end_idx):
raise RuntimeError(
f"Global specials out of order: start={block_layout.global_start_idx} "
f"end={block_layout.global_end_idx}"
)
# ---------------- Wrapper embedding validations --------------------------------------------------
@_skip_if_assertions_disabled
def assert_wrapper_embed_length(
*,
embeddings: torch.Tensor,
ids: List[int],
context: str,
) -> None:
"""Ensure embedding sequence length matches token count exactly.
This is a critical invariant check that ensures the 1:1 mapping between
input token IDs and output embeddings is preserved.
Args:
embeddings: Output embedding tensor (must be at least 1-D)
ids: Input token ID list
context: Description of where this check is being performed
Raises:
RuntimeError: if ids is not a list, embeddings are not at least 1-D,
or if the embedding length doesn't match the token count
"""
if not isinstance(ids, list):
raise RuntimeError(f"{context}: ids must be a python list of ints")
if embeddings.dim() < 1:
raise RuntimeError(f"{context}: embeddings must be at least 1-D, got shape {tuple(embeddings.shape)}")
if embeddings.size(0) != len(ids):
raise RuntimeError(f"{context}: embed length {embeddings.size(0)} != token count {len(ids)}")
@_skip_if_assertions_disabled
def assert_rest_length_nonnegative(*, rest_length: int) -> None:
"""Validate that rest token length is non-negative.
This should never happen in correct code, but catching it early helps
identify bugs in label construction logic.
Args:
rest_length: Length of ids_rest list
Raises:
ValueError: if rest_length is negative
"""
if rest_length < 0:
raise ValueError("ids_rest length is negative (internal error).")
# ---------------- Utility assertion helpers ------------------------------------------------------
@_skip_if_assertions_disabled
def assert_sorted_non_overlapping_spans(spans: List[tuple[int, int]], length: int, ctx: str) -> None:
"""Validate that spans are sorted, non-overlapping, and within bounds."""
prev_end = -1
for i, (s, e) in enumerate(spans):
if not (0 <= s <= e <= length):
raise AssertionError(f"{ctx}: span {i}={(s,e)} out of bounds for length {length}")
if s < prev_end:
raise AssertionError(f"{ctx}: spans overlap or not sorted at {i-1},{i}: prev_end={prev_end}, curr_start={s}")
prev_end = e
@_skip_if_assertions_disabled
def assert_equal_int(a: int, b: int, msg: str) -> None:
"""Assert two integers are equal with a descriptive message."""
if int(a) != int(b):
raise AssertionError(f"{msg}: {a} != {b}")
@_skip_if_assertions_disabled
def assert_positive_int(n: int, msg: str) -> None:
"""Assert an integer is positive (> 0)."""
if int(n) <= 0:
raise AssertionError(f"{msg}: expected > 0, got {n}")
# ---------------- Schema and catalog validations -------------------------------------------------
@_skip_if_assertions_disabled
def assert_ecg_catalog_valid(catalog: Any, schema: Any) -> None:
"""Validate ECG special token catalog for uniqueness and mapping consistency.
Checks:
- All tokens are unique
- Every canonical lead has entries in lead_to_indices and lead_to_tokens
- Token-to-index mappings are consistent across all structures
- Global markers (start/end) are present in the catalog
"""
# Uniqueness
if len(set(catalog.tokens)) != len(catalog.tokens):
raise AssertionError("ECG special tokens contain duplicates")
# Per-lead mappings
for lead in schema.ecg.canonical_leads:
if lead not in catalog.lead_to_indices or lead not in catalog.lead_to_tokens:
raise AssertionError(f"Missing lead in catalog: {lead}")
for kind in ("start", "end"):
tok = catalog.lead_to_tokens[lead][kind]
idx = catalog.lead_to_indices[lead][kind]
if catalog.tokens[idx] != tok:
raise AssertionError(f"Catalog mismatch for {lead}:{kind}: tokens[idx] != tok")
if catalog.token_to_index.get(tok, None) != idx:
raise AssertionError(f"token_to_index mismatch for {lead}:{kind}")
# Global markers
for tok in (schema.ecg.global_start, schema.ecg.global_end):
if tok not in catalog.token_to_index:
raise AssertionError(f"Global ECG token missing from catalog: {tok}")
# ---------------- Conversation and role validations ----------------------------------------------
@_skip_if_assertions_disabled
def assert_normalized_role_canonical(role: str, schema: Any) -> None:
"""Ensure normalized role matches one of the canonical prompt roles."""
if role not in (schema.prompt.user_role, schema.prompt.model_role):
raise AssertionError(f"Normalized role '{role}' did not resolve to a canonical prompt role")
# ---------------- Tokenization and span validations ----------------------------------------------
@_skip_if_assertions_disabled
def assert_tokenization_cursor_matches(cursor: int, ids_length: int) -> None:
"""Ensure cursor tracking matches actual text_ids length."""
if cursor != ids_length:
raise AssertionError(f"cursor ({cursor}) != len(text_ids) ({ids_length})")
@_skip_if_assertions_disabled
def assert_model_spans_valid(model_spans: List[tuple[int, int]], ids_length: int) -> None:
"""Validate model spans are sorted, non-overlapping, and at least one exists."""
assert_sorted_non_overlapping_spans(model_spans, ids_length, ctx="model_spans_in_text")
if len(model_spans) == 0:
raise AssertionError("No model spans found in text ids")
@_skip_if_assertions_disabled
def assert_eos_appended(ids: List[int], tokenizer: Any, require_eos: bool) -> None:
"""Validate EOS token was appended if required."""
if require_eos and tokenizer.eos_token_id is not None:
if not ids or ids[-1] != tokenizer.eos_token_id:
raise AssertionError("Required EOS was not appended at the end of text_ids")
# ---------------- u0 parts structure validations -------------------------------------------------
@_skip_if_assertions_disabled
def assert_turn_parts_structure_valid(
parts: List[Dict[str, Any]],
ecg_blocks: List[Dict[str, Any]],
schema: Any,
catalog: Any,
) -> None:
"""Validate the complete structure of turn parts for all blocks present."""
block_indices = sorted({int(p.get("block_index")) for p in parts if p.get("block_index") is not None})
for block_idx in block_indices:
if block_idx < 0 or block_idx >= len(ecg_blocks):
raise AssertionError(f"Unknown block_index {block_idx} in turn parts")
blk = ecg_blocks[block_idx]
leads_present = [str(ld) for ld in blk.get("lead_names", [])]
segments_per_lead = [int(n) for n in blk.get("segments_per_lead", [])]
special_tokens = [p.get("token") for p in parts if p.get("kind") == "special" and p.get("block_index") == block_idx]
if schema.ecg.global_start not in special_tokens:
raise AssertionError(f"Missing global_start special in block {block_idx}")
if schema.ecg.global_end not in special_tokens:
raise AssertionError(f"Missing global_end special in block {block_idx}")
idx_global_start = next(
(i for i, p in enumerate(parts)
if p.get("block_index") == block_idx and p.get("kind") == "special"
and p.get("token") == schema.ecg.global_start),
None
)
idx_global_end = next(
(i for i, p in enumerate(parts)
if p.get("block_index") == block_idx and p.get("kind") == "special"
and p.get("token") == schema.ecg.global_end),
None
)
if idx_global_start is None or idx_global_end is None or not (idx_global_start < idx_global_end):
raise AssertionError(f"Block {block_idx}: missing or misordered global start/end specials")
for lead, nseg in zip(leads_present, segments_per_lead):
assert_positive_int(nseg, f"segments_per_lead for {lead}")
idx_start = next(
(i for i, p in enumerate(parts)
if p.get("block_index") == block_idx and p.get("kind") == "special" and p.get("lead") == lead
and p.get("token") == catalog.lead_to_tokens[lead]["start"]),
None
)
idx_end = next(
(i for i, p in enumerate(parts)
if p.get("block_index") == block_idx and p.get("kind") == "special" and p.get("lead") == lead
and p.get("token") == catalog.lead_to_tokens[lead]["end"]),
None
)
if idx_start is None or idx_end is None or not (idx_start < idx_end):
raise AssertionError(f"Lead {lead}: missing or misordered start/end specials in block {block_idx}")
secs = [p["sec"] for p in parts[idx_start+1:idx_end] if p.get("kind") == "ecg" and p.get("lead") == lead]
if secs != list(range(1, int(nseg) + 1)):
raise AssertionError(f"Lead {lead}: ECG seconds sequence invalid: {secs} vs 1..{nseg}")
if parts[idx_start]["token_index"] != catalog.lead_to_indices[lead]["start"]:
raise AssertionError(f"Lead {lead}: start token_index mismatch")
if parts[idx_end]["token_index"] != catalog.lead_to_indices[lead]["end"]:
raise AssertionError(f"Lead {lead}: end token_index mismatch")
@_skip_if_assertions_disabled
def assert_turn_content_ends_with_eot(text_block: str, end_of_turn: str) -> None:
"""Ensure turn content ends with the provided end-of-turn suffix."""
if not text_block.endswith(end_of_turn):
raise AssertionError("Turn content must end with end_of_turn suffix")
# ---------------- Per-sample packing validations -------------------------------------------------
@_skip_if_assertions_disabled
def assert_leads_canonical_and_ordered(leads_present: List[str], canonical_leads: tuple) -> None:
"""Validate all leads are canonical (order is explicit and not enforced here)."""
lead_list = list(leads_present)
if any(ld not in canonical_leads for ld in lead_list):
raise AssertionError(f"Non-canonical lead found in leads_present: {lead_list}")
if len(set(lead_list)) != len(lead_list):
raise AssertionError(f"Duplicate lead detected in leads_present: {lead_list}")
@_skip_if_assertions_disabled
def assert_waveform_shapes_valid(
leads_present: List[str],
segments_per_lead: List[int],
waveform_segments: Dict[str, Any],
) -> None:
"""Validate waveform tensor shapes and segment counts.
Each waveform must be [T, 256] where T matches segments_per_lead.
"""
for ld, nseg in zip(leads_present, segments_per_lead):
assert_positive_int(nseg, f"segments_per_lead[{ld}]")
wf = waveform_segments[ld]
if wf.ndim != 2 or wf.shape[1] != 256:
raise AssertionError(f"Waveform for {ld} must be [T,256], got {tuple(wf.shape)}")
assert_equal_int(wf.shape[0], nseg, f"Waveform seconds vs segments_per_lead for {ld}")
__all__ = [
"assertions_active",
"capture_adapter_snapshot",
"assert_wrapper_adapter_requires_grad",
"assert_wrapper_optimizer_coverage",
"assert_adapter_gradients",
"assert_adapter_updates",
"assert_trainable_param_sync",
"assert_tensor_dtype",
"assert_only_llava_proj_trainable",
"summarize_trainables_llava_lora",
"assert_language_lora_only",
"assert_single_bos_eos",
"assert_ecg_layout_valid",
"assert_ecg_mask_against_layout",
"assert_single_block_mask_matches_reference",
"assert_additive_mask_padding",
"assert_nonempty_waveform_segments",
"assert_prefix_split_complete",
"assert_prefix_matches_segments",
"assert_ids_are_lists",
"assert_embedding_length_matches_tokens",
"assert_ecg_part_bounds",
"assert_layout_specials_complete",
"assert_labels_match_spans",
"assert_wrapper_embed_length",
"assert_rest_length_nonnegative",
"assert_sorted_non_overlapping_spans",
"assert_equal_int",
"assert_positive_int",
"assert_ecg_catalog_valid",
"assert_normalized_role_canonical",
"assert_rest_blocks_valid",
"assert_tokenization_cursor_matches",
"assert_model_spans_valid",
"assert_turn_parts_consistent",
"assert_ecg_blocks_consistent",
"assert_eos_appended",
"assert_turn_parts_structure_valid",
"assert_turn_content_ends_with_eot",
"assert_leads_canonical_and_ordered",
"assert_waveform_shapes_valid",
"assert_collate_item_valid",
]