tefoteknik commited on
Commit
74e89c5
·
verified ·
1 Parent(s): 9472f1c

Update AGIFORMER with Turkish benchmark

Browse files
Files changed (1) hide show
  1. src/data/turkish_wiki.py +0 -164
src/data/turkish_wiki.py CHANGED
@@ -224,167 +224,3 @@ def get_turkish_wiki_dataloader(batch_size, seq_len, split="train"):
224
 
225
  return loader
226
 
227
- """
228
- Turkish Wikipedia Dataset for byte-level language modeling.
229
- Comparable to enwik8 format for benchmarking.
230
- """
231
- def __init__(self, data_dir="./data", split="train", seq_len=1024, download=True):
232
- super().__init__()
233
- self.data_dir = data_dir
234
- self.split = split
235
- self.seq_len = seq_len
236
-
237
- os.makedirs(data_dir, exist_ok=True)
238
-
239
- # File paths
240
- self.raw_file = os.path.join(data_dir, "trwiki_raw.txt")
241
-
242
- # Download if needed
243
- if download and not os.path.exists(self.raw_file):
244
- self._download_and_process()
245
-
246
- # Load data
247
- if not os.path.exists(self.raw_file):
248
- raise FileNotFoundError(
249
- f"Turkish Wikipedia data not found at {self.raw_file}. "
250
- "Set download=True to download automatically."
251
- )
252
-
253
- with open(self.raw_file, 'rb') as f:
254
- self.data = f.read()
255
-
256
- # Split data (90% train, 5% val, 5% test - same as enwik8)
257
- total_len = len(self.data)
258
- train_len = int(0.9 * total_len)
259
- val_len = int(0.05 * total_len)
260
-
261
- if split == "train":
262
- self.data = self.data[:train_len]
263
- elif split == "val":
264
- self.data = self.data[train_len:train_len + val_len]
265
- elif split == "test":
266
- self.data = self.data[train_len + val_len:]
267
- else:
268
- raise ValueError(f"Invalid split: {split}")
269
-
270
- print(f"Loaded Turkish Wikipedia ({split}): {len(self.data):,} bytes")
271
-
272
- def _download_and_process(self):
273
- """
274
- Download Turkish Wikipedia dump and process to plain text.
275
- Note: This is a simplified version. Full processing requires WikiExtractor.
276
- """
277
- print("Downloading Turkish Wikipedia...")
278
-
279
- # URL to Turkish Wikipedia dump (latest articles)
280
- # Using a small subset for demo - full dump is ~3GB compressed
281
- url = "https://dumps.wikimedia.org/trwiki/latest/trwiki-latest-pages-articles1.xml-p1p187422.bz2"
282
-
283
- compressed_file = os.path.join(self.data_dir, "trwiki.xml.bz2")
284
-
285
- try:
286
- print(f"Downloading from {url}...")
287
- urllib.request.urlretrieve(url, compressed_file)
288
- print("Download complete.")
289
-
290
- # Decompress
291
- import bz2
292
- print("Decompressing...")
293
- with bz2.open(compressed_file, 'rb') as f_in:
294
- xml_content = f_in.read()
295
-
296
- # Extract text from XML
297
- print("Extracting text...")
298
- text = self._extract_text_from_xml(xml_content)
299
-
300
- # Save as raw bytes
301
- with open(self.raw_file, 'wb') as f:
302
- f.write(text.encode('utf-8'))
303
-
304
- print(f"Processed {len(text):,} characters to {self.raw_file}")
305
-
306
- # Cleanup
307
- os.remove(compressed_file)
308
-
309
- except Exception as e:
310
- print(f"Error downloading Turkish Wikipedia: {e}")
311
- print("Please download manually or use a smaller test file.")
312
- raise
313
-
314
- def _extract_text_from_xml(self, xml_content):
315
- """
316
- Simple text extraction from Wikipedia XML.
317
- Removes markup but keeps structure similar to enwik8.
318
- """
319
- # Convert bytes to string
320
- xml_str = xml_content.decode('utf-8', errors='ignore')
321
-
322
- # Clean up (basic - not as sophisticated as WikiExtractor)
323
- # Remove XML tags but keep some structure
324
- text = re.sub(r'<[^>]+>', '', xml_str)
325
-
326
- # Remove empty lines
327
- lines = [line.strip() for line in text.split('\n') if line.strip()]
328
-
329
- return '\n'.join(lines)
330
-
331
- def __len__(self):
332
- # Number of possible sequences
333
- return max(0, len(self.data) - 2 * self.seq_len)
334
-
335
- def __getitem__(self, idx):
336
- """
337
- Returns:
338
- input: (seq_len,) - Context bytes
339
- target: (seq_len,) - Target bytes (next patch)
340
- """
341
- # Input context
342
- start_idx = idx
343
- end_idx = start_idx + self.seq_len
344
-
345
- # Target is shifted by patch_size (4 bytes default)
346
- target_start = start_idx + 4
347
- target_end = target_start + self.seq_len
348
-
349
- # Extract bytes
350
- input_bytes = torch.tensor(
351
- list(self.data[start_idx:end_idx]),
352
- dtype=torch.long
353
- )
354
-
355
- target_bytes = torch.tensor(
356
- list(self.data[target_start:target_end]),
357
- dtype=torch.long
358
- )
359
-
360
- return input_bytes, target_bytes
361
-
362
-
363
- def get_turkish_wiki_dataloader(batch_size, seq_len, split="train"):
364
- """
365
- Create DataLoader for Turkish Wikipedia.
366
-
367
- Args:
368
- batch_size: Batch size
369
- seq_len: Sequence length
370
- split: "train", "val", or "test"
371
-
372
- Returns:
373
- DataLoader yielding (input, target) batches
374
- """
375
- dataset = TurkishWikiDataset(
376
- data_dir="./data",
377
- split=split,
378
- seq_len=seq_len,
379
- download=True
380
- )
381
-
382
- loader = data.DataLoader(
383
- dataset,
384
- batch_size=batch_size,
385
- shuffle=(split == "train"),
386
- num_workers=0,
387
- pin_memory=True
388
- )
389
-
390
- return loader
 
224
 
225
  return loader
226