File size: 4,037 Bytes
b7f3196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#! /usr/bin/env python3

import os
import gzip
import json
import re
import subprocess
import xml.etree.ElementTree as ET

from tqdm import tqdm
from urllib.request import urlopen, urlretrieve


PUBMED_DATASET_BASE_URL = "https://ftp.ncbi.nlm.nih.gov/pubmed/baseline"

PUBMED_FILE_LIMIT = 10


def get_pubmed_dataset_size():
    try:
        with urlopen(PUBMED_DATASET_BASE_URL) as response:
            html = response.read().decode("utf-8")

            files = re.findall(r"(pubmed\d+n\d+)\.xml\.gz(?!\.)", html)
            unique_files = set(files)

            return len(unique_files)

    except Exception as e:
        print(f"Unable to count PubMed files: {e}")

        return 0


def download_pubmed_xml(output_dir, num_files=1, year='25'):
    os.makedirs(output_dir, exist_ok=True)

    total_dataset_size = get_pubmed_dataset_size()

    files = []
    pbar = tqdm(total=total_dataset_size, desc=f"Downloading {num_files}/{total_dataset_size} files in PubMed dataset")

    for i in range(1, num_files + 1):
        filename = f"pubmed{year}n{i:04d}.xml.gz"
        filepath = os.path.join(output_dir, filename)

        if not os.path.exists(filepath):
            urlretrieve(f"{PUBMED_DATASET_BASE_URL}/{filename}", filepath)

        pbar.update(1)

        files.append(filepath)

    pbar.close()

    return files


def parse_pubmed_to_jsonl(xml_files, output_jsonl):
    with open(output_jsonl, 'w') as out:
        for xml_file in xml_files:
            print(f"Parsing {xml_file}...")
            with gzip.open(xml_file, 'rt', encoding='utf-8') as f:
                tree = ET.parse(f)
                root = tree.getroot()

                for article in tqdm(root.findall('.//PubmedArticle')):
                    pmid_elem = article.find('.//PMID')
                    title_elem = article.find('.//ArticleTitle')
                    abstract_elem = article.find('.//Abstract/AbstractText')

                    if pmid_elem is not None:
                        title = title_elem.text if title_elem is not None else ""
                        abstract = abstract_elem.text if abstract_elem is not None else ""

                        doc = {
                            'id': pmid_elem.text,
                            'title': title,
                            'contents': f"{title} {abstract}".strip()
                        }
                        out.write(json.dumps(doc) + '\n')


def download_pubmed(output_jsonl, num_files=1):
    if os.path.exists(output_jsonl):
        print(f"Already downloaded PubMed dataset: {output_jsonl}")

        return

    xml_dir = os.path.join(os.path.dirname(output_jsonl), '../pubmed-xml')
    xml_files = download_pubmed_xml(xml_dir, num_files=num_files)
    parse_pubmed_to_jsonl(xml_files, output_jsonl)


def build_index_cmd(input_file, index_dir):
    return [
        "python", "-m", "pyserini.index.lucene",
        "--collection", "JsonCollection",
        "--input", os.path.dirname(input_file),
        "--index", index_dir,
        "--generator", "DefaultLuceneDocumentGenerator",
        "--threads", "32",
        "--storePositions",
        "--storeDocvectors",
        "--storeRaw",
    ]


def build_index(input_file, index_dir, cmd_generator=build_index_cmd):
    if os.path.exists(index_dir) and os.listdir(index_dir):
        print(f"Skipping existing index: {index_dir}")

        return

    os.makedirs(os.path.dirname(index_dir) or '.', exist_ok=True)

    cmd = cmd_generator(input_file, index_dir)

    subprocess.run(cmd, check=True)


def main(base_data_dir="data", base_index_dir="indexes", num_files=1):
    corpus_jsonl = os.path.join(base_data_dir, "pubmed", "corpus.jsonl")
    index_dir = os.path.join(base_index_dir, "pubmed")

    download_pubmed(corpus_jsonl, num_files=num_files)

    build_index(corpus_jsonl, index_dir)


if __name__ == "__main__":
    main(num_files=PUBMED_FILE_LIMIT)