Artvv commited on
Commit
97b2358
·
verified ·
1 Parent(s): 035ac0e

Upload src/persistentpoker_bench/leaderboard.py with huggingface_hub

Browse files
src/persistentpoker_bench/leaderboard.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import csv
4
+ from collections import defaultdict
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+
8
+ from persistentpoker_bench.tournament import TournamentResult, flatten_tournament_match_transcript
9
+
10
+
11
+ @dataclass(frozen=True, slots=True)
12
+ class LeaderboardRow:
13
+ track: str
14
+ provider: str
15
+ model_id: str
16
+ display_name: str
17
+ matches_played: int
18
+ hands_played: int
19
+ win_rate: float
20
+ average_final_stack: float
21
+ average_chip_delta: float
22
+ survival_rate: float
23
+ bust_rate: float
24
+ memory_accuracy: float
25
+ parsing_success_rate: float
26
+ reset_rate: float
27
+ average_pool_size: float
28
+ total_input_tokens: int
29
+ total_output_tokens: int
30
+ total_cached_input_tokens: int
31
+ estimated_total_cost: float | None
32
+
33
+
34
+ def build_leaderboard_rows(tournament_result: TournamentResult) -> tuple[LeaderboardRow, ...]:
35
+ aggregates: dict[tuple[str, str], dict[str, object]] = defaultdict(
36
+ lambda: {
37
+ "track": tournament_result.track.value,
38
+ "provider": "",
39
+ "model_id": "",
40
+ "display_name": "",
41
+ "matches_played": 0,
42
+ "hands_played": 0,
43
+ "win_rate_sum": 0.0,
44
+ "average_final_stack_sum": 0.0,
45
+ "average_chip_delta_sum": 0.0,
46
+ "survival_rate_sum": 0.0,
47
+ "bust_rate_sum": 0.0,
48
+ "memory_accuracy_sum": 0.0,
49
+ "parsing_success_rate_sum": 0.0,
50
+ "reset_rate_sum": 0.0,
51
+ "average_pool_size_sum": 0.0,
52
+ "total_input_tokens": 0,
53
+ "total_output_tokens": 0,
54
+ "total_cached_input_tokens": 0,
55
+ "estimated_total_cost": 0.0,
56
+ "have_cost": False,
57
+ }
58
+ )
59
+
60
+ for match_record in tournament_result.match_records:
61
+ transcript_by_player = _group_transcript_by_player(flatten_tournament_match_transcript(match_record))
62
+ for entrant in match_record.entrants:
63
+ key = (entrant.registered_model.provider, entrant.registered_model.model_id)
64
+ aggregate = aggregates[key]
65
+ player_events = transcript_by_player.get(entrant.seat_name, ())
66
+ aggregate["provider"] = entrant.registered_model.provider
67
+ aggregate["model_id"] = entrant.registered_model.model_id
68
+ aggregate["display_name"] = entrant.registered_model.display_name
69
+ aggregate["matches_played"] = int(aggregate["matches_played"]) + 1
70
+ aggregate["hands_played"] = int(aggregate["hands_played"]) + match_record.metrics.hands_played
71
+ aggregate["win_rate_sum"] = float(aggregate["win_rate_sum"]) + match_record.metrics.win_rate_by_player.get(
72
+ entrant.seat_name,
73
+ 0.0,
74
+ )
75
+ aggregate["average_final_stack_sum"] = float(aggregate["average_final_stack_sum"]) + match_record.metrics.final_stacks_by_player.get(
76
+ entrant.seat_name,
77
+ 0,
78
+ )
79
+ aggregate["average_chip_delta_sum"] = float(aggregate["average_chip_delta_sum"]) + match_record.metrics.chip_delta_by_player.get(
80
+ entrant.seat_name,
81
+ 0,
82
+ )
83
+ aggregate["survival_rate_sum"] = float(aggregate["survival_rate_sum"]) + (
84
+ 1.0 if entrant.seat_name in match_record.metrics.surviving_players else 0.0
85
+ )
86
+ aggregate["bust_rate_sum"] = float(aggregate["bust_rate_sum"]) + (
87
+ 1.0 if entrant.seat_name in match_record.metrics.busted_players else 0.0
88
+ )
89
+ aggregate["memory_accuracy_sum"] = float(aggregate["memory_accuracy_sum"]) + _average_memory_accuracy(player_events)
90
+ aggregate["parsing_success_rate_sum"] = float(aggregate["parsing_success_rate_sum"]) + _parsing_success_rate(player_events)
91
+ aggregate["reset_rate_sum"] = float(aggregate["reset_rate_sum"]) + _reset_rate(player_events)
92
+ aggregate["average_pool_size_sum"] = float(aggregate["average_pool_size_sum"]) + match_record.metrics.average_pool_size
93
+ aggregate["total_input_tokens"] = int(aggregate["total_input_tokens"]) + _usage_sum(player_events, "prompt_tokens")
94
+ aggregate["total_output_tokens"] = int(aggregate["total_output_tokens"]) + _usage_sum(player_events, "completion_tokens")
95
+ aggregate["total_cached_input_tokens"] = int(aggregate["total_cached_input_tokens"]) + _usage_sum(player_events, "cached_tokens")
96
+ cost = _usage_cost_sum(player_events)
97
+ if cost is not None:
98
+ aggregate["estimated_total_cost"] = float(aggregate["estimated_total_cost"]) + cost
99
+ aggregate["have_cost"] = True
100
+
101
+ rows: list[LeaderboardRow] = []
102
+ for aggregate in aggregates.values():
103
+ matches_played = int(aggregate["matches_played"])
104
+ rows.append(
105
+ LeaderboardRow(
106
+ track=str(aggregate["track"]),
107
+ provider=str(aggregate["provider"]),
108
+ model_id=str(aggregate["model_id"]),
109
+ display_name=str(aggregate["display_name"]),
110
+ matches_played=matches_played,
111
+ hands_played=int(aggregate["hands_played"]),
112
+ win_rate=float(aggregate["win_rate_sum"]) / matches_played if matches_played else 0.0,
113
+ average_final_stack=float(aggregate["average_final_stack_sum"]) / matches_played if matches_played else 0.0,
114
+ average_chip_delta=float(aggregate["average_chip_delta_sum"]) / matches_played if matches_played else 0.0,
115
+ survival_rate=float(aggregate["survival_rate_sum"]) / matches_played if matches_played else 0.0,
116
+ bust_rate=float(aggregate["bust_rate_sum"]) / matches_played if matches_played else 0.0,
117
+ memory_accuracy=float(aggregate["memory_accuracy_sum"]) / matches_played if matches_played else 1.0,
118
+ parsing_success_rate=float(aggregate["parsing_success_rate_sum"]) / matches_played if matches_played else 1.0,
119
+ reset_rate=float(aggregate["reset_rate_sum"]) / matches_played if matches_played else 0.0,
120
+ average_pool_size=float(aggregate["average_pool_size_sum"]) / matches_played if matches_played else 0.0,
121
+ total_input_tokens=int(aggregate["total_input_tokens"]),
122
+ total_output_tokens=int(aggregate["total_output_tokens"]),
123
+ total_cached_input_tokens=int(aggregate["total_cached_input_tokens"]),
124
+ estimated_total_cost=(
125
+ float(aggregate["estimated_total_cost"]) if bool(aggregate["have_cost"]) else None
126
+ ),
127
+ )
128
+ )
129
+
130
+ rows.sort(
131
+ key=lambda row: (
132
+ -row.average_chip_delta,
133
+ -row.average_final_stack,
134
+ -row.survival_rate,
135
+ -row.win_rate,
136
+ -row.memory_accuracy,
137
+ row.estimated_total_cost or 0.0,
138
+ row.display_name,
139
+ )
140
+ )
141
+ return tuple(rows)
142
+
143
+
144
+ def export_leaderboard_csv(rows: tuple[LeaderboardRow, ...], path: str | Path) -> Path:
145
+ destination = Path(path)
146
+ destination.parent.mkdir(parents=True, exist_ok=True)
147
+ fieldnames = [
148
+ "track",
149
+ "provider",
150
+ "model_id",
151
+ "display_name",
152
+ "matches_played",
153
+ "hands_played",
154
+ "win_rate",
155
+ "average_final_stack",
156
+ "average_chip_delta",
157
+ "survival_rate",
158
+ "bust_rate",
159
+ "memory_accuracy",
160
+ "parsing_success_rate",
161
+ "reset_rate",
162
+ "average_pool_size",
163
+ "total_input_tokens",
164
+ "total_output_tokens",
165
+ "total_cached_input_tokens",
166
+ "estimated_total_cost",
167
+ ]
168
+ with destination.open("w", encoding="utf-8", newline="") as handle:
169
+ writer = csv.DictWriter(handle, fieldnames=fieldnames)
170
+ writer.writeheader()
171
+ for row in rows:
172
+ writer.writerow(
173
+ {
174
+ "track": row.track,
175
+ "provider": row.provider,
176
+ "model_id": row.model_id,
177
+ "display_name": row.display_name,
178
+ "matches_played": row.matches_played,
179
+ "hands_played": row.hands_played,
180
+ "win_rate": row.win_rate,
181
+ "average_final_stack": row.average_final_stack,
182
+ "average_chip_delta": row.average_chip_delta,
183
+ "survival_rate": row.survival_rate,
184
+ "bust_rate": row.bust_rate,
185
+ "memory_accuracy": row.memory_accuracy,
186
+ "parsing_success_rate": row.parsing_success_rate,
187
+ "reset_rate": row.reset_rate,
188
+ "average_pool_size": row.average_pool_size,
189
+ "total_input_tokens": row.total_input_tokens,
190
+ "total_output_tokens": row.total_output_tokens,
191
+ "total_cached_input_tokens": row.total_cached_input_tokens,
192
+ "estimated_total_cost": row.estimated_total_cost,
193
+ }
194
+ )
195
+ return destination
196
+
197
+
198
+ def _group_transcript_by_player(
199
+ transcript: tuple[dict[str, object], ...],
200
+ ) -> dict[str, tuple[dict[str, object], ...]]:
201
+ grouped: dict[str, list[dict[str, object]]] = defaultdict(list)
202
+ for event in transcript:
203
+ player_name = event.get("player_name")
204
+ if isinstance(player_name, str):
205
+ grouped[player_name].append(event)
206
+ return {player_name: tuple(events) for player_name, events in grouped.items()}
207
+
208
+
209
+ def _average_memory_accuracy(events: tuple[dict[str, object], ...]) -> float:
210
+ if not events:
211
+ return 1.0
212
+ scores = [
213
+ float(event["memory"]["multiset_accuracy"])
214
+ for event in events
215
+ if isinstance(event.get("memory"), dict) and "multiset_accuracy" in event["memory"]
216
+ ]
217
+ return sum(scores) / len(scores) if scores else 1.0
218
+
219
+
220
+ def _parsing_success_rate(events: tuple[dict[str, object], ...]) -> float:
221
+ if not events:
222
+ return 1.0
223
+ successful = sum(1 for event in events if event.get("parse_mode"))
224
+ return successful / len(events)
225
+
226
+
227
+ def _reset_rate(events: tuple[dict[str, object], ...]) -> float:
228
+ if not events:
229
+ return 0.0
230
+ resets = sum(1 for event in events if event.get("winner_pool_decision") == "reset")
231
+ return resets / len(events)
232
+
233
+
234
+ def _usage_sum(events: tuple[dict[str, object], ...], key: str) -> int:
235
+ total = 0
236
+ for event in events:
237
+ usage = event.get("usage")
238
+ if isinstance(usage, dict):
239
+ value = usage.get(key)
240
+ if isinstance(value, int):
241
+ total += value
242
+ return total
243
+
244
+
245
+ def _usage_cost_sum(events: tuple[dict[str, object], ...]) -> float | None:
246
+ total = 0.0
247
+ have_cost = False
248
+ for event in events:
249
+ usage = event.get("usage")
250
+ if isinstance(usage, dict):
251
+ value = usage.get("estimated_cost")
252
+ if isinstance(value, int | float):
253
+ total += float(value)
254
+ have_cost = True
255
+ return total if have_cost else None