File size: 5,274 Bytes
2a2c039
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
from __future__ import annotations

import json
import logging
import urllib.request
from dataclasses import dataclass, field
from typing import Any, Iterable

from datasets import Dataset, DatasetDict, load_dataset

from .config import DIABETES_KEYWORDS

logger = logging.getLogger(__name__)


@dataclass
class PubMedQASample:
    qid: str
    question: str
    context: str
    answer: str
    authors: str = ""
    year: str = ""
    journal: str = ""
    title: str = ""


def _normalize_text(text: str) -> str:
    return " ".join(str(text).split())


def _extract_context_text(record: dict[str, Any]) -> str:
    context = record.get("context", "")

    if isinstance(context, dict):
        blocks = []
        for key in ("contexts", "sentences", "text", "abstract"):
            val = context.get(key)
            if isinstance(val, list):
                blocks.extend(str(v) for v in val)
            elif isinstance(val, str):
                blocks.append(val)
        if blocks:
            return _normalize_text(" ".join(blocks))

    if isinstance(context, list):
        return _normalize_text(" ".join(str(v) for v in context))

    if isinstance(context, str):
        return _normalize_text(context)

    long_answer = record.get("long_answer") or record.get("final_decision") or ""
    return _normalize_text(str(long_answer))


def _extract_answer_text(record: dict[str, Any]) -> str:
    for key in ("long_answer", "final_decision", "answer"):
        val = record.get(key)
        if isinstance(val, str) and val.strip():
            return _normalize_text(val)
    return ""


def _is_diabetes_related(question: str, context: str, keywords: Iterable[str]) -> bool:
    corpus = f"{question} {context}".lower()
    return any(keyword.lower() in corpus for keyword in keywords)


def load_diabetes_pubmedqa(
    dataset_name: str,
    max_samples: int = 2000,
    keywords: Iterable[str] = DIABETES_KEYWORDS,
) -> list[PubMedQASample]:
    import warnings
    import os
    os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
    
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        # PubMedQA requires a config name; prefer artificial/unlabeled for scale
        for config_name in ("pqa_artificial", "pqa_unlabeled", "pqa_labeled"):
            try:
                raw = load_dataset(dataset_name, config_name)
                break
            except Exception:
                continue
        else:
            raw = load_dataset(dataset_name)
    split = _pick_split(raw)

    filtered: list[PubMedQASample] = []
    for idx, record in enumerate(split):
        question = _normalize_text(str(record.get("question", "")))
        context = _extract_context_text(record)

        if not question or not context:
            continue

        if not _is_diabetes_related(question, context, keywords):
            continue

        filtered.append(
            PubMedQASample(
                qid=str(record.get("pubid", idx)),
                question=question,
                context=context,
                answer=_extract_answer_text(record),
            )
        )

        if len(filtered) >= max_samples:
            break

    # Fetch PubMed metadata (authors, year, journal) in batch
    # _enrich_with_pubmed_metadata(filtered) # Disabled to prevent API timeout and speed up indexing

    return filtered


def _enrich_with_pubmed_metadata(samples: list[PubMedQASample]) -> None:
    """Fetch author/year/journal from PubMed API for all samples."""
    if not samples:
        return
    pubids = [s.qid for s in samples if s.qid.isdigit()]
    if not pubids:
        return
    metadata: dict[str, dict] = {}
    for i in range(0, len(pubids), 200):
        batch = pubids[i:i+200]
        ids_str = ",".join(batch)
        url = f"https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esummary.fcgi?db=pubmed&id={ids_str}&retmode=json"
        try:
            req = urllib.request.Request(url, headers={"User-Agent": "BioRAG/1.0"})
            resp = urllib.request.urlopen(req, timeout=15)
            data = json.loads(resp.read())
            result = data.get("result", {})
            for pid in batch:
                if pid in result and isinstance(result[pid], dict):
                    metadata[pid] = result[pid]
        except Exception as e:
            logger.warning("PubMed metadata fetch failed: %s", e)
    for s in samples:
        info = metadata.get(s.qid)
        if not info:
            continue
        authors_list = info.get("authors", [])
        if authors_list:
            names = [a.get("name", "") for a in authors_list[:3]]
            s.authors = ", ".join(names)
            if len(authors_list) > 3:
                s.authors += " et al."
        pubdate = info.get("pubdate", "")
        if pubdate:
            s.year = pubdate.split()[0] if pubdate.split() else pubdate[:4]
        s.journal = info.get("source", "")
        s.title = info.get("title", "")


def _pick_split(raw: DatasetDict | Dataset) -> Dataset:
    if isinstance(raw, Dataset):
        return raw

    for candidate in ("train", "pqa_labeled", "validation", "test"):
        if candidate in raw:
            return raw[candidate]

    first_key = next(iter(raw.keys()))
    return raw[first_key]