Spaces:
Sleeping
Sleeping
File size: 3,717 Bytes
4f0dc81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
from pathlib import Path
from datasets import load_dataset
from loguru import logger
from tqdm import tqdm
from scientific_rag.domain.documents import ScientificPaper
from scientific_rag.domain.types import DataSource
from scientific_rag.settings import settings
class DataLoader:
def __init__(
self,
dataset_name: str | None = None,
split: str | None = None,
cache_dir: str | None = None,
):
"""Initialize data loader.
Args:
dataset_name: HuggingFace dataset name (default from settings)
split: Dataset split - "arxiv", "pubmed", or both (default from settings)
cache_dir: Cache directory for downloaded data (default from settings)
"""
self.dataset_name = dataset_name or settings.dataset_name
self.split = split or settings.dataset_split
self.cache_dir = str(cache_dir or settings.dataset_cache_dir)
Path(self.cache_dir).mkdir(parents=True, exist_ok=True)
def load_papers(
self,
sample_size: int | None = None,
data_split: str = "train",
) -> list[ScientificPaper]:
"""Load scientific papers from dataset.
Args:
sample_size: Number of papers to load (None for all)
data_split: Data split - "train", "validation", or "test"
Returns:
List of ScientificPaper objects
"""
sample_size = sample_size or settings.dataset_sample_size
logger.info(f"Loading {self.split} papers from {self.dataset_name} ({data_split} split)")
dataset = load_dataset(
self.dataset_name,
self.split,
split=data_split,
cache_dir=self.cache_dir,
trust_remote_code=True,
)
if sample_size is not None:
logger.info(f"Sampling {sample_size} papers")
dataset = dataset.select(range(min(sample_size, len(dataset))))
papers = []
for idx, item in enumerate(tqdm(dataset, desc="Loading papers")):
try:
paper = ScientificPaper(
paper_id=f"{self.split}_{idx}",
abstract=item.get("abstract", ""),
article=item.get("article", ""),
section_names=item.get("section_names", ""),
source=DataSource.ARXIV if self.split == "arxiv" else DataSource.PUBMED,
)
papers.append(paper)
except Exception as e:
logger.warning(f"Failed to parse paper {idx}: {e}")
continue
logger.info(f"Loaded {len(papers)} papers")
return papers
def load_both_sources(
self,
sample_size_per_source: int | None = None,
data_split: str = "train",
) -> list[ScientificPaper]:
"""Load papers from both ArXiv and PubMed.
Args:
sample_size_per_source: Number of papers per source
data_split: Data split - "train", "validation", or "test"
Returns:
Combined list of papers from both sources
"""
papers = []
arxiv_loader = DataLoader(
dataset_name=self.dataset_name,
split="arxiv",
cache_dir=self.cache_dir,
)
papers.extend(arxiv_loader.load_papers(sample_size_per_source, data_split))
pubmed_loader = DataLoader(
dataset_name=self.dataset_name,
split="pubmed",
cache_dir=self.cache_dir,
)
papers.extend(pubmed_loader.load_papers(sample_size_per_source, data_split))
logger.info(f"Loaded {len(papers)} papers from both sources")
return papers
|