DarshanScripts commited on
Commit
ef90a8e
·
verified ·
1 Parent(s): a1593f9

Upload stratego\datasets\builder.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. stratego//datasets//builder.py +108 -0
stratego//datasets//builder.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # stratego/datasets/builder.py
2
+ """
3
+ Build Hugging Face Datasets from Stratego game CSV logs.
4
+ """
5
+
6
+ from __future__ import annotations
7
+ import os
8
+ import csv
9
+ from pathlib import Path
10
+ from typing import Optional, List, Dict, Any
11
+
12
+ try:
13
+ from datasets import Dataset
14
+ HF_AVAILABLE = True
15
+ except ImportError:
16
+ HF_AVAILABLE = False
17
+ Dataset = None
18
+
19
+
20
+ class GameDatasetBuilder:
21
+ """
22
+ Builds a Hugging Face Dataset from CSV game logs.
23
+ """
24
+
25
+ def __init__(self, logs_dir: str = "logs/games"):
26
+ if not HF_AVAILABLE:
27
+ raise ImportError(
28
+ "Hugging Face datasets not installed. "
29
+ "Run: pip install datasets huggingface_hub"
30
+ )
31
+ self.logs_dir = Path(logs_dir)
32
+ self.moves: List[Dict[str, Any]] = []
33
+
34
+ def _parse_csv_file(self, csv_path: Path) -> List[Dict[str, Any]]:
35
+ """Parse a single game CSV file into move records."""
36
+ moves = []
37
+ game_id = csv_path.stem
38
+
39
+ try:
40
+ with open(csv_path, "r", encoding="utf-8") as f:
41
+ reader = csv.DictReader(f)
42
+ for row in reader:
43
+ move_record = {
44
+ "game_id": game_id,
45
+ "turn": int(row.get("turn", 0)),
46
+ "player": int(row.get("player", 0)),
47
+ "model_name": row.get("model_name", "unknown"),
48
+ "move": row.get("move", ""),
49
+ "from_pos": row.get("from_pos", ""),
50
+ "to_pos": row.get("to_pos", ""),
51
+ "piece_type": row.get("piece_type", ""),
52
+ # New training-relevant fields
53
+ "board_state": row.get("board_state", ""),
54
+ "available_moves": row.get("available_moves", ""),
55
+ "move_direction": row.get("move_direction", ""),
56
+ "target_piece": row.get("target_piece", ""),
57
+ "battle_outcome": row.get("battle_outcome", ""),
58
+ "prompt_name": row.get("prompt_name", ""),
59
+ "game_type": row.get("game_type", "standard"),
60
+ "board_size": int(row.get("board_size", 10)) if row.get("board_size") else 10,
61
+ "game_winner": row.get("game_winner", ""),
62
+ "game_result": row.get("game_result", ""),
63
+ }
64
+ moves.append(move_record)
65
+ except Exception as e:
66
+ print(f"Error parsing {csv_path}: {e}")
67
+
68
+ return moves
69
+
70
+ def scan_logs(self) -> int:
71
+ """Scan logs directory and load all CSV files."""
72
+ self.moves = []
73
+
74
+ if not self.logs_dir.exists():
75
+ return 0
76
+
77
+ csv_files = list(self.logs_dir.glob("*.csv"))
78
+
79
+ for csv_path in csv_files:
80
+ game_moves = self._parse_csv_file(csv_path)
81
+ self.moves.extend(game_moves)
82
+
83
+ return len(csv_files)
84
+
85
+ def build(self) -> "Dataset":
86
+ """Build a Dataset from all game logs."""
87
+ if not self.moves:
88
+ self.scan_logs()
89
+
90
+ if not self.moves:
91
+ raise ValueError("No moves found in logs directory.")
92
+
93
+ return Dataset.from_list(self.moves)
94
+
95
+
96
+ def build_dataset_from_logs(logs_dir: str = "logs/games") -> "Dataset":
97
+ """
98
+ Build a dataset from game logs.
99
+
100
+ Args:
101
+ logs_dir: Path to directory containing game CSV files
102
+
103
+ Returns:
104
+ Dataset with all moves
105
+ """
106
+ builder = GameDatasetBuilder(logs_dir)
107
+ builder.scan_logs()
108
+ return builder.build()