File size: 1,934 Bytes
c7a6fe6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import argparse
import tqdm
import pandas as pd
import gc
from datasets import load_dataset

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--lang", type=str, default="en")
    parser.add_argument("--shard_id", type=int, required=True)
    parser.add_argument("--num_shards", type=int, default=20)
    parser.add_argument("--max_chunks", type=int, default=15)
    args = parser.parse_args()

    # 1. Load Shard
    print(f"Loading {args.lang} Wikipedia shard {args.shard_id}...")
    ds = load_dataset("wikimedia/wikipedia", f"20231101.{args.lang}", split='train')
    ds_shard = ds.shard(num_shards=args.num_shards, index=args.shard_id)

    # 2. Cleaning & Chunking
    STOP_HEADERS = ["\nReferences", "\nSee also", "\nExternal links", "\nNotes", "\nFurther reading", "\nBibliography"]
    wiki_chunks = []
    
    # Track which original article each chunk came from (optional but helpful)
    for article in tqdm.tqdm(ds_shard):
        text = article['text']
        
        # Clean: Remove reference sections
        clean_text = text
        for header in STOP_HEADERS:
            if header in clean_text:
                clean_text = clean_text.split(header)[0]
        
        # Split into paragraphs
        paragraphs = [p.strip() for p in clean_text.split('\n\n') if len(p.split()) > 20]
        
        # Cap chunks per article
        if len(paragraphs) > args.max_chunks:
            paragraphs = paragraphs[:args.max_chunks]
            
        wiki_chunks.extend(paragraphs)

    # 3. Save to Parquet
    # Saving as a DataFrame is highly efficient for loading later
    df = pd.DataFrame({"text": wiki_chunks})
    save_path = f"/home/mshahidul/readctrl/data/wiki_chunks/wiki_chunks_{args.lang}_shard_{args.shard_id}.parquet"
    df.to_parquet(save_path, compression='snappy')
    
    print(f"Saved {len(wiki_chunks)} chunks to {save_path}")

if __name__ == "__main__":
    main()