| | import argparse |
| | import tqdm |
| | import pandas as pd |
| | import gc |
| | from datasets import load_dataset |
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--lang", type=str, default="en") |
| | parser.add_argument("--shard_id", type=int, required=True) |
| | parser.add_argument("--num_shards", type=int, default=20) |
| | parser.add_argument("--max_chunks", type=int, default=15) |
| | args = parser.parse_args() |
| |
|
| | |
| | print(f"Loading {args.lang} Wikipedia shard {args.shard_id}...") |
| | ds = load_dataset("wikimedia/wikipedia", f"20231101.{args.lang}", split='train') |
| | ds_shard = ds.shard(num_shards=args.num_shards, index=args.shard_id) |
| |
|
| | |
| | STOP_HEADERS = ["\nReferences", "\nSee also", "\nExternal links", "\nNotes", "\nFurther reading", "\nBibliography"] |
| | wiki_chunks = [] |
| | |
| | |
| | for article in tqdm.tqdm(ds_shard): |
| | text = article['text'] |
| | |
| | |
| | clean_text = text |
| | for header in STOP_HEADERS: |
| | if header in clean_text: |
| | clean_text = clean_text.split(header)[0] |
| | |
| | |
| | paragraphs = [p.strip() for p in clean_text.split('\n\n') if len(p.split()) > 20] |
| | |
| | |
| | if len(paragraphs) > args.max_chunks: |
| | paragraphs = paragraphs[:args.max_chunks] |
| | |
| | wiki_chunks.extend(paragraphs) |
| |
|
| | |
| | |
| | df = pd.DataFrame({"text": wiki_chunks}) |
| | save_path = f"/home/mshahidul/readctrl/data/wiki_chunks/wiki_chunks_{args.lang}_shard_{args.shard_id}.parquet" |
| | df.to_parquet(save_path, compression='snappy') |
| | |
| | print(f"Saved {len(wiki_chunks)} chunks to {save_path}") |
| |
|
| | if __name__ == "__main__": |
| | main() |