Spaces:
Paused
Paused
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| class ExecutionTrace: | |
| capture_timings: bool = False | |
| m0_full_page_materializations: int = 0 | |
| payload_bytes_read: int = 0 | |
| metadata_bytes_read: int = 0 | |
| host_to_device_bytes: int = 0 | |
| max_temporary_bytes: int = 0 | |
| prepared_page_cache_hits: int = 0 | |
| prepared_page_cache_misses: int = 0 | |
| cache_resident_bytes: int = 0 | |
| prepared_page_cache_evictions: int = 0 | |
| cache_evicted_bytes: int = 0 | |
| prepare_ms_total: float = 0.0 | |
| prepare_calls: int = 0 | |
| score_ms_total: float = 0.0 | |
| score_calls: int = 0 | |
| mix_ms_total: float = 0.0 | |
| mix_calls: int = 0 | |
| softmax_ms_total: float = 0.0 | |
| softmax_calls: int = 0 | |
| unpack_ms_total: float = 0.0 | |
| unpack_calls: int = 0 | |
| fwht_ms_total: float = 0.0 | |
| fwht_calls: int = 0 | |
| chunk_assembly_ms_total: float = 0.0 | |
| chunk_assembly_calls: int = 0 | |
| grouped_decode_calls: int = 0 | |
| grouped_decode_output_only_calls: int = 0 | |
| grouped_score_chunk_count: int = 0 | |
| grouped_mix_chunk_count: int = 0 | |
| grouped_score_chunk_pages_total: int = 0 | |
| grouped_mix_chunk_pages_total: int = 0 | |
| grouped_score_chunk_pages_max: int = 0 | |
| grouped_mix_chunk_pages_max: int = 0 | |
| grouped_logits_elements_total: int = 0 | |
| grouped_weights_elements_total: int = 0 | |
| grouped_output_elements_total: int = 0 | |
| grouped_score_packed_cuda_calls: int = 0 | |
| grouped_score_fused_two_group64_calls: int = 0 | |
| grouped_score_fused_generic_calls: int = 0 | |
| grouped_score_generic_calls: int = 0 | |
| grouped_mix_packed_cuda_calls: int = 0 | |
| grouped_mix_fused_two_group64_calls: int = 0 | |
| grouped_mix_fused_generic_calls: int = 0 | |
| grouped_mix_generic_calls: int = 0 | |
| per_kv_decode_calls: int = 0 | |
| per_kv_score_chunk_count: int = 0 | |
| per_kv_mix_chunk_count: int = 0 | |
| per_kv_score_chunk_pages_total: int = 0 | |
| per_kv_mix_chunk_pages_total: int = 0 | |
| per_kv_score_chunk_pages_max: int = 0 | |
| per_kv_mix_chunk_pages_max: int = 0 | |
| per_kv_logits_elements_total: int = 0 | |
| per_kv_weights_elements_total: int = 0 | |
| per_kv_output_elements_total: int = 0 | |
| per_kv_score_fused_two_group64_calls: int = 0 | |
| per_kv_score_fused_generic_calls: int = 0 | |
| per_kv_score_generic_calls: int = 0 | |
| per_kv_mix_fused_two_group64_calls: int = 0 | |
| per_kv_mix_fused_generic_calls: int = 0 | |
| per_kv_mix_generic_calls: int = 0 | |
| def record_page_read(self, payload_bytes: int, metadata_bytes: int) -> None: | |
| self.payload_bytes_read += int(payload_bytes) | |
| self.metadata_bytes_read += int(metadata_bytes) | |
| def record_host_to_device(self, nbytes: int) -> None: | |
| self.host_to_device_bytes += int(nbytes) | |
| def record_temporary(self, nbytes: int) -> None: | |
| self.max_temporary_bytes = max(self.max_temporary_bytes, int(nbytes)) | |
| def record_m0_full_page_materialization(self, count: int = 1) -> None: | |
| self.m0_full_page_materializations += int(count) | |
| def record_cache_hit(self, count: int = 1) -> None: | |
| self.prepared_page_cache_hits += int(count) | |
| def record_cache_miss(self, count: int = 1) -> None: | |
| self.prepared_page_cache_misses += int(count) | |
| def observe_cache_resident_bytes(self, nbytes: int) -> None: | |
| self.cache_resident_bytes = max(self.cache_resident_bytes, int(nbytes)) | |
| def record_cache_eviction(self, nbytes: int, count: int = 1) -> None: | |
| self.prepared_page_cache_evictions += int(count) | |
| self.cache_evicted_bytes += int(nbytes) | |
| def record_grouped_decode_call(self, *, output_only: bool) -> None: | |
| if output_only: | |
| self.grouped_decode_output_only_calls += 1 | |
| return | |
| self.grouped_decode_calls += 1 | |
| def record_per_kv_decode_call(self) -> None: | |
| self.per_kv_decode_calls += 1 | |
| def record_grouped_score_chunk( | |
| self, | |
| *, | |
| batch_size: int, | |
| query_count: int, | |
| page_count: int, | |
| token_count: int, | |
| ) -> None: | |
| self.grouped_score_chunk_count += 1 | |
| self.grouped_score_chunk_pages_total += int(page_count) | |
| self.grouped_score_chunk_pages_max = max(self.grouped_score_chunk_pages_max, int(page_count)) | |
| self.grouped_logits_elements_total += int(batch_size) * int(query_count) * int(page_count) * int(token_count) | |
| def record_grouped_mix_chunk( | |
| self, | |
| *, | |
| batch_size: int, | |
| query_count: int, | |
| page_count: int, | |
| token_count: int, | |
| head_dim: int, | |
| ) -> None: | |
| self.grouped_mix_chunk_count += 1 | |
| self.grouped_mix_chunk_pages_total += int(page_count) | |
| self.grouped_mix_chunk_pages_max = max(self.grouped_mix_chunk_pages_max, int(page_count)) | |
| self.grouped_weights_elements_total += int(batch_size) * int(query_count) * int(page_count) * int(token_count) | |
| self.grouped_output_elements_total += int(batch_size) * int(query_count) * int(head_dim) | |
| def record_per_kv_score_chunk( | |
| self, | |
| *, | |
| query_count: int, | |
| page_count: int, | |
| token_count: int, | |
| ) -> None: | |
| self.per_kv_score_chunk_count += 1 | |
| self.per_kv_score_chunk_pages_total += int(page_count) | |
| self.per_kv_score_chunk_pages_max = max(self.per_kv_score_chunk_pages_max, int(page_count)) | |
| self.per_kv_logits_elements_total += int(query_count) * int(page_count) * int(token_count) | |
| def record_per_kv_mix_chunk( | |
| self, | |
| *, | |
| query_count: int, | |
| page_count: int, | |
| token_count: int, | |
| head_dim: int, | |
| ) -> None: | |
| self.per_kv_mix_chunk_count += 1 | |
| self.per_kv_mix_chunk_pages_total += int(page_count) | |
| self.per_kv_mix_chunk_pages_max = max(self.per_kv_mix_chunk_pages_max, int(page_count)) | |
| self.per_kv_weights_elements_total += int(query_count) * int(page_count) * int(token_count) | |
| self.per_kv_output_elements_total += int(query_count) * int(head_dim) | |
| def record_grouped_kernel_variant(self, *, section: str, variant: str) -> None: | |
| if section == "score": | |
| if variant == "packed_cuda": | |
| self.grouped_score_packed_cuda_calls += 1 | |
| return | |
| if variant == "fused_two_group64": | |
| self.grouped_score_fused_two_group64_calls += 1 | |
| return | |
| if variant == "fused_generic": | |
| self.grouped_score_fused_generic_calls += 1 | |
| return | |
| if variant == "generic": | |
| self.grouped_score_generic_calls += 1 | |
| return | |
| if section == "mix": | |
| if variant == "packed_cuda": | |
| self.grouped_mix_packed_cuda_calls += 1 | |
| return | |
| if variant == "fused_two_group64": | |
| self.grouped_mix_fused_two_group64_calls += 1 | |
| return | |
| if variant == "fused_generic": | |
| self.grouped_mix_fused_generic_calls += 1 | |
| return | |
| if variant == "generic": | |
| self.grouped_mix_generic_calls += 1 | |
| return | |
| raise ValueError(f"unknown grouped kernel variant: {section}/{variant}") | |
| def record_per_kv_kernel_variant(self, *, section: str, variant: str) -> None: | |
| if section == "score": | |
| if variant == "fused_two_group64": | |
| self.per_kv_score_fused_two_group64_calls += 1 | |
| return | |
| if variant == "fused_generic": | |
| self.per_kv_score_fused_generic_calls += 1 | |
| return | |
| if variant == "generic": | |
| self.per_kv_score_generic_calls += 1 | |
| return | |
| if section == "mix": | |
| if variant == "fused_two_group64": | |
| self.per_kv_mix_fused_two_group64_calls += 1 | |
| return | |
| if variant == "fused_generic": | |
| self.per_kv_mix_fused_generic_calls += 1 | |
| return | |
| if variant == "generic": | |
| self.per_kv_mix_generic_calls += 1 | |
| return | |
| raise ValueError(f"unknown per_kv kernel variant: {section}/{variant}") | |
| def record_timing(self, section: str, ms: float, count: int = 1) -> None: | |
| if section == "prepare": | |
| self.prepare_ms_total += float(ms) | |
| self.prepare_calls += int(count) | |
| return | |
| if section == "score": | |
| self.score_ms_total += float(ms) | |
| self.score_calls += int(count) | |
| return | |
| if section == "mix": | |
| self.mix_ms_total += float(ms) | |
| self.mix_calls += int(count) | |
| return | |
| if section == "softmax": | |
| self.softmax_ms_total += float(ms) | |
| self.softmax_calls += int(count) | |
| return | |
| if section == "unpack": | |
| self.unpack_ms_total += float(ms) | |
| self.unpack_calls += int(count) | |
| return | |
| if section == "fwht": | |
| self.fwht_ms_total += float(ms) | |
| self.fwht_calls += int(count) | |
| return | |
| if section == "chunk_assembly": | |
| self.chunk_assembly_ms_total += float(ms) | |
| self.chunk_assembly_calls += int(count) | |
| return | |
| raise ValueError(f"unknown timing section: {section}") | |
| def merge(self, other: "ExecutionTrace") -> None: | |
| self.m0_full_page_materializations += other.m0_full_page_materializations | |
| self.payload_bytes_read += other.payload_bytes_read | |
| self.metadata_bytes_read += other.metadata_bytes_read | |
| self.host_to_device_bytes += other.host_to_device_bytes | |
| self.max_temporary_bytes = max(self.max_temporary_bytes, other.max_temporary_bytes) | |
| self.prepared_page_cache_hits += other.prepared_page_cache_hits | |
| self.prepared_page_cache_misses += other.prepared_page_cache_misses | |
| self.cache_resident_bytes = max(self.cache_resident_bytes, other.cache_resident_bytes) | |
| self.prepared_page_cache_evictions += other.prepared_page_cache_evictions | |
| self.cache_evicted_bytes += other.cache_evicted_bytes | |
| self.prepare_ms_total += other.prepare_ms_total | |
| self.prepare_calls += other.prepare_calls | |
| self.score_ms_total += other.score_ms_total | |
| self.score_calls += other.score_calls | |
| self.mix_ms_total += other.mix_ms_total | |
| self.mix_calls += other.mix_calls | |
| self.softmax_ms_total += other.softmax_ms_total | |
| self.softmax_calls += other.softmax_calls | |
| self.unpack_ms_total += other.unpack_ms_total | |
| self.unpack_calls += other.unpack_calls | |
| self.fwht_ms_total += other.fwht_ms_total | |
| self.fwht_calls += other.fwht_calls | |
| self.chunk_assembly_ms_total += other.chunk_assembly_ms_total | |
| self.chunk_assembly_calls += other.chunk_assembly_calls | |
| self.grouped_decode_calls += other.grouped_decode_calls | |
| self.grouped_decode_output_only_calls += other.grouped_decode_output_only_calls | |
| self.grouped_score_chunk_count += other.grouped_score_chunk_count | |
| self.grouped_mix_chunk_count += other.grouped_mix_chunk_count | |
| self.grouped_score_chunk_pages_total += other.grouped_score_chunk_pages_total | |
| self.grouped_mix_chunk_pages_total += other.grouped_mix_chunk_pages_total | |
| self.grouped_score_chunk_pages_max = max(self.grouped_score_chunk_pages_max, other.grouped_score_chunk_pages_max) | |
| self.grouped_mix_chunk_pages_max = max(self.grouped_mix_chunk_pages_max, other.grouped_mix_chunk_pages_max) | |
| self.grouped_logits_elements_total += other.grouped_logits_elements_total | |
| self.grouped_weights_elements_total += other.grouped_weights_elements_total | |
| self.grouped_output_elements_total += other.grouped_output_elements_total | |
| self.grouped_score_packed_cuda_calls += other.grouped_score_packed_cuda_calls | |
| self.grouped_score_fused_two_group64_calls += other.grouped_score_fused_two_group64_calls | |
| self.grouped_score_fused_generic_calls += other.grouped_score_fused_generic_calls | |
| self.grouped_score_generic_calls += other.grouped_score_generic_calls | |
| self.grouped_mix_packed_cuda_calls += other.grouped_mix_packed_cuda_calls | |
| self.grouped_mix_fused_two_group64_calls += other.grouped_mix_fused_two_group64_calls | |
| self.grouped_mix_fused_generic_calls += other.grouped_mix_fused_generic_calls | |
| self.grouped_mix_generic_calls += other.grouped_mix_generic_calls | |
| self.per_kv_decode_calls += other.per_kv_decode_calls | |
| self.per_kv_score_chunk_count += other.per_kv_score_chunk_count | |
| self.per_kv_mix_chunk_count += other.per_kv_mix_chunk_count | |
| self.per_kv_score_chunk_pages_total += other.per_kv_score_chunk_pages_total | |
| self.per_kv_mix_chunk_pages_total += other.per_kv_mix_chunk_pages_total | |
| self.per_kv_score_chunk_pages_max = max(self.per_kv_score_chunk_pages_max, other.per_kv_score_chunk_pages_max) | |
| self.per_kv_mix_chunk_pages_max = max(self.per_kv_mix_chunk_pages_max, other.per_kv_mix_chunk_pages_max) | |
| self.per_kv_logits_elements_total += other.per_kv_logits_elements_total | |
| self.per_kv_weights_elements_total += other.per_kv_weights_elements_total | |
| self.per_kv_output_elements_total += other.per_kv_output_elements_total | |
| self.per_kv_score_fused_two_group64_calls += other.per_kv_score_fused_two_group64_calls | |
| self.per_kv_score_fused_generic_calls += other.per_kv_score_fused_generic_calls | |
| self.per_kv_score_generic_calls += other.per_kv_score_generic_calls | |
| self.per_kv_mix_fused_two_group64_calls += other.per_kv_mix_fused_two_group64_calls | |
| self.per_kv_mix_fused_generic_calls += other.per_kv_mix_fused_generic_calls | |
| self.per_kv_mix_generic_calls += other.per_kv_mix_generic_calls | |
| def to_dict(self) -> dict[str, int | float]: | |
| return { | |
| "m0_full_page_materializations": self.m0_full_page_materializations, | |
| "payload_bytes_read": self.payload_bytes_read, | |
| "metadata_bytes_read": self.metadata_bytes_read, | |
| "host_to_device_bytes": self.host_to_device_bytes, | |
| "max_temporary_bytes": self.max_temporary_bytes, | |
| "prepared_page_cache_hits": self.prepared_page_cache_hits, | |
| "prepared_page_cache_misses": self.prepared_page_cache_misses, | |
| "cache_resident_bytes": self.cache_resident_bytes, | |
| "prepared_page_cache_evictions": self.prepared_page_cache_evictions, | |
| "cache_evicted_bytes": self.cache_evicted_bytes, | |
| "prepare_ms_total": self.prepare_ms_total, | |
| "prepare_calls": self.prepare_calls, | |
| "score_ms_total": self.score_ms_total, | |
| "score_calls": self.score_calls, | |
| "mix_ms_total": self.mix_ms_total, | |
| "mix_calls": self.mix_calls, | |
| "softmax_ms_total": self.softmax_ms_total, | |
| "softmax_calls": self.softmax_calls, | |
| "unpack_ms_total": self.unpack_ms_total, | |
| "unpack_calls": self.unpack_calls, | |
| "fwht_ms_total": self.fwht_ms_total, | |
| "fwht_calls": self.fwht_calls, | |
| "chunk_assembly_ms_total": self.chunk_assembly_ms_total, | |
| "chunk_assembly_calls": self.chunk_assembly_calls, | |
| "grouped_decode_calls": self.grouped_decode_calls, | |
| "grouped_decode_output_only_calls": self.grouped_decode_output_only_calls, | |
| "grouped_score_chunk_count": self.grouped_score_chunk_count, | |
| "grouped_mix_chunk_count": self.grouped_mix_chunk_count, | |
| "grouped_score_chunk_pages_total": self.grouped_score_chunk_pages_total, | |
| "grouped_mix_chunk_pages_total": self.grouped_mix_chunk_pages_total, | |
| "grouped_score_chunk_pages_max": self.grouped_score_chunk_pages_max, | |
| "grouped_mix_chunk_pages_max": self.grouped_mix_chunk_pages_max, | |
| "grouped_logits_elements_total": self.grouped_logits_elements_total, | |
| "grouped_weights_elements_total": self.grouped_weights_elements_total, | |
| "grouped_output_elements_total": self.grouped_output_elements_total, | |
| "grouped_score_packed_cuda_calls": self.grouped_score_packed_cuda_calls, | |
| "grouped_score_fused_two_group64_calls": self.grouped_score_fused_two_group64_calls, | |
| "grouped_score_fused_generic_calls": self.grouped_score_fused_generic_calls, | |
| "grouped_score_generic_calls": self.grouped_score_generic_calls, | |
| "grouped_mix_packed_cuda_calls": self.grouped_mix_packed_cuda_calls, | |
| "grouped_mix_fused_two_group64_calls": self.grouped_mix_fused_two_group64_calls, | |
| "grouped_mix_fused_generic_calls": self.grouped_mix_fused_generic_calls, | |
| "grouped_mix_generic_calls": self.grouped_mix_generic_calls, | |
| "per_kv_decode_calls": self.per_kv_decode_calls, | |
| "per_kv_score_chunk_count": self.per_kv_score_chunk_count, | |
| "per_kv_mix_chunk_count": self.per_kv_mix_chunk_count, | |
| "per_kv_score_chunk_pages_total": self.per_kv_score_chunk_pages_total, | |
| "per_kv_mix_chunk_pages_total": self.per_kv_mix_chunk_pages_total, | |
| "per_kv_score_chunk_pages_max": self.per_kv_score_chunk_pages_max, | |
| "per_kv_mix_chunk_pages_max": self.per_kv_mix_chunk_pages_max, | |
| "per_kv_logits_elements_total": self.per_kv_logits_elements_total, | |
| "per_kv_weights_elements_total": self.per_kv_weights_elements_total, | |
| "per_kv_output_elements_total": self.per_kv_output_elements_total, | |
| "per_kv_score_fused_two_group64_calls": self.per_kv_score_fused_two_group64_calls, | |
| "per_kv_score_fused_generic_calls": self.per_kv_score_fused_generic_calls, | |
| "per_kv_score_generic_calls": self.per_kv_score_generic_calls, | |
| "per_kv_mix_fused_two_group64_calls": self.per_kv_mix_fused_two_group64_calls, | |
| "per_kv_mix_fused_generic_calls": self.per_kv_mix_fused_generic_calls, | |
| "per_kv_mix_generic_calls": self.per_kv_mix_generic_calls, | |
| } | |