|
|
""" |
|
|
Unit tests for dataset.py module. |
|
|
|
|
|
Tests functions for downloading and extracting the SkillScope dataset. |
|
|
""" |
|
|
import pytest |
|
|
from pathlib import Path |
|
|
import tempfile |
|
|
import zipfile |
|
|
import sqlite3 |
|
|
from unittest.mock import patch, MagicMock |
|
|
|
|
|
from hopcroft_skill_classification_tool_competition.dataset import ( |
|
|
download_skillscope_dataset, |
|
|
) |
|
|
|
|
|
|
|
|
@pytest.mark.unit |
|
|
class TestDatasetDownload: |
|
|
"""Unit tests for dataset download functionality.""" |
|
|
|
|
|
def test_download_returns_path(self): |
|
|
"""Test that download function returns a Path object.""" |
|
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
|
output_dir = Path(tmpdir) |
|
|
|
|
|
|
|
|
with patch('hopcroft_skill_classification_tool_competition.dataset.hf_hub_download') as mock_download: |
|
|
|
|
|
zip_path = output_dir / "skillscope_data.zip" |
|
|
db_path = output_dir / "skillscope_data.db" |
|
|
|
|
|
|
|
|
conn = sqlite3.connect(db_path) |
|
|
conn.execute("CREATE TABLE test (id INTEGER)") |
|
|
conn.close() |
|
|
|
|
|
|
|
|
with zipfile.ZipFile(zip_path, 'w') as zf: |
|
|
zf.write(db_path, arcname='skillscope_data.db') |
|
|
|
|
|
|
|
|
db_path.unlink() |
|
|
|
|
|
|
|
|
mock_download.return_value = str(zip_path) |
|
|
|
|
|
result = download_skillscope_dataset(output_dir) |
|
|
|
|
|
assert isinstance(result, Path) |
|
|
assert result.exists() |
|
|
assert result.name == "skillscope_data.db" |
|
|
|
|
|
def test_download_creates_directory(self): |
|
|
"""Test that download creates output directory if it doesn't exist.""" |
|
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
|
output_dir = Path(tmpdir) / "nonexistent" / "nested" / "dir" |
|
|
|
|
|
assert not output_dir.exists() |
|
|
|
|
|
with patch('hopcroft_skill_classification_tool_competition.dataset.hf_hub_download') as mock_download: |
|
|
|
|
|
temp_db = Path(tmpdir) / "skillscope_data.db" |
|
|
conn = sqlite3.connect(temp_db) |
|
|
conn.execute("CREATE TABLE test (id INTEGER)") |
|
|
conn.close() |
|
|
|
|
|
|
|
|
zip_path = Path(tmpdir) / "skillscope_data.zip" |
|
|
with zipfile.ZipFile(zip_path, 'w') as zf: |
|
|
zf.write(temp_db, arcname='skillscope_data.db') |
|
|
|
|
|
mock_download.return_value = str(zip_path) |
|
|
|
|
|
download_skillscope_dataset(output_dir) |
|
|
|
|
|
assert output_dir.exists() |
|
|
|
|
|
def test_download_skips_if_exists(self): |
|
|
"""Test that download is skipped if database already exists.""" |
|
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
|
output_dir = Path(tmpdir) |
|
|
db_path = output_dir / "skillscope_data.db" |
|
|
|
|
|
|
|
|
conn = sqlite3.connect(db_path) |
|
|
conn.execute("CREATE TABLE test (id INTEGER)") |
|
|
conn.close() |
|
|
|
|
|
with patch('hopcroft_skill_classification_tool_competition.dataset.hf_hub_download') as mock_download: |
|
|
result = download_skillscope_dataset(output_dir) |
|
|
|
|
|
|
|
|
mock_download.assert_not_called() |
|
|
assert result == db_path |
|
|
|
|
|
def test_download_extracts_zip(self): |
|
|
"""Test that zip file is properly extracted.""" |
|
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
|
output_dir = Path(tmpdir) |
|
|
|
|
|
with patch('hopcroft_skill_classification_tool_competition.dataset.hf_hub_download') as mock_download: |
|
|
|
|
|
temp_db = Path(tmpdir) / "temp_skillscope_data.db" |
|
|
conn = sqlite3.connect(temp_db) |
|
|
conn.execute("CREATE TABLE nlbse_tool_competition_data_by_issue (id INTEGER)") |
|
|
conn.close() |
|
|
|
|
|
zip_path = output_dir / "skillscope_data.zip" |
|
|
with zipfile.ZipFile(zip_path, 'w') as zf: |
|
|
zf.write(temp_db, arcname='skillscope_data.db') |
|
|
|
|
|
temp_db.unlink() |
|
|
mock_download.return_value = str(zip_path) |
|
|
|
|
|
result = download_skillscope_dataset(output_dir) |
|
|
|
|
|
|
|
|
assert result.exists() |
|
|
|
|
|
|
|
|
conn = sqlite3.connect(result) |
|
|
cursor = conn.cursor() |
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") |
|
|
tables = cursor.fetchall() |
|
|
conn.close() |
|
|
|
|
|
assert len(tables) > 0 |
|
|
|
|
|
def test_download_cleans_up_zip(self): |
|
|
"""Test that zip file is deleted after extraction.""" |
|
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
|
output_dir = Path(tmpdir) |
|
|
|
|
|
with patch('hopcroft_skill_classification_tool_competition.dataset.hf_hub_download') as mock_download: |
|
|
|
|
|
temp_db = Path(tmpdir) / "temp_db.db" |
|
|
conn = sqlite3.connect(temp_db) |
|
|
conn.execute("CREATE TABLE test (id INTEGER)") |
|
|
conn.close() |
|
|
|
|
|
zip_path = output_dir / "skillscope_data.zip" |
|
|
with zipfile.ZipFile(zip_path, 'w') as zf: |
|
|
zf.write(temp_db, arcname='skillscope_data.db') |
|
|
|
|
|
temp_db.unlink() |
|
|
mock_download.return_value = str(zip_path) |
|
|
|
|
|
download_skillscope_dataset(output_dir) |
|
|
|
|
|
|
|
|
assert not zip_path.exists() |
|
|
|
|
|
def test_download_raises_on_missing_database(self): |
|
|
"""Test that error is raised if database not in zip.""" |
|
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
|
output_dir = Path(tmpdir) |
|
|
|
|
|
with patch('hopcroft_skill_classification_tool_competition.dataset.hf_hub_download') as mock_download: |
|
|
|
|
|
zip_path = output_dir / "skillscope_data.zip" |
|
|
with zipfile.ZipFile(zip_path, 'w') as zf: |
|
|
zf.writestr('dummy.txt', 'dummy content') |
|
|
|
|
|
mock_download.return_value = str(zip_path) |
|
|
|
|
|
with pytest.raises(FileNotFoundError): |
|
|
download_skillscope_dataset(output_dir) |
|
|
|
|
|
|
|
|
@pytest.mark.unit |
|
|
class TestDatasetEdgeCases: |
|
|
"""Unit tests for edge cases in dataset handling.""" |
|
|
|
|
|
def test_download_with_none_output_dir(self): |
|
|
"""Test download with None as output directory (should use default).""" |
|
|
with patch('hopcroft_skill_classification_tool_competition.dataset.hf_hub_download') as mock_download: |
|
|
with patch('hopcroft_skill_classification_tool_competition.dataset.RAW_DATA_DIR') as mock_raw_dir: |
|
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
|
mock_raw_dir.__truediv__ = MagicMock(return_value=Path(tmpdir) / "skillscope_data.db") |
|
|
|
|
|
|
|
|
db_path = Path(tmpdir) / "skillscope_data.db" |
|
|
db_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
conn = sqlite3.connect(db_path) |
|
|
conn.execute("CREATE TABLE test (id INTEGER)") |
|
|
conn.close() |
|
|
|
|
|
|
|
|
result = download_skillscope_dataset(None) |
|
|
|
|
|
assert isinstance(result, Path) |
|
|
|
|
|
def test_download_handles_permission_error(self): |
|
|
"""Test handling of permission errors during file operations.""" |
|
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
|
output_dir = Path(tmpdir) |
|
|
|
|
|
with patch('hopcroft_skill_classification_tool_competition.dataset.hf_hub_download') as mock_download: |
|
|
|
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
@pytest.mark.unit |
|
|
class TestDatasetIntegration: |
|
|
"""Integration-like tests for dataset module (still unit-scoped).""" |
|
|
|
|
|
def test_download_produces_valid_sqlite_database(self): |
|
|
"""Test that downloaded file is a valid SQLite database.""" |
|
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
|
output_dir = Path(tmpdir) |
|
|
|
|
|
with patch('hopcroft_skill_classification_tool_competition.dataset.hf_hub_download') as mock_download: |
|
|
|
|
|
temp_db = Path(tmpdir) / "temp.db" |
|
|
conn = sqlite3.connect(temp_db) |
|
|
conn.execute(""" |
|
|
CREATE TABLE nlbse_tool_competition_data_by_issue ( |
|
|
id INTEGER PRIMARY KEY, |
|
|
repo_name TEXT, |
|
|
pr_number INTEGER |
|
|
) |
|
|
""") |
|
|
conn.execute(""" |
|
|
INSERT INTO nlbse_tool_competition_data_by_issue |
|
|
VALUES (1, 'test_repo', 123) |
|
|
""") |
|
|
conn.commit() |
|
|
conn.close() |
|
|
|
|
|
|
|
|
zip_path = output_dir / "skillscope_data.zip" |
|
|
with zipfile.ZipFile(zip_path, 'w') as zf: |
|
|
zf.write(temp_db, arcname='skillscope_data.db') |
|
|
|
|
|
temp_db.unlink() |
|
|
mock_download.return_value = str(zip_path) |
|
|
|
|
|
result = download_skillscope_dataset(output_dir) |
|
|
|
|
|
|
|
|
conn = sqlite3.connect(result) |
|
|
cursor = conn.cursor() |
|
|
cursor.execute("SELECT * FROM nlbse_tool_competition_data_by_issue") |
|
|
rows = cursor.fetchall() |
|
|
conn.close() |
|
|
|
|
|
assert len(rows) == 1 |
|
|
assert rows[0][1] == 'test_repo' |
|
|
assert rows[0][2] == 123 |
|
|
|