File size: 7,877 Bytes
c7ebaa1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 | """Tests for dataset creation and loading module."""
import json
import tempfile
from pathlib import Path
import pytest
class TestDatasetCreation:
"""Tests for dataset creation functions."""
def test_generate_factual_examples_import(self):
"""Test that _generate_factual_examples can be imported and called."""
from biorlhf.data.dataset import _generate_factual_examples
examples = _generate_factual_examples()
assert isinstance(examples, list)
assert len(examples) > 0
def test_factual_examples_structure(self):
"""Test that factual examples have required fields."""
from biorlhf.data.dataset import _generate_factual_examples
examples = _generate_factual_examples()
for ex in examples:
assert "instruction" in ex
assert "output" in ex
# Input can be empty string but must exist
assert "input" in ex
def test_generate_comparison_examples(self):
"""Test comparison example generation."""
from biorlhf.data.dataset import _generate_comparison_examples
examples = _generate_comparison_examples()
assert isinstance(examples, list)
assert len(examples) > 0
# Check for specific comparison questions
instructions = [ex["instruction"] for ex in examples]
assert any("most sensitive" in instr.lower() for instr in instructions)
def test_generate_interaction_examples(self):
"""Test interaction prediction example generation."""
from biorlhf.data.dataset import _generate_interaction_examples
examples = _generate_interaction_examples()
assert isinstance(examples, list)
# Should have one example per tissue
assert len(examples) == 4
def test_generate_design_critique_examples(self):
"""Test experimental design critique example generation."""
from biorlhf.data.dataset import _generate_design_critique_examples
examples = _generate_design_critique_examples()
assert isinstance(examples, list)
assert len(examples) > 0
def test_generate_mechanistic_examples(self):
"""Test mechanistic reasoning example generation."""
from biorlhf.data.dataset import _generate_mechanistic_examples
examples = _generate_mechanistic_examples()
assert isinstance(examples, list)
assert len(examples) > 0
def test_generate_calibration_examples(self):
"""Test uncertainty calibration example generation."""
from biorlhf.data.dataset import _generate_calibration_examples
examples = _generate_calibration_examples()
assert isinstance(examples, list)
assert len(examples) > 0
# Calibration examples should express uncertainty
for ex in examples:
output = ex["output"].lower()
uncertainty_markers = ["cannot", "insufficient", "confidence", "needed", "missing"]
has_uncertainty = any(marker in output for marker in uncertainty_markers)
assert has_uncertainty, f"Calibration example should express uncertainty: {ex['output'][:100]}"
class TestCreateSFTDataset:
"""Tests for the main create_sft_dataset function."""
def test_creates_dataset_file(self):
"""Test that create_sft_dataset creates a JSON file."""
from biorlhf.data.dataset import create_sft_dataset
with tempfile.TemporaryDirectory() as tmpdir:
output_path = Path(tmpdir) / "test_dataset.json"
result = create_sft_dataset(output_path=output_path)
assert output_path.exists()
assert isinstance(result, list)
assert len(result) > 0
def test_dataset_format(self):
"""Test that created dataset has correct format."""
from biorlhf.data.dataset import create_sft_dataset
with tempfile.TemporaryDirectory() as tmpdir:
output_path = Path(tmpdir) / "test_dataset.json"
result = create_sft_dataset(output_path=output_path)
# Each example should have "text" field
for ex in result:
assert "text" in ex
text = ex["text"]
# Should have instruction format
assert "### Instruction:" in text
assert "### Response:" in text
def test_dataset_json_valid(self):
"""Test that output file is valid JSON."""
from biorlhf.data.dataset import create_sft_dataset
with tempfile.TemporaryDirectory() as tmpdir:
output_path = Path(tmpdir) / "test_dataset.json"
create_sft_dataset(output_path=output_path)
with open(output_path) as f:
data = json.load(f)
assert isinstance(data, list)
def test_exclude_calibration(self):
"""Test that calibration examples can be excluded."""
from biorlhf.data.dataset import create_sft_dataset
with tempfile.TemporaryDirectory() as tmpdir:
path_with = Path(tmpdir) / "with_cal.json"
path_without = Path(tmpdir) / "without_cal.json"
result_with = create_sft_dataset(output_path=path_with, include_calibration=True)
result_without = create_sft_dataset(output_path=path_without, include_calibration=False)
# Dataset with calibration should be larger
assert len(result_with) > len(result_without)
def test_exclude_chain_of_thought(self):
"""Test that chain-of-thought examples can be excluded."""
from biorlhf.data.dataset import create_sft_dataset
with tempfile.TemporaryDirectory() as tmpdir:
path_with = Path(tmpdir) / "with_cot.json"
path_without = Path(tmpdir) / "without_cot.json"
result_with = create_sft_dataset(output_path=path_with, include_chain_of_thought=True)
result_without = create_sft_dataset(output_path=path_without, include_chain_of_thought=False)
# Dataset with CoT should be larger
assert len(result_with) > len(result_without)
class TestLoadDataset:
"""Tests for the load_dataset function."""
def test_load_dataset_basic(self):
"""Test basic dataset loading."""
from biorlhf.data.dataset import create_sft_dataset, load_dataset
with tempfile.TemporaryDirectory() as tmpdir:
output_path = Path(tmpdir) / "test_dataset.json"
create_sft_dataset(output_path=output_path)
# Load the dataset
dataset = load_dataset(output_path, test_size=0)
assert hasattr(dataset, "__len__")
assert len(dataset) > 0
def test_load_dataset_with_split(self):
"""Test dataset loading with train/test split."""
from biorlhf.data.dataset import create_sft_dataset, load_dataset
with tempfile.TemporaryDirectory() as tmpdir:
output_path = Path(tmpdir) / "test_dataset.json"
create_sft_dataset(output_path=output_path)
# Load with split
splits = load_dataset(output_path, test_size=0.2)
assert "train" in splits
assert "test" in splits
assert len(splits["train"]) > len(splits["test"])
def test_load_specific_split(self):
"""Test loading a specific split."""
from biorlhf.data.dataset import create_sft_dataset, load_dataset
with tempfile.TemporaryDirectory() as tmpdir:
output_path = Path(tmpdir) / "test_dataset.json"
create_sft_dataset(output_path=output_path)
# Load only train split
train_dataset = load_dataset(output_path, split="train", test_size=0.2)
# Should not be a dict, should be a Dataset
assert not isinstance(train_dataset, dict)
assert hasattr(train_dataset, "__len__")
|