DotCache-Arena / dotcache /tracing.py
DeanoCalver's picture
Initial DotCache Arena Space upload
751ad26 verified
Raw
History Blame Contribute Delete
18.1 kB
from __future__ import annotations
from dataclasses import dataclass
@dataclass(slots=True)
class ExecutionTrace:
capture_timings: bool = False
m0_full_page_materializations: int = 0
payload_bytes_read: int = 0
metadata_bytes_read: int = 0
host_to_device_bytes: int = 0
max_temporary_bytes: int = 0
prepared_page_cache_hits: int = 0
prepared_page_cache_misses: int = 0
cache_resident_bytes: int = 0
prepared_page_cache_evictions: int = 0
cache_evicted_bytes: int = 0
prepare_ms_total: float = 0.0
prepare_calls: int = 0
score_ms_total: float = 0.0
score_calls: int = 0
mix_ms_total: float = 0.0
mix_calls: int = 0
softmax_ms_total: float = 0.0
softmax_calls: int = 0
unpack_ms_total: float = 0.0
unpack_calls: int = 0
fwht_ms_total: float = 0.0
fwht_calls: int = 0
chunk_assembly_ms_total: float = 0.0
chunk_assembly_calls: int = 0
grouped_decode_calls: int = 0
grouped_decode_output_only_calls: int = 0
grouped_score_chunk_count: int = 0
grouped_mix_chunk_count: int = 0
grouped_score_chunk_pages_total: int = 0
grouped_mix_chunk_pages_total: int = 0
grouped_score_chunk_pages_max: int = 0
grouped_mix_chunk_pages_max: int = 0
grouped_logits_elements_total: int = 0
grouped_weights_elements_total: int = 0
grouped_output_elements_total: int = 0
grouped_score_packed_cuda_calls: int = 0
grouped_score_fused_two_group64_calls: int = 0
grouped_score_fused_generic_calls: int = 0
grouped_score_generic_calls: int = 0
grouped_mix_packed_cuda_calls: int = 0
grouped_mix_fused_two_group64_calls: int = 0
grouped_mix_fused_generic_calls: int = 0
grouped_mix_generic_calls: int = 0
per_kv_decode_calls: int = 0
per_kv_score_chunk_count: int = 0
per_kv_mix_chunk_count: int = 0
per_kv_score_chunk_pages_total: int = 0
per_kv_mix_chunk_pages_total: int = 0
per_kv_score_chunk_pages_max: int = 0
per_kv_mix_chunk_pages_max: int = 0
per_kv_logits_elements_total: int = 0
per_kv_weights_elements_total: int = 0
per_kv_output_elements_total: int = 0
per_kv_score_fused_two_group64_calls: int = 0
per_kv_score_fused_generic_calls: int = 0
per_kv_score_generic_calls: int = 0
per_kv_mix_fused_two_group64_calls: int = 0
per_kv_mix_fused_generic_calls: int = 0
per_kv_mix_generic_calls: int = 0
def record_page_read(self, payload_bytes: int, metadata_bytes: int) -> None:
self.payload_bytes_read += int(payload_bytes)
self.metadata_bytes_read += int(metadata_bytes)
def record_host_to_device(self, nbytes: int) -> None:
self.host_to_device_bytes += int(nbytes)
def record_temporary(self, nbytes: int) -> None:
self.max_temporary_bytes = max(self.max_temporary_bytes, int(nbytes))
def record_m0_full_page_materialization(self, count: int = 1) -> None:
self.m0_full_page_materializations += int(count)
def record_cache_hit(self, count: int = 1) -> None:
self.prepared_page_cache_hits += int(count)
def record_cache_miss(self, count: int = 1) -> None:
self.prepared_page_cache_misses += int(count)
def observe_cache_resident_bytes(self, nbytes: int) -> None:
self.cache_resident_bytes = max(self.cache_resident_bytes, int(nbytes))
def record_cache_eviction(self, nbytes: int, count: int = 1) -> None:
self.prepared_page_cache_evictions += int(count)
self.cache_evicted_bytes += int(nbytes)
def record_grouped_decode_call(self, *, output_only: bool) -> None:
if output_only:
self.grouped_decode_output_only_calls += 1
return
self.grouped_decode_calls += 1
def record_per_kv_decode_call(self) -> None:
self.per_kv_decode_calls += 1
def record_grouped_score_chunk(
self,
*,
batch_size: int,
query_count: int,
page_count: int,
token_count: int,
) -> None:
self.grouped_score_chunk_count += 1
self.grouped_score_chunk_pages_total += int(page_count)
self.grouped_score_chunk_pages_max = max(self.grouped_score_chunk_pages_max, int(page_count))
self.grouped_logits_elements_total += int(batch_size) * int(query_count) * int(page_count) * int(token_count)
def record_grouped_mix_chunk(
self,
*,
batch_size: int,
query_count: int,
page_count: int,
token_count: int,
head_dim: int,
) -> None:
self.grouped_mix_chunk_count += 1
self.grouped_mix_chunk_pages_total += int(page_count)
self.grouped_mix_chunk_pages_max = max(self.grouped_mix_chunk_pages_max, int(page_count))
self.grouped_weights_elements_total += int(batch_size) * int(query_count) * int(page_count) * int(token_count)
self.grouped_output_elements_total += int(batch_size) * int(query_count) * int(head_dim)
def record_per_kv_score_chunk(
self,
*,
query_count: int,
page_count: int,
token_count: int,
) -> None:
self.per_kv_score_chunk_count += 1
self.per_kv_score_chunk_pages_total += int(page_count)
self.per_kv_score_chunk_pages_max = max(self.per_kv_score_chunk_pages_max, int(page_count))
self.per_kv_logits_elements_total += int(query_count) * int(page_count) * int(token_count)
def record_per_kv_mix_chunk(
self,
*,
query_count: int,
page_count: int,
token_count: int,
head_dim: int,
) -> None:
self.per_kv_mix_chunk_count += 1
self.per_kv_mix_chunk_pages_total += int(page_count)
self.per_kv_mix_chunk_pages_max = max(self.per_kv_mix_chunk_pages_max, int(page_count))
self.per_kv_weights_elements_total += int(query_count) * int(page_count) * int(token_count)
self.per_kv_output_elements_total += int(query_count) * int(head_dim)
def record_grouped_kernel_variant(self, *, section: str, variant: str) -> None:
if section == "score":
if variant == "packed_cuda":
self.grouped_score_packed_cuda_calls += 1
return
if variant == "fused_two_group64":
self.grouped_score_fused_two_group64_calls += 1
return
if variant == "fused_generic":
self.grouped_score_fused_generic_calls += 1
return
if variant == "generic":
self.grouped_score_generic_calls += 1
return
if section == "mix":
if variant == "packed_cuda":
self.grouped_mix_packed_cuda_calls += 1
return
if variant == "fused_two_group64":
self.grouped_mix_fused_two_group64_calls += 1
return
if variant == "fused_generic":
self.grouped_mix_fused_generic_calls += 1
return
if variant == "generic":
self.grouped_mix_generic_calls += 1
return
raise ValueError(f"unknown grouped kernel variant: {section}/{variant}")
def record_per_kv_kernel_variant(self, *, section: str, variant: str) -> None:
if section == "score":
if variant == "fused_two_group64":
self.per_kv_score_fused_two_group64_calls += 1
return
if variant == "fused_generic":
self.per_kv_score_fused_generic_calls += 1
return
if variant == "generic":
self.per_kv_score_generic_calls += 1
return
if section == "mix":
if variant == "fused_two_group64":
self.per_kv_mix_fused_two_group64_calls += 1
return
if variant == "fused_generic":
self.per_kv_mix_fused_generic_calls += 1
return
if variant == "generic":
self.per_kv_mix_generic_calls += 1
return
raise ValueError(f"unknown per_kv kernel variant: {section}/{variant}")
def record_timing(self, section: str, ms: float, count: int = 1) -> None:
if section == "prepare":
self.prepare_ms_total += float(ms)
self.prepare_calls += int(count)
return
if section == "score":
self.score_ms_total += float(ms)
self.score_calls += int(count)
return
if section == "mix":
self.mix_ms_total += float(ms)
self.mix_calls += int(count)
return
if section == "softmax":
self.softmax_ms_total += float(ms)
self.softmax_calls += int(count)
return
if section == "unpack":
self.unpack_ms_total += float(ms)
self.unpack_calls += int(count)
return
if section == "fwht":
self.fwht_ms_total += float(ms)
self.fwht_calls += int(count)
return
if section == "chunk_assembly":
self.chunk_assembly_ms_total += float(ms)
self.chunk_assembly_calls += int(count)
return
raise ValueError(f"unknown timing section: {section}")
def merge(self, other: "ExecutionTrace") -> None:
self.m0_full_page_materializations += other.m0_full_page_materializations
self.payload_bytes_read += other.payload_bytes_read
self.metadata_bytes_read += other.metadata_bytes_read
self.host_to_device_bytes += other.host_to_device_bytes
self.max_temporary_bytes = max(self.max_temporary_bytes, other.max_temporary_bytes)
self.prepared_page_cache_hits += other.prepared_page_cache_hits
self.prepared_page_cache_misses += other.prepared_page_cache_misses
self.cache_resident_bytes = max(self.cache_resident_bytes, other.cache_resident_bytes)
self.prepared_page_cache_evictions += other.prepared_page_cache_evictions
self.cache_evicted_bytes += other.cache_evicted_bytes
self.prepare_ms_total += other.prepare_ms_total
self.prepare_calls += other.prepare_calls
self.score_ms_total += other.score_ms_total
self.score_calls += other.score_calls
self.mix_ms_total += other.mix_ms_total
self.mix_calls += other.mix_calls
self.softmax_ms_total += other.softmax_ms_total
self.softmax_calls += other.softmax_calls
self.unpack_ms_total += other.unpack_ms_total
self.unpack_calls += other.unpack_calls
self.fwht_ms_total += other.fwht_ms_total
self.fwht_calls += other.fwht_calls
self.chunk_assembly_ms_total += other.chunk_assembly_ms_total
self.chunk_assembly_calls += other.chunk_assembly_calls
self.grouped_decode_calls += other.grouped_decode_calls
self.grouped_decode_output_only_calls += other.grouped_decode_output_only_calls
self.grouped_score_chunk_count += other.grouped_score_chunk_count
self.grouped_mix_chunk_count += other.grouped_mix_chunk_count
self.grouped_score_chunk_pages_total += other.grouped_score_chunk_pages_total
self.grouped_mix_chunk_pages_total += other.grouped_mix_chunk_pages_total
self.grouped_score_chunk_pages_max = max(self.grouped_score_chunk_pages_max, other.grouped_score_chunk_pages_max)
self.grouped_mix_chunk_pages_max = max(self.grouped_mix_chunk_pages_max, other.grouped_mix_chunk_pages_max)
self.grouped_logits_elements_total += other.grouped_logits_elements_total
self.grouped_weights_elements_total += other.grouped_weights_elements_total
self.grouped_output_elements_total += other.grouped_output_elements_total
self.grouped_score_packed_cuda_calls += other.grouped_score_packed_cuda_calls
self.grouped_score_fused_two_group64_calls += other.grouped_score_fused_two_group64_calls
self.grouped_score_fused_generic_calls += other.grouped_score_fused_generic_calls
self.grouped_score_generic_calls += other.grouped_score_generic_calls
self.grouped_mix_packed_cuda_calls += other.grouped_mix_packed_cuda_calls
self.grouped_mix_fused_two_group64_calls += other.grouped_mix_fused_two_group64_calls
self.grouped_mix_fused_generic_calls += other.grouped_mix_fused_generic_calls
self.grouped_mix_generic_calls += other.grouped_mix_generic_calls
self.per_kv_decode_calls += other.per_kv_decode_calls
self.per_kv_score_chunk_count += other.per_kv_score_chunk_count
self.per_kv_mix_chunk_count += other.per_kv_mix_chunk_count
self.per_kv_score_chunk_pages_total += other.per_kv_score_chunk_pages_total
self.per_kv_mix_chunk_pages_total += other.per_kv_mix_chunk_pages_total
self.per_kv_score_chunk_pages_max = max(self.per_kv_score_chunk_pages_max, other.per_kv_score_chunk_pages_max)
self.per_kv_mix_chunk_pages_max = max(self.per_kv_mix_chunk_pages_max, other.per_kv_mix_chunk_pages_max)
self.per_kv_logits_elements_total += other.per_kv_logits_elements_total
self.per_kv_weights_elements_total += other.per_kv_weights_elements_total
self.per_kv_output_elements_total += other.per_kv_output_elements_total
self.per_kv_score_fused_two_group64_calls += other.per_kv_score_fused_two_group64_calls
self.per_kv_score_fused_generic_calls += other.per_kv_score_fused_generic_calls
self.per_kv_score_generic_calls += other.per_kv_score_generic_calls
self.per_kv_mix_fused_two_group64_calls += other.per_kv_mix_fused_two_group64_calls
self.per_kv_mix_fused_generic_calls += other.per_kv_mix_fused_generic_calls
self.per_kv_mix_generic_calls += other.per_kv_mix_generic_calls
def to_dict(self) -> dict[str, int | float]:
return {
"m0_full_page_materializations": self.m0_full_page_materializations,
"payload_bytes_read": self.payload_bytes_read,
"metadata_bytes_read": self.metadata_bytes_read,
"host_to_device_bytes": self.host_to_device_bytes,
"max_temporary_bytes": self.max_temporary_bytes,
"prepared_page_cache_hits": self.prepared_page_cache_hits,
"prepared_page_cache_misses": self.prepared_page_cache_misses,
"cache_resident_bytes": self.cache_resident_bytes,
"prepared_page_cache_evictions": self.prepared_page_cache_evictions,
"cache_evicted_bytes": self.cache_evicted_bytes,
"prepare_ms_total": self.prepare_ms_total,
"prepare_calls": self.prepare_calls,
"score_ms_total": self.score_ms_total,
"score_calls": self.score_calls,
"mix_ms_total": self.mix_ms_total,
"mix_calls": self.mix_calls,
"softmax_ms_total": self.softmax_ms_total,
"softmax_calls": self.softmax_calls,
"unpack_ms_total": self.unpack_ms_total,
"unpack_calls": self.unpack_calls,
"fwht_ms_total": self.fwht_ms_total,
"fwht_calls": self.fwht_calls,
"chunk_assembly_ms_total": self.chunk_assembly_ms_total,
"chunk_assembly_calls": self.chunk_assembly_calls,
"grouped_decode_calls": self.grouped_decode_calls,
"grouped_decode_output_only_calls": self.grouped_decode_output_only_calls,
"grouped_score_chunk_count": self.grouped_score_chunk_count,
"grouped_mix_chunk_count": self.grouped_mix_chunk_count,
"grouped_score_chunk_pages_total": self.grouped_score_chunk_pages_total,
"grouped_mix_chunk_pages_total": self.grouped_mix_chunk_pages_total,
"grouped_score_chunk_pages_max": self.grouped_score_chunk_pages_max,
"grouped_mix_chunk_pages_max": self.grouped_mix_chunk_pages_max,
"grouped_logits_elements_total": self.grouped_logits_elements_total,
"grouped_weights_elements_total": self.grouped_weights_elements_total,
"grouped_output_elements_total": self.grouped_output_elements_total,
"grouped_score_packed_cuda_calls": self.grouped_score_packed_cuda_calls,
"grouped_score_fused_two_group64_calls": self.grouped_score_fused_two_group64_calls,
"grouped_score_fused_generic_calls": self.grouped_score_fused_generic_calls,
"grouped_score_generic_calls": self.grouped_score_generic_calls,
"grouped_mix_packed_cuda_calls": self.grouped_mix_packed_cuda_calls,
"grouped_mix_fused_two_group64_calls": self.grouped_mix_fused_two_group64_calls,
"grouped_mix_fused_generic_calls": self.grouped_mix_fused_generic_calls,
"grouped_mix_generic_calls": self.grouped_mix_generic_calls,
"per_kv_decode_calls": self.per_kv_decode_calls,
"per_kv_score_chunk_count": self.per_kv_score_chunk_count,
"per_kv_mix_chunk_count": self.per_kv_mix_chunk_count,
"per_kv_score_chunk_pages_total": self.per_kv_score_chunk_pages_total,
"per_kv_mix_chunk_pages_total": self.per_kv_mix_chunk_pages_total,
"per_kv_score_chunk_pages_max": self.per_kv_score_chunk_pages_max,
"per_kv_mix_chunk_pages_max": self.per_kv_mix_chunk_pages_max,
"per_kv_logits_elements_total": self.per_kv_logits_elements_total,
"per_kv_weights_elements_total": self.per_kv_weights_elements_total,
"per_kv_output_elements_total": self.per_kv_output_elements_total,
"per_kv_score_fused_two_group64_calls": self.per_kv_score_fused_two_group64_calls,
"per_kv_score_fused_generic_calls": self.per_kv_score_fused_generic_calls,
"per_kv_score_generic_calls": self.per_kv_score_generic_calls,
"per_kv_mix_fused_two_group64_calls": self.per_kv_mix_fused_two_group64_calls,
"per_kv_mix_fused_generic_calls": self.per_kv_mix_fused_generic_calls,
"per_kv_mix_generic_calls": self.per_kv_mix_generic_calls,
}