Spaces:
Sleeping
Sleeping
File size: 2,055 Bytes
665e120 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 | from __future__ import annotations
from collections import Counter
from dataclasses import dataclass
from persistentpoker_bench.cards import Card, cards_to_notation, parse_cards
@dataclass(frozen=True, slots=True)
class MemoryCheckResult:
exact_match: bool
matched_instances: int
actual_count: int
believed_count: int
precision: float
recall: float
multiset_accuracy: float
missing_cards: tuple[str, ...]
extra_cards: tuple[str, ...]
def evaluate_memory(
believed_pool: list[str] | tuple[str, ...],
actual_pool: list[Card] | tuple[Card, ...],
) -> MemoryCheckResult:
believed_cards = parse_cards(list(believed_pool))
actual_counter = Counter(cards_to_notation(actual_pool))
believed_counter = Counter(cards_to_notation(believed_cards))
matched_counter = actual_counter & believed_counter
missing_counter = actual_counter - believed_counter
extra_counter = believed_counter - actual_counter
matched_instances = sum(matched_counter.values())
actual_count = sum(actual_counter.values())
believed_count = sum(believed_counter.values())
precision = matched_instances / believed_count if believed_count else 1.0
recall = matched_instances / actual_count if actual_count else 1.0
denominator = max(actual_count + believed_count - matched_instances, 1)
multiset_accuracy = matched_instances / denominator
return MemoryCheckResult(
exact_match=actual_counter == believed_counter,
matched_instances=matched_instances,
actual_count=actual_count,
believed_count=believed_count,
precision=precision,
recall=recall,
multiset_accuracy=multiset_accuracy,
missing_cards=_counter_to_sorted_tuple(missing_counter),
extra_cards=_counter_to_sorted_tuple(extra_counter),
)
def _counter_to_sorted_tuple(counter: Counter[str]) -> tuple[str, ...]:
expanded: list[str] = []
for card, count in counter.items():
expanded.extend([card] * count)
return tuple(sorted(expanded))
|