from pathlib import Path from datasets import load_dataset from loguru import logger from tqdm import tqdm from scientific_rag.domain.documents import ScientificPaper from scientific_rag.domain.types import DataSource from scientific_rag.settings import settings class DataLoader: def __init__( self, dataset_name: str | None = None, split: str | None = None, cache_dir: str | None = None, ): """Initialize data loader. Args: dataset_name: HuggingFace dataset name (default from settings) split: Dataset split - "arxiv", "pubmed", or both (default from settings) cache_dir: Cache directory for downloaded data (default from settings) """ self.dataset_name = dataset_name or settings.dataset_name self.split = split or settings.dataset_split self.cache_dir = str(cache_dir or settings.dataset_cache_dir) Path(self.cache_dir).mkdir(parents=True, exist_ok=True) def load_papers( self, sample_size: int | None = None, data_split: str = "train", ) -> list[ScientificPaper]: """Load scientific papers from dataset. Args: sample_size: Number of papers to load (None for all) data_split: Data split - "train", "validation", or "test" Returns: List of ScientificPaper objects """ sample_size = sample_size or settings.dataset_sample_size logger.info(f"Loading {self.split} papers from {self.dataset_name} ({data_split} split)") dataset = load_dataset( self.dataset_name, self.split, split=data_split, cache_dir=self.cache_dir, trust_remote_code=True, ) if sample_size is not None: logger.info(f"Sampling {sample_size} papers") dataset = dataset.select(range(min(sample_size, len(dataset)))) papers = [] for idx, item in enumerate(tqdm(dataset, desc="Loading papers")): try: paper = ScientificPaper( paper_id=f"{self.split}_{idx}", abstract=item.get("abstract", ""), article=item.get("article", ""), section_names=item.get("section_names", ""), source=DataSource.ARXIV if self.split == "arxiv" else DataSource.PUBMED, ) papers.append(paper) except Exception as e: logger.warning(f"Failed to parse paper {idx}: {e}") continue logger.info(f"Loaded {len(papers)} papers") return papers def load_both_sources( self, sample_size_per_source: int | None = None, data_split: str = "train", ) -> list[ScientificPaper]: """Load papers from both ArXiv and PubMed. Args: sample_size_per_source: Number of papers per source data_split: Data split - "train", "validation", or "test" Returns: Combined list of papers from both sources """ papers = [] arxiv_loader = DataLoader( dataset_name=self.dataset_name, split="arxiv", cache_dir=self.cache_dir, ) papers.extend(arxiv_loader.load_papers(sample_size_per_source, data_split)) pubmed_loader = DataLoader( dataset_name=self.dataset_name, split="pubmed", cache_dir=self.cache_dir, ) papers.extend(pubmed_loader.load_papers(sample_size_per_source, data_split)) logger.info(f"Loaded {len(papers)} papers from both sources") return papers