File size: 3,717 Bytes
4f0dc81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from pathlib import Path

from datasets import load_dataset
from loguru import logger
from tqdm import tqdm

from scientific_rag.domain.documents import ScientificPaper
from scientific_rag.domain.types import DataSource
from scientific_rag.settings import settings


class DataLoader:
    def __init__(
        self,
        dataset_name: str | None = None,
        split: str | None = None,
        cache_dir: str | None = None,
    ):
        """Initialize data loader.

        Args:
            dataset_name: HuggingFace dataset name (default from settings)
            split: Dataset split - "arxiv", "pubmed", or both (default from settings)
            cache_dir: Cache directory for downloaded data (default from settings)
        """
        self.dataset_name = dataset_name or settings.dataset_name
        self.split = split or settings.dataset_split
        self.cache_dir = str(cache_dir or settings.dataset_cache_dir)

        Path(self.cache_dir).mkdir(parents=True, exist_ok=True)

    def load_papers(
        self,
        sample_size: int | None = None,
        data_split: str = "train",
    ) -> list[ScientificPaper]:
        """Load scientific papers from dataset.

        Args:
            sample_size: Number of papers to load (None for all)
            data_split: Data split - "train", "validation", or "test"

        Returns:
            List of ScientificPaper objects
        """
        sample_size = sample_size or settings.dataset_sample_size

        logger.info(f"Loading {self.split} papers from {self.dataset_name} ({data_split} split)")

        dataset = load_dataset(
            self.dataset_name,
            self.split,
            split=data_split,
            cache_dir=self.cache_dir,
            trust_remote_code=True,
        )

        if sample_size is not None:
            logger.info(f"Sampling {sample_size} papers")
            dataset = dataset.select(range(min(sample_size, len(dataset))))

        papers = []
        for idx, item in enumerate(tqdm(dataset, desc="Loading papers")):
            try:
                paper = ScientificPaper(
                    paper_id=f"{self.split}_{idx}",
                    abstract=item.get("abstract", ""),
                    article=item.get("article", ""),
                    section_names=item.get("section_names", ""),
                    source=DataSource.ARXIV if self.split == "arxiv" else DataSource.PUBMED,
                )
                papers.append(paper)
            except Exception as e:
                logger.warning(f"Failed to parse paper {idx}: {e}")
                continue

        logger.info(f"Loaded {len(papers)} papers")
        return papers

    def load_both_sources(
        self,
        sample_size_per_source: int | None = None,
        data_split: str = "train",
    ) -> list[ScientificPaper]:
        """Load papers from both ArXiv and PubMed.

        Args:
            sample_size_per_source: Number of papers per source
            data_split: Data split - "train", "validation", or "test"

        Returns:
            Combined list of papers from both sources
        """
        papers = []

        arxiv_loader = DataLoader(
            dataset_name=self.dataset_name,
            split="arxiv",
            cache_dir=self.cache_dir,
        )
        papers.extend(arxiv_loader.load_papers(sample_size_per_source, data_split))

        pubmed_loader = DataLoader(
            dataset_name=self.dataset_name,
            split="pubmed",
            cache_dir=self.cache_dir,
        )
        papers.extend(pubmed_loader.load_papers(sample_size_per_source, data_split))

        logger.info(f"Loaded {len(papers)} papers from both sources")
        return papers