Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from collections import Counter | |
| from dataclasses import dataclass | |
| from persistentpoker_bench.cards import Card, cards_to_notation, parse_cards | |
| class MemoryCheckResult: | |
| exact_match: bool | |
| matched_instances: int | |
| actual_count: int | |
| believed_count: int | |
| precision: float | |
| recall: float | |
| multiset_accuracy: float | |
| missing_cards: tuple[str, ...] | |
| extra_cards: tuple[str, ...] | |
| def evaluate_memory( | |
| believed_pool: list[str] | tuple[str, ...], | |
| actual_pool: list[Card] | tuple[Card, ...], | |
| ) -> MemoryCheckResult: | |
| believed_cards = parse_cards(list(believed_pool)) | |
| actual_counter = Counter(cards_to_notation(actual_pool)) | |
| believed_counter = Counter(cards_to_notation(believed_cards)) | |
| matched_counter = actual_counter & believed_counter | |
| missing_counter = actual_counter - believed_counter | |
| extra_counter = believed_counter - actual_counter | |
| matched_instances = sum(matched_counter.values()) | |
| actual_count = sum(actual_counter.values()) | |
| believed_count = sum(believed_counter.values()) | |
| precision = matched_instances / believed_count if believed_count else 1.0 | |
| recall = matched_instances / actual_count if actual_count else 1.0 | |
| denominator = max(actual_count + believed_count - matched_instances, 1) | |
| multiset_accuracy = matched_instances / denominator | |
| return MemoryCheckResult( | |
| exact_match=actual_counter == believed_counter, | |
| matched_instances=matched_instances, | |
| actual_count=actual_count, | |
| believed_count=believed_count, | |
| precision=precision, | |
| recall=recall, | |
| multiset_accuracy=multiset_accuracy, | |
| missing_cards=_counter_to_sorted_tuple(missing_counter), | |
| extra_cards=_counter_to_sorted_tuple(extra_counter), | |
| ) | |
| def _counter_to_sorted_tuple(counter: Counter[str]) -> tuple[str, ...]: | |
| expanded: list[str] = [] | |
| for card, count in counter.items(): | |
| expanded.extend([card] * count) | |
| return tuple(sorted(expanded)) | |