Sunxt25 commited on
Commit
ac4ed54
·
verified ·
1 Parent(s): e25c15f

Delete data.py

Browse files
Files changed (1) hide show
  1. data.py +0 -155
data.py DELETED
@@ -1,155 +0,0 @@
1
- """
2
- Data loading utilities for the Chess Challenge using color+piece/from/to tokenizer.
3
- """
4
-
5
- from __future__ import annotations
6
- from typing import Dict, Iterator, List, Optional
7
- import torch
8
- from torch.utils.data import Dataset
9
-
10
- class ChessDataset(Dataset):
11
- """
12
- PyTorch Dataset for chess games with color+piece/from/to tokenizer.
13
- """
14
-
15
- def __init__(
16
- self,
17
- tokenizer,
18
- dataset_name: str = "dlouapre/lichess_2025-01_1M",
19
- split: str = "train",
20
- column: str = "text",
21
- max_length: int = 256,
22
- max_samples: Optional[int] = None,
23
- ):
24
- from datasets import load_dataset
25
-
26
- self.tokenizer = tokenizer
27
- self.max_length = max_length
28
- self.column = column
29
-
30
- # Load dataset
31
- dataset = load_dataset(dataset_name, split=split)
32
- if max_samples is not None:
33
- dataset = dataset.select(range(min(max_samples, len(dataset))))
34
- self.data = dataset
35
-
36
- def __len__(self) -> int:
37
- return len(self.data)
38
-
39
- def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
40
- game = self.data[idx][self.column]
41
-
42
- # Prepend BOS token
43
- game_with_bos = self.tokenizer.bos_token + " " + game
44
-
45
- # Tokenize: tokenizer 已经拆成 color+piece/from/to
46
- encoding = self.tokenizer(
47
- game_with_bos,
48
- truncation=True,
49
- max_length=self.max_length,
50
- padding="max_length",
51
- return_tensors="pt",
52
- )
53
-
54
- input_ids = encoding["input_ids"].squeeze(0)
55
- attention_mask = encoding["attention_mask"].squeeze(0)
56
-
57
- # Labels = input_ids (shift internally)
58
- labels = input_ids.clone()
59
- labels[attention_mask == 0] = -100 # ignore padding in loss
60
-
61
- return {
62
- "input_ids": input_ids,
63
- "attention_mask": attention_mask,
64
- "labels": labels,
65
- }
66
-
67
-
68
- class ChessDataCollator:
69
- """Data collator for chess games."""
70
-
71
- def __init__(self, tokenizer, max_length: int = 256):
72
- self.tokenizer = tokenizer
73
- self.max_length = max_length
74
-
75
- def __call__(self, features: List[Dict]) -> Dict[str, torch.Tensor]:
76
- input_ids = torch.stack([f["input_ids"] for f in features])
77
- attention_mask = torch.stack([f["attention_mask"] for f in features])
78
- labels = torch.stack([f["labels"] for f in features])
79
- return {
80
- "input_ids": input_ids,
81
- "attention_mask": attention_mask,
82
- "labels": labels,
83
- }
84
-
85
-
86
- def create_train_val_datasets(
87
- tokenizer,
88
- dataset_name: str = "dlouapre/lichess_2025-01_1M",
89
- max_length: int = 256,
90
- train_samples: Optional[int] = None,
91
- val_samples: int = 5000,
92
- val_ratio: float = 0.05,
93
- ):
94
- from datasets import load_dataset
95
-
96
- full_dataset = load_dataset(dataset_name, split="train")
97
- total = len(full_dataset)
98
-
99
- if train_samples is not None:
100
- n_train = min(train_samples, total - val_samples)
101
- else:
102
- n_train = int(total * (1 - val_ratio))
103
-
104
- n_val = min(val_samples, total - n_train)
105
-
106
- train_data = full_dataset.select(range(n_train))
107
- val_data = full_dataset.select(range(n_train, n_train + n_val))
108
-
109
- train_dataset = ChessDataset(tokenizer=tokenizer, dataset_name=dataset_name, max_length=max_length)
110
- train_dataset.data = train_data
111
-
112
- val_dataset = ChessDataset(tokenizer=tokenizer, dataset_name=dataset_name, max_length=max_length)
113
- val_dataset.data = val_data
114
-
115
- return train_dataset, val_dataset
116
-
117
-
118
- def stream_games(dataset_name: str = "dlouapre/lichess_2025-01_1M", split: str = "train", column: str = "text") -> Iterator[str]:
119
- """Stream games for memory-efficient processing."""
120
- from datasets import load_dataset
121
-
122
- dataset = load_dataset(dataset_name, split=split, streaming=True)
123
- for example in dataset:
124
- yield example[column]
125
-
126
-
127
- def analyze_dataset_statistics(dataset_name: str = "dlouapre/lichess_2025-01_1M", max_samples: int = 10000) -> Dict:
128
- """Analyze chess dataset statistics."""
129
- from collections import Counter
130
- from datasets import load_dataset
131
-
132
- dataset = load_dataset(dataset_name, split="train")
133
- dataset = dataset.select(range(min(max_samples, len(dataset))))
134
-
135
- game_lengths = []
136
- move_counts = Counter()
137
- opening_moves = Counter()
138
-
139
- for example in dataset:
140
- moves = example["text"].strip().split()
141
- game_lengths.append(len(moves))
142
- move_counts.update(moves)
143
- if len(moves) >= 4:
144
- opening = " ".join(moves[:4])
145
- opening_moves[opening] += 1
146
-
147
- return {
148
- "total_games": len(dataset),
149
- "avg_game_length": sum(game_lengths) / len(game_lengths),
150
- "min_game_length": min(game_lengths),
151
- "max_game_length": max(game_lengths),
152
- "unique_moves": len(move_counts),
153
- "most_common_moves": move_counts.most_common(20),
154
- "most_common_openings": opening_moves.most_common(10),
155
- }