| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import json |
| import tempfile |
| from pathlib import Path |
| from unittest.mock import MagicMock, patch |
|
|
| import pytest |
| from datasets import Dataset, DatasetDict |
|
|
| from nemo.collections.llm.bert.data.specter import SpecterDataModule |
|
|
|
|
| @pytest.fixture |
| def mock_tokenizer(): |
| tokenizer = MagicMock() |
| tokenizer.text_to_ids.side_effect = lambda text: [ord(c) for c in text] |
| tokenizer.bos_id = 1 |
| tokenizer.eos_id = 2 |
| return tokenizer |
|
|
|
|
| @pytest.fixture |
| def mock_trainer(): |
| trainer = MagicMock() |
| trainer.global_step = 0 |
| trainer.max_steps = 1000 |
| return trainer |
|
|
|
|
| @pytest.fixture |
| def mock_sampler(): |
| sampler = MagicMock() |
| sampler.init_global_step = 0 |
| return sampler |
|
|
|
|
| @pytest.fixture |
| def sample_specter_dataset(): |
| dataset_len = 30 |
| train_dataset = Dataset.from_dict( |
| { |
| "anchor": ["Paper A abstract"] * dataset_len, |
| "positive": ["Similar paper B abstract"] * dataset_len, |
| "negative": ["Unrelated paper C abstract"] * dataset_len, |
| } |
| ) |
| return DatasetDict({'train': train_dataset}) |
|
|
|
|
| @pytest.fixture |
| def temp_dataset_dir(): |
| with tempfile.TemporaryDirectory() as temp_dir: |
| yield Path(temp_dir) |
|
|
|
|
| @pytest.fixture |
| def specter_data_module(mock_tokenizer, temp_dataset_dir, sample_specter_dataset): |
| |
| with patch.object(SpecterDataModule, '_download_data', return_value=sample_specter_dataset): |
| with patch('nemo.collections.llm.bert.data.core.get_dataset_root', return_value=temp_dataset_dir): |
| data_module = SpecterDataModule( |
| tokenizer=mock_tokenizer, |
| seq_length=512, |
| micro_batch_size=4, |
| global_batch_size=8, |
| force_redownload=True, |
| ) |
| data_module.dataset_root = temp_dataset_dir |
| yield data_module |
|
|
|
|
| def test_specter_data_module_initialization(specter_data_module): |
| assert specter_data_module.seq_length == 512 |
| assert specter_data_module.micro_batch_size == 4 |
| assert specter_data_module.global_batch_size == 8 |
| assert specter_data_module.force_redownload is True |
| assert specter_data_module.delete_raw is True |
|
|
|
|
| def test_preprocess_and_split_data(specter_data_module, temp_dataset_dir, sample_specter_dataset): |
| |
| specter_data_module._preprocess_and_split_data(sample_specter_dataset) |
|
|
| |
| assert (temp_dataset_dir / "training.jsonl").exists() |
| assert (temp_dataset_dir / "validation.jsonl").exists() |
| assert (temp_dataset_dir / "test.jsonl").exists() |
|
|
| |
| with open(temp_dataset_dir / "training.jsonl", "r") as f: |
| lines = f.readlines() |
| assert len(lines) > 0 |
| data = json.loads(lines[0]) |
| assert "query" in data |
| assert "pos_doc" in data |
| assert "neg_doc" in data |
| assert isinstance(data["neg_doc"], list) |
| assert len(data["neg_doc"]) == 1 |
|
|
|
|
| def test_force_redownload(specter_data_module, temp_dataset_dir, sample_specter_dataset): |
| |
| specter_data_module.prepare_data() |
|
|
| |
| marker_file = temp_dataset_dir / "training.jsonl" |
| assert marker_file.exists() |
|
|
| |
| original_mtime = marker_file.stat().st_mtime |
|
|
| |
| specter_data_module.force_redownload = True |
| specter_data_module.prepare_data() |
|
|
| |
| assert marker_file.exists() |
| new_mtime = marker_file.stat().st_mtime |
| assert new_mtime > original_mtime, "File modification time should be newer after redownload" |
|
|
|
|
| def test_delete_raw(specter_data_module, temp_dataset_dir, sample_specter_dataset): |
| |
| specter_data_module.prepare_data() |
|
|
| |
| raw_file = temp_dataset_dir / "raw_data.txt" |
| raw_file.touch() |
|
|
| |
| specter_data_module.delete_raw = True |
| specter_data_module._preprocess_and_split_data(sample_specter_dataset) |
|
|
| |
| assert not raw_file.exists() |
|
|