Spaces:
Running
Running
File size: 5,274 Bytes
2a2c039 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 | from __future__ import annotations
import json
import logging
import urllib.request
from dataclasses import dataclass, field
from typing import Any, Iterable
from datasets import Dataset, DatasetDict, load_dataset
from .config import DIABETES_KEYWORDS
logger = logging.getLogger(__name__)
@dataclass
class PubMedQASample:
qid: str
question: str
context: str
answer: str
authors: str = ""
year: str = ""
journal: str = ""
title: str = ""
def _normalize_text(text: str) -> str:
return " ".join(str(text).split())
def _extract_context_text(record: dict[str, Any]) -> str:
context = record.get("context", "")
if isinstance(context, dict):
blocks = []
for key in ("contexts", "sentences", "text", "abstract"):
val = context.get(key)
if isinstance(val, list):
blocks.extend(str(v) for v in val)
elif isinstance(val, str):
blocks.append(val)
if blocks:
return _normalize_text(" ".join(blocks))
if isinstance(context, list):
return _normalize_text(" ".join(str(v) for v in context))
if isinstance(context, str):
return _normalize_text(context)
long_answer = record.get("long_answer") or record.get("final_decision") or ""
return _normalize_text(str(long_answer))
def _extract_answer_text(record: dict[str, Any]) -> str:
for key in ("long_answer", "final_decision", "answer"):
val = record.get(key)
if isinstance(val, str) and val.strip():
return _normalize_text(val)
return ""
def _is_diabetes_related(question: str, context: str, keywords: Iterable[str]) -> bool:
corpus = f"{question} {context}".lower()
return any(keyword.lower() in corpus for keyword in keywords)
def load_diabetes_pubmedqa(
dataset_name: str,
max_samples: int = 2000,
keywords: Iterable[str] = DIABETES_KEYWORDS,
) -> list[PubMedQASample]:
import warnings
import os
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# PubMedQA requires a config name; prefer artificial/unlabeled for scale
for config_name in ("pqa_artificial", "pqa_unlabeled", "pqa_labeled"):
try:
raw = load_dataset(dataset_name, config_name)
break
except Exception:
continue
else:
raw = load_dataset(dataset_name)
split = _pick_split(raw)
filtered: list[PubMedQASample] = []
for idx, record in enumerate(split):
question = _normalize_text(str(record.get("question", "")))
context = _extract_context_text(record)
if not question or not context:
continue
if not _is_diabetes_related(question, context, keywords):
continue
filtered.append(
PubMedQASample(
qid=str(record.get("pubid", idx)),
question=question,
context=context,
answer=_extract_answer_text(record),
)
)
if len(filtered) >= max_samples:
break
# Fetch PubMed metadata (authors, year, journal) in batch
# _enrich_with_pubmed_metadata(filtered) # Disabled to prevent API timeout and speed up indexing
return filtered
def _enrich_with_pubmed_metadata(samples: list[PubMedQASample]) -> None:
"""Fetch author/year/journal from PubMed API for all samples."""
if not samples:
return
pubids = [s.qid for s in samples if s.qid.isdigit()]
if not pubids:
return
metadata: dict[str, dict] = {}
for i in range(0, len(pubids), 200):
batch = pubids[i:i+200]
ids_str = ",".join(batch)
url = f"https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esummary.fcgi?db=pubmed&id={ids_str}&retmode=json"
try:
req = urllib.request.Request(url, headers={"User-Agent": "BioRAG/1.0"})
resp = urllib.request.urlopen(req, timeout=15)
data = json.loads(resp.read())
result = data.get("result", {})
for pid in batch:
if pid in result and isinstance(result[pid], dict):
metadata[pid] = result[pid]
except Exception as e:
logger.warning("PubMed metadata fetch failed: %s", e)
for s in samples:
info = metadata.get(s.qid)
if not info:
continue
authors_list = info.get("authors", [])
if authors_list:
names = [a.get("name", "") for a in authors_list[:3]]
s.authors = ", ".join(names)
if len(authors_list) > 3:
s.authors += " et al."
pubdate = info.get("pubdate", "")
if pubdate:
s.year = pubdate.split()[0] if pubdate.split() else pubdate[:4]
s.journal = info.get("source", "")
s.title = info.get("title", "")
def _pick_split(raw: DatasetDict | Dataset) -> Dataset:
if isinstance(raw, Dataset):
return raw
for candidate in ("train", "pqa_labeled", "validation", "test"):
if candidate in raw:
return raw[candidate]
first_key = next(iter(raw.keys()))
return raw[first_key]
|