Artvv's picture
Upload src/persistentpoker_bench/memory_check.py with huggingface_hub
665e120 verified
from __future__ import annotations
from collections import Counter
from dataclasses import dataclass
from persistentpoker_bench.cards import Card, cards_to_notation, parse_cards
@dataclass(frozen=True, slots=True)
class MemoryCheckResult:
exact_match: bool
matched_instances: int
actual_count: int
believed_count: int
precision: float
recall: float
multiset_accuracy: float
missing_cards: tuple[str, ...]
extra_cards: tuple[str, ...]
def evaluate_memory(
believed_pool: list[str] | tuple[str, ...],
actual_pool: list[Card] | tuple[Card, ...],
) -> MemoryCheckResult:
believed_cards = parse_cards(list(believed_pool))
actual_counter = Counter(cards_to_notation(actual_pool))
believed_counter = Counter(cards_to_notation(believed_cards))
matched_counter = actual_counter & believed_counter
missing_counter = actual_counter - believed_counter
extra_counter = believed_counter - actual_counter
matched_instances = sum(matched_counter.values())
actual_count = sum(actual_counter.values())
believed_count = sum(believed_counter.values())
precision = matched_instances / believed_count if believed_count else 1.0
recall = matched_instances / actual_count if actual_count else 1.0
denominator = max(actual_count + believed_count - matched_instances, 1)
multiset_accuracy = matched_instances / denominator
return MemoryCheckResult(
exact_match=actual_counter == believed_counter,
matched_instances=matched_instances,
actual_count=actual_count,
believed_count=believed_count,
precision=precision,
recall=recall,
multiset_accuracy=multiset_accuracy,
missing_cards=_counter_to_sorted_tuple(missing_counter),
extra_cards=_counter_to_sorted_tuple(extra_counter),
)
def _counter_to_sorted_tuple(counter: Counter[str]) -> tuple[str, ...]:
expanded: list[str] = []
for card, count in counter.items():
expanded.extend([card] * count)
return tuple(sorted(expanded))